uopgraph.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. from __future__ import annotations
  2. from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TYPE_CHECKING
  3. import functools, itertools, heapq, math
  4. from collections import defaultdict
  5. from tinygrad.dtype import dtypes, DType, PtrDType, ImageDType
  6. from tinygrad.shape.symbolic import Variable
  7. from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu
  8. from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI
  9. from tinygrad.codegen.uops import UOp, UOps, END_FOR_UOP, type_verify
  10. from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
  11. if TYPE_CHECKING:
  12. from tinygrad.renderer import Renderer
  13. # *** simplification logic ***
  14. class UPat:
  15. def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
  16. name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False):
  17. self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
  18. self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
  19. self.arg = arg
  20. self.src: Any = None
  21. if isinstance(src, list):
  22. # try all permutations if it's a list
  23. self.src = list(itertools.permutations(src))
  24. elif isinstance(src, tuple):
  25. # only one if it's a tuple
  26. self.src = [src]
  27. elif isinstance(src, UPat):
  28. # repeat if it's a UPat
  29. self.src = [itertools.repeat(src)]
  30. self.name: Optional[str] = name
  31. self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
  32. @staticmethod
  33. def compile(u: UOp, name:Optional[str]=None) -> UPat:
  34. if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
  35. return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None,
  36. name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name))
  37. def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
  38. if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
  39. (pat.dtype is not None and uop.dtype is not None and uop.dtype not in pat.dtype) or \
  40. (pat.arg is not None and pat.arg != uop.arg) or \
  41. (pat.op is not None and uop.op not in pat.op): return []
  42. if pat.src is None: return [store]
  43. res: List[Dict[str, UOp]] = []
  44. for vp in pat.src:
  45. if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return []
  46. new_stores = [store.copy()]
  47. for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
  48. res.extend(new_stores)
  49. return res
  50. class PatternMatcher:
  51. def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
  52. self.patterns = patterns
  53. self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
  54. # uop is required, arg is optional
  55. for p,fxn in self.patterns:
  56. if isinstance(p, UOp): p = UPat.compile(p)
  57. assert p.op is not None
  58. for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
  59. @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
  60. def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
  61. def rewrite(self, uop:UOp) -> Optional[UOp]:
  62. for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
  63. if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
  64. return None
  65. # ***** image handling *****
  66. def fix_image_idx(ls:UOp):
  67. if ls.src[1].dtype is None or ls.src[1].dtype.count != 1: return None
  68. if not isinstance(ls.src[0].dtype, ImageDType): return None
  69. assert ls.op is not UOps.STORE or cast(DType, ls.src[2].dtype).count == 4, "image store must be float4"
  70. idxy = ls.src[1]
  71. #if not idxy.divides(4): raise RuntimeError("image index must divide 4")
  72. base_shape = ls.src[0].dtype.shape
  73. idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
  74. image_idx = UOp(UOps.VECTORIZE, cast(DType, idxy.dtype).vec(2), (idx, idy))
  75. if ls.op is UOps.LOAD and cast(DType, ls.dtype).count == 1:
  76. cconst = (UOp(UOps.VECTORIZE, cast(DType, ls.dtype).vec(4), src=(ls.src[3], ls.src[3], ls.src[3], ls.src[3])),) if len(ls.src) >= 3 else ()
  77. loaded = UOp(ls.op, cast(DType, ls.dtype).vec(4), (ls.src[0], image_idx) + ls.src[2:3] + cconst, ls.arg)
  78. subidx = idxy%4
  79. ret = UOp.const(ls.dtype, 0)
  80. for i in range(4): ret = UOp.alu(TernaryOps.WHERE, subidx.ne(i), ret, UOp(UOps.GEP, ls.dtype, (loaded,), i))
  81. return ret
  82. return UOp(ls.op, ls.dtype, (ls.src[0], image_idx) + ls.src[2:], ls.arg)
  83. # ***** float4 handling *****
  84. def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None):
  85. if len(ex.src) != 4: return None
  86. if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None
  87. if buf.dtype != PtrDType(dtypes.float) and not isinstance(buf.dtype, ImageDType): return None
  88. if idx2 is not None: idx = idx + idx2
  89. if not idx.divides(len(ex.src)): return None
  90. if load.dtype.scalar() != load.dtype: return None # how does this happen?
  91. vec_load = UOp(UOps.LOAD, load.dtype.vec(len(ex.src)), (buf, idx))
  92. return UOp(UOps.EXPAND, load.dtype, tuple(UOp(UOps.GEP, load.dtype, (vec_load,), i) for i in range(len(ex.src))), ex.arg)
  93. def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtypes.int, 0), idx2=None, idx3=None):
  94. if len(ex.src) not in [2, 4]: return None
  95. if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None
  96. if buf.dtype != PtrDType(dtypes.float) and not isinstance(buf.dtype, ImageDType): return None
  97. if idx2 is not None: idx = idx + idx2
  98. if idx3 is not None: idx = idx + idx3
  99. if not idx.divides(len(ex.src)): return None
  100. new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg[0][0],))
  101. return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:])
  102. def no_float4_alu(alu):
  103. if alu.dtype.count == 1: return None
  104. alus = tuple(UOp(UOps.ALU, alu.dtype.scalar(),
  105. tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
  106. return UOp(UOps.VECTORIZE, alu.dtype, alus)
  107. float4_folding = PatternMatcher([
  108. (UOp(UOps.STORE, dtype=dtypes.float, src=(UOp.var("buf"), UOp.var("idx")+
  109. (UOp(UOps.EXPAND, src=tuple(UOp.const(dtypes.int, i) for i in range(4))).name("ex")+UOp.var("idx2")), UOp.var("var"))).name("store"),
  110. lambda buf, store, idx, idx2, ex, var: UOp(UOps.STORE, store.dtype, (buf, idx+idx2+ex, var), store.arg)),
  111. # float(2,4) load
  112. (UOp(UOps.LOAD, dtype=dtypes.float, src=(UOp.var("buf"),
  113. UOp(UOps.EXPAND).name("ex")+UOp.var("idx")+UOp.var("idx2"))).name("load"),
  114. float4_expand_load),
  115. (UOp(UOps.LOAD, dtype=dtypes.float, src=(UOp.var("buf"),
  116. UOp(UOps.EXPAND).name("ex")+UOp.var("idx"))).name("load"), float4_expand_load),
  117. (UOp(UOps.LOAD, dtype=dtypes.float, src=(UOp.var("buf"),
  118. UOp(UOps.EXPAND).name("ex"))).name("load"), float4_expand_load),
  119. # float(2,4) store
  120. # TODO: fold ADDs into one UOp and remove add chains
  121. (UOp(UOps.STORE, src=(UOp.var("buf"),
  122. UOp(UOps.EXPAND).name("ex")+UOp.var("idx")+UOp.var("idx2")+UOp.var("idx3"), UOp.var("var"))).name("store_allow_any_len"),
  123. float4_contract_store),
  124. (UOp(UOps.STORE, src=(UOp.var("buf"),
  125. UOp(UOps.EXPAND).name("ex")+UOp.var("idx")+UOp.var("idx2"), UOp.var("var"))).name("store_allow_any_len"),
  126. float4_contract_store),
  127. (UOp(UOps.STORE, src=(UOp.var("buf"),
  128. UOp(UOps.EXPAND).name("ex")+UOp.var("idx"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
  129. (UOp(UOps.STORE, src=(UOp.var("buf"),
  130. UOp(UOps.EXPAND).name("ex"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
  131. # no ALU on float4 (float4 constructor doesn't work in METAL/GPU)
  132. (UOp(UOps.ALU).name("alu"), no_float4_alu),
  133. ])
  134. # ***** transcendental *****
  135. transcendental_folding = PatternMatcher([
  136. (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="x"),), arg=UnaryOps.EXP2), xexp2),
  137. (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.LOG2), xlog2),
  138. (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.SIN), xsin),
  139. ])
  140. # ***** threefry *****
  141. def threefry2x32(x: UOp, seed: UOp):
  142. # split x into two uint32, since x in a uint64
  143. x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
  144. rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
  145. ks = [0x0, (seed := seed.cast(dtypes.uint32)) ^ 0x1BD11BDA, seed]
  146. xr = [x0 + ks[-1], x1 + ks[0]]
  147. for i in range(5):
  148. for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
  149. xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
  150. return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
  151. # ***** main rewriter *****
  152. def reduce_before_expand(reduce_allow_any_len, expand, x):
  153. red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce_allow_any_len.src[1:], reduce_allow_any_len.arg)
  154. gep = tuple(UOp(UOps.GEP, reduce_allow_any_len.dtype, (red,), i) for i in range(x.dtype.count))
  155. return UOp(expand.op, expand.dtype, gep, expand.arg)
  156. def sum_collapse(phi_input, loop, val1, val2):
  157. for v1,v2 in [(val1, val2), (val2, val1)]:
  158. if loop not in v1.parents:
  159. loop_range = loop.src[1]-loop.src[0]
  160. ret = v1*loop_range.cast(v1.dtype)
  161. return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
  162. return None
  163. def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
  164. if getenv("DISABLE_LOOP_COLLAPSE") or not rng.arg[1]: return None # must be a REDUCE
  165. if mval.arg >= 0 or loop_start.arg != 0:
  166. # TODO: support and test this with other mvals and loop_starts
  167. if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
  168. return None
  169. comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
  170. return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
  171. # this is symbolic 2.0
  172. constant_folder = PatternMatcher([
  173. # CONTRACT before REDUCE
  174. (UPat(UOps.CONTRACT, name="con", src=UPat(UOps.REDUCE, name="red")),
  175. lambda con, red: UOp(UOps.REDUCE, con.dtype, (UOp(UOps.CONTRACT, con.dtype, red.src[0:1], con.arg),)+red.src[1:], red.arg)),
  176. # bigint is rewritten to int32
  177. (UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.bigint, name="x"),
  178. lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
  179. # VECTORIZE/GEP
  180. (UOp(UOps.GEP, src=(UOp(UOps.VECTORIZE).name("cast"),)).name("gep"), lambda gep, cast: cast.src[gep.arg]),
  181. *[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j)
  182. for j in range(i))), lambda x: x) for i in [2, 4, 8]],
  183. # tensor core with a 0 input is acc
  184. (UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
  185. (UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
  186. # tensor core cleanups
  187. (UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))).name("expand"),))
  188. .name("reduce_allow_any_len"), reduce_before_expand),
  189. (UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))).name("expand"),))
  190. .name("reduce_allow_any_len"), reduce_before_expand),
  191. (UOp.var("add") + UOp(UOps.WMMA).name("wmma"),
  192. lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
  193. # threefry
  194. (UOp(UOps.ALU, dtype=dtypes.uint64, src=(UOp.var("x"), UOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
  195. # arange loop folding (early)
  196. (UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(BinaryOps.MUL,
  197. UOp.cvar("mval"), UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
  198. UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None,0)), loop_collapse),
  199. (UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(UnaryOps.NEG,
  200. UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
  201. UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None, 0)),
  202. lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
  203. # sum collapse to mul (with possible GEP)
  204. (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
  205. UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
  206. (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
  207. UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
  208. # deal with UNMUL
  209. (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
  210. lambda c1,c2,v: v if c1.arg == c2.arg else None),
  211. (UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
  212. (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
  213. # indexing (with a multiply offset)!
  214. (UOp.var('idx').eq(UOp(UOps.RANGE).name("rng")).where(
  215. UOp(UOps.LOAD, src=(UOp.var("buf"), UOp.var('add')+UOp.var('mul')*UOp(UOps.RANGE).name("rng"))).name("ld"), UOp.const(None, 0.0)),
  216. lambda idx,rng,buf,add,mul,ld: UOp(UOps.UNMUL, ld.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx)), rng.src[1]-rng.src[0]))),
  217. # other arange folders
  218. (UOp.cvar("c1") - (UOp.var("x") + UOp.cvar("c2")), lambda c1, c2, x: (c1-c2)-x), # c1 - (x + c2) -> (c1-c2) - x
  219. # max on special can go away (TODO: special should be variable, same thing applies)
  220. (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
  221. (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')+UOp.cvar('c2')), lambda c,s,c2: (s+c2) if 0 >= c.arg else None), # TODO: generic
  222. (UOp.max(UOp.cvar('c'), -(UOp(UOps.SPECIAL).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.arg[2]-1+c2.arg) >= c.arg else None),
  223. # max on range can go away (ugh: copy of SPECIAL, and with/without const)
  224. (UOp.max(UOp.cvar('c'), UOp(UOps.RANGE).name('s')), lambda c,s: s if s.src[0].arg >= c.arg else None), # TODO: generic
  225. (UOp.max(UOp.cvar('c'), UOp(UOps.RANGE).name('s')+UOp.cvar('c2')), lambda c,s,c2: (s+c2) if s.src[0].arg >= c.arg else None), # TODO: generic
  226. (UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s'))), lambda c,s: -s if -(s.src[1].arg-1) >= c.arg else None),
  227. (UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.src[1].arg-1+c2.arg) >= c.arg else None),
  228. # const rules
  229. (UOp(UOps.GEP, src=(UOp.cvar("c"),)).name("root"), lambda root, c: UOp.const(root.dtype, c.arg)),
  230. (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
  231. (UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
  232. # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
  233. (UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC).name("acc"), UOp.var("acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
  234. (UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)), UOp.var("x"))), lambda x: x),
  235. (UOp(UOps.PHI, src=(UOp.cvar(), UOp.var("x"))), lambda x: x),
  236. # a DEFINE_ACC without inputs is a const + GEP on a const is the const
  237. (UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)).name("root"), lambda root: UOp.cast(root.src[0], root.dtype)),
  238. (UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: UOp.const(root.dtype, x.arg)),
  239. # max -2147483648
  240. (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
  241. # bool < False is always false, True < bool is always false
  242. (UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
  243. (UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
  244. # a conditional with the same results either way is a noop, also fold const conditionals
  245. (UOp.var().where(UOp.var("val"), UOp.var("val")), lambda val: val),
  246. (UOp.cvar('gate').where(UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
  247. # ** constant folding **
  248. (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
  249. # ** self folding **
  250. (-(-UOp.var('x')), lambda x: x), # -(-x) -> x
  251. (UOp.var('x') + 0, lambda x: x), # x+0 -> x
  252. (UOp.var('x') * 1, lambda x: x), # x*1 -> x
  253. (UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
  254. (UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
  255. (UOp.var('x') // 1, lambda x: x), # x//1 -> x
  256. (UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
  257. (UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
  258. (UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
  259. (UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
  260. # ** zero folding **
  261. #x*0 -> 0 or 0*x -> 0
  262. #if x is nan or inf it should render the nan value.
  263. # NOTE: this can be wrong for loaded NaN
  264. (UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
  265. (UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
  266. # ** load/store folding **
  267. (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
  268. # ** two stage add/sub folding **
  269. ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
  270. ((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
  271. # *** rules from symbolic ***
  272. # two stage mul, (x*c1)*c2 = x*(c1*c2)
  273. ((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
  274. # -(x+y) -> -x + -y
  275. #(-(UOp.var("x") + UOp.var("y")), lambda x,y: (-x)+(-y)),
  276. # x%1 -> 0
  277. (UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
  278. # (x*c0)+(x*c1) -> x*(c0+c1)
  279. (UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
  280. # (x*c0)+(y*c0) -> (x+y)*c0
  281. #((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
  282. # (x*c0)//c0 -> x
  283. ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
  284. # (x*x2)/x2 -> x
  285. ((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
  286. # (x//c0)//c1 -> x//(c0*c1)
  287. ((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
  288. # (x/x1)/x2 -> x/(x1*x2)
  289. ((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
  290. # c0 + x < c1 -> x < c1 - c0
  291. ((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
  292. lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
  293. # (x+x*c0)-> x*(c0+1)
  294. (UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
  295. # x!=0 -> (bool)x
  296. (UOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
  297. # bool != 1 -> not bool
  298. (UOp.var("x", dtype=dtypes.bool).ne(1), lambda x: -x),
  299. # TODO: can do the invert of this (flip alt/load) when we fix double ops
  300. (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
  301. lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
  302. # VECTORIZE-PHI-GEP -> PHI-VECTORIZE
  303. (UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
  304. lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
  305. (UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
  306. lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
  307. # NEG/CMPLT -> CMPLT
  308. (UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
  309. # cast NOOP (NOTE: it's str to deal with PtrDType)
  310. (UOp(UOps.CAST).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
  311. (UOp(UOps.VECTORIZE).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
  312. # fold gated LOAD/STORE
  313. (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
  314. (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),
  315. lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
  316. (UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var")), lambda var: var),
  317. (UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var"), UOp.var()), lambda var: var),
  318. (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.bool, True)), UOp.store),
  319. (UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
  320. # remove NOOPs from SINK
  321. (UOp(UOps.SINK).name("root"),
  322. lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
  323. ])
  324. # *** uop expander ***
  325. def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> int:
  326. idx, mul = 0, 1
  327. for axis,m in args[::-1]:
  328. idx += rpk[axis] * mul
  329. mul *= m
  330. return idx
  331. def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
  332. return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
  333. def do_expand(root:UOp):
  334. if root.op is UOps.REDUCE:
  335. if root.src[0].op is not UOps.EXPAND: return None
  336. reduce_expand_args = flatten([x.arg for x in root.src[1:] if x.op is UOps.EXPAND])
  337. expand_args = tuple(x for x in root.src[0].arg if x not in reduce_expand_args)
  338. if len(expand_args) == 0: return None
  339. dont_expand_args = tuple(x for x in root.src[0].arg if x in reduce_expand_args)
  340. else:
  341. expands = [x for x in root.src if x.op is UOps.EXPAND]
  342. if len(expands) == 0: return None
  343. expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
  344. if root.op is UOps.WMMA:
  345. dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in root.arg[-2])
  346. expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
  347. else:
  348. dont_expand_args = ()
  349. new_srcs: List[UOp] = []
  350. lrpks = _choices_from_args(dont_expand_args)
  351. for rpk in _choices_from_args(expand_args):
  352. new_src: List[UOp] = []
  353. for src in root.src:
  354. if src.op is UOps.EXPAND:
  355. lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks]
  356. if len(dont_expand_args):
  357. # TODO: is this right for UOps.WMMA? all lnew_src should be the same
  358. new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
  359. else:
  360. assert len(lnew_src) == 1
  361. new_src.append(lnew_src[0])
  362. else:
  363. new_src.append(src)
  364. new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
  365. if root.op is UOps.EXPAND:
  366. expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
  367. assert len(expand_args) == (len(old_args) + len(root.arg))
  368. new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)]
  369. assert prod([x[1] for x in expand_args]) == len(new_srcs)
  370. return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
  371. acc_number = 0
  372. def do_reduce_with_expand(root):
  373. global acc_number
  374. expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
  375. expands_reduce = [x for x in expands if root.src[0].op is UOps.EXPAND and all(y in root.src[0].arg for y in x.arg)]
  376. expands_non_reduce = [x for x in expands if x not in expands_reduce]
  377. const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
  378. ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
  379. acc_number += 1
  380. if len(expands_reduce):
  381. assert root.src[0].op is UOps.EXPAND
  382. expand_reduce_args = dedup(flatten([x.arg for x in expands_reduce]))
  383. assert prod([y[1] for y in expand_reduce_args]) == len(root.src[0].src)
  384. for xx in root.src[0].src:
  385. ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx)
  386. else:
  387. ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, root.src[0])
  388. ret = UOp(UOps.PHI, ret.dtype, (acc, ret))
  389. if len(expands_non_reduce): ret = ret * prod([sz for _,sz in flatten([x.arg for x in expands_non_reduce])])
  390. return ret
  391. def do_contract(con:UOp):
  392. ex = con.src[0]
  393. assert con.dtype is not None
  394. # CONTRACT without EXPAND repeats the element VECTORIZED
  395. if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
  396. # simple CONTRACT and EXPAND cancel out
  397. if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg[0][0] in con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src)
  398. # complex CONTRACT may only remove one axis from EXPAND
  399. assert len(con.arg) == 1, "contract arg one is all that's supported"
  400. try:
  401. split_index = [x[0] for x in ex.arg].index(con.arg[0])
  402. except ValueError:
  403. # CONTRACT without EXPAND (still) repeats the element VECTORIZED
  404. return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
  405. assert con.dtype.count == ex.arg[split_index][1], "contract arg must match"
  406. number_after = prod([x[1] for x in ex.arg[split_index+1:]])
  407. to_join = [ex.src[i:i+number_after] for i in range(0, len(ex.src), number_after)]
  408. srcs = []
  409. for i in range(0, len(to_join), con.dtype.count):
  410. srcs += [UOp(UOps.VECTORIZE, con.dtype, tuple(src)) for src in zip(*to_join[i:i+con.dtype.count])]
  411. return UOp(UOps.EXPAND, con.dtype, tuple(srcs), tuple(x for x in ex.arg if x[0] != con.arg[0]))
  412. expander = PatternMatcher([
  413. (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
  414. UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
  415. (UOp(UOps.REDUCE).name("root"), do_reduce_with_expand),
  416. (UOp(UOps.CONTRACT).name("con"), do_contract),
  417. # remove EXPANDs from SINK
  418. (UOp(UOps.SINK).name("root"),
  419. lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
  420. if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
  421. # BARRIERs aren't actually expanded
  422. (UOp(UOps.BARRIER, src=(UOp(UOps.EXPAND).name("ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)),
  423. # image indexing (needs to be here)
  424. (UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
  425. # empty EXPAND is NOOP
  426. (UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
  427. ])
  428. # *** uop graph ***
  429. def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
  430. if u in children: return
  431. children[u] = []
  432. for x in u.src:
  433. get_children_dfs(x, children, in_degree)
  434. children[x].append(u)
  435. in_degree[u] = len(u.src)
  436. def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
  437. nodes: Dict[Tuple, UOp] = {}
  438. replace: Dict[UOp, UOp] = {}
  439. def __inner_rewrite(n:UOp) -> UOp:
  440. if n in replace: return replace[n]
  441. replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
  442. if found := nodes.get(replace_source): replace[n] = found
  443. else: nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
  444. return found
  445. return __inner_rewrite(sink)
  446. class UOpGraph:
  447. def __init__(self, sink:Union[UOp, List[UOp]], opts:Optional[Renderer]=None):
  448. self.sink: UOp = sink if isinstance(sink, UOp) else UOp(UOps.SINK, None, tuple(sink))
  449. assert self.sink.op is UOps.SINK, f"sink isn't sink, it's {self.sink.op}"
  450. # used by linearizer
  451. self._uops: Optional[List[UOp]] = None
  452. self.opts = opts
  453. self.folder = constant_folder if opts is None or not opts.supports_float4 else (constant_folder+float4_folding)
  454. if TRANSCENDENTAL >= 2 or (opts is not None and TRANSCENDENTAL >= 1 and opts.device in {"CLANG", "LLVM"}):
  455. self.folder = self.folder + transcendental_folding
  456. def __reduce__(self): return self.__class__, (self.sink, self.opts)
  457. def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
  458. def __getitem__(self, index) -> UOp: return self.uops[index]
  459. def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
  460. def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
  461. @property
  462. def uops(self) -> List[UOp]:
  463. if self._uops is None: self.linearize()
  464. return cast(List[UOp], self._uops)
  465. def graph(self):
  466. from tinygrad.engine.graph import graph_uops
  467. graph_uops(self.uops)
  468. def print(self):
  469. for i,u in enumerate(self):
  470. formatted_parents = [self.uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
  471. print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
  472. cnt = 0
  473. def linearize(self, extra_pm:Optional[PatternMatcher]=None):
  474. global acc_number
  475. acc_number = 0
  476. # NOTE: relinearizering should be okay
  477. #assert self._uops is None, "already linearized"
  478. # fixup gated stores with an IF block to save extra local loads
  479. @functools.lru_cache(None)
  480. def _dfs(u:UOp, gate:UOp) -> UOp:
  481. if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
  482. if_uop = UOp(UOps.IF, None, (gate, u.src[-1]))
  483. return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg)
  484. if (replace_source:=tuple(_dfs(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
  485. return u
  486. sink_srcs = list(self.sink.src)
  487. for i, s in enumerate(sink_srcs):
  488. # breaks for WMMA
  489. if all(x.op is not UOps.WMMA for x in s.parents):
  490. if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s:
  491. sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
  492. sink = UOp(UOps.SINK, None, tuple(sink_srcs))
  493. # do graph rewrite
  494. sink = graph_rewrite(sink, self.folder)
  495. # expand
  496. UOpGraph.cnt += 1
  497. if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander+self.folder)
  498. # for PTX only
  499. if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)
  500. # filter nodes that don't link to a sink
  501. # BFS toposort
  502. children: Dict[UOp, List[UOp]] = {}
  503. in_degree: Dict[UOp, int] = {}
  504. get_children_dfs(sink, children, in_degree)
  505. @functools.lru_cache(None)
  506. def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
  507. if x.op is UOps.SINK: return set()
  508. return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
  509. # scope children impact the toposort and END* insertion
  510. scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
  511. queue:List[Tuple[int, UOp]] = []
  512. def push(u:UOp):
  513. priority = 0
  514. # prefer uops that are loop children
  515. for l, ss in scope_children.items():
  516. if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
  517. heapq.heappush(queue, (priority, u))
  518. for u in children:
  519. if in_degree[u] == 0: push(u)
  520. scope_end: Dict[UOp, UOp] = {}
  521. self._uops = []
  522. while queue:
  523. p,x = heapq.heappop(queue)
  524. if DEBUG >= 7: print(p,x)
  525. if x in scope_children: scope_end[x] = x
  526. if x.op is UOps.DEFINE_ACC:
  527. idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
  528. self._uops.insert(idx, x)
  529. else: self._uops.append(x)
  530. for u, ss in scope_children.items():
  531. if x in ss:
  532. ss.remove(x)
  533. if len(ss) == 0: scope_end[u] = x
  534. for u in children[x]:
  535. in_degree[u] -= 1
  536. if in_degree[u] == 0: push(u)
  537. # end scopes in toposort order
  538. for u, x in scope_end.items(): self._uops.insert(self._uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
  539. # sanity checks (NOTE: these can cause things to be skipped in BEAM)
  540. bad_ops = dedup([x.op for x in self._uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE, UOps.UNMUL}])
  541. try:
  542. type_verify(self.uops)
  543. assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
  544. assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
  545. # TODO: this should be enabled, and the valid clause should be removed
  546. # NOTE: multiple identical stores to DEFINE_LOCAL is okay
  547. assert len(all_stores := [x.src[0:2]+x.src[3:] for x in self._uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \
  548. == len(dedup(all_stores)), "repeated stores in uops"
  549. except AssertionError as e:
  550. self.print()
  551. if not CI: self.graph()
  552. raise e
  553. # strip the SINK
  554. self._uops = self._uops[:-1]
  555. if getenv("FUZZ_UOPS"):
  556. from test.external.fuzz_uops import fuzz_uops
  557. self._fuzz_paths = fuzz_uops(self)