| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
- from __future__ import annotations
- from dataclasses import dataclass
- from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
- from tinygrad.helpers import merge_dicts, getenv
- from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
- from tinygrad.shape.view import View, strides_for_shape
- @dataclass(frozen=True)
- class ShapeTracker:
- views: Tuple[View, ...]
- def __add__(self, st:ShapeTracker) -> ShapeTracker:
- ret = self
- for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
- return ret
- def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
- inverted_views:List[View] = []
- for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
- if (inverted:= v.invert(s)) is None: return None
- inverted_views.append(inverted)
- return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
- @staticmethod
- def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
- @property
- def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
- @property
- def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
- @property
- def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
- @property
- def size(self) -> int: return self.views[-1].size()
- def real_size(self) -> int:
- if 0 in self.shape: return 0
- idx, valid = self.expr_idxs()
- if not valid: return 0
- # TODO: it's possible that the real_size is smaller condition on valid being true
- ret = idx.max
- if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
- assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
- return ret+1
- def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views], set())
- @property
- def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
- def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
- unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
- return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
- # NOTE: if a stride is not always valid, it will be None
- def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
- if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
- idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
- idx, valid = self.expr_idxs(idxs)
- ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
- bad_idx_vars: Set[Variable] = set()
- for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
- idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
- try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
- except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
- idx_vars, valid_vars = idx.vars(), valid.vars()
- for i,tidx in enumerate(idxs):
- if tidx in bad_idx_vars or (tidx in valid_vars and not ignore_valid): ret[i] = None
- elif tidx not in idx_vars: ret[i] = 0
- return tuple(ret)
- def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
- def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
- idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
- idx, valid = self.views[-1].expr(idxs)
- for view in reversed(self.views[0:-1]):
- if valid.max == 0: return NumNode(-1), valid
- view = view.minify()
- acc, idxs = 1, []
- for d in reversed(view.shape):
- idxs.append((idx//acc)%d)
- acc *= d
- idx, valid = view.expr(idxs[::-1], valid)
- assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
- assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
- return idx, valid
- def axis_is_masked(self, axis:int) -> bool:
- _, valid = self.expr_idxs()
- return f'idx{axis}' in [v.expr for v in valid.vars()]
- def simplify(self) -> ShapeTracker:
- if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
- return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
- return self
- # *** under this line are the movement ops ***
- def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
- def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
- def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
- def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
- def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
- def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
- if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
- return ShapeTracker(self.views + (View.create(new_shape), ))
|