shapetracker.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
  2. from __future__ import annotations
  3. from dataclasses import dataclass
  4. from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
  5. from tinygrad.helpers import merge_dicts, getenv
  6. from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
  7. from tinygrad.shape.view import View, strides_for_shape
  8. @dataclass(frozen=True)
  9. class ShapeTracker:
  10. views: Tuple[View, ...]
  11. def __add__(self, st:ShapeTracker) -> ShapeTracker:
  12. ret = self
  13. for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
  14. return ret
  15. def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
  16. inverted_views:List[View] = []
  17. for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
  18. if (inverted:= v.invert(s)) is None: return None
  19. inverted_views.append(inverted)
  20. return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
  21. @staticmethod
  22. def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
  23. @property
  24. def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
  25. @property
  26. def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
  27. @property
  28. def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
  29. @property
  30. def size(self) -> int: return self.views[-1].size()
  31. def real_size(self) -> int:
  32. if 0 in self.shape: return 0
  33. idx, valid = self.expr_idxs()
  34. if not valid: return 0
  35. # TODO: it's possible that the real_size is smaller condition on valid being true
  36. ret = idx.max
  37. if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
  38. assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
  39. return ret+1
  40. def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views], set())
  41. @property
  42. def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
  43. def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
  44. unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
  45. return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
  46. # NOTE: if a stride is not always valid, it will be None
  47. def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
  48. if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
  49. idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
  50. idx, valid = self.expr_idxs(idxs)
  51. ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
  52. bad_idx_vars: Set[Variable] = set()
  53. for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
  54. idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
  55. try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
  56. except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
  57. idx_vars, valid_vars = idx.vars(), valid.vars()
  58. for i,tidx in enumerate(idxs):
  59. if tidx in bad_idx_vars or (tidx in valid_vars and not ignore_valid): ret[i] = None
  60. elif tidx not in idx_vars: ret[i] = 0
  61. return tuple(ret)
  62. def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
  63. def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
  64. idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
  65. idx, valid = self.views[-1].expr(idxs)
  66. for view in reversed(self.views[0:-1]):
  67. if valid.max == 0: return NumNode(-1), valid
  68. view = view.minify()
  69. acc, idxs = 1, []
  70. for d in reversed(view.shape):
  71. idxs.append((idx//acc)%d)
  72. acc *= d
  73. idx, valid = view.expr(idxs[::-1], valid)
  74. assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
  75. assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
  76. return idx, valid
  77. def axis_is_masked(self, axis:int) -> bool:
  78. _, valid = self.expr_idxs()
  79. return f'idx{axis}' in [v.expr for v in valid.vars()]
  80. def simplify(self) -> ShapeTracker:
  81. if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
  82. return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
  83. return self
  84. # *** under this line are the movement ops ***
  85. def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
  86. def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
  87. def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
  88. def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
  89. def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
  90. def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
  91. if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
  92. return ShapeTracker(self.views + (View.create(new_shape), ))