external_test_speed_llama.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # NOTE: this only tests the speed of the LLaMA codegen, it doesn't actually run the net
  2. import unittest, time
  3. from examples.llama import Transformer, MODEL_PARAMS
  4. from tinygrad.tensor import Tensor
  5. from tinygrad import Device
  6. from tinygrad.nn.state import get_state_dict
  7. from tinygrad.device import Allocator
  8. from tinygrad.engine.realize import method_cache
  9. from tinygrad.helpers import Profiling
  10. class FakeProgram:
  11. def __init__(self, name:str, prg:bytes): pass
  12. def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): pass
  13. class FakeAllocator(Allocator):
  14. def _alloc(self, sz, options): return None
  15. def copyin(self, dest, src:memoryview): pass
  16. class TestLLaMASpeed(unittest.TestCase):
  17. def test_llama_compile(self):
  18. backup_program = Device[Device.DEFAULT].runtime
  19. backup_allocator = Device[Device.DEFAULT].allocator
  20. backup_compiler = Device[Device.DEFAULT].compiler
  21. Device[Device.DEFAULT].runtime = FakeProgram
  22. Device[Device.DEFAULT].allocator = FakeAllocator()
  23. print("testing llama python run time")
  24. model = Transformer(**MODEL_PARAMS["1"]["7B"]["args"])
  25. print("built model")
  26. # assign fake tensors to the values
  27. for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
  28. print("assigned empty tensors, doing warmup")
  29. def run_llama(st, empty_method_cache=True):
  30. if empty_method_cache: method_cache.clear()
  31. tms = [time.perf_counter()]
  32. for i in range(5):
  33. model(Tensor([[1,2,3,4]]), i).realize()
  34. tms.append(time.perf_counter())
  35. timings = [(tms[i+1]-tms[i])*1000 for i in range(len(tms)-1)]
  36. print(f"{st:15s} mean runtime: {sum(timings)/len(timings):7.2f}ms, runs: ", ", ".join(f'{x:7.2f}' for x in timings))
  37. run_llama("codegen(0)")
  38. run_llama("codegen(1)")
  39. # test no compiler use for this
  40. Device[Device.DEFAULT].compiler = None
  41. run_llama("methodcache", False)
  42. with Profiling(sort='time', frac=0.1, fn="/tmp/llama.prof", ts=5):
  43. run_llama("profile", False)
  44. Device[Device.DEFAULT].runtime = backup_program
  45. Device[Device.DEFAULT].allocator = backup_allocator
  46. Device[Device.DEFAULT].compiler = backup_compiler
  47. if __name__ == '__main__':
  48. TestLLaMASpeed().test_llama_compile()
  49. #unittest.main()