uops.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from __future__ import annotations
  2. from typing import Optional, Tuple, Any, Set, cast, List, Union
  3. import functools
  4. from enum import Enum, auto
  5. from dataclasses import dataclass
  6. from tinygrad.dtype import ConstType, dtypes, DType
  7. from tinygrad.shape.symbolic import sint, Variable
  8. from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
  9. from tinygrad.helpers import prod
  10. # the order of these UOps controls the order of the toposort
  11. class UOps(Enum):
  12. # ops that aren't rendered
  13. SINK = auto(); VAR = auto(); EXPAND = auto(); CONTRACT = auto() # noqa: E702
  14. DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
  15. CONST = auto(); SPECIAL = auto() # noqa: E702
  16. NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
  17. # math ops
  18. CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
  19. ALU = auto(); REDUCE = auto(); WMMA = auto() # noqa: E702
  20. # memory/assignment ops
  21. LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
  22. # control flow ops
  23. BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
  24. # these two are not graph nodes
  25. ENDRANGE = auto(); ENDIF = auto() # noqa: E702
  26. END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
  27. def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
  28. @dataclass(frozen=True, eq=False)
  29. class UOp:
  30. op: UOps
  31. dtype: Optional[DType] = None
  32. src: Tuple[UOp, ...] = tuple()
  33. arg: Any = None
  34. def commutative(self) -> bool:
  35. return self.op is UOps.ALU and \
  36. self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
  37. @functools.cached_property
  38. def cmp_tuple(self):
  39. # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
  40. return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
  41. self.arg.value, self.dtype, self.src)
  42. def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
  43. def __repr__(self):
  44. return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
  45. def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
  46. def bitcast(self, dtype=None): return UOp(UOps.BITCAST, dtype, (self,))
  47. def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
  48. def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
  49. def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
  50. def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
  51. def __sub__(self, x): return UOp.alu(BinaryOps.ADD, self, -ufix(self.dtype, x))
  52. def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
  53. def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
  54. def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
  55. def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
  56. def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
  57. def __xor__(self, x): return UOp.alu(BinaryOps.XOR, self, ufix(self.dtype, x))
  58. def __and__(self, x): return UOp.alu(BinaryOps.AND, self, ufix(self.dtype, x))
  59. def __or__(self, x): return UOp.alu(BinaryOps.OR, self, ufix(self.dtype, x))
  60. def ne(self, x): return UOp.alu(BinaryOps.CMPNE, self, ufix(self.dtype, x))
  61. def eq(self, x): return -self.ne(x)
  62. def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
  63. def ge(self, x): return -self.lt(x)
  64. def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
  65. def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
  66. def where(self, x, y): return UOp.alu(TernaryOps.WHERE, self, x, y)
  67. def recip(self): return UOp.alu(UnaryOps.RECIP, self)
  68. def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
  69. @staticmethod
  70. @functools.lru_cache(maxsize=None)
  71. def _const(dtype:Optional[DType], b:ConstType|Variable):
  72. if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
  73. return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
  74. @staticmethod
  75. def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
  76. @staticmethod
  77. def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
  78. @staticmethod
  79. def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values()))
  80. @staticmethod
  81. def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
  82. @staticmethod
  83. def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
  84. @functools.cached_property
  85. def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
  86. @property # parents with self
  87. def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
  88. def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
  89. def divides(self, v):
  90. if self.op is UOps.CONST:
  91. return self.arg%v == 0
  92. if self.op is UOps.ALU:
  93. if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
  94. if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
  95. return False # generic false if we aren't sure
  96. def type_verify(uops):
  97. for u in uops:
  98. uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
  99. if uop in {UOps.CONST, UOps.DEFINE_ACC}:
  100. if uop is UOps.DEFINE_ACC:
  101. assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
  102. arg = src[0].arg
  103. assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
  104. if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
  105. if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count
  106. if uop is UOps.VECTORIZE:
  107. assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
  108. assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
  109. if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
  110. if uop is UOps.STORE:
  111. assert dtype is None, f"{uop} dtype must be None, got {dtype}"
  112. if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
  113. if uop is UOps.ALU:
  114. if arg in UnaryOps:
  115. assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
  116. elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
  117. assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
  118. assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
  119. elif arg is BinaryOps.IDIV:
  120. assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
  121. f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
  122. assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
  123. elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
  124. # the distance to shift isn't typechecked
  125. assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
  126. elif arg in BinaryOps:
  127. assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
  128. elif arg == TernaryOps.WHERE:
  129. assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
  130. assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
  131. def uop_alu_resolve(u:UOp) -> sint:
  132. if u.op is UOps.SPECIAL: return u.arg[2]-1
  133. if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
  134. if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
  135. raise RuntimeError(f"ALU resolve fail @ {u.op}")
  136. def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
  137. flops: sint = 0
  138. mem: sint = 0
  139. mults: sint = 1
  140. mult_stack: List[sint] = []
  141. dont_count: Set[UOp] = set()
  142. if ignore_indexing:
  143. for u in uops:
  144. if u.op is UOps.LOAD:
  145. dont_count = dont_count.union(u.src[1].sparents)
  146. if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
  147. elif u.op is UOps.STORE:
  148. dont_count = dont_count.union(u.src[1].sparents)
  149. if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
  150. for u in uops:
  151. if u.op is UOps.RANGE:
  152. mult_stack.append(mults)
  153. mults *= uop_alu_resolve(u.src[1])
  154. elif u.op is UOps.ENDRANGE:
  155. mults = mult_stack.pop(-1)
  156. elif u.op is UOps.LOAD:
  157. assert u.dtype is not None
  158. mem += u.dtype.itemsize * mults
  159. elif u.op is UOps.STORE:
  160. assert u.src[2].dtype is not None
  161. mem += u.src[2].dtype.itemsize * mults
  162. elif u.op is UOps.ALU and u not in dont_count:
  163. flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
  164. elif u.op is UOps.WMMA and u not in dont_count:
  165. assert u.arg[1] is not None
  166. flops += 2 * prod(u.arg[1]) // 32 * mults
  167. return flops, mem