test_speed_v_torch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. import os
  2. os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
  3. os.environ["MKL_NUM_THREADS"] = "1"
  4. os.environ["NUMEXPR_NUM_THREADS"] = "1"
  5. os.environ["OMP_NUM_THREADS"] = "1"
  6. import unittest
  7. import torch
  8. torch.set_num_threads(1)
  9. import time
  10. import numpy as np
  11. np.set_printoptions(linewidth=160)
  12. from tinygrad import Tensor, Device, GlobalCounters, TinyJit
  13. from tinygrad.nn import Conv2d
  14. from tinygrad.helpers import colorize_float, getenv, CI
  15. IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
  16. torch_dt = torch.float16 if getenv("HALF", 0) else torch.float32
  17. torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
  18. if str(torch_device) == "mps":
  19. import torch.mps
  20. def sync(): torch.mps.synchronize()
  21. elif str(torch_device) == "cuda":
  22. import torch.cuda
  23. def sync(): torch.cuda.synchronize()
  24. else:
  25. def sync(): pass
  26. save_ops, save_mem = 0, 0
  27. CNT = getenv("CNT", 8)
  28. def helper_test_speed(f1, *args):
  29. global save_ops, save_mem
  30. ets = []
  31. ret = None
  32. cache_defeat = np.zeros((2048,2048))
  33. for i in range(CNT):
  34. del ret
  35. # operation cache defeats
  36. args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args]
  37. # force syncing
  38. [x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
  39. # clear 32MB global memory cache (CPU and global memory only)
  40. cache_defeat += 1
  41. # manual pre sync
  42. if isinstance(args[0], Tensor): Device[args[0].device].synchronize()
  43. else: sync()
  44. GlobalCounters.global_ops = 0
  45. GlobalCounters.global_mem = 0
  46. st = time.perf_counter()
  47. ret = f1(*args)
  48. if isinstance(ret, Tensor): Device[ret.device].synchronize()
  49. else: sync()
  50. et = (time.perf_counter() - st) * 1000
  51. if i >= 1: ets.append(et)
  52. if GlobalCounters.global_ops:
  53. save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
  54. return ret.numpy() if isinstance(ret, Tensor) else ret.cpu().numpy(), np.min(ets)
  55. def helper_test_generic_square(name, N, f1, f2, onearg=False):
  56. torch.manual_seed(0)
  57. torch_a = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device)
  58. torch_b = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) if not onearg else None
  59. tiny_a = Tensor(torch_a.cpu().numpy())
  60. tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
  61. helper_test_generic(f"{name:30s} {N:5d}x{N:5d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
  62. def helper_test_matvec(name, N, M):
  63. torch.manual_seed(0)
  64. torch_a = (torch.rand(N, dtype=torch_dt) - 0.5).to(torch_device)
  65. torch_b = (torch.rand(N, M, dtype=torch_dt) - 0.5).to(torch_device)
  66. tiny_a = Tensor(torch_a.cpu().numpy())
  67. tiny_b = Tensor(torch_b.cpu().numpy())
  68. helper_test_generic(f"{name:30s} {N:5d}x{M:5d}", lambda a,b: a@b, (torch_a, torch_b), TinyJit(lambda a,b:(a@b).realize()), (tiny_a, tiny_b))
  69. prefix = None
  70. def helper_test_generic(name, f1, f1_args, f2, f2_args):
  71. global prefix
  72. with torch.no_grad():
  73. val_torch, et_torch = helper_test_speed(f1, *f1_args)
  74. val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args)
  75. desc = "faster" if et_torch > et_tinygrad else "slower"
  76. flops = save_ops*1e-6
  77. mem = save_mem*1e-6
  78. print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501
  79. np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3)
  80. def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
  81. torch.manual_seed(0)
  82. torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x, dtype=torch_dt).to(torch_device)
  83. torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None, dtype=torch_dt).to(torch_device)
  84. tiny_dat = Tensor(torch_dat.cpu().numpy())
  85. tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None)
  86. tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
  87. def f1(torch_dat): return torch_conv(torch_dat)
  88. def f2(tiny_dat): return tiny_conv(tiny_dat).realize()
  89. helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
  90. @unittest.skipIf(getenv("BIG") == 0, "no big tests")
  91. @unittest.skipIf(getenv("CUDACPU") or getenv("MOCKGPU"), "no CUDACPU or MOCKGPUs")
  92. class TestBigSpeed(unittest.TestCase):
  93. def test_add(self):
  94. def f(a, b): return a+b
  95. helper_test_generic_square('add', 8192, f, f)
  96. def test_exp(self):
  97. def f(a, b): return a.exp()
  98. helper_test_generic_square('exp', 8192, f, f, onearg=True)
  99. def test_gemm_2048(self):
  100. def f(a, b): return a @ b
  101. helper_test_generic_square('gemm', 2048, f, f)
  102. def test_gemm_4096(self):
  103. def f(a, b): return a @ b
  104. helper_test_generic_square('gemm', 4096, f, f)
  105. def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
  106. def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
  107. def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=132, img_size_x=132)
  108. def test_matvec_4096_16384(self): helper_test_matvec('matvec_4096_16384', 4096, 16384)
  109. def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096)
  110. @unittest.skipIf(getenv("BIG") == 1, "only big tests")
  111. @unittest.skipIf(getenv("CUDACPU") or getenv("MOCKGPU"), "no CUDACPU or MOCKGPUs")
  112. class TestSpeed(unittest.TestCase):
  113. def test_sub(self):
  114. def f(a, b): return a-b
  115. helper_test_generic_square('sub', 4096, f, f)
  116. @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "breaking on webgpu CI")
  117. def test_pow(self):
  118. def f(a, b): return a.pow(b)
  119. helper_test_generic_square('pow', 2048, f, f)
  120. def test_sum(self):
  121. def f(a, b): return a.sum()
  122. helper_test_generic_square('sum', 2048, f, f, onearg=True)
  123. helper_test_generic_square('sum', 4096, f, f, onearg=True)
  124. def test_partial_sum(self):
  125. R = 256
  126. def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1)
  127. helper_test_generic_square('partial_sum', 4096, f, f, onearg=True)
  128. @unittest.skip("not really used in models")
  129. def test_cumsum(self):
  130. def f0(a, b): return a.cumsum(axis=0)
  131. def f1(a, b): return a.cumsum(axis=1)
  132. helper_test_generic_square('cumsum_0', 256, f0, f0, onearg=True)
  133. helper_test_generic_square('cumsum_1', 256, f1, f1, onearg=True)
  134. def test_cat(self):
  135. helper_test_generic_square('cat_0', 256, lambda x,y: torch.cat((x,y),dim=0), lambda x,y: x.cat(y,dim=0))
  136. helper_test_generic_square('cat_1', 256, lambda x,y: torch.cat((x,y),dim=1), lambda x,y: x.cat(y,dim=1))
  137. def test_array_packing(self):
  138. N = 2048
  139. def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous()
  140. helper_test_generic_square('array_packing', N, f, f, onearg=True)
  141. def test_permute(self):
  142. for N in [1024, 4096]:
  143. # this is a 64MB tensor, M1 L1 cache is 128kB
  144. # to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
  145. def f(a, b): return a.permute(1,0).contiguous()
  146. helper_test_generic_square('permute', N, f, f, onearg=True)
  147. def test_double_permute(self):
  148. N = 64
  149. torch.manual_seed(0)
  150. torch_a = (torch.rand(N, N, N, N, dtype=torch_dt) - 0.5).to(torch_device)
  151. tiny_a = Tensor(torch_a.cpu().numpy())
  152. def f(a): return a.permute(1,0,3,2).contiguous()
  153. helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,))
  154. def test_neg(self):
  155. def f(a, b): return -a
  156. helper_test_generic_square('neg', 4096, f, f, onearg=True)
  157. def test_exp(self):
  158. def f(a, b): return a.exp()
  159. helper_test_generic_square('exp', 2048, f, f, onearg=True)
  160. def test_relu(self):
  161. def f(a, b): return a.relu()
  162. helper_test_generic_square('relu', 4096, f, f, onearg=True)
  163. def test_max(self):
  164. def f(a, b): return a.max()
  165. helper_test_generic_square('max', 4096, f, f, onearg=True)
  166. def test_mul_sum(self):
  167. def f(a, b): return (a*b).sum()
  168. helper_test_generic_square('mul_sum', 4096, f, f)
  169. def test_add(self):
  170. for N in [1, 1024, 4096]:
  171. def f(a, b): return a + b
  172. helper_test_generic_square('add', N, f, f)
  173. def test_add_constant(self):
  174. def f(a, b): return a+2.0
  175. helper_test_generic_square('add_constant', 4096, f, f, onearg=True)
  176. def test_add_sq(self):
  177. def f(a, b): return a*a + b*b
  178. helper_test_generic_square('add_sq', 4096, f, f)
  179. def test_gemm(self):
  180. def f(a, b): return a @ b
  181. helper_test_generic_square('gemm', 1024, f, f)
  182. def test_gemm_small(self):
  183. def f(a, b): return a @ b
  184. helper_test_generic_square('gemm', 256, f, f)
  185. def test_gemm_unrolled(self):
  186. N = 512
  187. def f1(a, b): return a@b.T
  188. def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2)
  189. helper_test_generic_square('gemm_unrolled', N, f1, f2)
  190. def test_gemm_unrolled_permute_l(self):
  191. N = 512
  192. def f1(a, b): return a.T@b.T
  193. def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2)
  194. helper_test_generic_square('gemm_unrolled_permute_l', N, f1, f2)
  195. def test_gemm_unrolled_permute_r(self):
  196. N = 512
  197. def f1(a, b): return a@b
  198. def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
  199. helper_test_generic_square('gemm_unrolled_permute_r', N, f1, f2)
  200. def test_gemm_unrolled_permute_lr(self):
  201. N = 512
  202. def f1(a, b): return a.T@b
  203. def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
  204. helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2)
  205. def test_matvec_1024_1024(self): helper_test_matvec('matvec_1024_1024', 1024, 1024)
  206. def test_matvec_1024_4096(self): helper_test_matvec('matvec_1024_4096', 1024, 4096)
  207. def test_matvec_4096_1024(self): helper_test_matvec('matvec_4096_1024', 4096, 1024)
  208. def test_matvec_4096_4096(self): helper_test_matvec('matvec_4096_4096', 4096, 4096)
  209. def test_openpilot_conv2d(self):
  210. bs, in_chans, out_chans = 1,12,32
  211. torch.manual_seed(0)
  212. torch_dat = torch.rand(bs, 64, 128, 12, dtype=torch_dt).to(torch_device)
  213. torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1, dtype=torch_dt).to(torch_device)
  214. tiny_dat = Tensor(torch_dat.cpu().numpy())
  215. tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
  216. tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
  217. def f1(torch_dat): return torch_conv(torch_dat.permute(0,3,1,2))
  218. def f2(tiny_dat): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
  219. helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:3", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
  220. def test_conv2d(self):
  221. for bs in [32]:
  222. for in_chans in IN_CHANS:
  223. for out_chans in [32]:
  224. helper_test_conv(bs, in_chans, out_chans, 3, 34, 34)
  225. if __name__ == '__main__':
  226. unittest.main()