rdna.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from typing import Dict, Set
  2. import yaml
  3. from tinygrad.codegen.uops import UOpGraph, UOps, UOp
  4. from tinygrad.ops import BinaryOps
  5. from tinygrad.dtype import dtypes
  6. def uops_to_rdna(function_name:str, uops:UOpGraph) -> str:
  7. replace: Dict[UOp, UOp] = {}
  8. seen: Set[UOp] = set()
  9. for u in uops:
  10. if u in seen: continue
  11. seen.add(u)
  12. for o,n in replace.items():
  13. if o in u.vin and u is not n:
  14. u.vin = tuple(n if x == o else x for x in u.vin)
  15. # pointer indexing
  16. if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1:
  17. val = UOp(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_before=uops.uops.index(u))
  18. ptr = UOp(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(u))
  19. u.vin = (u.vin[0], ptr) + u.vin[2:]
  20. #uops.print()
  21. args = []
  22. ins = []
  23. v_cnt = 3 # v[0:2] is local_xyz
  24. s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
  25. r: Dict[UOp, str] = {}
  26. for u in uops:
  27. if u.uop == UOps.SPECIAL:
  28. if u.arg[1].startswith("lidx"):
  29. r[u] = f'v{u.arg[0]}'
  30. elif u.arg[1].startswith("gidx"):
  31. r[u] = f's{2+u.arg[0]}'
  32. else:
  33. raise NotImplementedError
  34. elif u.uop == UOps.CONST:
  35. #r[u] = u.arg
  36. # TODO: sometimes we can use s
  37. #r[u] = f"s{s_cnt}"
  38. #s_cnt += 1
  39. #ins.append(f"s_mov_b32 {r[u]}, {u.arg}")
  40. r[u] = f"v{v_cnt}"
  41. v_cnt += 1
  42. ins.append(f"v_mov_b32 {r[u]}, {u.arg}")
  43. elif u.uop == UOps.ALU:
  44. if u.arg == BinaryOps.ADD:
  45. r[u] = f"v{v_cnt}"
  46. v_cnt += 1
  47. ins.append(f"v_add_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
  48. elif u.arg == BinaryOps.MUL:
  49. r[u] = f"v{v_cnt}"
  50. v_cnt += 1
  51. if dtypes.is_float(u.dtype):
  52. ins.append(f"v_mul_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
  53. else:
  54. ins.append(f"v_mul_u32_u24 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
  55. else:
  56. raise NotImplementedError
  57. elif u.uop == UOps.LOAD:
  58. r[u] = f"v{v_cnt}"
  59. v_cnt += 1
  60. ins.append(f"global_load_b32 {r[u]}, {r[u.vin[1]]}, {r[u.vin[0]]}")
  61. ins.append("s_waitcnt vmcnt(0)")
  62. elif u.uop == UOps.STORE:
  63. ins.append(f"global_store_b32 {r[u.vin[1]]}, {r[u.vin[2]]}, {r[u.vin[0]]}")
  64. elif u.uop == UOps.DEFINE_GLOBAL:
  65. i = u.arg[0]
  66. args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8,
  67. '.type_name': u.dtype.name+"*", '.value_kind': 'global_buffer'})
  68. s_cnt += s_cnt%2 # skip
  69. r[u] = f"s[{s_cnt}:{s_cnt+1}]"
  70. s_cnt += 2
  71. ins.append(f"s_load_b64 {r[u]}, s[0:1], {i*8}")
  72. ins.append("s_waitcnt lgkmcnt(0)")
  73. else:
  74. raise NotImplementedError(f"can't render {u.uop}")
  75. # *** boilerplate rendering ***
  76. metadata = {
  77. 'amdhsa.kernels': [{'.args': args,
  78. '.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
  79. '.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
  80. '.name': function_name, '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
  81. '.symbol': f'{function_name}.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
  82. '.wavefront_size': 32}],
  83. 'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
  84. boilerplate_start = f"""
  85. .rodata
  86. .global {function_name}.kd
  87. .type {function_name}.kd,STT_OBJECT
  88. .align 0x10
  89. .amdhsa_kernel {function_name}"""
  90. kernel_desc = {
  91. '.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
  92. '.amdhsa_next_free_vgpr': v_cnt, # this matters!
  93. '.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
  94. '.amdhsa_next_free_sgpr': s_cnt,
  95. '.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3,
  96. '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, '.amdhsa_fp16_overflow': 0,
  97. '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
  98. '.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
  99. '.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
  100. '.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0,
  101. '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
  102. '.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0,
  103. '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
  104. '.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
  105. code_start = f""".end_amdhsa_kernel
  106. .text
  107. .global {function_name}
  108. .type {function_name},@function
  109. .p2align 8
  110. {function_name}:
  111. """
  112. ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
  113. return ".amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" + \
  114. boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + \
  115. '\n'.join(ins) + f"\n.size {function_name}, .-{function_name}"