assembly_rdna.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import yaml
  2. from typing import Tuple, Set, Dict
  3. from tinygrad import dtypes
  4. from tinygrad.codegen.assembly import AssemblyCodegen, Register
  5. from tinygrad.codegen.kernel import UOps
  6. from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
  7. from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
  8. # ugh, is this really needed?
  9. from extra.helpers import enable_early_exec
  10. early_exec = enable_early_exec()
  11. boilerplate_start = """
  12. .global _start
  13. _start:
  14. .rodata
  15. .align 0x10
  16. .global code.kd
  17. .type code.kd,STT_OBJECT
  18. .amdhsa_kernel code"""
  19. code_start = """.end_amdhsa_kernel
  20. .text
  21. code:
  22. """
  23. # https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
  24. # https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
  25. # RDNA3 is actually a SIMD machine!
  26. class RDNACodegen(AssemblyCodegen):
  27. supports_float4: bool = True
  28. supports_float4_alu: bool = True
  29. supports_load3: bool = True
  30. sin_is_sin2pi: bool = True
  31. no_div: bool = True
  32. def specialize(self, asm) -> Tuple[str, str]:
  33. args = []
  34. for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
  35. ins = []
  36. v_cnt = 3 # v[0:2] is local_xyz
  37. s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
  38. dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
  39. alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
  40. BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
  41. UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
  42. BinaryOps.CMPLT: "cmp_lt"}
  43. pend_regs:Set[Register] = set()
  44. rtor:Dict[Register, str] = {}
  45. def reg_in(x):
  46. nonlocal pend_regs
  47. #print("reg_in", x, rtor[x], pend_regs)
  48. if x in pend_regs:
  49. #print("clear")
  50. ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
  51. pend_regs.clear()
  52. return rtor[x]
  53. def reg_out(x):
  54. return rtor[x]
  55. for uop, out, vin, arg in asm:
  56. if uop == UOps.DEFINE_REGISTER:
  57. if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
  58. for i in range(arg[2]):
  59. # TODO: Re-use gaps created by this to avoid wasting registers
  60. align = int(arg[0][0].itemsize / 4)
  61. if arg[0][1]:
  62. s_cnt += s_cnt % align
  63. reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
  64. s_cnt += align
  65. else:
  66. v_cnt += v_cnt % align
  67. reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
  68. v_cnt += align
  69. rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
  70. if arg[0][0] == dtypes.float.vec(4):
  71. for off in range(4):
  72. reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
  73. rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
  74. elif arg[0][0] == dtypes.bool:
  75. for i in range(arg[2]):
  76. reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
  77. rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
  78. else:
  79. raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
  80. elif uop == UOps.SPECIAL:
  81. if arg.startswith('buf'):
  82. i = int(arg[3:])
  83. ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
  84. pend_regs.add(out)
  85. for r in out.subregs(): pend_regs.add(r)
  86. elif arg.startswith('gid'):
  87. ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
  88. # the docs lied, this is actually y
  89. if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
  90. if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
  91. elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
  92. # get local size
  93. offset = len(args)*8
  94. args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
  95. ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
  96. ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
  97. pend_regs.clear()
  98. ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
  99. ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
  100. elif uop == UOps.CONST:
  101. if arg == float('inf'): arg = "0x7f800000"
  102. elif arg == float('-inf'): arg = "0xff800000"
  103. if out.dtype == dtypes.float.vec(4):
  104. for off in range(4):
  105. ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
  106. else:
  107. ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
  108. elif uop == UOps.ALU:
  109. if arg in [BinaryOps.CMPLT]:
  110. ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
  111. else:
  112. alu_arg = alu[arg]
  113. if arg == TernaryOps.MULACC and out == vin[2]:
  114. alu_arg = "fmac"
  115. vin = vin[0:2]
  116. if out.dtype == dtypes.float.vec(4):
  117. for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
  118. ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
  119. else:
  120. ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
  121. elif uop == UOps.LOAD:
  122. if out.scalar:
  123. # swap arg order
  124. ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
  125. else:
  126. ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
  127. pend_regs.add(out)
  128. for r in out.subregs(): pend_regs.add(r)
  129. elif uop == UOps.STORE:
  130. ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
  131. elif uop == UOps.LABEL:
  132. ins.append(f"{arg}:")
  133. elif uop == UOps.COND_BRANCH:
  134. ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
  135. elif uop == UOps.CAST:
  136. if vin[0].dtype == dtypes.bool:
  137. if out.dtype == dtypes.float32:
  138. ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
  139. else:
  140. raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
  141. else:
  142. raise NotImplementedError(uop)
  143. ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
  144. # dual alu group
  145. seen = set()
  146. new_ins = []
  147. for i,tins in enumerate(ins):
  148. if tins in seen: continue
  149. if tins.startswith("v_fmac_f32"):
  150. for gins in reversed(ins[i+1:]):
  151. if gins in seen: continue
  152. if gins.startswith("v_fmac_f32"):
  153. r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
  154. r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
  155. if r0[0]%2 == r1[0]%2: continue
  156. if r0[1]%2 == r1[1]%2: continue
  157. if r0[2]%2 == r1[2]%2: continue
  158. new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
  159. seen.add(tins)
  160. seen.add(gins)
  161. break
  162. if tins not in seen:
  163. new_ins.append(tins)
  164. ins = new_ins
  165. return 'code', self.assemble(args, ins, v_cnt, s_cnt)
  166. def assemble(self, args, ins, v_cnt, s_cnt):
  167. kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
  168. '.amdhsa_next_free_vgpr': v_cnt, # this matters!
  169. '.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
  170. '.amdhsa_next_free_sgpr': s_cnt,
  171. '.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
  172. '.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
  173. '.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
  174. '.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
  175. '.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
  176. '.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
  177. '.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
  178. metadata = {'amdhsa.kernels': [{'.args': args,
  179. '.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
  180. '.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
  181. '.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
  182. '.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
  183. '.wavefront_size': 32}],
  184. 'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
  185. code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
  186. obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
  187. asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
  188. return asm