test_method_cache.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import unittest
  2. from tinygrad import Tensor, Device, Variable
  3. from examples.gpt2 import Transformer
  4. from tinygrad.nn.state import get_state_dict
  5. class TestMethodCache(unittest.TestCase):
  6. def setUp(self):
  7. self.backup_compiler = Device[Device.DEFAULT].compiler
  8. def tearDown(self):
  9. Device[Device.DEFAULT].compiler = self.backup_compiler
  10. def test_simple_methodcache(self):
  11. a = Tensor([1])
  12. b = Tensor([2])
  13. c = Tensor([3])
  14. d = Tensor([4])
  15. (a+b).realize()
  16. Device[Device.DEFAULT].compiler = None
  17. (c+d).realize()
  18. def test_nested_methodcache(self):
  19. a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
  20. ((a+b)+(a+b)).realize()
  21. Device[Device.DEFAULT].compiler = None
  22. ((c+d)+(c+d)).realize()
  23. def test_nested_methodcache_swap(self):
  24. a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
  25. ((a+b)+(c+d)).realize()
  26. Device[Device.DEFAULT].compiler = None
  27. ((c+d)+(a+b)).realize()
  28. @unittest.skip("incorrect use of transformer")
  29. def test_small_transformer(self):
  30. args_tiny = {"dim": 16, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 10}
  31. model = Transformer(**args_tiny)
  32. for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype).realize())
  33. # NOTE: you have to do this twice due to the k-v cache
  34. for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
  35. for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
  36. Device[Device.DEFAULT].compiler = None
  37. for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
  38. if __name__ == '__main__':
  39. unittest.main()