test_train.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import unittest
  2. import time
  3. import numpy as np
  4. from tinygrad.nn.state import get_parameters
  5. from tinygrad.nn import optim
  6. from tinygrad.tensor import Device
  7. from tinygrad.helpers import getenv, CI
  8. from extra.training import train
  9. from extra.models.convnext import ConvNeXt
  10. from extra.models.efficientnet import EfficientNet
  11. from extra.models.transformer import Transformer
  12. from extra.models.vit import ViT
  13. from extra.models.resnet import ResNet18
  14. BS = getenv("BS", 2)
  15. def train_one_step(model,X,Y):
  16. params = get_parameters(model)
  17. pcount = 0
  18. for p in params:
  19. pcount += np.prod(p.shape)
  20. optimizer = optim.SGD(params, lr=0.001)
  21. print("stepping %r with %.1fM params bs %d" % (type(model), pcount/1e6, BS))
  22. st = time.time()
  23. train(model, X, Y, optimizer, steps=1, BS=BS)
  24. et = time.time()-st
  25. print("done in %.2f ms" % (et*1000.))
  26. def check_gc():
  27. if Device.DEFAULT == "GPU":
  28. from extra.introspection import print_objects
  29. assert print_objects() == 0
  30. class TestTrain(unittest.TestCase):
  31. def test_convnext(self):
  32. model = ConvNeXt(depths=[1], dims=[16])
  33. X = np.zeros((BS,3,224,224), dtype=np.float32)
  34. Y = np.zeros((BS), dtype=np.int32)
  35. train_one_step(model,X,Y)
  36. check_gc()
  37. @unittest.skipIf(CI, "slow")
  38. def test_efficientnet(self):
  39. model = EfficientNet(0)
  40. X = np.zeros((BS,3,224,224), dtype=np.float32)
  41. Y = np.zeros((BS), dtype=np.int32)
  42. train_one_step(model,X,Y)
  43. check_gc()
  44. @unittest.skipIf(CI, "slow")
  45. @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
  46. def test_vit(self):
  47. model = ViT()
  48. X = np.zeros((BS,3,224,224), dtype=np.float32)
  49. Y = np.zeros((BS,), dtype=np.int32)
  50. train_one_step(model,X,Y)
  51. check_gc()
  52. def test_transformer(self):
  53. # this should be small GPT-2, but the param count is wrong
  54. # (real ff_dim is 768*4)
  55. model = Transformer(syms=10, maxlen=6, layers=12, embed_dim=768, num_heads=12, ff_dim=768//4)
  56. X = np.zeros((BS,6), dtype=np.float32)
  57. Y = np.zeros((BS,6), dtype=np.int32)
  58. train_one_step(model,X,Y)
  59. check_gc()
  60. @unittest.skipIf(CI, "slow")
  61. def test_resnet(self):
  62. X = np.zeros((BS, 3, 224, 224), dtype=np.float32)
  63. Y = np.zeros((BS), dtype=np.int32)
  64. for resnet_v in [ResNet18]:
  65. model = resnet_v()
  66. model.load_from_pretrained()
  67. train_one_step(model, X, Y)
  68. check_gc()
  69. def test_bert(self):
  70. # TODO: write this
  71. pass
  72. if __name__ == '__main__':
  73. unittest.main()