metal_matvec.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import numpy as np
  2. import time, torch, torch.mps
  3. from tinygrad import Tensor, TinyJit, Device
  4. from tinygrad.helpers import flat_mv
  5. from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, MetalCompiler
  6. N = 16384
  7. M = 4096
  8. FLOPS = N*M*2
  9. nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
  10. nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32)
  11. b = torch.from_numpy(nb).to('mps')
  12. c = torch.from_numpy(nc).to('mps')
  13. def torch_prog(b, c):
  14. st = time.perf_counter()
  15. a = b@c
  16. torch.mps.synchronize()
  17. return time.perf_counter() - st
  18. tm = min([torch_prog(b, c) for _ in range(200)])
  19. print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch")
  20. torch_a = (b@c).cpu()
  21. device = MetalDevice("METAL")
  22. metalalloc = MetalAllocator(device)
  23. WORKSIZE_ROW = 16
  24. WORKSIZE_COL = 1
  25. LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
  26. GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1]
  27. prog = MetalProgram(device, "test", MetalCompiler(device).compile(f"""
  28. #include <metal_stdlib>
  29. using namespace metal;
  30. kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
  31. int gidx0 = gid.x; /* {GLOBAL_SIZE[0]} */
  32. int lidx1 = lid.x; /* {LOCAL_SIZE[0]} */
  33. int lidx2 = lid.y; /* {LOCAL_SIZE[1]} */
  34. int lidx3 = lid.z; /* {LOCAL_SIZE[2]} */
  35. // 4 rows per thread
  36. threadgroup float4 acc0[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}];
  37. int acc0_index = ((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*lidx3);
  38. acc0[acc0_index] = float4(0.0f,0.0f,0.0f,0.0f);
  39. threadgroup float4 val1[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}];
  40. // iterate over the columns
  41. for (int ridx2 = 0; ridx2 < {N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))}; ++ridx2) {{
  42. // load 4*threadgroup_size columns into shared memory
  43. int col_1 = (((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+(lidx1*{LOCAL_SIZE[1]})+lidx2;
  44. val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+((lidx1*{LOCAL_SIZE[1]})+lidx2)] = *((device float4*)(data1+(col_1*4)));
  45. threadgroup_barrier(mem_flags::mem_threadgroup);
  46. for (int ridx3 = 0; ridx3 < {LOCAL_SIZE[0]*LOCAL_SIZE[1]}; ++ridx3) {{
  47. int col = ((((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+ridx3);
  48. float4 val1_0 = val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+ridx3];
  49. float4 val2_0 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*0})));
  50. float4 val2_1 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*1})));
  51. float4 val2_2 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*2})));
  52. float4 val2_3 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*3})));
  53. acc0[acc0_index] = ((val1_0.x*val2_0)+acc0[acc0_index]);
  54. acc0[acc0_index] = ((val1_0.y*val2_1)+acc0[acc0_index]);
  55. acc0[acc0_index] = ((val1_0.z*val2_2)+acc0[acc0_index]);
  56. acc0[acc0_index] = ((val1_0.w*val2_3)+acc0[acc0_index]);
  57. }}
  58. threadgroup_barrier(mem_flags::mem_threadgroup);
  59. }} /* reduce */
  60. if (lidx3 == 0) {{
  61. float4 out = float4(0.0f,0.0f,0.0f,0.0f);
  62. for (int n = 0; n < {LOCAL_SIZE[2]}; n++) {{
  63. out += acc0[((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*n)];
  64. }}
  65. *( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
  66. }}
  67. }}
  68. """))
  69. a = metalalloc.alloc(M*4)
  70. b = metalalloc.alloc(N*4)
  71. c = metalalloc.alloc(N*M*4)
  72. metalalloc.copyin(b,nb.tobytes())
  73. metalalloc.copyin(c,nc.tobytes())
  74. def metalrun():
  75. prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
  76. return a
  77. def timeit(fxn):
  78. st = time.perf_counter()
  79. et = fxn()
  80. # NOTE: et doesn't contain the launch overhead
  81. return time.perf_counter() - st
  82. tm = min([timeit(metalrun) for _ in range(200)])
  83. print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
  84. metal_a = np.zeros(M, dtype=np.float32)
  85. metalalloc.copyout(flat_mv(metal_a.data), a)
  86. np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
  87. b = Tensor(nb)
  88. c = Tensor(nc)
  89. # TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
  90. @TinyJit
  91. def tiny_jit(b, c):
  92. return (b@c).realize()
  93. def tiny_prog(b, c):
  94. st = time.perf_counter()
  95. a = tiny_jit(b, c)
  96. Device["METAL"].synchronize()
  97. return time.perf_counter() - st
  98. tm = min([tiny_prog(b, c) for _ in range(200)])
  99. print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad")
  100. tiny_a = tiny_jit(b, c).numpy()
  101. np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)