assembly.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
  2. from tinygrad.codegen.kernel import UOps, MemOp, UOp
  3. from tinygrad.ops import BinaryOps, UnaryOps
  4. from tinygrad.dtype import DType, dtypes
  5. from tinygrad.helpers import DEBUG
  6. from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
  7. import functools
  8. import math
  9. from collections import defaultdict
  10. _type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
  11. dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
  12. class Register(NamedTuple):
  13. nm:str
  14. dtype:DType
  15. scalar:bool
  16. off:Optional[int] = None
  17. def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
  18. def subregs(self):
  19. if self.dtype == dtypes.float.vec(4):
  20. return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
  21. return []
  22. class AssemblyInstruction(NamedTuple):
  23. op: UOps
  24. out: Optional[Register]
  25. vin: List[Union[Register, int, float]]
  26. arg: Any = None
  27. # warp size of 32, s registers are shared across the warp, v are 32-wide vectors
  28. class AssemblyLanguage:
  29. supports_load3: bool = False
  30. sin_is_sin2pi: bool = False
  31. no_div: bool = False
  32. #TODO: these should be global vars
  33. cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
  34. tor: Dict[Any, Register] = {}
  35. ins: List[AssemblyInstruction] = []
  36. def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
  37. def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
  38. self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
  39. if dtype == dtypes.float.vec(4):
  40. for off in range(4):
  41. self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
  42. self.cnts[(dtype, scalar)] += 1
  43. return ret
  44. def render_numnode(self, b) -> Register:
  45. key = ("num", b)
  46. if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
  47. return self.tor[key]
  48. def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
  49. key = (op, a, b)
  50. if key not in self.tor:
  51. #if not isinstance(b, Register): b = render_numnode(b)
  52. self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
  53. return self.tor[key]
  54. def render_cast(self, a:Register, new_dtype:DType) -> Register:
  55. if a.dtype == new_dtype: return a
  56. key = (a, new_dtype)
  57. if key not in self.tor:
  58. self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
  59. return self.tor[key]
  60. render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
  61. MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
  62. DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
  63. ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
  64. LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
  65. SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
  66. AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
  67. def addr_w_offset(self, args):
  68. assert isinstance(args, MemOp)
  69. idx = args.idx*args.memory_dtype.itemsize
  70. off = 0 # TODO: should this be None?
  71. if isinstance(idx, SumNode):
  72. nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
  73. if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
  74. idx -= nums[0]
  75. off = cast(int, nums[0])
  76. reg = idx.render(self.render_ops, self)
  77. if self.supports_load3:
  78. if reg.scalar:
  79. new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
  80. self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
  81. reg = new_reg
  82. return self.tor[args.name], reg, off
  83. reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
  84. return reg, None, off
  85. def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
  86. #TODO: Do not use clear()
  87. lang.ins.clear()
  88. lang.tor.clear()
  89. lang.cnts.clear()
  90. buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
  91. global_size, local_size = [], []
  92. skipload_branch = 0
  93. lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
  94. for u in uops:
  95. uop,dtype,vin,args,_ = u
  96. if uop == UOps.DEFINE_LOCAL:
  97. lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
  98. lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
  99. elif uop == UOps.LOOP:
  100. if args[1] == "global":
  101. for i,var in enumerate(args[0]):
  102. global_size.append(var.max+1)
  103. lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
  104. elif args[1] == "local":
  105. for i,var in enumerate(args[0]):
  106. local_size.append(var.max+1)
  107. lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
  108. else:
  109. for var in args[0]:
  110. if not isinstance(var, NumNode): # TODO: why is this coming through?
  111. lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
  112. lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
  113. elif uop == UOps.ENDLOOP:
  114. if args[1] not in ["global", "local", "global+local"]:
  115. for var in reversed(args[0]):
  116. if not isinstance(var, NumNode): # TODO: why is this coming through?
  117. lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
  118. pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
  119. lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
  120. elif args[1] == "global+local":
  121. for i, var in enumerate(reversed(args[0])):
  122. lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
  123. elif args[1] == 'local':
  124. for i, var in enumerate(reversed(args[0])):
  125. lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
  126. elif uop == UOps.CAST:
  127. # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
  128. out = lang.newreg(u, dtype)
  129. for i,sr in enumerate(out.subregs()):
  130. lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
  131. elif uop == UOps.ALU:
  132. out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
  133. # this is the only thing that can violate SSA
  134. if args in [BinaryOps.CMPLT]:
  135. pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
  136. lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
  137. lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
  138. elif args == BinaryOps.DIV and lang.no_div:
  139. tmp = lang.newreg((u, "rcp"))
  140. lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
  141. lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
  142. elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
  143. tmp = lang.newreg((u, "2pi"))
  144. lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
  145. lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
  146. else:
  147. lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
  148. elif uop == UOps.DEFINE_ACC:
  149. reg = lang.newreg(u, dtype=dtype)
  150. lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
  151. elif uop == UOps.SPECIAL:
  152. lang.tor[u] = lang.tor[args]
  153. elif uop == UOps.CONST:
  154. lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
  155. elif uop == UOps.LOAD:
  156. idx, treg, off = lang.addr_w_offset(args)
  157. reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
  158. if args.valid.min == 0:
  159. lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
  160. if args.valid.max == 1:
  161. pred = args.valid.render(lang.render_ops, lang)
  162. lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
  163. if args.valid.max == 1:
  164. # NOTE: you can't compute the index in here, because it assumes it's all available later
  165. lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
  166. if args.valid.min == 0 and args.valid.max == 1:
  167. lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
  168. skipload_branch += 1
  169. elif uop == UOps.STORE:
  170. if args is None:
  171. lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
  172. else:
  173. idx, treg, off = lang.addr_w_offset(args)
  174. lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
  175. if DEBUG >= 4:
  176. for tins in lang.ins: print(tins)
  177. return global_size, local_size