ops.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from __future__ import annotations
  2. from typing import Union, Tuple, Any, List, Dict, Callable
  3. import functools, hashlib, math, operator, ctypes, struct
  4. from enum import Enum, auto
  5. from dataclasses import dataclass
  6. from tinygrad.helpers import prod, dedup
  7. from tinygrad.dtype import dtypes, DType, ConstType
  8. from tinygrad.shape.symbolic import Variable, sint
  9. from tinygrad.shape.shapetracker import ShapeTracker
  10. # these are the llops your accelerator must implement, along with toCpu
  11. # the Enum class doesn't work with mypy, this is static. sorry it's ugly
  12. # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
  13. # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
  14. class UnaryOps(Enum):
  15. """A -> A (elementwise)"""
  16. EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
  17. class BinaryOps(Enum):
  18. """A + A -> A (elementwise)"""
  19. ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
  20. SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
  21. class TernaryOps(Enum):
  22. """A + A + A -> A (elementwise)"""
  23. WHERE = auto(); MULACC = auto() # noqa: E702
  24. class ReduceOps(Enum):
  25. """A -> B (reduce)"""
  26. SUM = auto(); MAX = auto(); WMMA = auto() # noqa: E702
  27. class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
  28. class MetaOps(Enum):
  29. EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); KERNEL = auto() # noqa: E702
  30. Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
  31. # do not preserve f(0) = 0
  32. UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
  33. @dataclass(frozen=True)
  34. class MemBuffer:
  35. idx: int
  36. dtype: DType
  37. st: ShapeTracker
  38. @dataclass(frozen=True)
  39. class ConstBuffer:
  40. val: ConstType | Variable
  41. dtype: DType
  42. st: ShapeTracker
  43. @dataclass(frozen=True)
  44. class KernelInfo:
  45. local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
  46. upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
  47. @dataclass(frozen=True, eq=False)
  48. class LazyOp:
  49. op: Op
  50. src: Tuple[LazyOp, ...] = ()
  51. arg: Any = None
  52. def cached_compare(self, x, context):
  53. if id(self) == id(x): return True
  54. if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
  55. if (key := (id(self), id(x))) in context: return context[key]
  56. ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
  57. return ret
  58. def __eq__(self, x): return self.cached_compare(x, context={})
  59. def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
  60. @functools.cached_property
  61. def dtype(self) -> DType:
  62. if self.op in BufferOps: return self.arg.dtype
  63. if self.op is ReduceOps.WMMA: return self.arg[3] # WMMA can change the type
  64. if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
  65. return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
  66. @functools.cached_property
  67. def full_shape(self):
  68. if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
  69. return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
  70. @functools.cached_property
  71. def key(self) -> bytes:
  72. return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
  73. @functools.cached_property
  74. def hash(self): return hash((self.op, self.src, self.arg))
  75. def __hash__(self): return self.hash
  76. @functools.cached_property
  77. def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
  78. def vars(self) -> List[Variable]:
  79. extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
  80. const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
  81. return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
  82. # TODO: support non-lazyop
  83. def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x))
  84. def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x))
  85. def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x))
  86. def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x))
  87. def eq(self, x:LazyOp): return -self.ne(x)
  88. def __neg__(self): return LazyOp(UnaryOps.NEG, (self,))
  89. @staticmethod
  90. def const(val, dtype:DType, shape:Tuple[sint, ...]):
  91. return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
  92. # **************** independent FlopCounter ****************
  93. @dataclass
  94. class FlopCounter:
  95. shape: Tuple[int, ...]
  96. flops: sint
  97. mem: Dict[int, int]
  98. @property
  99. def mem_estimate(self): return sum(self.mem.values())
  100. def consume_flops(self):
  101. self.flops, ret = 0, self.flops
  102. return ret
  103. InterpretedFlopCounter: Dict[Op, Callable] = {
  104. BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
  105. BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
  106. BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
  107. UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
  108. UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
  109. **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
  110. **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
  111. **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
  112. TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
  113. @functools.lru_cache(None)
  114. def get_lazyop_info(ast:LazyOp) -> FlopCounter:
  115. @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
  116. def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
  117. return run_ast(ast)
  118. # **************** ops in python ****************
  119. def hook_overflow(dv, fxn):
  120. def wfxn(*args):
  121. try: return fxn(*args)
  122. except OverflowError: return dv
  123. return wfxn
  124. python_alu: Dict[Op, Callable] = {
  125. UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
  126. UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
  127. UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
  128. UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
  129. UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
  130. UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
  131. BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
  132. BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
  133. BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
  134. BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
  135. BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
  136. TernaryOps.MULACC: lambda x,y,z: (x*y)+z,
  137. TernaryOps.WHERE: lambda x,y,z: y if x else z}
  138. def truncate_fp16(x):
  139. try:
  140. x = float(x)
  141. struct.pack("@e", x)
  142. return x
  143. except OverflowError: return math.copysign(math.inf, x)
  144. truncate: Dict[DType, Callable] = {dtypes.bool: bool,
  145. # TODO: bfloat16
  146. dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
  147. dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
  148. dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
  149. dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
  150. if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value, dtypes.bigint: lambda x: x }
  151. def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
  152. def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
  153. # the living definition of LazyOps
  154. def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
  155. assert ast.op is MetaOps.KERNEL, "must be SINK"
  156. sts: Dict[LazyOp, ShapeTracker] = {}
  157. def dfs(op:LazyOp, st:ShapeTracker):
  158. if op in sts: return
  159. # restore globals from the two stage reduce
  160. if op.op is BufferOps.LOAD and op.arg.idx == -1:
  161. dfs(local_reduce:=op.src[0].src[0], op.arg.st)
  162. return sts.setdefault(op, sts[local_reduce])
  163. for x in op.src: dfs(x, st)
  164. # only reduceop is allowed to change shape, limited to turning n to 1
  165. if op.op in ReduceOps:
  166. axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
  167. assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
  168. st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
  169. else:
  170. # movementops are pushed to the edges with LOAD
  171. if op.op in BufferOps: st = op.arg.st
  172. else: st = sts[op.src[0]]
  173. for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
  174. sts[op] = st
  175. for i, out in enumerate(ast.src):
  176. assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
  177. assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
  178. assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
  179. dfs(out, out.arg.st)
  180. return sts