amx.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. import time
  4. import sys
  5. np.set_printoptions(linewidth=160)
  6. np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
  7. from tinygrad.runtime.ops_llvm import LLVMDevice, LLVMProgram, LLVMCompiler
  8. from llvmlite import ir # type: ignore
  9. from tinygrad.helpers import flat_mv
  10. from tinygrad.device import MallocAllocator
  11. # https://github.com/corsix/amx/blob/main/Instructions.md
  12. # 12 lines for AMX support
  13. from functools import partialmethod
  14. class AMX:
  15. @staticmethod
  16. def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True)
  17. @staticmethod
  18. def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True)
  19. set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
  20. ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3)
  21. ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7)
  22. extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
  23. fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13)
  24. mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16)
  25. vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22)
  26. def int_const(x): return ir.Constant(ir.IntType(64), x)
  27. N = 4096
  28. # N = 1024
  29. # N = 64
  30. BW = N*N*4
  31. # matrix is 64M, max load bandwidth is 57 GB/s
  32. # cache line looks like 256 bytes (64 floats)
  33. na = np.zeros((256), dtype=np.float32)
  34. # na = np.zeros((N, N), dtype=np.float32)
  35. nb = np.random.randn(N, N).astype(np.float32)
  36. nc = np.random.randn(N, N).astype(np.float32)
  37. ns = nb.reshape(-1, 32).sum(axis=0)
  38. a = MallocAllocator.alloc(na.size * np.dtype(np.float32).itemsize)
  39. b = MallocAllocator.alloc(nb.size * np.dtype(np.float32).itemsize)
  40. c = MallocAllocator.alloc(nc.size * np.dtype(np.float32).itemsize)
  41. MallocAllocator.copyin(b, flat_mv(nb.data))
  42. MallocAllocator.copyin(c, flat_mv(nc.data))
  43. module = ir.Module(name=__file__)
  44. func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
  45. # load all
  46. entry = ir.IRBuilder(func.append_basic_block(name="entry"))
  47. zm, xm, ym = [entry.ptrtoint(func.args[i], ir.IntType(64)) for i in range(3)]
  48. loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y"))
  49. loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit"))
  50. exit = ir.IRBuilder(func.append_basic_block(name="exit"))
  51. y = loop_1.phi(ir.IntType(64), name="y")
  52. y.add_incoming(int_const(0), entry._block)
  53. yp = loop_1_exit.add(y, int_const(32*2))
  54. y.add_incoming(yp, loop_1_exit._block)
  55. prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch")
  56. xptr = y
  57. addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
  58. #prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType()))
  59. #loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)])
  60. AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1<<62), addr))
  61. xptr = loop_1_exit.add(xptr, int_const(32))
  62. AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))))
  63. AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
  64. AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16*4)<<10))
  65. AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29))
  66. AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16*4)))
  67. AMX.set(entry)
  68. AMX.stz(exit, exit.add(zm, int_const(1 << 62 | (0 << 56) | 0)))
  69. AMX.clr(exit)
  70. entry.branch(loop_1._block)
  71. loop_1.branch(loop_1_exit._block)
  72. loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block)
  73. exit.ret(int_const(0))
  74. device = LLVMDevice("llvm")
  75. prog = LLVMProgram(device, "exec", LLVMCompiler(device).compile(str(module)))
  76. """
  77. loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y"))
  78. loop_2 = ir.IRBuilder(func.append_basic_block(name="loop_x"))
  79. loop_3 = ir.IRBuilder(func.append_basic_block(name="loop_k"))
  80. loop_3_exit = ir.IRBuilder(func.append_basic_block(name="loop_k_exit"))
  81. loop_2_exit = ir.IRBuilder(func.append_basic_block(name="loop_x_exit"))
  82. loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit"))
  83. y = loop_1.phi(ir.IntType(64), name="y")
  84. x = loop_2.phi(ir.IntType(64), name="x")
  85. k = loop_3.phi(ir.IntType(64), name="k")
  86. exit = ir.IRBuilder(func.append_basic_block(name="exit"))
  87. AMX.set(loop_2)
  88. # stride
  89. xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(N)))
  90. yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(N)))
  91. # if you are okay with the wrong answer, this is faster
  92. #xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(32)))
  93. #yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(32)))
  94. # double loads load 32 floats
  95. AMX.ldx(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(xm, loop_3_exit.mul(int_const(4), xptr))))
  96. AMX.ldy(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(ym, loop_3_exit.mul(int_const(4), yptr))))
  97. # <Z row> <X offset> <Y offset>
  98. AMX.fma32(loop_3_exit, int_const(0<<20 | (0*16*4)<<10 | (0*16*4)))
  99. AMX.fma32(loop_3_exit, int_const(1<<20 | (1*16*4)<<10 | (0*16*4)))
  100. AMX.fma32(loop_3_exit, int_const(2<<20 | (0*16*4)<<10 | (1*16*4)))
  101. AMX.fma32(loop_3_exit, int_const(3<<20 | (1*16*4)<<10 | (1*16*4)))
  102. # store
  103. gptr = loop_2_exit.mul(loop_2_exit.add(loop_2.mul(y, int_const(N)), x), int_const(4))
  104. zmp = loop_2_exit.add(zm, gptr)
  105. for j in range(2):
  106. for r in range(16):
  107. z_row = j*2
  108. ptr = ((j*16)+r)*N
  109. AMX.stz(loop_2_exit, loop_2_exit.add(zmp, int_const(1 << 62 | ((r*4+z_row) << 56) | ptr*4)))
  110. AMX.clr(loop_2_exit)
  111. yp = loop_1_exit.add(y, int_const(32))
  112. xp = loop_2_exit.add(x, int_const(32))
  113. kp = loop_3_exit.add(k, int_const(1))
  114. y.add_incoming(int_const(0), entry._block)
  115. x.add_incoming(int_const(0), loop_1._block)
  116. k.add_incoming(int_const(0), loop_2._block)
  117. y.add_incoming(yp, loop_1_exit._block)
  118. x.add_incoming(xp, loop_2_exit._block)
  119. k.add_incoming(kp, loop_3_exit._block)
  120. entry.branch(loop_1._block)
  121. loop_1.branch(loop_2._block)
  122. loop_2.branch(loop_3._block)
  123. loop_3.branch(loop_3_exit._block)
  124. loop_3_exit.cbranch(loop_3_exit.icmp_unsigned("==", kp, int_const(N)), loop_2_exit._block, loop_3._block)
  125. loop_2_exit.cbranch(loop_2_exit.icmp_unsigned("==", xp, int_const(N)), loop_1_exit._block, loop_2._block)
  126. loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N)), exit._block, loop_1._block)
  127. exit.ret(int_const(0))
  128. device = LLVMDevice("llvm")
  129. prog = LLVMProgram(device, "exec", LLVMCompiler(device).compile(str(module)))
  130. """
  131. def timeit(fxn):
  132. st = time.perf_counter()
  133. et = fxn()
  134. return time.perf_counter() - st
  135. tm = min([timeit(lambda: prog(a, b, c, N**2)) for _ in range(20)])
  136. MallocAllocator.copyout(flat_mv(na.data), a)
  137. print(f"{N*N:10d} {tm*1e6:9.2f} us, {BW*1e-9/tm:.2f} GB/s")
  138. np.testing.assert_allclose(na[:ns.shape[0]], ns, atol=1e-4, rtol=1e-4)
  139. # comp = (nb.T @ nc).T
  140. # np.testing.assert_allclose(na, comp, atol=1e-4, rtol=1e-5)