symbolic.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. from __future__ import annotations
  2. import functools
  3. from math import gcd
  4. from tinygrad.helpers import partition
  5. from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
  6. # NOTE: Python has different behavior for negative mod and floor div than c
  7. # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
  8. class Node:
  9. b: Union[Node, int]
  10. min: int
  11. max: sint
  12. def render(self, ops=None, ctx=None) -> Any:
  13. if ops is None: ops = render_python
  14. assert self.__class__ in (Variable, NumNode) or self.min != self.max
  15. return ops[type(self)](self, ops, ctx)
  16. def vars(self) -> Set[Variable]: return set()
  17. # substitute Variables with the values in var_vals
  18. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
  19. def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
  20. @functools.cached_property
  21. def key(self) -> str: return self.render(ctx="DEBUG")
  22. def __repr__(self): return self.render(ctx="REPR")
  23. def __str__(self): return "<"+self.key+">"
  24. def __hash__(self): return hash(self.key)
  25. def __bool__(self): return not (self.max == self.min == 0)
  26. def __eq__(self, other:object) -> bool:
  27. if not isinstance(other, Node): return NotImplemented
  28. return self.key == other.key
  29. def __neg__(self): return self*-1
  30. def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
  31. def __radd__(self, b:int): return self+b
  32. def __sub__(self, b:Union[Node,int]): return self+-b
  33. def __rsub__(self, b:int): return -self+b
  34. def __le__(self, b:Union[Node,int]): return self < (b+1)
  35. def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
  36. def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
  37. def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
  38. def __mul__(self, b:Union[Node, int]):
  39. if b == 0: return NumNode(0)
  40. if b == 1: return self
  41. return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
  42. def __rmul__(self, b:int): return self*b
  43. def __lshift__(self, b:int): return self*2**b
  44. # *** complex ops ***
  45. def __rfloordiv__(self, b:int): return NumNode(b) // self
  46. def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
  47. if isinstance(b, Node):
  48. if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
  49. if self == b: return NumNode(1)
  50. if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
  51. raise RuntimeError(f"not supported: {self} // {b}")
  52. assert b != 0
  53. if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
  54. if b == 1: return self
  55. # the numerator of div is not allowed to be negative
  56. if self.min < 0:
  57. offset = self.min//b
  58. # factor out an "offset" to make the numerator positive. don't allowing factoring again
  59. return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
  60. return create_node(DivNode(self, b))
  61. def __rmod__(self, b:int): return NumNode(b) % self
  62. def __mod__(self, b:Union[Node,int]):
  63. if isinstance(b, Node):
  64. if b.__class__ is NumNode: return self % b.b
  65. if self == b: return NumNode(0)
  66. if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
  67. raise RuntimeError(f"not supported: {self} % {b}")
  68. assert b > 0
  69. if b == 1: return NumNode(0)
  70. if isinstance(self.max, int) and isinstance(self.min, int):
  71. if self.min >= 0 and self.max < b: return self
  72. if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
  73. if self.min < 0: return (self - ((self.min//b)*b)) % b
  74. return create_node(ModNode(self, b))
  75. @staticmethod
  76. def sum(nodes:List[Node]) -> Node:
  77. nodes = [x for x in nodes if x.max or x.min]
  78. if not nodes: return NumNode(0)
  79. if len(nodes) == 1: return nodes[0]
  80. mul_groups: Dict[Node, int] = {}
  81. num_node_sum = 0
  82. for node in SumNode(nodes).flat_components:
  83. if node.__class__ is NumNode: num_node_sum += node.b
  84. elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
  85. else: mul_groups[node] = mul_groups.get(node, 0) + 1
  86. new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
  87. if num_node_sum: new_nodes.append(NumNode(num_node_sum))
  88. return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
  89. @staticmethod
  90. def ands(nodes:List[Node]) -> Node:
  91. if not nodes: return NumNode(1)
  92. if len(nodes) == 1: return nodes[0]
  93. if any(not x for x in nodes): return NumNode(0)
  94. # filter 1s
  95. nodes = [x for x in nodes if x.min != x.max]
  96. return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
  97. # 4 basic node types
  98. class Variable(Node):
  99. def __new__(cls, *args):
  100. expr, nmin, nmax = args
  101. assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
  102. if nmin == nmax: return NumNode(nmin)
  103. return super().__new__(cls)
  104. def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
  105. def __init__(self, expr:str, nmin:int, nmax:sint):
  106. self.expr, self.min, self.max = expr, nmin, nmax
  107. self._val: Optional[int] = None
  108. @property
  109. def val(self):
  110. assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
  111. return self._val
  112. def bind(self, val):
  113. assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
  114. self._val = val
  115. return self
  116. def unbind(self) -> Tuple[Variable, int]:
  117. assert self.val is not None, f"cannot unbind {self}"
  118. return Variable(self.expr, self.min, self.max), self.val
  119. def vars(self): return {self}
  120. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
  121. class NumNode(Node):
  122. def __init__(self, num:int):
  123. assert isinstance(num, int), f"{num} is not an int"
  124. self.b:int = num
  125. self.min, self.max = num, num
  126. def bind(self, val):
  127. assert self.b == val, f"cannot bind {val} to {self}"
  128. return self
  129. def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
  130. def __eq__(self, other): return self.b == other
  131. def __hash__(self): return hash(self.b) # needed with __eq__ override
  132. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
  133. def create_node(ret:Node):
  134. assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
  135. if ret.min == ret.max: return NumNode(ret.min)
  136. return ret
  137. def create_lt_node(lhs:Node, b:Union[Node, int]):
  138. if isinstance(lhs, SumNode):
  139. if isinstance(b, int):
  140. new_sum = []
  141. for x in lhs.nodes:
  142. # TODO: should we just force the last one to always be the number
  143. if isinstance(x, NumNode): b -= x.b
  144. else: new_sum.append(x)
  145. lhs = Node.sum(new_sum)
  146. nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
  147. assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
  148. muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
  149. if muls:
  150. # NOTE: gcd in python 3.8 takes exactly 2 args
  151. mul_gcd = b
  152. for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
  153. all_others = Node.sum(others)
  154. if all_others.min >= 0 and all_others.max < mul_gcd:
  155. lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
  156. return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
  157. if isinstance(lhs, MulNode):
  158. if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
  159. sgn = 1 if lhs.b > 0 else -1
  160. return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
  161. return create_node(LtNode(lhs, b))
  162. def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
  163. class OpNode(Node):
  164. def __init__(self, a:Node, b:Union[Node, int]):
  165. self.a, self.b = a, b
  166. self.min, self.max = self.get_bounds()
  167. def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
  168. def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
  169. class LtNode(OpNode):
  170. def get_bounds(self) -> Tuple[int, int]:
  171. if self.a == self.b: return (0, 0)
  172. if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
  173. return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
  174. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
  175. return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
  176. class MulNode(OpNode):
  177. def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
  178. def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
  179. if self.b % b == 0: return self.a*(self.b//b)
  180. if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
  181. return Node.__floordiv__(self, b, factoring_allowed)
  182. def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
  183. def get_bounds(self) -> Tuple[int, sint]:
  184. assert self.a.min >= 0
  185. if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
  186. return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
  187. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
  188. return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
  189. class DivNode(OpNode):
  190. def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
  191. def get_bounds(self) -> Tuple[int, sint]:
  192. assert self.a.min >= 0 and isinstance(self.b, int)
  193. return self.a.min//self.b, self.a.max//self.b
  194. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
  195. class ModNode(OpNode):
  196. def __mod__(self, b: Union[Node, int]):
  197. if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
  198. return Node.__mod__(self, b)
  199. def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
  200. return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
  201. def get_bounds(self) -> Tuple[int, sint]:
  202. assert self.a.min >= 0 and isinstance(self.b, int)
  203. if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
  204. return (self.a.min%self.b, self.a.max%self.b)
  205. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
  206. class RedNode(Node):
  207. def __init__(self, nodes:List[Node]):
  208. self.nodes = nodes
  209. self.min, self.max = self.get_bounds()
  210. def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
  211. def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
  212. class SumNode(RedNode):
  213. def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
  214. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  215. def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
  216. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  217. def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
  218. if self == b: return NumNode(1)
  219. fully_divided: List[Node] = []
  220. rest: List[Node] = []
  221. if isinstance(b, Node):
  222. for x in self.flat_components:
  223. if x % b == 0: fully_divided.append(x // b)
  224. else: rest.append(x)
  225. if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
  226. return Node.__floordiv__(self, b, False)
  227. if b == 1: return self
  228. if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
  229. _gcd = b
  230. divisor = 1
  231. for x in self.flat_components:
  232. if x.__class__ in (NumNode, MulNode):
  233. if x.b % b == 0: fully_divided.append(x // b)
  234. else:
  235. if x.__class__ is NumNode and (div := x.b // b):
  236. fully_divided.append(NumNode(div))
  237. x = NumNode(x.b - b * div)
  238. rest.append(x)
  239. if isinstance(x.b, int):
  240. _gcd = gcd(_gcd, x.b)
  241. if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
  242. else:
  243. _gcd = 1
  244. else:
  245. rest.append(x)
  246. _gcd = 1
  247. if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
  248. if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
  249. return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
  250. @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
  251. def __mod__(self, b: Union[Node, int]):
  252. if self == b: return NumNode(0)
  253. if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
  254. new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
  255. return Node.__mod__(new_sum, b)
  256. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
  257. return Node.sum([node.substitute(var_vals) for node in self.nodes])
  258. # recursively expand sumnode components
  259. # TODO: can remove this if there's no SumNode inside SumNode
  260. @property
  261. def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
  262. class AndNode(RedNode):
  263. def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
  264. def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
  265. subed = []
  266. for node in self.nodes:
  267. if not (sub:=node.substitute(var_vals)): return NumNode(0)
  268. subed.append(sub)
  269. return Node.ands(subed)
  270. def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
  271. def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
  272. if isinstance(a, (int, float)): return a
  273. ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
  274. assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
  275. return ret.b
  276. # symbolic int, these are allowed in a Tensor shape
  277. sint = Union[int, Variable, MulNode, SumNode]
  278. def render_mulnode(node:MulNode, ops, ctx):
  279. # TODO: add ProdNode and remove this case
  280. if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
  281. return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
  282. return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
  283. render_python: Dict[Type, Callable[..., str]] = {
  284. Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
  285. else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
  286. else f"{self.expr}"),
  287. NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
  288. MulNode: render_mulnode,
  289. DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
  290. ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
  291. LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
  292. SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
  293. AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
  294. }