metal_matmul.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. os.environ["METAL"] = "1"
  3. import time
  4. import numpy as np
  5. from tinygrad import Device, dtypes
  6. from tinygrad.helpers import getenv, flat_mv
  7. from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, MetalCompiler
  8. N = getenv("N", 2048)
  9. LID = 2
  10. device = MetalDevice("METAL")
  11. metalalloc = MetalAllocator(device)
  12. a = metalalloc.alloc(N*N*4)
  13. b = metalalloc.alloc(N*N*4)
  14. c = metalalloc.alloc(N*N*4)
  15. na = np.zeros((N,N),dtype=np.float32)
  16. nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)N
  17. nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
  18. metalalloc.copyin(b,nb.tobytes())
  19. metalalloc.copyin(c,nc.tobytes())
  20. FLOPS = N*N*N*2
  21. BW = N*N*3*4
  22. prog = MetalProgram(device, "test", MetalCompiler(device).compile(f"""
  23. #include <metal_stdlib>
  24. #include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
  25. using namespace metal;
  26. kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
  27. a += gid.x * 32 * {N} + (gid.y * {LID} + lid.y) * 32;
  28. data1 += gid.x * 32 * {N};
  29. data2 += (gid.y * {LID} + lid.y) * 32;
  30. simdgroup_float8x8 acc[4][4];
  31. for (uint i = 0; i < 4; i++) {{
  32. for (uint j = 0; j < 4; j++) {{
  33. acc[i][j] = simdgroup_float8x8(0);
  34. }}
  35. }}
  36. simdgroup_float8x8 A[4];
  37. simdgroup_float8x8 B[4];
  38. for (uint k = 0; k < {N}; k+=8) {{
  39. threadgroup_barrier(mem_flags::mem_threadgroup);
  40. simdgroup_load(A[0], data1+k+{0*N}, {N}, ulong2(0, 0));
  41. simdgroup_load(A[1], data1+k+{8*N}, {N}, ulong2(0, 0));
  42. simdgroup_load(A[2], data1+k+{16*N}, {N}, ulong2(0, 0));
  43. simdgroup_load(A[3], data1+k+{24*N}, {N}, ulong2(0, 0));
  44. simdgroup_load(B[0], data2+0+k*{N}, {N}, ulong2(0, 0));
  45. simdgroup_load(B[1], data2+8+k*{N}, {N}, ulong2(0, 0));
  46. simdgroup_load(B[2], data2+16+k*{N}, {N}, ulong2(0, 0));
  47. simdgroup_load(B[3], data2+24+k*{N}, {N}, ulong2(0, 0));
  48. simdgroup_multiply_accumulate(acc[0][0], A[0], B[0], acc[0][0]);
  49. simdgroup_multiply_accumulate(acc[0][1], A[1], B[0], acc[0][1]);
  50. simdgroup_multiply_accumulate(acc[0][2], A[2], B[0], acc[0][2]);
  51. simdgroup_multiply_accumulate(acc[0][3], A[3], B[0], acc[0][3]);
  52. simdgroup_multiply_accumulate(acc[1][0], A[0], B[1], acc[1][0]);
  53. simdgroup_multiply_accumulate(acc[1][1], A[1], B[1], acc[1][1]);
  54. simdgroup_multiply_accumulate(acc[1][2], A[2], B[1], acc[1][2]);
  55. simdgroup_multiply_accumulate(acc[1][3], A[3], B[1], acc[1][3]);
  56. simdgroup_multiply_accumulate(acc[2][0], A[0], B[2], acc[2][0]);
  57. simdgroup_multiply_accumulate(acc[2][1], A[1], B[2], acc[2][1]);
  58. simdgroup_multiply_accumulate(acc[2][2], A[2], B[2], acc[2][2]);
  59. simdgroup_multiply_accumulate(acc[2][3], A[3], B[2], acc[2][3]);
  60. simdgroup_multiply_accumulate(acc[3][0], A[0], B[3], acc[3][0]);
  61. simdgroup_multiply_accumulate(acc[3][1], A[1], B[3], acc[3][1]);
  62. simdgroup_multiply_accumulate(acc[3][2], A[2], B[3], acc[3][2]);
  63. simdgroup_multiply_accumulate(acc[3][3], A[3], B[3], acc[3][3]);
  64. }}
  65. simdgroup_store(acc[0][0], a+{0+0*N}, {N}, ulong2(0, 0));
  66. simdgroup_store(acc[1][0], a+{8+0*N}, {N}, ulong2(0, 0));
  67. simdgroup_store(acc[2][0], a+{16+0*N}, {N}, ulong2(0, 0));
  68. simdgroup_store(acc[3][0], a+{24+0*N}, {N}, ulong2(0, 0));
  69. simdgroup_store(acc[0][1], a+{0+8*N}, {N}, ulong2(0, 0));
  70. simdgroup_store(acc[1][1], a+{8+8*N}, {N}, ulong2(0, 0));
  71. simdgroup_store(acc[2][1], a+{16+8*N}, {N}, ulong2(0, 0));
  72. simdgroup_store(acc[3][1], a+{24+8*N}, {N}, ulong2(0, 0));
  73. simdgroup_store(acc[0][2], a+{0+16*N}, {N}, ulong2(0, 0));
  74. simdgroup_store(acc[1][2], a+{8+16*N}, {N}, ulong2(0, 0));
  75. simdgroup_store(acc[2][2], a+{16+16*N}, {N}, ulong2(0, 0));
  76. simdgroup_store(acc[3][2], a+{24+16*N}, {N}, ulong2(0, 0));
  77. simdgroup_store(acc[0][3], a+{0+24*N}, {N}, ulong2(0, 0));
  78. simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
  79. simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
  80. simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
  81. }}"""))
  82. def timeit(fxn):
  83. st = time.perf_counter()
  84. et = fxn()
  85. # NOTE: et doesn't contain the launch overhead
  86. return time.perf_counter() - st
  87. tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
  88. comp = nb@nc
  89. metalalloc.copyout(flat_mv(na.data), a)
  90. if N <= 32:
  91. print(na)
  92. print(comp)
  93. print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
  94. np.testing.assert_allclose(na, comp, atol=1e-3)
  95. import torch, torch.mps
  96. b = torch.from_numpy(nb).to('mps')
  97. c = torch.from_numpy(nc).to('mps')
  98. def torch_prog(b, c):
  99. st = time.perf_counter()
  100. a = b@c
  101. torch.mps.synchronize()
  102. return time.perf_counter() - st
  103. tm = min([torch_prog(b, c) for _ in range(20)])
  104. print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
  105. from tinygrad.tensor import Tensor
  106. from tinygrad.engine.jit import TinyJit
  107. b = Tensor(nb)
  108. c = Tensor(nc)
  109. # TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
  110. @TinyJit
  111. def tiny_jit(b, c):
  112. return (b@c).realize()
  113. def tiny_prog(b, c):
  114. st = time.perf_counter()
  115. a = tiny_jit(b, c)
  116. Device["METAL"].synchronize()
  117. return time.perf_counter() - st
  118. tm = min([tiny_prog(b, c) for _ in range(20)])
  119. print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")