| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- #!/usr/bin/env python3
- import numpy as np
- import time
- import sys
- np.set_printoptions(linewidth=160)
- np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
- from tinygrad.runtime.ops_llvm import LLVMDevice, LLVMProgram, LLVMCompiler
- from llvmlite import ir # type: ignore
- from tinygrad.helpers import flat_mv
- from tinygrad.device import MallocAllocator
- # https://github.com/corsix/amx/blob/main/Instructions.md
- # 12 lines for AMX support
- from functools import partialmethod
- class AMX:
- @staticmethod
- def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True)
- @staticmethod
- def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True)
- set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
- ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3)
- ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7)
- extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
- fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13)
- mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16)
- vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22)
- def int_const(x): return ir.Constant(ir.IntType(64), x)
- N = 4096
- # N = 1024
- # N = 64
- BW = N*N*4
- # matrix is 64M, max load bandwidth is 57 GB/s
- # cache line looks like 256 bytes (64 floats)
- na = np.zeros((256), dtype=np.float32)
- # na = np.zeros((N, N), dtype=np.float32)
- nb = np.random.randn(N, N).astype(np.float32)
- nc = np.random.randn(N, N).astype(np.float32)
- ns = nb.reshape(-1, 32).sum(axis=0)
- a = MallocAllocator.alloc(na.size * np.dtype(np.float32).itemsize)
- b = MallocAllocator.alloc(nb.size * np.dtype(np.float32).itemsize)
- c = MallocAllocator.alloc(nc.size * np.dtype(np.float32).itemsize)
- MallocAllocator.copyin(b, flat_mv(nb.data))
- MallocAllocator.copyin(c, flat_mv(nc.data))
- module = ir.Module(name=__file__)
- func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
- # load all
- entry = ir.IRBuilder(func.append_basic_block(name="entry"))
- zm, xm, ym = [entry.ptrtoint(func.args[i], ir.IntType(64)) for i in range(3)]
- loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y"))
- loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit"))
- exit = ir.IRBuilder(func.append_basic_block(name="exit"))
- y = loop_1.phi(ir.IntType(64), name="y")
- y.add_incoming(int_const(0), entry._block)
- yp = loop_1_exit.add(y, int_const(32*2))
- y.add_incoming(yp, loop_1_exit._block)
- prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch")
- xptr = y
- addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
- #prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType()))
- #loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)])
- AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1<<62), addr))
- xptr = loop_1_exit.add(xptr, int_const(32))
- AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))))
- AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
- AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16*4)<<10))
- AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29))
- AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16*4)))
- AMX.set(entry)
- AMX.stz(exit, exit.add(zm, int_const(1 << 62 | (0 << 56) | 0)))
- AMX.clr(exit)
- entry.branch(loop_1._block)
- loop_1.branch(loop_1_exit._block)
- loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block)
- exit.ret(int_const(0))
- device = LLVMDevice("llvm")
- prog = LLVMProgram(device, "exec", LLVMCompiler(device).compile(str(module)))
- """
- loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y"))
- loop_2 = ir.IRBuilder(func.append_basic_block(name="loop_x"))
- loop_3 = ir.IRBuilder(func.append_basic_block(name="loop_k"))
- loop_3_exit = ir.IRBuilder(func.append_basic_block(name="loop_k_exit"))
- loop_2_exit = ir.IRBuilder(func.append_basic_block(name="loop_x_exit"))
- loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit"))
- y = loop_1.phi(ir.IntType(64), name="y")
- x = loop_2.phi(ir.IntType(64), name="x")
- k = loop_3.phi(ir.IntType(64), name="k")
- exit = ir.IRBuilder(func.append_basic_block(name="exit"))
- AMX.set(loop_2)
- # stride
- xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(N)))
- yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(N)))
- # if you are okay with the wrong answer, this is faster
- #xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(32)))
- #yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(32)))
- # double loads load 32 floats
- AMX.ldx(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(xm, loop_3_exit.mul(int_const(4), xptr))))
- AMX.ldy(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(ym, loop_3_exit.mul(int_const(4), yptr))))
- # <Z row> <X offset> <Y offset>
- AMX.fma32(loop_3_exit, int_const(0<<20 | (0*16*4)<<10 | (0*16*4)))
- AMX.fma32(loop_3_exit, int_const(1<<20 | (1*16*4)<<10 | (0*16*4)))
- AMX.fma32(loop_3_exit, int_const(2<<20 | (0*16*4)<<10 | (1*16*4)))
- AMX.fma32(loop_3_exit, int_const(3<<20 | (1*16*4)<<10 | (1*16*4)))
- # store
- gptr = loop_2_exit.mul(loop_2_exit.add(loop_2.mul(y, int_const(N)), x), int_const(4))
- zmp = loop_2_exit.add(zm, gptr)
- for j in range(2):
- for r in range(16):
- z_row = j*2
- ptr = ((j*16)+r)*N
- AMX.stz(loop_2_exit, loop_2_exit.add(zmp, int_const(1 << 62 | ((r*4+z_row) << 56) | ptr*4)))
- AMX.clr(loop_2_exit)
- yp = loop_1_exit.add(y, int_const(32))
- xp = loop_2_exit.add(x, int_const(32))
- kp = loop_3_exit.add(k, int_const(1))
- y.add_incoming(int_const(0), entry._block)
- x.add_incoming(int_const(0), loop_1._block)
- k.add_incoming(int_const(0), loop_2._block)
- y.add_incoming(yp, loop_1_exit._block)
- x.add_incoming(xp, loop_2_exit._block)
- k.add_incoming(kp, loop_3_exit._block)
- entry.branch(loop_1._block)
- loop_1.branch(loop_2._block)
- loop_2.branch(loop_3._block)
- loop_3.branch(loop_3_exit._block)
- loop_3_exit.cbranch(loop_3_exit.icmp_unsigned("==", kp, int_const(N)), loop_2_exit._block, loop_3._block)
- loop_2_exit.cbranch(loop_2_exit.icmp_unsigned("==", xp, int_const(N)), loop_1_exit._block, loop_2._block)
- loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N)), exit._block, loop_1._block)
- exit.ret(int_const(0))
- device = LLVMDevice("llvm")
- prog = LLVMProgram(device, "exec", LLVMCompiler(device).compile(str(module)))
- """
- def timeit(fxn):
- st = time.perf_counter()
- et = fxn()
- return time.perf_counter() - st
- tm = min([timeit(lambda: prog(a, b, c, N**2)) for _ in range(20)])
- MallocAllocator.copyout(flat_mv(na.data), a)
- print(f"{N*N:10d} {tm*1e6:9.2f} us, {BW*1e-9/tm:.2f} GB/s")
- np.testing.assert_allclose(na[:ns.shape[0]], ns, atol=1e-4, rtol=1e-4)
- # comp = (nb.T @ nc).T
- # np.testing.assert_allclose(na, comp, atol=1e-4, rtol=1e-5)
|