lowerer.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import annotations
  2. from typing import List, Tuple, cast, Optional, Any, Dict
  3. import functools
  4. from tinygrad.shape.shapetracker import ShapeTracker, View
  5. from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
  6. from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo
  7. from tinygrad.codegen.uops import UOp, UOps
  8. from tinygrad.renderer import Renderer
  9. from tinygrad.helpers import getenv, prod
  10. # TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
  11. from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
  12. def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.bigint, x) if isinstance(x, int) else x.render(render_ops, ctx)
  13. render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.bigint, self.b),
  14. MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
  15. DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
  16. ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
  17. LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
  18. Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp(UOps.DEFINE_VAR, dtypes.int32, (), self),
  19. SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
  20. AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
  21. if getenv("UOP_IS_SYMBOLIC"):
  22. # TODO: change this once UOps is ready to replace symbolic
  23. def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
  24. # TODO: dtypes.realint
  25. iexpr = variable_to_uop(view.offset)
  26. for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
  27. if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
  28. if m is not None:
  29. if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
  30. if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
  31. return iexpr, vexpr
  32. def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
  33. idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
  34. for view in reversed(st.views[0:-1]):
  35. view = view.minify()
  36. acc, idxs = 1, []
  37. for _d in reversed(view.shape):
  38. d = variable_to_uop(_d)
  39. idxs.append((idx//acc)%d)
  40. acc *= d
  41. idx, valid = _uop_view(view, idxs[::-1], valid)
  42. return idx, valid
  43. else:
  44. def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
  45. fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
  46. idx, valid = st.expr_idxs(fake_idxs)
  47. ctx = dict(zip(fake_idxs, idxs))
  48. uidx, uvalid = idx.render(render_ops, ctx), valid.render(render_ops, ctx)
  49. if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
  50. assert uvalid.dtype == dtypes.bool
  51. return uidx, uvalid
  52. def get_grouped_dims(prefix, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
  53. # TODO: this should be per dim max
  54. maxdim = len(max_sizes) if max_sizes is not None else 0
  55. local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (),
  56. (i, f"{prefix}{i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
  57. if maxdim != 0 and len(dims) > maxdim:
  58. dd = local_idxs[0]
  59. nli = []
  60. for s in dims[:-(maxdim-1)]:
  61. nli.append(dd % s)
  62. dd //= s
  63. local_idxs = nli + local_idxs[-(maxdim-1):]
  64. return local_idxs
  65. class IndependentLowerer:
  66. def lower(self, ast:LazyOp, opts:Renderer) -> UOp:
  67. self.output_count = len(ast.src)
  68. ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
  69. # NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
  70. full_shape = ast.full_shape
  71. first_upcasted = len(full_shape)-ki.upcasted
  72. # if there's no reduce, this is first_upcasted
  73. first_reduce = [x!=y for x,y in zip(ast.src[0].arg.st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
  74. local_loads = [x for x in ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1]
  75. # NOTE: this is taking the first one...there may be subtlelies here with multireduces
  76. group_for_reduces = sum([x!=y for x,y in zip(
  77. local_loads[0].arg.st.shape[first_reduce:first_upcasted], ast.src[0].arg.st.shape[first_reduce:first_upcasted])]) if local_loads else 0
  78. global_dims = first_reduce-ki.local_dims
  79. if opts.has_local:
  80. # define indexes for GPU-like execution
  81. self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max) + \
  82. get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
  83. else:
  84. # all loops are RANGES
  85. self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))
  86. for i,g in enumerate(full_shape[:first_reduce])]
  87. # reduce loops
  88. self.idxs += [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, True))
  89. for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
  90. # upcast loops
  91. for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
  92. assert isinstance(g, int), "needs to be int to upcast/unroll"
  93. self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), ((i,g),)))
  94. # late indexes (group for reduce)
  95. self.ridxs = self.idxs[:]
  96. for a in range(first_reduce, first_reduce+group_for_reduces):
  97. self.ridxs[a] = UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
  98. self.uop_cache: Dict[LazyOp, UOp] = {}
  99. return self.to_uop(ast)
  100. def to_uop(self, x:LazyOp) -> UOp:
  101. if uop:=self.uop_cache.get(x, None): return uop
  102. ret = self._to_uop(x)
  103. self.uop_cache[x] = ret
  104. return ret
  105. def _to_uop(self, x:LazyOp) -> UOp:
  106. if x.op in BufferOps:
  107. idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs)
  108. # TODO: check has_valid in UPat, not here
  109. has_valid = valid.op is not UOps.CONST or valid.arg is not True
  110. if x.op is BufferOps.CONST:
  111. dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
  112. return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
  113. if x.arg.idx == -1:
  114. buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
  115. else:
  116. buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
  117. (x.arg.idx, x.arg.idx < self.output_count))
  118. if x.op is BufferOps.LOAD:
  119. barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
  120. return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
  121. # NOTE: only store the local reduceop in the first thread
  122. if x.arg.idx != -1:
  123. has_valid = True
  124. for oidx, ridx in zip(self.idxs, self.ridxs):
  125. if oidx != ridx: valid = valid * oidx.eq(0)
  126. return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
  127. in_uops = tuple(self.to_uop(y) for y in x.src)
  128. if x.op is MetaOps.KERNEL: return UOp(UOps.SINK, src=in_uops)
  129. if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
  130. if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
  131. if x.op in ReduceOps:
  132. dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
  133. if x.op is ReduceOps.WMMA:
  134. wmma_sz, upcast_axis = x.arg[4], x.arg[6]
  135. ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
  136. UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
  137. UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
  138. UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
  139. return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
  140. # NOTE: always using ridxs is fine here
  141. return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
  142. return UOp.alu(x.op, *in_uops)
  143. def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)