test_real_world.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import unittest, time, gc
  2. import numpy as np
  3. from tinygrad.nn import optim
  4. from tinygrad.nn.state import get_parameters
  5. from tinygrad.engine.jit import TinyJit
  6. from tinygrad import Tensor, Device, GlobalCounters, dtypes
  7. from tinygrad.helpers import CI, Context
  8. from tinygrad.shape.symbolic import Variable
  9. from extra.lr_scheduler import OneCycleLR
  10. from test.helpers import derandomize_model, is_dtype_supported
  11. from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
  12. from examples.hlb_cifar10 import SpeedyResNet, hyp
  13. from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
  14. from examples.stable_diffusion import UNetModel, unet_params
  15. from extra.models.unet import ResBlock
  16. global_mem_used = 0
  17. def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False):
  18. tms = []
  19. for _ in range(4):
  20. early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
  21. GlobalCounters.reset()
  22. Device[Device.DEFAULT].synchronize()
  23. st = time.perf_counter_ns()
  24. model(*early_gen)
  25. Device[Device.DEFAULT].synchronize()
  26. tms.append(time.perf_counter_ns() - st)
  27. mem_used = GlobalCounters.mem_used - global_mem_used
  28. # TODO: jit should expose this correctly with graph
  29. kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None
  30. print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
  31. assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
  32. assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
  33. if all_jitted:
  34. assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used <= GlobalCounters.kernel_count and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
  35. class TestRealWorld(unittest.TestCase):
  36. def setUp(self):
  37. gc.collect()
  38. global global_mem_used
  39. global_mem_used = GlobalCounters.mem_used
  40. self.old_float = dtypes.default_float
  41. np.random.seed(2002)
  42. def tearDown(self):
  43. dtypes.default_float = self.old_float
  44. @unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow, covered by METAL")
  45. @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
  46. def test_stable_diffusion(self):
  47. params = unet_params
  48. if CI:
  49. params["model_ch"] = 16
  50. params["ctx_dim"] = 16
  51. params["num_res_blocks"] = 1
  52. params["n_heads"] = 2
  53. model = UNetModel(**params)
  54. derandomize_model(model)
  55. @TinyJit
  56. def test(t, t2): return model(t, Tensor([801]), t2).realize()
  57. helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 513 if CI else 839)
  58. def test_unet_resblock(self):
  59. model = [ResBlock(16, 24, 16) for _ in range(4)]
  60. derandomize_model(model)
  61. @TinyJit
  62. def test(t, t2):
  63. for l in model: t = l(t, t2)
  64. return t.realize()
  65. helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 43)
  66. @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
  67. def test_llama(self):
  68. dtypes.default_float = dtypes.float16
  69. args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
  70. model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"]))
  71. derandomize_model(model)
  72. @TinyJit
  73. def test(t): return model(t, 0).realize()
  74. # TODO: test first token vs rest properly
  75. helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 192 if CI else 719, all_jitted=True)
  76. @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
  77. def test_gpt2(self):
  78. dtypes.default_float = dtypes.float16
  79. args_tiny = {"dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-5, "vocab_size": 1000}
  80. model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"]))
  81. derandomize_model(model)
  82. @TinyJit
  83. def test(t, v):
  84. with Context(JIT=0): return model(t, v).realize()
  85. helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 164 if CI else 468, all_jitted=True)
  86. @unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow")
  87. def test_train_mnist(self):
  88. from examples.beautiful_mnist import Model
  89. with Tensor.train():
  90. model = Model()
  91. optimizer = optim.Adam(get_parameters(model))
  92. BS = 32
  93. @TinyJit
  94. def train(X):
  95. out = model(X)
  96. loss = out.mean()
  97. optimizer.zero_grad()
  98. loss.backward()
  99. optimizer.step()
  100. helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 127)
  101. @unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
  102. def test_train_cifar(self):
  103. with Tensor.train():
  104. model = SpeedyResNet(Tensor.ones((12,3,2,2)))
  105. optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
  106. BS = 32 if CI else 512
  107. @TinyJit
  108. def train(X):
  109. out = model(X)
  110. loss = out.mean()
  111. optimizer.zero_grad()
  112. loss.backward()
  113. optimizer.step()
  114. helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 142 if CI else 154) # it's 154 on metal
  115. @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
  116. def test_train_cifar_hyp(self):
  117. dtypes.default_float = dtypes.float16
  118. with Tensor.train():
  119. model = SpeedyResNet(Tensor.ones((12,3,2,2)))
  120. optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
  121. initial_div_factor = hyp['opt']['initial_div_factor']
  122. final_lr_ratio = hyp['opt']['final_lr_ratio']
  123. pct_start = hyp['opt']['percent_start']
  124. lr_scheduler = OneCycleLR(optimizer, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor,
  125. final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
  126. assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half"
  127. if __name__ == '__main__':
  128. unittest.main()