assembly_arm64.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import struct
  2. from platform import system
  3. from typing import Tuple, Dict, List, Optional
  4. from tinygrad import dtypes
  5. from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
  6. from tinygrad.codegen.kernel import UOps, UOp
  7. from tinygrad.helpers import CI
  8. from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
  9. def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
  10. def compute_offsets(total):
  11. quotient, remainder = divmod(total, 4096)
  12. return [4096]*quotient + [remainder] if remainder else [4096]*quotient
  13. #NOTE: Darwin needs names to start with a "_"
  14. def get_name(name): return ('_' if system() == 'Darwin' else '') + name
  15. class ARM64Language(AssemblyLanguage): pass
  16. def specialize_to_arm64(fn_nm, asm):
  17. var_size = 16
  18. prev_uop:Optional[UOps] = None
  19. ins = []
  20. x_regs = ['x' + str(i) for i in reversed(range(12))]
  21. s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
  22. type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
  23. alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
  24. BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
  25. UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
  26. UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
  27. TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
  28. def mov_imm(value, reg):
  29. # Manually move value into reg if value can't fit
  30. if value.__class__ is not float and abs(value) > abs(65535):
  31. ins.append(f"movz w15, #{value & 0xffff}")
  32. ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
  33. ins.append(f"sxtw {reg}, w15")
  34. elif reg[0] == 's':
  35. ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
  36. ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
  37. ins.append("str x15, [sp, 16]")
  38. ins.append(f"ldr {reg}, [sp, 16]")
  39. else:
  40. ins.append(f"mov {reg}, #{value}")
  41. # Get variables intervals
  42. live_range:Dict[str, List[int]] = {}
  43. for i, (uop, out, vin, arg) in enumerate(asm):
  44. for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
  45. live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
  46. mem_vars:Dict[str, int] = {}
  47. rtor:Dict[str, str] = {}
  48. def allocate_regs(mvars):
  49. nonlocal var_size
  50. for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
  51. available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
  52. #NOTE: Very simple spill, everything that don't fit in regs goes to mem
  53. if not available_regs:
  54. # ARM needs the stack 16-byte aligned
  55. var_size += 16
  56. available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
  57. mem_vars[v.nm] = var_size
  58. rtor[v.nm] = available_regs.pop()
  59. temp_floats = ['s0', 's1', 's2']
  60. temp_ints = ['x12', 'x13', 'x16']
  61. for i, (uop, out, vin, arg) in enumerate(asm):
  62. # Clear regs out of interval
  63. for var, reg in list(rtor.items()):
  64. available_regs = s_regs if reg[0] == 's' else x_regs
  65. if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
  66. available_regs.append(rtor.pop(var))
  67. # Assign a registers to the variables using live ranges.
  68. allocate_regs([out] + vin)
  69. # Assign temp regs to vin and load them before direct use
  70. for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
  71. rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
  72. # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
  73. ins.append(f"mov x15, {mem_vars[v.nm]}")
  74. ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
  75. if uop == UOps.SPECIAL:
  76. if arg.startswith('data'):
  77. # data 8 to n into the stack
  78. if int(arg[4:]) >= 8:
  79. ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
  80. ins.append(f"mov {rtor[out.nm]}, x15")
  81. else:
  82. ins.append(f"mov {rtor[out.nm]}, #0")
  83. ins.append(f"loop_{arg}:")
  84. elif uop == UOps.CAST:
  85. if arg == BinaryOps.CMPLT:
  86. if rtor[out.nm][0] == 's':
  87. mov_imm(0.0, 's0')
  88. mov_imm(1.0, 's1')
  89. ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
  90. if rtor[out.nm][0] == 'x':
  91. mov_imm(0, 'x14')
  92. mov_imm(1, 'x15')
  93. ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
  94. else:
  95. ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
  96. elif uop == UOps.ALU:
  97. if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
  98. if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
  99. ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
  100. elif arg == TernaryOps.WHERE:
  101. ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
  102. ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
  103. elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
  104. #NOTE: Not a real instruction, use to emulate a ext call in unicorn
  105. if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
  106. else:
  107. save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
  108. ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
  109. # Save the registers before they are cleared by func call
  110. for i,k in enumerate(save_regs,1):
  111. ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
  112. ins.append("stp x29, x30, [sp, #0]!")
  113. ins.append("mov x29, sp")
  114. ins.append(f"fmov s0, {rtor[vin[0].nm]}")
  115. ins.append(alu[arg])
  116. ins.append(f"fmov {rtor[out.nm]}, s0")
  117. ins.append("mov sp, x29")
  118. ins.append("ldp x29, x30, [sp], #0")
  119. for i,k in enumerate(save_regs,1):
  120. ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
  121. ins.append(f"add sp, sp, #{len(save_regs)*16}")
  122. elif arg == BinaryOps.CMPLT:
  123. ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
  124. elif arg == BinaryOps.MOD:
  125. rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
  126. ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
  127. ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
  128. else:
  129. ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
  130. elif uop == UOps.LOAD:
  131. if arg.__class__ in (int, float):
  132. mov_imm(arg, rtor[out.nm])
  133. else:
  134. #NOTE: if need casting load var in s/h0 or x/w12 temp regs
  135. reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
  136. mov_imm(arg[0], "x15")
  137. ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
  138. ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
  139. if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
  140. elif uop == UOps.STORE:
  141. #NOTE: if need casting load var in s/h0 or x/w12 temp regs
  142. reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
  143. if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
  144. ins.append(f"mov x15, #{arg[0]}")
  145. ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
  146. elif uop == UOps.COND_BRANCH:
  147. #TODO: this is a hack it shouldn't always be a cmp before a cond branch?
  148. if prev_uop == UOps.LOAD:
  149. ins.append(f"cmp {rtor[vin[0].nm]}, #0")
  150. ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
  151. elif uop == UOps.LABEL:
  152. ins.append(f"{arg[1:]}:")
  153. elif uop == UOps.ENDLOOP:
  154. mov_imm(arg[0], "x15")
  155. ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
  156. ins.append(f"cmp {rtor[vin[0].nm]}, x15")
  157. ins.append(f"b.lt loop_{arg[1]}")
  158. prev_uop = uop
  159. # store regs into memory if needed
  160. if out is not None and out.nm in mem_vars:
  161. ins.append(f"mov x15, {mem_vars[out.nm]}")
  162. ins.append(f"str {rtor[out.nm]}, [sp, x15]")
  163. return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
  164. def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
  165. lang = ARM64Language()
  166. global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
  167. return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True