lazy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from __future__ import annotations
  2. from typing import Union, Optional, Any, Tuple, List
  3. from tinygrad.dtype import dtypes, DType, ConstType
  4. from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
  5. from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
  6. from tinygrad.shape.symbolic import sint, Variable
  7. from tinygrad.shape.shapetracker import ShapeTracker
  8. from tinygrad.device import Buffer
  9. from weakref import ref, ReferenceType, WeakValueDictionary
  10. lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
  11. def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
  12. base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
  13. if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
  14. if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
  15. cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
  16. if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
  17. ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
  18. if enable_cache: lazycache[cache_key] = ret
  19. return ret
  20. view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DISK"}
  21. class LazyBuffer:
  22. def __init__(self, device:str, st:ShapeTracker, dtype:DType,
  23. op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
  24. base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
  25. self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, dtype, st.shape, st.size, metadata
  26. self._base: Optional[LazyBuffer] = None
  27. if base is None:
  28. # properties on base
  29. self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
  30. assert self.op is not MetaOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
  31. if self.op is MetaOps.VIEW:
  32. # some LazyBuffers can be processed with only a view, no AST required
  33. self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
  34. else:
  35. self.buffer = srcs[1].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, dtype)
  36. self.buffer.ref(1)
  37. self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
  38. self.forced_realize = False
  39. else:
  40. # properties on view
  41. assert base.base == base, "base must be a base itself"
  42. self._base = base
  43. def __del__(self):
  44. if hasattr(self, 'buffer'): self.buffer.ref(-1)
  45. def __repr__(self) -> str:
  46. return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
  47. @property
  48. def realized(self) -> Optional[Buffer]:
  49. # NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
  50. return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
  51. # NOTE: this has to be a function to prevent self reference
  52. @property
  53. def base(self) -> LazyBuffer: return self._base if self._base is not None else self
  54. # same API as multi
  55. @property
  56. def lbs(self) -> List[LazyBuffer]: return [self]
  57. @staticmethod
  58. def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
  59. assert isinstance(src, tuple)
  60. return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
  61. def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
  62. assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
  63. shape = self.shape if shape is None else shape
  64. return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
  65. def is_realized(self) -> bool: return self.base.realized is not None
  66. def assign(self, x:LazyBuffer) -> LazyBuffer:
  67. assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
  68. return LazyBuffer.metaop(MetaOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
  69. def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
  70. def contiguous(self, allow_buffer_view=True):
  71. if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
  72. ret = self.e(MetaOps.VIEW) if allow_buffer_view and self.can_view() else self.e(MetaOps.CONTIGUOUS)
  73. if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
  74. return ret
  75. self.base.forced_realize = True
  76. return self
  77. def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
  78. if self.dtype == dtype: return self
  79. if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
  80. if self.is_unrealized_unmasked_const() and not bitcast:
  81. return create_lazybuffer(self.device, self.st, dtype, MetaOps.CONST, dtypes.as_const(self.base.arg, dtype))
  82. new_shape = self.shape
  83. if bitcast and self.dtype.itemsize != dtype.itemsize:
  84. if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
  85. if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
  86. # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
  87. if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
  88. new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
  89. elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
  90. # TODO: applying this makes gpt2 slower
  91. return self.base.cast(dtype, bitcast)._view(self.st)
  92. cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
  93. return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
  94. def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
  95. def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
  96. def _copy(self, device:str) -> LazyBuffer:
  97. return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
  98. def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
  99. # no COPY
  100. if self.device == device: return self
  101. # double COPY = one COPY
  102. if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is MetaOps.COPY:
  103. return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
  104. # const doesn't have to be copied (issues with disk tensor)
  105. if self.is_unrealized_const():
  106. return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
  107. # if it's a shrink, do the shrink before the copy with CONTIGUOUS
  108. if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
  109. # copy the base and apply the shapetracker on the new device
  110. return self.base._copy(device)._view(self.st)
  111. def e(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
  112. srcs: List[LazyBuffer] = []
  113. for s in (self,)+in_srcs:
  114. if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
  115. srcs.append(root._view(s.base.contiguous_child[1]))
  116. else:
  117. srcs.append(s)
  118. assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
  119. assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
  120. if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
  121. if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
  122. out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
  123. # const folding
  124. if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
  125. return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
  126. if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG and self.base.realized is None: return self.base.srcs[0]
  127. if op in BinaryOps:
  128. x, y = self, in_srcs[0]
  129. if op is BinaryOps.ADD:
  130. if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
  131. if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
  132. if op is BinaryOps.MUL:
  133. if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
  134. return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
  135. if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
  136. return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
  137. return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
  138. # *** reduce ops ***
  139. def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
  140. assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
  141. axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
  142. if len(axis) == 0: return self
  143. return create_lazybuffer(self.device, ShapeTracker.from_shape(reduce_st(self.st, axis)), self.dtype, op, axis, (self,))
  144. def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
  145. new_shape = reduce_st(self.st, axis)
  146. # TODO: this logic should move to the scheduler
  147. if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
  148. # const folding
  149. # TODO: fold this for symbolic?
  150. if self.is_unrealized_unmasked_const() and all_int(self.shape):
  151. return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
  152. # TODO: can we split symbolic shape if the reduce axis is not symbolic?
  153. if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
  154. prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
  155. return self._reduce_op(op, axis)
  156. # if there are few globals, make some reduces into globals by splitting into two kernels
  157. # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
  158. # ~2**10 should be enough if GROUP is used
  159. # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
  160. # split is moved to the end to provide maximum locality for the second phase reduce.
  161. self_real_strides = self.st.real_strides(ignore_valid=True)
  162. split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
  163. if self.shape[i] % x == 0 and self_real_strides[i] != 0]
  164. if not split_candidates: return self._reduce_op(op, axis)
  165. dim_to_split, divisor = split_candidates[0]
  166. splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
  167. splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
  168. if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
  169. return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
  170. # *** movement ops ***
  171. def _view(self, new_st:ShapeTracker) -> LazyBuffer:
  172. if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
  173. return self.const(0, new_st.shape)
  174. if new_st.contiguous and self.base.shape == new_st.shape: return self.base
  175. return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
  176. def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
  177. def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
  178. def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
  179. def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
  180. def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
  181. def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))