serious_mnist.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/env python
  2. #inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
  3. import sys
  4. import numpy as np
  5. from tinygrad.nn.state import get_parameters
  6. from tinygrad.tensor import Tensor
  7. from tinygrad.nn import BatchNorm2d, optim
  8. from tinygrad.helpers import getenv
  9. from extra.datasets import fetch_mnist
  10. from extra.augment import augment_img
  11. from extra.training import train, evaluate
  12. GPU = getenv("GPU")
  13. QUICK = getenv("QUICK")
  14. DEBUG = getenv("DEBUG")
  15. class SqueezeExciteBlock2D:
  16. def __init__(self, filters):
  17. self.filters = filters
  18. self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32)
  19. self.bias1 = Tensor.scaled_uniform(1,self.filters//32)
  20. self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters)
  21. self.bias2 = Tensor.scaled_uniform(1, self.filters)
  22. def __call__(self, input):
  23. se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
  24. se = se.reshape(shape=(-1, self.filters))
  25. se = se.dot(self.weight1) + self.bias1
  26. se = se.relu()
  27. se = se.dot(self.weight2) + self.bias2
  28. se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
  29. se = input.mul(se)
  30. return se
  31. class ConvBlock:
  32. def __init__(self, h, w, inp, filters=128, conv=3):
  33. self.h, self.w = h, w
  34. self.inp = inp
  35. #init weights
  36. self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
  37. self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
  38. #init layers
  39. self._bn = BatchNorm2d(128)
  40. self._seb = SqueezeExciteBlock2D(filters)
  41. def __call__(self, input):
  42. x = input.reshape(shape=(-1, self.inp, self.w, self.h))
  43. for cweight, cbias in zip(self.cweights, self.cbiases):
  44. x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
  45. x = self._bn(x)
  46. x = self._seb(x)
  47. return x
  48. class BigConvNet:
  49. def __init__(self):
  50. self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
  51. self.weight1 = Tensor.scaled_uniform(128,10)
  52. self.weight2 = Tensor.scaled_uniform(128,10)
  53. def parameters(self):
  54. if DEBUG: #keeping this for a moment
  55. pars = [par for par in get_parameters(self) if par.requires_grad]
  56. no_pars = 0
  57. for par in pars:
  58. print(par.shape)
  59. no_pars += np.prod(par.shape)
  60. print('no of parameters', no_pars)
  61. return pars
  62. else:
  63. return get_parameters(self)
  64. def save(self, filename):
  65. with open(filename+'.npy', 'wb') as f:
  66. for par in get_parameters(self):
  67. #if par.requires_grad:
  68. np.save(f, par.numpy())
  69. def load(self, filename):
  70. with open(filename+'.npy', 'rb') as f:
  71. for par in get_parameters(self):
  72. #if par.requires_grad:
  73. try:
  74. par.numpy()[:] = np.load(f)
  75. if GPU:
  76. par.gpu()
  77. except:
  78. print('Could not load parameter')
  79. def forward(self, x):
  80. x = self.conv[0](x)
  81. x = self.conv[1](x)
  82. x = x.avg_pool2d(kernel_size=(2,2))
  83. x = self.conv[2](x)
  84. x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
  85. x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
  86. xo = x1.dot(self.weight1) + x2.dot(self.weight2)
  87. return xo
  88. if __name__ == "__main__":
  89. lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
  90. epochss = [2, 1] if QUICK else [13, 3, 3, 1]
  91. BS = 32
  92. lmbd = 0.00025
  93. lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
  94. X_train, Y_train, X_test, Y_test = fetch_mnist()
  95. X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
  96. X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
  97. steps = len(X_train)//BS
  98. np.random.seed(1337)
  99. if QUICK:
  100. steps = 1
  101. X_test, Y_test = X_test[:BS], Y_test[:BS]
  102. model = BigConvNet()
  103. if len(sys.argv) > 1:
  104. try:
  105. model.load(sys.argv[1])
  106. print('Loaded weights "'+sys.argv[1]+'", evaluating...')
  107. evaluate(model, X_test, Y_test, BS=BS)
  108. except:
  109. print('could not load weights "'+sys.argv[1]+'".')
  110. if GPU:
  111. params = get_parameters(model)
  112. [x.gpu_() for x in params]
  113. for lr, epochs in zip(lrs, epochss):
  114. optimizer = optim.Adam(model.parameters(), lr=lr)
  115. for epoch in range(1,epochs+1):
  116. #first epoch without augmentation
  117. X_aug = X_train if epoch == 1 else augment_img(X_train)
  118. train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
  119. accuracy = evaluate(model, X_test, Y_test, BS=BS)
  120. model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')