1
0

external_test_jit_on_models.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #!/usr/bin/env python
  2. import unittest
  3. import numpy as np
  4. from tinygrad import Tensor, dtypes
  5. from tinygrad.engine.jit import TinyJit
  6. from tinygrad.helpers import CI
  7. from test.helpers import derandomize_model
  8. from examples.llama import Transformer
  9. def helper_test_jitted_correctness(gen, train, train_jit):
  10. nojit = train(*gen()).numpy()
  11. for _ in range(5): jit = train_jit(*gen()).numpy()
  12. np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5)
  13. class TestJittedModels(unittest.TestCase):
  14. def test_jitted_tiny_llama(self):
  15. old_float = dtypes.default_float
  16. dtypes.default_float = dtypes.float16
  17. args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
  18. model = Transformer(**args_tiny)
  19. derandomize_model(model)
  20. def test(t): return model(t, 0).realize()
  21. @TinyJit
  22. def test_jit(t): return model(t, 0).realize()
  23. helper_test_jitted_correctness(lambda: (Tensor([[1,]]),), test, test_jit)
  24. dtypes.default_float = old_float
  25. @unittest.skipUnless(not CI, "huge for CI")
  26. def test_jitted_stable_diffusion(self):
  27. from examples.stable_diffusion import UNetModel, unet_params
  28. model = UNetModel(**unet_params)
  29. derandomize_model(model)
  30. def test(t, t2): return model(t, 801, t2).realize()
  31. @TinyJit
  32. def test_jit(t, t2): return model(t, 801, t2).realize()
  33. helper_test_jitted_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test, test_jit)
  34. if __name__ == "__main__":
  35. unittest.main()