cuda_matmul.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. import numpy as np
  3. os.environ["CUDA"] = "1"
  4. from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler
  5. from tinygrad.helpers import flat_mv
  6. FLOAT16 = True
  7. ACC_FLOAT16 = False
  8. N = 4096
  9. na = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32)
  10. nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32)
  11. nc = np.empty(N*N, np.float32)
  12. if FLOAT16:
  13. na = na.astype(np.float16)
  14. nb = nb.astype(np.float16)
  15. device = CUDADevice("cuda:0")
  16. cudaalloc = CUDAAllocator(device)
  17. a = cudaalloc.alloc(N*N*2 if FLOAT16 else N*N*4)
  18. b = cudaalloc.alloc(N*N*2 if FLOAT16 else N*N*4)
  19. c = cudaalloc.alloc(N*N*4)
  20. cudaalloc.copyin(a, bytearray(na))
  21. cudaalloc.copyin(b, bytearray(nb))
  22. FLOPS = N*N*N*2
  23. BW = N*N*3*4
  24. print(device.arch)
  25. compiler = CUDACompiler(device.arch)
  26. prog = CUDAProgram(device, "wmma_example", compiler.compile(f"""
  27. #include <mma.h>
  28. using namespace nvcuda;
  29. const int WMMA_M = 16;
  30. const int WMMA_N = 16;
  31. const int WMMA_K = {'16' if FLOAT16 else '8'};
  32. extern "C" __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLOAT16 else 'float'} *b, float *c)
  33. {{
  34. int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  35. int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
  36. warpM *= 4;
  37. warpN *= 4;
  38. wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, {'half' if FLOAT16 else 'wmma::precision::tf32'}, wmma::col_major> a_frag[4];
  39. wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, {'half' if FLOAT16 else 'wmma::precision::tf32'}, wmma::col_major> b_frag[4];
  40. wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, {'half' if ACC_FLOAT16 else 'float'}> acc_frag[4][4];
  41. for (int j = 0; j < 4; j++) {{
  42. for (int i = 0; i < 4; i++) {{
  43. wmma::fill_fragment(acc_frag[i][j], 0.0f);
  44. }}
  45. }}
  46. for (int k = 0; k < {N}; k += WMMA_K) {{
  47. int aRow = warpM * WMMA_M;
  48. int aCol = k;
  49. int bRow = k;
  50. int bCol = warpN * WMMA_N;
  51. wmma::load_matrix_sync(a_frag[0], a + aRow + 0 * WMMA_M + aCol * {N}, {N});
  52. wmma::load_matrix_sync(a_frag[1], a + aRow + 1 * WMMA_M + aCol * {N}, {N});
  53. wmma::load_matrix_sync(a_frag[2], a + aRow + 2 * WMMA_M + aCol * {N}, {N});
  54. wmma::load_matrix_sync(a_frag[3], a + aRow + 3 * WMMA_M + aCol * {N}, {N});
  55. wmma::load_matrix_sync(b_frag[0], b + bRow + (0 * WMMA_N + bCol) * {N}, {N});
  56. wmma::load_matrix_sync(b_frag[1], b + bRow + (1 * WMMA_N + bCol) * {N}, {N});
  57. wmma::load_matrix_sync(b_frag[2], b + bRow + (2 * WMMA_N + bCol) * {N}, {N});
  58. wmma::load_matrix_sync(b_frag[3], b + bRow + (3 * WMMA_N + bCol) * {N}, {N});
  59. #pragma unroll
  60. for (int i = 0; i < {'0' if FLOAT16 else '4'}; i++) {{
  61. #pragma unroll
  62. for (int t = 0; t < a_frag[i].num_elements; t++) {{ a_frag[i].x[t] = wmma::__float_to_tf32(a_frag[i].x[t]); }}
  63. #pragma unroll
  64. for (int t = 0; t < b_frag[i].num_elements; t++) {{ b_frag[i].x[t] = wmma::__float_to_tf32(b_frag[i].x[t]); }}
  65. }}
  66. #pragma unroll
  67. for (int j = 0; j < 4; j++) {{
  68. #pragma unroll
  69. for (int i = 0; i < 4; i++) {{
  70. wmma::mma_sync(acc_frag[i][j], a_frag[i], b_frag[j], acc_frag[i][j]);
  71. }}
  72. }}
  73. }}
  74. for (int j = 0; j < 4; j++) {{
  75. for (int i = 0; i < 4; i++) {{
  76. wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_store;
  77. for (int t = 0; t < acc_frag[i][j].num_elements; t++) acc_store.x[t] = acc_frag[i][j].x[t];
  78. int cRow = (warpM + i) * WMMA_M;
  79. int cCol = (warpN + j) * WMMA_N;
  80. wmma::store_matrix_sync(c + cRow + cCol * {N}, acc_store, {N}, wmma::mem_col_major);
  81. }}
  82. }}
  83. }}
  84. """))
  85. global_size, local_size = [(N//16)//4, (N//16)//4, 1], [32, 1, 1]
  86. tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)])
  87. print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
  88. cudaalloc.copyout(flat_mv(nc.data), c)
  89. np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), nc.reshape(N,N).T, atol=1e-2)