intel_xmx.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. from tinygrad.runtime.ops_gpu import CLProgram, CLCompiler
  4. from tinygrad import Device, dtypes
  5. from tinygrad.device import Buffer
  6. from hexdump import hexdump
  7. # https://github.com/intel/intel-graphics-compiler/blob/master/documentation/visa/instructions/DPAS.md
  8. # https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html
  9. # https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
  10. # https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_split_matrix_multiply_accumulate.html
  11. # https://hc34.hotchips.org/assets/program/conference/day1/GPU%20HPC/Intel_s%20Ponte%20Vecchio%20GPU%20-%20Architecture%20Systems%20and%20Software%20FINAL.pdf
  12. device = Device["GPU"]
  13. # NOTE: only the subgroup type 8 ones work
  14. prog = CLProgram(device, "test", CLCompiler(device, "test").compile(f"""
  15. __attribute__((intel_reqd_sub_group_size(8)))
  16. __kernel void test(__global float* data0, const __global int* data1, const __global int8* data2) {{
  17. int lidx0 = get_local_id(0);
  18. int a = data1[lidx0];
  19. int8 b = data2[lidx0];
  20. float out = intel_sub_group_f16_f16_matrix_mad_k16(a, b, 0.0f);
  21. data0[lidx0] = out;
  22. }}
  23. """))
  24. #with open("/tmp/test.elf", "wb") as f: f.write(prog.lib)
  25. a = Buffer("GPU", 8, dtypes.float32)
  26. b = Buffer("GPU", 0x10, dtypes.float16)
  27. c = Buffer("GPU", 8*0x10, dtypes.float16)
  28. row = np.array([1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8], np.float16)
  29. mat = np.random.random((8, 0x10)).astype(np.float16)
  30. b.copyin(row.data)
  31. c.copyin(mat.data)
  32. ret = prog(a._buf, b._buf, c._buf, global_size=[1,1,1], local_size=[8,1,1], wait=True)
  33. print(ret)
  34. out = np.frombuffer(a.as_buffer(), np.float32)
  35. real = row.astype(np.float32)@mat.T.astype(np.float32)
  36. print("out:", out)
  37. print("real", real)