test_mnist.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #!/usr/bin/env python
  2. import unittest
  3. import numpy as np
  4. from tinygrad import Tensor, Device
  5. from tinygrad.helpers import CI
  6. from tinygrad.nn.state import get_parameters
  7. from tinygrad.nn import optim, BatchNorm2d
  8. from extra.training import train, evaluate
  9. from extra.datasets import fetch_mnist
  10. # load the mnist dataset
  11. X_train, Y_train, X_test, Y_test = fetch_mnist()
  12. # create a model
  13. class TinyBobNet:
  14. def __init__(self):
  15. self.l1 = Tensor.scaled_uniform(784, 128)
  16. self.l2 = Tensor.scaled_uniform(128, 10)
  17. def parameters(self):
  18. return get_parameters(self)
  19. def forward(self, x):
  20. return x.dot(self.l1).relu().dot(self.l2)
  21. # create a model with a conv layer
  22. class TinyConvNet:
  23. def __init__(self, has_batchnorm=False):
  24. # https://keras.io/examples/vision/mnist_convnet/
  25. conv = 3
  26. #inter_chan, out_chan = 32, 64
  27. inter_chan, out_chan = 8, 16 # for speed
  28. self.c1 = Tensor.scaled_uniform(inter_chan,1,conv,conv)
  29. self.c2 = Tensor.scaled_uniform(out_chan,inter_chan,conv,conv)
  30. self.l1 = Tensor.scaled_uniform(out_chan*5*5, 10)
  31. if has_batchnorm:
  32. self.bn1 = BatchNorm2d(inter_chan)
  33. self.bn2 = BatchNorm2d(out_chan)
  34. else:
  35. self.bn1, self.bn2 = lambda x: x, lambda x: x
  36. def parameters(self):
  37. return get_parameters(self)
  38. def forward(self, x:Tensor):
  39. x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
  40. x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d()
  41. x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d()
  42. x = x.reshape(shape=[x.shape[0], -1])
  43. return x.dot(self.l1)
  44. @unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow")
  45. class TestMNIST(unittest.TestCase):
  46. def test_sgd_onestep(self):
  47. np.random.seed(1337)
  48. model = TinyBobNet()
  49. optimizer = optim.SGD(model.parameters(), lr=0.001)
  50. train(model, X_train, Y_train, optimizer, BS=69, steps=1)
  51. for p in model.parameters(): p.realize()
  52. def test_sgd_threestep(self):
  53. np.random.seed(1337)
  54. model = TinyBobNet()
  55. optimizer = optim.SGD(model.parameters(), lr=0.001)
  56. train(model, X_train, Y_train, optimizer, BS=69, steps=3)
  57. def test_sgd_sixstep(self):
  58. np.random.seed(1337)
  59. model = TinyBobNet()
  60. optimizer = optim.SGD(model.parameters(), lr=0.001)
  61. train(model, X_train, Y_train, optimizer, BS=69, steps=6, noloss=True)
  62. def test_adam_onestep(self):
  63. np.random.seed(1337)
  64. model = TinyBobNet()
  65. optimizer = optim.Adam(model.parameters(), lr=0.001)
  66. train(model, X_train, Y_train, optimizer, BS=69, steps=1)
  67. for p in model.parameters(): p.realize()
  68. def test_adam_threestep(self):
  69. np.random.seed(1337)
  70. model = TinyBobNet()
  71. optimizer = optim.Adam(model.parameters(), lr=0.001)
  72. train(model, X_train, Y_train, optimizer, BS=69, steps=3)
  73. def test_conv_onestep(self):
  74. np.random.seed(1337)
  75. model = TinyConvNet()
  76. optimizer = optim.SGD(model.parameters(), lr=0.001)
  77. train(model, X_train, Y_train, optimizer, BS=69, steps=1, noloss=True)
  78. for p in model.parameters(): p.realize()
  79. def test_conv(self):
  80. np.random.seed(1337)
  81. model = TinyConvNet()
  82. optimizer = optim.Adam(model.parameters(), lr=0.001)
  83. train(model, X_train, Y_train, optimizer, steps=100)
  84. assert evaluate(model, X_test, Y_test) > 0.93 # torch gets 0.9415 sometimes
  85. def test_conv_with_bn(self):
  86. np.random.seed(1337)
  87. model = TinyConvNet(has_batchnorm=True)
  88. optimizer = optim.AdamW(model.parameters(), lr=0.003)
  89. train(model, X_train, Y_train, optimizer, steps=200)
  90. assert evaluate(model, X_test, Y_test) > 0.94
  91. def test_sgd(self):
  92. np.random.seed(1337)
  93. model = TinyBobNet()
  94. optimizer = optim.SGD(model.parameters(), lr=0.001)
  95. train(model, X_train, Y_train, optimizer, steps=600)
  96. assert evaluate(model, X_test, Y_test) > 0.94 # CPU gets 0.9494 sometimes
  97. if __name__ == '__main__':
  98. unittest.main()