hip_matmul.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import time
  2. import numpy as np
  3. from tinygrad.helpers import getenv, prod, flat_mv
  4. from tinygrad.runtime.ops_amd import AMDAllocator, AMDDevice, AMDProgram
  5. # AMD_LOG_LEVEL=3 ./MIOpenDriver gemm --iter 1000 --time 1 --a_w 2048 --a_h 2048 --b_w 2048
  6. # 5.5: Cijk_Ailk_Bljk_HHS_BH_MT128x128x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_DTL0_DTVA0_DVO0_ETSP_EPS1_FL0_GRVW8_GSU1_GSUASB_GLS0_ISA1100_IU1_K1_KLA_LBSPP128_LPA0_LPB8_LDL1_LRVW16_LWPMn1_LDW0_FMA_MIAV1_MDA2_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA1_NLCB1_ONLL1_OPLV0_PK0_PAP0_PGR1_PLR1_RK0_SIA1_SS1_SU32_SUM0_SUS128_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TT4_64_TLDS1_USFGROn1_VAW2_VSn1_VW4_WSGRA1_WSGRB1_WS32_WG32_4_1_WGM4
  7. # 5.6: Cijk_Ailk_Bljk_HHS_BH_MT128x128x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASLT_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_DTL0_DTVA0_DVO0_ETSP_EPS1_FL0_GRPM1_GRVW8_GSU1_GSUASB_GLS0_ISA1100_IU1_K1_KLA_LBSPP128_LPA0_LPB8_LDL1_LRVW16_LWPMn1_LDW0_FMA_MIAV1_MDA2_MO40_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA1_NLCB1_ONLL1_OPLV0_PK0_PAP0_PGR1_PLR1_RK0_SIA1_SS1_SU32_SUM0_SUS128_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TT4_64_TLDS1_USFGROn1_VAW2_VSn1_VW4_WSGRA1_WSGRB1_WS32_WG32_4_1_WGM4
  8. # gets ~100
  9. # hipExtModuleLaunchKernel ( 0x0x16ccde0, 2048, 16, 1, 128, 1, 1,
  10. # 161.60 us = 106.31 TFLOPS
  11. # with --batch_count 8 / 1.258128 ms / (8*2048*2048*2048*2)/(1.258128)*1e-9 / 109.24 TFLOPS
  12. # we only get ~53
  13. # KY=2 KX=2 N=2048 python3 extra/gemm/hip_matmul.py
  14. # 4194304 324.76 us, would be 52899.88 GFLOPS matmul, 154.98 GB/s
  15. DEBUG = getenv("DEBUG", 0)
  16. RAND = getenv("RAND", 0)
  17. CNT = getenv("CNT", 128)
  18. N = getenv("N", 4096)
  19. KX = getenv("KX", 4)
  20. KY = getenv("KY", 4)
  21. assert N%(16*KX) == 0, f"N must be multiple of {16*KX}"
  22. assert N%(16*KY) == 0, f"N must be multiple of {16*KY}"
  23. FLOPS = N*N*N*2
  24. BW = N*N*3*4
  25. local_size = [32, 1, 1]
  26. global_size = [N//(KX*16), N//(KY*16), 1]
  27. num_threads = prod(local_size)
  28. # Can AMDAllocator initialized as device=0 by default?
  29. device = AMDDevice()
  30. hipallocator = AMDAllocator(device)
  31. a = hipallocator.alloc(N*N*4)
  32. b = hipallocator.alloc(N*N*2)
  33. c = hipallocator.alloc(N*N*2)
  34. na = np.empty(N*N, np.float32)
  35. nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
  36. nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
  37. hipallocator.copyin(b, memoryview(bytearray(nb)))
  38. hipallocator.copyin(c, memoryview(bytearray(nc)))
  39. prog_str = f"""
  40. #define F32
  41. typedef long unsigned int size_t;
  42. #define half _Float16
  43. typedef float float8 __attribute__((ext_vector_type(8)));
  44. typedef _Float16 half4 __attribute__((ext_vector_type(4)));
  45. typedef _Float16 half8 __attribute__((ext_vector_type(8)));
  46. typedef _Float16 half16 __attribute__((ext_vector_type(16)));
  47. extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
  48. extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
  49. extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
  50. extern "C" __attribute__((global))void __attribute__((amdgpu_flat_work_group_size(1, {num_threads}))) test(float* c, half* a, half* b) {{
  51. const int gx = __ockl_get_group_id(0) + __ockl_get_local_id(2);
  52. const int gy = __ockl_get_group_id(1) + __ockl_get_local_id(3);
  53. const int lIdx = __ockl_get_local_id(0);
  54. const int lane = lIdx%16;
  55. c += gx*{KX*16}*{N} + gy*{KY*16} + (lIdx/16)*{N} + lane;
  56. a += gx*{KX*16}*{N};
  57. b += gy*{KY*16};
  58. half16 a_frag[{KX}];
  59. half16 b_frag[{KY}];
  60. #ifdef F32
  61. float8 c_frag[{KY}][{KX}] = {{}};
  62. #else
  63. half16 c_frag[{KY}][{KX}] = {{}};
  64. #endif
  65. for (int k = 0; k < {N}; k += 16) {{
  66. __builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");
  67. __builtin_amdgcn_s_barrier();
  68. __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");
  69. for (int ele = 0; ele < 16; ++ele) {{
  70. for (int x = 0; x < {KX}; x++) {{
  71. a_frag[x][ele] = a[(k+ele) + x*{16*N} + {N}*lane];
  72. }}
  73. }}
  74. for (int ele = 0; ele < 16; ++ele) {{
  75. for (int y = 0; y < {KY}; y++) {{
  76. b_frag[y][ele] = b[(k+ele)*{N} + y*16 + lane];
  77. }}
  78. }}
  79. for (int y = 0; y < {KY}; y++) {{
  80. for (int x = 0; x < {KX}; x++) {{
  81. #ifdef F32
  82. c_frag[y][x] = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag[x], b_frag[y], c_frag[y][x]);
  83. #else
  84. c_frag[y][x] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag[x], b_frag[y], c_frag[y][x], false);
  85. #endif
  86. }}
  87. }}
  88. }}
  89. for (int ele = 0; ele < 8; ++ele) {{
  90. for (int y = 0; y < {KY}; y++) {{
  91. for (int x = 0; x < {KX}; x++) {{
  92. #ifdef F32
  93. c[ele*{2*N} + y*16 + x*{16*N}] = c_frag[y][x][ele];
  94. #else
  95. c[ele*{2*N} + y*16 + x*{16*N}] = c_frag[y][x][ele*2];
  96. #endif
  97. }}
  98. }}
  99. }}
  100. }}"""
  101. if DEBUG > 1: print(prog_str)
  102. lib = device.compiler.compile(prog_str)
  103. prog = AMDProgram(device, "test", lib)
  104. def timeit(fxn):
  105. st = time.perf_counter()
  106. et = fxn()
  107. ret = time.perf_counter() - st # NOTE: et doesn't contain the launch overhead
  108. if DEBUG > 0: print(f"{ret*1e6:.2f} us")
  109. # rerun rand
  110. if RAND:
  111. nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
  112. nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
  113. hipallocator.copyin(b, memoryview(bytearray(nb)))
  114. hipallocator.copyin(c, memoryview(bytearray(nc)))
  115. return et
  116. print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}")
  117. tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(CNT)])
  118. hipallocator.copyout(flat_mv(na.data),a)
  119. na = na.reshape(N,N)
  120. comp = nb.astype(np.float32) @ nc.astype(np.float32)
  121. print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
  122. if DEBUG > 2: print(f"which nan={np.where(np.isnan(na))} len={len(np.where(np.isnan(na))[0])}")
  123. if DEBUG > 2: print(f"which diff={np.where(abs(na-comp) > 2e-2)} len={len(np.where(abs(na-comp) > 2e-2)[0])}")
  124. if DEBUG > 2: print(f"which zero={np.where(abs(na) < 2e-2)} len={len(np.where(abs(na) < 2e-2)[0])}")
  125. np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)