| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
- from tinygrad.codegen.kernel import UOps, MemOp, UOp
- from tinygrad.ops import BinaryOps, UnaryOps
- from tinygrad.dtype import DType, dtypes
- from tinygrad.helpers import DEBUG
- from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
- import functools
- import math
- from collections import defaultdict
- _type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
- dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
- class Register(NamedTuple):
- nm:str
- dtype:DType
- scalar:bool
- off:Optional[int] = None
- def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
- def subregs(self):
- if self.dtype == dtypes.float.vec(4):
- return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
- return []
- class AssemblyInstruction(NamedTuple):
- op: UOps
- out: Optional[Register]
- vin: List[Union[Register, int, float]]
- arg: Any = None
- # warp size of 32, s registers are shared across the warp, v are 32-wide vectors
- class AssemblyLanguage:
- supports_load3: bool = False
- sin_is_sin2pi: bool = False
- no_div: bool = False
- #TODO: these should be global vars
- cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
- tor: Dict[Any, Register] = {}
- ins: List[AssemblyInstruction] = []
- def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
- def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
- self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
- if dtype == dtypes.float.vec(4):
- for off in range(4):
- self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
- self.cnts[(dtype, scalar)] += 1
- return ret
- def render_numnode(self, b) -> Register:
- key = ("num", b)
- if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
- return self.tor[key]
- def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
- key = (op, a, b)
- if key not in self.tor:
- #if not isinstance(b, Register): b = render_numnode(b)
- self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
- return self.tor[key]
- def render_cast(self, a:Register, new_dtype:DType) -> Register:
- if a.dtype == new_dtype: return a
- key = (a, new_dtype)
- if key not in self.tor:
- self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
- return self.tor[key]
- render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
- MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
- DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
- ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
- LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
- SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
- AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
- def addr_w_offset(self, args):
- assert isinstance(args, MemOp)
- idx = args.idx*args.memory_dtype.itemsize
- off = 0 # TODO: should this be None?
- if isinstance(idx, SumNode):
- nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
- if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
- idx -= nums[0]
- off = cast(int, nums[0])
- reg = idx.render(self.render_ops, self)
- if self.supports_load3:
- if reg.scalar:
- new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
- self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
- reg = new_reg
- return self.tor[args.name], reg, off
- reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
- return reg, None, off
- def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
- #TODO: Do not use clear()
- lang.ins.clear()
- lang.tor.clear()
- lang.cnts.clear()
- buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
- global_size, local_size = [], []
- skipload_branch = 0
- lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
- for u in uops:
- uop,dtype,vin,args,_ = u
- if uop == UOps.DEFINE_LOCAL:
- lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
- lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
- elif uop == UOps.LOOP:
- if args[1] == "global":
- for i,var in enumerate(args[0]):
- global_size.append(var.max+1)
- lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
- elif args[1] == "local":
- for i,var in enumerate(args[0]):
- local_size.append(var.max+1)
- lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
- else:
- for var in args[0]:
- if not isinstance(var, NumNode): # TODO: why is this coming through?
- lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
- lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
- elif uop == UOps.ENDLOOP:
- if args[1] not in ["global", "local", "global+local"]:
- for var in reversed(args[0]):
- if not isinstance(var, NumNode): # TODO: why is this coming through?
- lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
- pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
- lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
- elif args[1] == "global+local":
- for i, var in enumerate(reversed(args[0])):
- lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
- elif args[1] == 'local':
- for i, var in enumerate(reversed(args[0])):
- lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
- elif uop == UOps.CAST:
- # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
- out = lang.newreg(u, dtype)
- for i,sr in enumerate(out.subregs()):
- lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
- elif uop == UOps.ALU:
- out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
- # this is the only thing that can violate SSA
- if args in [BinaryOps.CMPLT]:
- pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
- lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
- lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
- elif args == BinaryOps.DIV and lang.no_div:
- tmp = lang.newreg((u, "rcp"))
- lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
- lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
- elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
- tmp = lang.newreg((u, "2pi"))
- lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
- lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
- else:
- lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
- elif uop == UOps.DEFINE_ACC:
- reg = lang.newreg(u, dtype=dtype)
- lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
- elif uop == UOps.SPECIAL:
- lang.tor[u] = lang.tor[args]
- elif uop == UOps.CONST:
- lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
- elif uop == UOps.LOAD:
- idx, treg, off = lang.addr_w_offset(args)
- reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
- if args.valid.min == 0:
- lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
- if args.valid.max == 1:
- pred = args.valid.render(lang.render_ops, lang)
- lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
- if args.valid.max == 1:
- # NOTE: you can't compute the index in here, because it assumes it's all available later
- lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
- if args.valid.min == 0 and args.valid.max == 1:
- lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
- skipload_branch += 1
- elif uop == UOps.STORE:
- if args is None:
- lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
- else:
- idx, treg, off = lang.addr_w_offset(args)
- lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
- if DEBUG >= 4:
- for tins in lang.ins: print(tins)
- return global_size, local_size
|