view.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from __future__ import annotations
  2. import functools, operator, itertools, math
  3. from dataclasses import dataclass
  4. from typing import Tuple, List, Optional, Dict, Set, cast
  5. from tinygrad.helpers import prod, all_int, argsort
  6. from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer, create_lt_node, create_ge_node
  7. @functools.lru_cache(maxsize=None)
  8. def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
  9. return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
  10. @functools.lru_cache(maxsize=None)
  11. def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
  12. if not shape: return ()
  13. strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
  14. return canonicalize_strides(shape, strides)
  15. @functools.lru_cache(maxsize=None)
  16. def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
  17. # merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
  18. if not shape: return ()
  19. assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
  20. ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
  21. # merge this dim to next dim if size is 1
  22. merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
  23. for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
  24. last_s, last_st, last_pre_expand_s = ret[-1]
  25. # always merge 1
  26. if s == 1: continue
  27. # merge last dim with this dim if merging or strides matched
  28. if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
  29. else: ret.append((s, st, s if st else 0))
  30. # merge this dim to next dim if size is 1
  31. merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
  32. return tuple(ret)
  33. @functools.lru_cache(maxsize=None)
  34. def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) \
  35. -> Optional[Tuple[Tuple[sint, sint], ...]]:
  36. """Returns the new mask if reshape is possible, and None if not possible."""
  37. if _mask is None: return tuple((0, s) for s in new_shape)
  38. if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
  39. if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
  40. new_mask: List[Tuple[int, int]] = []
  41. # _mask is all int here
  42. r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
  43. curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
  44. while len(new_mask) < len(new_shape):
  45. (l, r), next_stride = mask, new_dim * curr_stride
  46. if old_dim >= next_stride: # need to split mask.
  47. if old_dim == next_stride: # simply copy the mask and get next batch for merging
  48. new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
  49. curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
  50. else: # mask can only be splitted if reshape doesn't cut across the mask.
  51. if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
  52. or old_dim % next_stride != 0): return None
  53. new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
  54. curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
  55. else:
  56. next_mask = next(r_masks, (0, 1))
  57. # combine if the mask can unfold continuously
  58. if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
  59. mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
  60. for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
  61. if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
  62. return tuple(reversed(new_mask))
  63. def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
  64. strides = strides_for_shape(shape)
  65. result = []
  66. for stride in strides:
  67. here = offs // stride if stride else 0
  68. result.append(here)
  69. offs -= here * stride
  70. return result
  71. @dataclass(frozen=True)
  72. class View:
  73. shape:Tuple[sint, ...]
  74. strides:Tuple[sint, ...]
  75. offset:sint
  76. mask:Optional[Tuple[Tuple[sint, sint], ...]]
  77. contiguous:bool
  78. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  79. def size(self) -> int:
  80. # NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
  81. ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
  82. assert isinstance(ret, int), f"{ret=} is not int"
  83. return ret
  84. @staticmethod
  85. @functools.lru_cache(maxsize=None)
  86. def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
  87. strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
  88. # canonicalize 0 in shape
  89. if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
  90. # canonicalize empty mask
  91. if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
  92. # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
  93. # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
  94. # TODO: assert comparison with LtNode to avoid mis-using symbolic
  95. if mask and any(elim := [not (b+1 < e) for b,e in mask]):
  96. if any(not (b < e) for b,e in mask):
  97. strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
  98. offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
  99. strides = tuple(0 if e else st for st,e in zip(strides, elim))
  100. contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
  101. return View(shape, strides, offset, mask, contiguous)
  102. @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
  103. def vars(self) -> Set[Variable]:
  104. flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
  105. return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
  106. @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
  107. def unbind(self) -> Tuple[View, Dict[Variable, int]]:
  108. var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
  109. unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
  110. def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
  111. new_shape = tuple(map(substitute, self.shape))
  112. new_strides = tuple(map(substitute, self.strides))
  113. new_offset = substitute(self.offset)
  114. new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
  115. return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
  116. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  117. def __add__(self, vm1:View) -> Optional[View]:
  118. vm2 = self
  119. if vm2.contiguous: return vm1
  120. if vm1.contiguous and vm1.shape == vm2.shape: return vm2
  121. if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
  122. if vm1.mask:
  123. for b,e in vm1.mask:
  124. if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
  125. return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
  126. # Project vm1's offset and strides on to vm2.
  127. origin = un1d(vm2.shape, vm1.offset)
  128. terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
  129. strides: List[sint] = [0] * len(vm1.shape)
  130. for d1, st in enumerate(vm1.strides):
  131. if st == 0: continue
  132. for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
  133. if (s1 := s1 - o) == 0: continue
  134. terms[d2].append((d1, s1))
  135. strides[d1] += s1 * vm2.strides[d2]
  136. # Merge dimensions in vm2 if required.
  137. # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
  138. idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
  139. merged_size, merged_term = 1, NumNode(0)
  140. extents: List[Tuple[sint, Node]] = []
  141. for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
  142. merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
  143. merged_size *= s
  144. if not (merged_term >= merged_size) and not (merged_term < 0):
  145. extents.append((merged_size, merged_term))
  146. merged_size, merged_term = 1, NumNode(0)
  147. if merged_term: return None
  148. if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
  149. return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
  150. if vm2.mask:
  151. # Try to project vm2's mask on to vm1.
  152. newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
  153. for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
  154. if not (t.min < b or t.max >= e): continue
  155. if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
  156. bad = True
  157. continue
  158. term = terms[d2]
  159. if len(term) != 1:
  160. if not term and newe: newe[0] = 0
  161. else: bad = True
  162. continue
  163. d1, s1 = term[0]
  164. if not isinstance(s1, int) or not isinstance(newe[d1], int):
  165. bad = True
  166. continue
  167. newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
  168. newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
  169. # If any of vm1 was masked off, try again with that mask in place.
  170. for b, e, s in zip(newb, newe, vm1.shape):
  171. if b != 0 or e != s:
  172. return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
  173. # Otherwise if vm2's mask was violated, then cannot merge.
  174. if bad: return None
  175. return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
  176. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  177. def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
  178. ret = View.create(self.shape)
  179. if self.mask: ret = ret.shrink(self.mask)
  180. ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
  181. return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
  182. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  183. def minify(self):
  184. min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
  185. return nv if (nv := self.reshape(min_shape)) else self
  186. def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
  187. offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
  188. if self.mask:
  189. # move the old mask
  190. nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
  191. # merge the masks if we have two
  192. mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
  193. shape = [y-x for x,y in arg]
  194. if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
  195. return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
  196. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  197. def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
  198. assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
  199. if any(b or e for b, e in arg):
  200. zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
  201. mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
  202. return self.__unsafe_resize(zvarg, mask=mask)
  203. return self
  204. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  205. def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
  206. assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
  207. return self.__unsafe_resize(arg)
  208. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  209. def expand(self, new_shape: Tuple[sint, ...]) -> View:
  210. if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
  211. if 0 in self.shape:
  212. assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
  213. return View.create(new_shape)
  214. assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
  215. # NOTE: can the mask ever be (0,0)?
  216. mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
  217. return View.create(new_shape, self.strides, self.offset, mask)
  218. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  219. def permute(self, axis: Tuple[int, ...]) -> View:
  220. assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
  221. return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
  222. tuple(self.mask[a] for a in axis) if self.mask is not None else None)
  223. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  224. def stride(self, mul: Tuple[int, ...]) -> View:
  225. # except for the negative case, you can build this from the others. invertible in the negative case
  226. assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
  227. strides = tuple([z*m for z,m in zip(self.strides, mul)])
  228. new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
  229. offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
  230. mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
  231. for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
  232. return View.create(new_shape, strides, self.offset + offset, mask)
  233. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  234. def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
  235. if self.shape == new_shape: return self
  236. assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
  237. if 0 in self.shape:
  238. assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
  239. return View.create(new_shape)
  240. # check for the same size
  241. if (self_all_int := all_int(self.shape)):
  242. assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
  243. if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
  244. raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
  245. if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
  246. # after the asserts, it's okay to check contiguous
  247. if self.contiguous: return View.create(new_shape)
  248. # if it's not contiguous and new shape is symbolic, check if it's directly replaceable
  249. if self_all_int and not all_int(new_shape):
  250. if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
  251. for si, so in zip(self.shape, new_shape):
  252. if isinstance(so, int):
  253. if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
  254. else:
  255. var_vals = {v: v.unbind()[1] for v in so.vars()}
  256. if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
  257. # all dimensions matched, return the new view directly
  258. return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
  259. strides, r_new_shape = [], reversed(new_shape)
  260. for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
  261. acc = 1
  262. # TODO: this <= and != is for symbolic!?
  263. while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
  264. strides.append(new_stride)
  265. if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
  266. if acc != merged_dim: break
  267. else:
  268. strides += [0,] * (len(new_shape) - len(strides))
  269. new_mask = _reshape_mask(self.mask, self.shape, new_shape)
  270. if new_mask is not None:
  271. new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
  272. extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
  273. (sum(m[0] * s for m,s in zip(new_mask, new_strides)))
  274. return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
  275. return None
  276. def expr(self, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
  277. assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
  278. iexpr: List[Node] = [NumNode(self.offset) if isinstance(self.offset, int) else self.offset]
  279. vexpr: List[Node] = [valid] if valid is not None else []
  280. for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
  281. if sh != 1 and st != 0: iexpr.append(idx*st)
  282. if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
  283. return Node.sum(iexpr), Node.ands(vexpr)