kernel.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. from __future__ import annotations
  2. import itertools, functools
  3. from dataclasses import replace
  4. from collections import defaultdict
  5. from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
  6. from tinygrad.engine.graph import print_tree
  7. from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, \
  8. verify_lazyop, KernelInfo, get_lazyop_info
  9. from tinygrad.device import Device
  10. from tinygrad.renderer import Renderer, TensorCore, Program
  11. from tinygrad.dtype import dtypes, ImageDType
  12. from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, \
  13. get_contraction, to_function_name, diskcache_put, ContextVar
  14. from tinygrad.shape.shapetracker import ShapeTracker
  15. from tinygrad.shape.symbolic import sint
  16. from tinygrad.shape.view import strides_for_shape
  17. from tinygrad.codegen.uops import UOps, flops_mem
  18. from tinygrad.codegen.uopgraph import UOpGraph
  19. from tinygrad.codegen.lowerer import lazyop_to_uop
  20. from dataclasses import dataclass
  21. from enum import Enum, auto
  22. class OptOps(Enum):
  23. TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
  24. GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); MERGE = auto() # noqa: E702
  25. def __lt__(self, x:OptOps): return self.value < x.value
  26. class KernelOptError(Exception): pass
  27. def check(cond:bool, msg:str=""):
  28. if not cond: raise KernelOptError(msg)
  29. @dataclass(frozen=True, order=True)
  30. class Opt:
  31. op: OptOps
  32. axis: Optional[int] = None
  33. amt: Optional[int] = None
  34. def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
  35. def real_axis(self, k:Kernel):
  36. if self.axis is None: return -1
  37. if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
  38. if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
  39. return self.axis
  40. @dataclass
  41. class TensorCoreOptions:
  42. axes: Tuple[int, ...] # the location of the original N and M axes if still in the shape
  43. axes_exist: Tuple[bool, ...] # true if the original N and M axes are still in the shape
  44. axis_pads: Tuple[Tuple[int, int], ...]
  45. def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when a dimension is removed
  46. axes, axes_exist = list(self.axes), list(self.axes_exist)
  47. for tc_dim in [i for i in range(2) if axes_exist[i]]:
  48. if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
  49. elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
  50. self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
  51. class Kernel:
  52. def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
  53. if len(ast) > 1 or ast[0].op is BufferOps.STORE:
  54. assert all(x.op is BufferOps.STORE for x in ast)
  55. self.ast = LazyOp(MetaOps.KERNEL, ast)
  56. else:
  57. assert len(ast) == 1 and ast[0].op is MetaOps.KERNEL
  58. self.ast = ast[0]
  59. self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
  60. try: lazyop_sts_map = verify_lazyop(self.ast)
  61. except AssertionError as e:
  62. print("INVALID AST")
  63. for op in ast: print_tree(op)
  64. raise e
  65. cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
  66. def ordered_lazyops(op):
  67. if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
  68. return cached_ordered_lazyops[op]
  69. self.reduceops = dedup([x for x in ordered_lazyops(self.ast) if x.op in ReduceOps])
  70. self.vars = self.ast.vars()
  71. self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.ast.lazyops if x.op in BufferOps])
  72. # get earlybufs, before any reduceops
  73. earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
  74. self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
  75. # NOTE: full_shape can be wrong if there's a tree of reduces
  76. # create new shapetrackers inside this kernel, we will permute them
  77. self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
  78. # add the shapetrackers for each reduce
  79. # we use this to track which axes are reduced in each reduce
  80. for x in self.reduceops:
  81. self.sts.append(lazyop_sts_map[x])
  82. self.sts.append(lazyop_sts_map[x.src[0]])
  83. # move all reduce axes to the end
  84. reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
  85. permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
  86. self.reshape_and_permute(None, permute)
  87. # parameters for optimization
  88. self.applied_opts: List[Opt] = []
  89. self.group_for_reduces: int = 0
  90. self.upcasted: int = 0
  91. self.local_dims: int = 0
  92. self.tensor_core: Optional[TensorCore] = None
  93. self.tensor_core_opts: Optional[TensorCoreOptions] = None
  94. # the local aliased buffers for A and B
  95. self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
  96. self.dont_use_locals: bool = False
  97. # group simplifies
  98. self.simplify_ones()
  99. self.simplify_merge_adjacent()
  100. # cache
  101. self.applied_opts_cache: Optional[List[Opt]] = None
  102. def copy(self):
  103. ret = type(self).__new__(type(self))
  104. # base linearizer params
  105. ret.opts, ret.ast = self.opts, self.ast
  106. # things downstream of the AST
  107. ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
  108. self.reduceops, self.vars, self.bufs, self.full_buf_index
  109. ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
  110. # parameters for optimizations
  111. ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
  112. self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
  113. ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core
  114. # uncached since linearize didn't run
  115. ret.applied_opts_cache = None
  116. return ret
  117. @property
  118. def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
  119. # TODO: these need more tests or it might silently be no-op
  120. def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
  121. def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
  122. def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
  123. upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
  124. assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
  125. return list(zip(upcasted_shape, upcasted_stride,
  126. [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
  127. # TODO: is there a better way to write this?
  128. def acc_offsets(self, i:int) -> List[int]:
  129. if self.upcasted == 0: return [0]
  130. upcasted_i = self.upcasted_axis(i)
  131. acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
  132. return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
  133. def get_float4_upcast_dim(self, i:int) -> List[int]:
  134. should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in (dtypes.float, dtypes.half) or isinstance(self.bufs[i].dtype, ImageDType))
  135. return [x for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] if should_upcast else []
  136. @property
  137. def first_reduce(self) -> int:
  138. return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
  139. @property
  140. def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
  141. @property
  142. def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
  143. @property
  144. def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
  145. @property
  146. def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
  147. @property
  148. def shape_len(self) -> int: return len(self.sts[0].shape)
  149. @property
  150. def upcast_in_mid_reduce_axes(self) -> List[int]:
  151. return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
  152. @property
  153. def global_dims(self) -> int: return self.first_reduce-self.local_dims
  154. # there's eight chunks of the shape
  155. # blue -- global dims
  156. # cyan -- local dims (warp ones first)
  157. # *** self.first_reduce
  158. # green -- reduce-local dims
  159. # white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
  160. # red -- reduce loops
  161. # *** self.upcasted
  162. # purple -- reduce upcasted
  163. # yellow -- normal upcasted dimensions
  164. def colors(self) -> List[str]:
  165. # first non local non reduce dims are global (blue)
  166. colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
  167. # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
  168. colors += ["cyan"] * self.local_dims
  169. # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
  170. colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
  171. # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
  172. colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
  173. # upcasted dimensions are reduce (magenta) or normal (yellow)
  174. colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
  175. assert len(colors) == self.shape_len, "colors size mismatch"
  176. return colors
  177. def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
  178. ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors())) # noqa: E501
  179. if pad: ret += ' '*(pad-ansilen(ret))
  180. return ret
  181. # ******************** base simplifiers ********************
  182. # apply reshape and permute to all shapetrackers
  183. def reshape_and_permute(self, new_shape_fxn, axis):
  184. new_sts = []
  185. for st in self.sts:
  186. if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
  187. if axis is not None: st = st.permute(tuple(axis))
  188. new_sts.append(st)
  189. self.sts = new_sts
  190. # drops the final dimension
  191. def upcast(self):
  192. check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
  193. self.upcasted += 1
  194. # axis : the axis to pull from
  195. # amount : the amount to take
  196. # top : if you want to pull that amount from the top
  197. # insert_before : place to insert the new stuff
  198. def shift_to(self, axis, amount, top=False, insert_before=None):
  199. if insert_before is None: insert_before = self.shape_len
  200. move_axis = axis if top else axis+1
  201. if move_axis < insert_before: insert_before += 1
  202. self.reshape_and_permute(
  203. lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
  204. [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
  205. # ******************** complex simplifiers ********************
  206. def simplify_ones(self) -> bool:
  207. # remove places where the shape is all ones
  208. # TODO: this should be factored in to multi shape stride
  209. if self.shape_len == 0: return False
  210. all_ones = [s==1 for s in self.full_shape]
  211. self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
  212. self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
  213. self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
  214. return any(all_ones)
  215. def simplify_merge_adjacent(self):
  216. if self.shape_len == 0: return
  217. shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
  218. # if it's an image, insert fake strides such that this fusion doesn't happen across image axes
  219. if isinstance(self.bufs[0].dtype, ImageDType):
  220. base_shape = self.bufs[0].dtype.shape
  221. if shape_idx_groups := get_contraction(self.output_shape, base_shape):
  222. special_strides: Tuple[sint, ...] = tuple()
  223. for i,g in enumerate(shape_idx_groups):
  224. shape_piece = tuple(self.output_shape[x] for x in g)
  225. assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
  226. special_strides += strides_for_shape(shape_piece)
  227. # adding the fake image shape
  228. shapes.append(self.output_shape)
  229. strides.append(special_strides)
  230. # merge dimensions if we can, multi _merge_dims
  231. # NOTE: this does not always preserve the reduce dimension
  232. # TODO: move this into shapetracker, with tests!
  233. # TODO: how does this work with multi-reduce?
  234. rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
  235. for i in range(1, len(shapes[0])):
  236. can_merge = []
  237. for s,st,ret in zip(shapes, strides, rets):
  238. # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
  239. si, sti, last_st = s[i], st[i], ret[-1][1]
  240. can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
  241. # more can merge than this
  242. mergeable = all(can_merge) and i != self.first_reduce
  243. for j,(s,st) in enumerate(zip(shapes, strides)):
  244. if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
  245. else: rets[j].append((s[i], st[i]))
  246. # do the reshapes
  247. for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
  248. # ******************** high level optimizers ********************
  249. def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
  250. has_cast = tc.dtype_in != tc.dtype_out
  251. if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
  252. mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
  253. if mul_op.op is not BinaryOps.MUL: return None
  254. def buf_index(src: LazyOp) -> Optional[int]:
  255. # TODO: apply tc even if the sources are not from LOAD
  256. if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
  257. try:
  258. if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
  259. except ValueError: return None
  260. return None
  261. if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
  262. buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
  263. axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
  264. axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
  265. if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
  266. axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
  267. if not(axis < len(axis_choices)): return None
  268. s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
  269. axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
  270. if axis_pads and (opt_level < 2): return None
  271. self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
  272. if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
  273. return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
  274. def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
  275. if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
  276. for tc in self.opts.tensor_cores:
  277. tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
  278. # can only fuse reduces with the same tc options
  279. assert all_same(tensor_core_opts)
  280. if tensor_core_opts[0] is None: continue
  281. # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
  282. self.tensor_core_opts = tc_opts = tensor_core_opts[0]
  283. # attempt to pad the tensor axes that require it
  284. try:
  285. for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
  286. except KernelOptError: continue
  287. if self.opts.device == "AMD":
  288. # NOTE: AMD requires locals first
  289. self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
  290. for (tc_dim, tc_amt) in tc.threads:
  291. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
  292. for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
  293. if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
  294. elif self.opts.device == "METAL":
  295. self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
  296. for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
  297. if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
  298. for (tc_dim, tc_amt) in tc.threads:
  299. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
  300. elif self.opts.device in {"CUDA", "NV"}:
  301. self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 8), append_opt=False)
  302. self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 2), append_opt=False)
  303. self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[0], 2), append_opt=False)
  304. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], 2), append_opt=False)
  305. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], 2), append_opt=False)
  306. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
  307. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
  308. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
  309. self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[1], 2), append_opt=False)
  310. # NOTE: MERGE is needed because we can't deal with two upcasted dimensions
  311. self.apply_opt(Opt(OptOps.MERGE, self.shape_len-2), append_opt=False)
  312. # assert tensor core
  313. if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
  314. return True
  315. return False
  316. def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
  317. """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
  318. Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
  319. Keyword arguments:
  320. use_tensor_cores -- controls how tensor cores are applied (default 1)
  321. 0: will disable any tensor core matching
  322. 1: enable tensor cores
  323. 2: apply tensor core shape but don't use UOp.WMMA
  324. extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
  325. tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
  326. 0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
  327. 1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
  328. 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
  329. """
  330. if tc_opt is None: tc_opt = TC_OPT.value
  331. if not self.opts.tensor_cores and use_tensor_cores != 2: return False
  332. try: # check TC first and apply hand-coded opts if successful
  333. self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
  334. if (tc_opts:=self.tensor_core_opts) is not None:
  335. if extra_opts is not None:
  336. for opt in extra_opts: self.apply_opt(opt)
  337. else:
  338. # hand-coded TC opts
  339. def late_upcast_tc(tc_dim: int):
  340. if tc_opts.axes_exist[tc_dim]:
  341. ax_div = [upc for upc in [5,4,3,2,1] if self.full_shape[tc_opts.axes[tc_dim]]%upc == 0][0]
  342. if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
  343. late_upcast_tc(1) # attempt to upcast M
  344. late_upcast_tc(0) # attempt to upcast N
  345. if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
  346. for upc in [4,2]:
  347. if self.full_shape[tc_opts.axes[0]] % upc == 0:
  348. self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
  349. break
  350. return True
  351. except KernelOptError:
  352. return False
  353. def apply_opt(self, opt:Opt, append_opt:bool=True):
  354. check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
  355. if opt.op is OptOps.TC:
  356. check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
  357. check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
  358. check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
  359. check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
  360. self.applied_opts.append(opt)
  361. return
  362. axis = opt.real_axis(self)
  363. check(axis < len(self.full_shape), "invalid axis")
  364. if opt.amt is not None:
  365. amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
  366. check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
  367. if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
  368. else: amt = -1
  369. if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
  370. acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
  371. upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
  372. local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
  373. smem_sz = amt*acc_sz*upcast_sz*local_sz
  374. check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
  375. if opt.op is OptOps.LOCAL: # cyan
  376. check(self.opts.has_local, "target does not support local")
  377. check(axis < self.global_dims, "local is for globals")
  378. self.shift_to(axis, amt, insert_before=self.first_reduce)
  379. self.local_dims += 1
  380. elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
  381. check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
  382. check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
  383. check(not self.tensor_core, "can't group with tensor cores")
  384. check(len(self.reduceops) == 1, "can't group with multiple reduces")
  385. self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
  386. self.group_for_reduces += 1
  387. elif opt.op is OptOps.UNROLL: # purple
  388. check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
  389. check(amt <= 32, "don't unroll more than 32")
  390. # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
  391. #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
  392. #self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
  393. if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
  394. if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
  395. self.shift_to(axis, amt, insert_before=None)
  396. self.upcast()
  397. elif opt.op is OptOps.UPCAST: # yellow
  398. check(axis < self.first_reduce, "upcast is for non-reduce")
  399. check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
  400. check(amt <= 8, "don't upcast more than 8")
  401. self.shift_to(axis, amt, insert_before=None)
  402. self.upcast()
  403. elif opt.op is OptOps.UPCASTMID: # white
  404. check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
  405. axes = self.sts[0].unit_stride_axes()
  406. check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
  407. check(axes[0] == axis, "wrong axis")
  408. check(amt == 4, "don't upcast mid anything but 4")
  409. self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
  410. self.group_for_reduces += 1
  411. elif opt.op is OptOps.NOLOCALS:
  412. check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
  413. check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
  414. self.dont_use_locals = True
  415. elif opt.op is OptOps.MERGE:
  416. check(axis >= self.shape_len-self.upcasted, "only merge upcasted")
  417. check(self.full_shape[axis:axis+2] == self.output_shape[axis:axis+2], "can't merge reduces")
  418. self.reshape_and_permute(None, tuple(range(axis)) + (axis+1, axis) + tuple(range(axis+2, self.shape_len)))
  419. self.reshape_and_permute(lambda x: x[0:axis] + (x[axis] * x[axis+1],) + x[axis+2:], None)
  420. self.upcasted -= 1
  421. elif opt.op is OptOps.PADTO:
  422. check(not self.vars, "does not work with symbolic shape")
  423. check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
  424. # ok to pad SUM if all parent ops have f(0) = 0
  425. if self.first_reduce <= axis:
  426. check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
  427. all(op.op not in UNSAFE_PAD_OPS for sop in r.src for op in sop.lazyops), "cannot pad")
  428. padded = False
  429. for i,st in enumerate(self.sts):
  430. if self.sts[i].shape[axis] == 1: continue # reduced
  431. check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
  432. if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
  433. # pad right seems to be faster
  434. self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
  435. padded = True
  436. check(padded, "nothing was padded")
  437. if append_opt: self.applied_opts.append(opt)
  438. if self.simplify_ones() and self.tensor_core_opts:
  439. self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
  440. def required_optimizations(self):
  441. if self.bufs[0].dtype.__class__ is ImageDType:
  442. unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
  443. assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
  444. if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
  445. self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
  446. def hand_coded_optimizations(self):
  447. self.required_optimizations()
  448. # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
  449. MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
  450. if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
  451. self.reduceop is not None and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
  452. (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
  453. st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
  454. strides0, strides1 = st0.real_strides(), st1.real_strides()
  455. def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
  456. if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
  457. for global_idx in range(self.global_dims):
  458. if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
  459. if DEBUG >= 3:
  460. print(f"MATVEC: {self.full_shape=} {self.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
  461. if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
  462. if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
  463. if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
  464. return
  465. if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
  466. # are we grouping? (requires local shape support)
  467. if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
  468. # TODO: use 1024 if it's allowed in a smarter way
  469. for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
  470. if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
  471. try: # may fail due to excessive smem usage
  472. self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
  473. break
  474. except KernelOptError: pass
  475. # are we upcasting in mid reduce? (only for images)
  476. if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
  477. axes = self.sts[0].unit_stride_axes()
  478. assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
  479. if self.sts[0].shape[axes[0]]%4 == 0:
  480. self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
  481. # upcast float4 images
  482. for buf_index,buf in enumerate(self.bufs):
  483. unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
  484. if buf.dtype.__class__ is ImageDType:
  485. #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
  486. if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
  487. if unit_stride_axes_mul_4[0] < self.first_reduce:
  488. self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
  489. else:
  490. self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
  491. # no more opt if we are grouping
  492. if self.group_for_reduces: return
  493. # **** below this line need to be optional and benchmarked ****
  494. # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
  495. # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
  496. # expression and run test/test_ops.py with IMAGE=2
  497. # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
  498. # this can be made much smarter
  499. to_upcast: List[int] = []
  500. # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
  501. for axis in range(self.first_reduce):
  502. # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
  503. # for now skip upcasting here if there is a symbolic axis
  504. if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
  505. prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
  506. if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
  507. to_upcast.append(axis)
  508. for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
  509. # potentially do more upcasts of non reduce axes based on a heuristic
  510. upcasted_axis = set()
  511. while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
  512. xb_choices = []
  513. for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
  514. # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
  515. if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
  516. xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
  517. if xb_choices:
  518. xb_choices = sorted(xb_choices)
  519. if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
  520. self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
  521. upcasted_axis.add(xb_choices[0][2])
  522. else: break
  523. # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
  524. if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
  525. if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
  526. self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
  527. # if it's small, upcast a second reduce dimension too
  528. if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
  529. self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
  530. else:
  531. for splits in [4]:
  532. if self.full_unupcasted_shape[-1]%splits == 0:
  533. self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
  534. break
  535. # if nothing at all is upcasted and it's easy to, do an upcast
  536. # TODO: this is breaking the tests
  537. for splits in [4]:
  538. if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
  539. self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
  540. # **** local groups ****
  541. if self.opts.has_local:
  542. if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
  543. self.apply_opt(Opt(OptOps.NOLOCALS))
  544. else:
  545. # prioritize making expand axes local
  546. local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
  547. to_local: List[Tuple[int, int]] = []
  548. for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
  549. local_size = prod(sz for _, sz in to_local)
  550. local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
  551. if local_sz is not None: to_local.append((axis, local_sz))
  552. deleted_shape = 0
  553. for axis, local_sz in sorted(to_local[:3]):
  554. axis = axis - deleted_shape
  555. will_delete_shape = local_sz == self.full_shape[axis]
  556. self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
  557. if will_delete_shape: deleted_shape += 1
  558. # **** kernel outputs ****
  559. kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
  560. @functools.cached_property
  561. def name(self) -> str:
  562. # kernel name (before late upcast)
  563. name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.ast.lazyops) else "E")) + \
  564. (f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
  565. colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
  566. # name the function something unique
  567. Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
  568. suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else ""
  569. return name+colored(suffix, 'BLACK')
  570. def get_optimized_ast(self) -> LazyOp:
  571. # set the shapetrackers to the optimized ones, fixup reduceop
  572. # transformed to the final LazyOp
  573. @functools.lru_cache(None)
  574. def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp:
  575. if op.op in BufferOps:
  576. idx = self.bufs.index(op.arg)
  577. arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
  578. elif op.op in ReduceOps:
  579. reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
  580. arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
  581. if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])
  582. if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
  583. rsrc = op.src[0]
  584. if rsrc.op is UnaryOps.CAST: rsrc = rsrc.src[0]
  585. assert rsrc.op is BinaryOps.MUL
  586. def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
  587. wd = self.global_dims
  588. tcd = self.shape_len-self.upcasted
  589. assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
  590. assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
  591. new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
  592. permaxis = list(range(wd))
  593. permaxis += [y + (wd if x == 0 else tcd) for x,y in pattern_1]
  594. permaxis += list(range(wd+len(warp_dims), tcd))
  595. permaxis += [y + (wd if x == 0 else tcd) for x,y in pattern_2]
  596. permaxis += list(range(tcd+len(tcd_expand), len(new_shape)))
  597. return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
  598. if self.opts.device == "AMD":
  599. reduce_axes = [self.shape_len-self.upcasted]
  600. upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
  601. fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
  602. fix_st2 = None
  603. elif self.opts.device == "METAL":
  604. reduce_axes = [self.shape_len-self.upcasted]
  605. upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1)
  606. fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
  607. fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
  608. elif self.opts.device in {"CUDA", "NV"}:
  609. reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
  610. upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
  611. # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
  612. fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
  613. ((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))
  614. fix_st2 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
  615. ((1,1), (1,0), (1,5), (0,0), (0,1)), ((0,4), (0,2), (1,4), (0,3), (1,3), (1,2)))
  616. else:
  617. raise RuntimeError("unsupported device for tensor cores")
  618. assert apply_to_st is None, "double tensor core? not supported"
  619. wmma_sz = [prod(l) for l in tc.thread_local_sizes]
  620. wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), self.opts.device, upcast_axis, tuple(reduce_axes))
  621. ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
  622. new_reduce_axes = tuple(i for i in arg if i not in reduce_axes)
  623. return LazyOp(op.op, (ret,), new_reduce_axes) if new_reduce_axes else ret
  624. if self.group_for_reduces:
  625. start = LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
  626. local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
  627. (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
  628. local_buffer = MemBuffer(-1, start.dtype, ShapeTracker.from_shape(local_shape))
  629. local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
  630. local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
  631. return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
  632. elif op.op is MetaOps.KERNEL:
  633. arg = KernelInfo(self.local_dims, self.upcasted)
  634. else:
  635. arg = op.arg
  636. return LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
  637. return fixup_ast(self.ast)
  638. # **** this is the lowerer ****
  639. def linearize(self) -> Kernel:
  640. modified_ast = self.get_optimized_ast()
  641. if DEBUG >= 3:
  642. print(self.name)
  643. print_tree(modified_ast)
  644. verify_lazyop(modified_ast)
  645. uop_sink = lazyop_to_uop(modified_ast, self.opts)
  646. # extract global/local sizes
  647. if self.opts.has_local:
  648. self.global_size: Optional[List[int]] = [1,1,1]
  649. self.local_size: Optional[List[int]] = [1,1,1]
  650. for u in uop_sink.parents:
  651. if u.op is UOps.SPECIAL:
  652. if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2]
  653. else: self.global_size[u.arg[0]] = u.arg[2]
  654. else:
  655. self.global_size, self.local_size = None, None
  656. # generate the UOpGraph
  657. self.uops:UOpGraph = UOpGraph(uop_sink, self.opts)
  658. if DEBUG >= 5: self.uops.print()
  659. if getenv("GRAPHUOPS"):
  660. self.uops.graph()
  661. if getenv("GRAPHUOPS") == 2: exit(0)
  662. return self
  663. def to_program(self) -> Program:
  664. self.linearize()
  665. src = self.opts.render(name:=to_function_name(self.name), self.uops)
  666. if getenv("RUN_PROCESS_REPLAY"):
  667. table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}"
  668. diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
  669. info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed
  670. ops, mem = flops_mem(self.uops.uops)
  671. run_count = prod((self.global_size or []) + (self.local_size or []))
  672. return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
  673. self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))