assembly.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
  2. import struct, math
  3. from collections import defaultdict
  4. from tinygrad.helpers import DEBUG
  5. from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
  6. from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
  7. from tinygrad.codegen.uops import UOps, UOp
  8. from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, UPat
  9. from tinygrad.renderer import Renderer, TensorCore
  10. def render_val(x, dtype):
  11. if dtypes.is_float(dtype):
  12. if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
  13. if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
  14. return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
  15. return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
  16. class PTXRenderer(Renderer):
  17. device = "CUDA"
  18. suffix = "PTX"
  19. global_max = (2147483647, 65535, 65535)
  20. local_max = (1024, 1024, 64)
  21. shared_max = 49152
  22. tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
  23. def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
  24. # language options
  25. kernel_prefix = """.version VERSION
  26. .target TARGET
  27. .address_size 64
  28. .visible .entry"""
  29. barrier = "bar.sync\t0;"
  30. gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
  31. gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
  32. lid = [f'%tid.{chr(120+i)}' for i in range(3)]
  33. asm_for_op: Dict[Op, Callable] = {
  34. UnaryOps.NEG: lambda d,a,dt,name:
  35. f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) else f"neg.{name} {d}, {a};",
  36. UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
  37. UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
  38. UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
  39. BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
  40. BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
  41. BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
  42. BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
  43. BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
  44. BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
  45. BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
  46. BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
  47. BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
  48. TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
  49. TernaryOps.WHERE: lambda d,a,b,c,dt,name:
  50. f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
  51. }
  52. supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
  53. # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
  54. types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
  55. dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
  56. dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
  57. mem_types: Dict[DType, str] = types.copy()
  58. mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
  59. const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
  60. def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
  61. val = render_val(x, dtype)
  62. if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
  63. return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
  64. def render_local(self, dest, name, size, dtype) -> List[str]:
  65. return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
  66. def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
  67. def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"]
  68. def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
  69. assert dtype != dtypes.bool
  70. if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
  71. return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
  72. def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
  73. return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_types[dtype]} [{loc}+{offset}], {val};"]
  74. def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
  75. if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
  76. if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
  77. if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
  78. rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
  79. '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
  80. return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
  81. def render_kernel(self, kernel, function_name, bufs, regs) -> str:
  82. kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
  83. def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
  84. return (f"{self.kernel_prefix} {function_name}(\n\t" +
  85. ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
  86. '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
  87. "\n}")
  88. def render(self, name:str, uops:UOpGraph) -> str:
  89. kernel:List[str] = []
  90. bufs = []
  91. uops.linearize(ptx_matcher)
  92. if DEBUG >= 4: uops.print()
  93. def kk(*s: str): kernel.append("\n".join(s))
  94. c: DefaultDict[str, int] = defaultdict(int)
  95. r: Dict[UOp, Union[List[str], str]] = {}
  96. def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
  97. nonlocal c, r
  98. prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
  99. c[prefix] += 1
  100. if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
  101. return f"%{prefix}{c[prefix]-1}"
  102. def const(x:ConstType, dtype:DType, mov=False):
  103. if mov or dtype in self.const_requires_mov:
  104. kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
  105. return out
  106. return self.render_const(x, dtype)
  107. def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
  108. if atype == dtype or isinstance(atype, PtrDType):
  109. if u: r[u] = a
  110. return a
  111. kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
  112. return ret
  113. for u in uops:
  114. uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
  115. if uop is UOps.IF:
  116. assert src[0].dtype is not None
  117. kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
  118. elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
  119. elif uop is UOps.ENDRANGE:
  120. kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
  121. self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
  122. kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
  123. elif uop is UOps.ENDIF:
  124. kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
  125. elif uop is UOps.STORE:
  126. assert src[0].dtype is not None and src[2].dtype is not None
  127. assert src[0].dtype == dtypes.int64, "store isn't int64"
  128. assert src[1].op is UOps.CONST, f"store isn't const {u}"
  129. mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
  130. if src[2].dtype.count > 1:
  131. kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
  132. f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
  133. else:
  134. kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
  135. else:
  136. assert dtype is not None, f"None dtype for uop {uop}"
  137. if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
  138. elif uop is UOps.ALU:
  139. assert src[0].dtype is not None
  140. if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
  141. # pass in the other dtype here
  142. kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
  143. else:
  144. kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
  145. elif uop is UOps.DEFINE_ACC:
  146. if dtype.count > 1:
  147. r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
  148. for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
  149. else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
  150. elif uop is UOps.SPECIAL:
  151. assert args[1][0] != "i", "idx not supported"
  152. kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
  153. r[u] = "%" + args[1]
  154. kernel = [f".reg .u32 %{args[1]};"] + kernel
  155. elif uop is UOps.CONST:
  156. if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
  157. else: r[u] = const(args, dtype, mov=True)
  158. elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
  159. elif uop is UOps.LOAD:
  160. assert src[0].dtype == dtypes.int64, "load isn't int64"
  161. assert src[1].op is UOps.CONST, f"load isn't const {u}"
  162. mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
  163. has_gate = len(src) > 3 and src[2].op is UOps.ALU
  164. if dtype.count > 1:
  165. r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
  166. if has_gate:
  167. for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
  168. kk((f"@{r[src[2]]}"if has_gate else "")
  169. + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
  170. else:
  171. kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
  172. alt=r[src[3]] if has_gate else None, ss=mem_type, offset=src[1].arg))
  173. elif uop is UOps.PHI:
  174. if dtype.count > 1:
  175. for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
  176. else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
  177. r[u] = r[src[0]]
  178. elif uop in {UOps.VECTORIZE}:
  179. assert src[0].dtype is not None and dtype.count > 1
  180. r[u] = [r[x] for x in src] # type: ignore
  181. elif uop in {UOps.CAST, UOps.BITCAST}:
  182. assert src[0].dtype is not None and dtype.count == 1
  183. _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
  184. elif uop is UOps.DEFINE_LOCAL:
  185. # TODO: we should sum these, and fetch 0xC000 from somewhere
  186. assert args[1]*dtype.itemsize <= 0xC000, "too large local"
  187. kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
  188. elif uop is UOps.DEFINE_VAR:
  189. bufs.append((args.expr, dtype))
  190. r[u] = f"%{args.expr}"
  191. kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
  192. elif uop is UOps.DEFINE_GLOBAL:
  193. bufs.append((nm:=f"data{args[0]}", dtype))
  194. r[u] = f"%{nm}"
  195. dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
  196. kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
  197. elif uop is UOps.WMMA:
  198. wmma = []
  199. for vv in src[:2]:
  200. for i in range(0, len(r[vv]), 2):
  201. wmma.append(ssa("wmma", dtype="b32"))
  202. kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
  203. r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
  204. kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
  205. {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
  206. else: raise NotImplementedError(f"no code for {uop}")
  207. return self.render_kernel(kernel, name, bufs, c.items())
  208. shiftable_consts = set([2**i for i in range(64)])
  209. ptx_matcher = PatternMatcher([
  210. (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
  211. src=[UPat(UOps.CONST, name="const"), UPat(name="mul")]),
  212. lambda root, mul, const: UOp(UOps.ALU, root.dtype,
  213. (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
  214. (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
  215. src=[UPat(UOps.CONST, name="const"), UPat(name="div")]),
  216. lambda root, div, const: UOp(UOps.ALU, root.dtype,
  217. (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None),
  218. (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
  219. (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
  220. lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
  221. (UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
  222. lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
  223. *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
  224. lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
  225. for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
  226. (UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
  227. lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)),
  228. (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
  229. lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z,k.cast(dtypes.uint8))).cast(dtypes.bool)),
  230. (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
  231. lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)),
  232. (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
  233. lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
  234. (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
  235. lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
  236. (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
  237. lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
  238. # ptr_ar (load/store)
  239. (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
  240. UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
  241. lambda root, alu, const: UOp(root.op, root.dtype,
  242. (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
  243. UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
  244. (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
  245. UPat(UOps.CONST, name="const"))),
  246. lambda root, const: UOp(root.op, root.dtype,
  247. (root.src[0].cast(dtypes.int64),
  248. UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
  249. (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
  250. UPat(name="alu"))), # no const here
  251. lambda root, alu: UOp(root.op, root.dtype,
  252. (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
  253. UOp.const(dtypes.int64, 0))+root.src[2:])),
  254. ])