| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- import struct
- from platform import system
- from typing import Tuple, Dict, List, Optional
- from tinygrad import dtypes
- from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
- from tinygrad.codegen.kernel import UOps, UOp
- from tinygrad.helpers import CI
- from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
- def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
- def compute_offsets(total):
- quotient, remainder = divmod(total, 4096)
- return [4096]*quotient + [remainder] if remainder else [4096]*quotient
- #NOTE: Darwin needs names to start with a "_"
- def get_name(name): return ('_' if system() == 'Darwin' else '') + name
- class ARM64Language(AssemblyLanguage): pass
- def specialize_to_arm64(fn_nm, asm):
- var_size = 16
- prev_uop:Optional[UOps] = None
- ins = []
- x_regs = ['x' + str(i) for i in reversed(range(12))]
- s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
- type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
- alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
- BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
- UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
- UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
- TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
- def mov_imm(value, reg):
- # Manually move value into reg if value can't fit
- if value.__class__ is not float and abs(value) > abs(65535):
- ins.append(f"movz w15, #{value & 0xffff}")
- ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
- ins.append(f"sxtw {reg}, w15")
- elif reg[0] == 's':
- ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
- ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
- ins.append("str x15, [sp, 16]")
- ins.append(f"ldr {reg}, [sp, 16]")
- else:
- ins.append(f"mov {reg}, #{value}")
- # Get variables intervals
- live_range:Dict[str, List[int]] = {}
- for i, (uop, out, vin, arg) in enumerate(asm):
- for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
- live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
- mem_vars:Dict[str, int] = {}
- rtor:Dict[str, str] = {}
- def allocate_regs(mvars):
- nonlocal var_size
- for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
- available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
- #NOTE: Very simple spill, everything that don't fit in regs goes to mem
- if not available_regs:
- # ARM needs the stack 16-byte aligned
- var_size += 16
- available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
- mem_vars[v.nm] = var_size
- rtor[v.nm] = available_regs.pop()
- temp_floats = ['s0', 's1', 's2']
- temp_ints = ['x12', 'x13', 'x16']
- for i, (uop, out, vin, arg) in enumerate(asm):
- # Clear regs out of interval
- for var, reg in list(rtor.items()):
- available_regs = s_regs if reg[0] == 's' else x_regs
- if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
- available_regs.append(rtor.pop(var))
- # Assign a registers to the variables using live ranges.
- allocate_regs([out] + vin)
- # Assign temp regs to vin and load them before direct use
- for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
- rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
- # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
- ins.append(f"mov x15, {mem_vars[v.nm]}")
- ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
- if uop == UOps.SPECIAL:
- if arg.startswith('data'):
- # data 8 to n into the stack
- if int(arg[4:]) >= 8:
- ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
- ins.append(f"mov {rtor[out.nm]}, x15")
- else:
- ins.append(f"mov {rtor[out.nm]}, #0")
- ins.append(f"loop_{arg}:")
- elif uop == UOps.CAST:
- if arg == BinaryOps.CMPLT:
- if rtor[out.nm][0] == 's':
- mov_imm(0.0, 's0')
- mov_imm(1.0, 's1')
- ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
- if rtor[out.nm][0] == 'x':
- mov_imm(0, 'x14')
- mov_imm(1, 'x15')
- ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
- else:
- ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
- elif uop == UOps.ALU:
- if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
- if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
- ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
- elif arg == TernaryOps.WHERE:
- ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
- ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
- elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
- #NOTE: Not a real instruction, use to emulate a ext call in unicorn
- if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
- else:
- save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
- ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
- # Save the registers before they are cleared by func call
- for i,k in enumerate(save_regs,1):
- ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
- ins.append("stp x29, x30, [sp, #0]!")
- ins.append("mov x29, sp")
- ins.append(f"fmov s0, {rtor[vin[0].nm]}")
- ins.append(alu[arg])
- ins.append(f"fmov {rtor[out.nm]}, s0")
- ins.append("mov sp, x29")
- ins.append("ldp x29, x30, [sp], #0")
- for i,k in enumerate(save_regs,1):
- ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
- ins.append(f"add sp, sp, #{len(save_regs)*16}")
- elif arg == BinaryOps.CMPLT:
- ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
- elif arg == BinaryOps.MOD:
- rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
- ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
- ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
- else:
- ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
- elif uop == UOps.LOAD:
- if arg.__class__ in (int, float):
- mov_imm(arg, rtor[out.nm])
- else:
- #NOTE: if need casting load var in s/h0 or x/w12 temp regs
- reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
- mov_imm(arg[0], "x15")
- ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
- ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
- if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
- elif uop == UOps.STORE:
- #NOTE: if need casting load var in s/h0 or x/w12 temp regs
- reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
- if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
- ins.append(f"mov x15, #{arg[0]}")
- ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
- elif uop == UOps.COND_BRANCH:
- #TODO: this is a hack it shouldn't always be a cmp before a cond branch?
- if prev_uop == UOps.LOAD:
- ins.append(f"cmp {rtor[vin[0].nm]}, #0")
- ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
- elif uop == UOps.LABEL:
- ins.append(f"{arg[1:]}:")
- elif uop == UOps.ENDLOOP:
- mov_imm(arg[0], "x15")
- ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
- ins.append(f"cmp {rtor[vin[0].nm]}, x15")
- ins.append(f"b.lt loop_{arg[1]}")
- prev_uop = uop
- # store regs into memory if needed
- if out is not None and out.nm in mem_vars:
- ins.append(f"mov x15, {mem_vars[out.nm]}")
- ins.append(f"str {rtor[out.nm]}, [sp, x15]")
- return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
- def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
- lang = ARM64Language()
- global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
- return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|