test_uop_symbolic.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. #!/usr/bin/env python
  2. import unittest, pickle
  3. #from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
  4. # TODO: fix all the @unittest.expectedFailure
  5. # *** fake symobilc uops ***
  6. from tinygrad.helpers import DEBUG
  7. from tinygrad.dtype import dtypes, PtrDType
  8. from tinygrad.codegen.uops import UOp, UOps
  9. from tinygrad.codegen.uopgraph import UOpGraph
  10. from tinygrad.ops import BinaryOps
  11. import functools
  12. def render(self) -> str:
  13. # NOTE: we need STORE so the ALU op has children
  14. glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(0,True))
  15. graph = UOpGraph([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
  16. graph.linearize()
  17. if DEBUG>=5: graph.print()
  18. from tinygrad.renderer.cstyle import CStyleLanguage
  19. class TestRenderer(CStyleLanguage):
  20. code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
  21. fxn = TestRenderer().render("", graph)
  22. return fxn.split("data0[0] = ")[1].split(";")[0]
  23. def NumNode(val): return UOp.const(dtypes.int, val)
  24. def Variable(expr, nmin, nmax):
  25. # TODO: fix DEFINE_VAR to not need this
  26. class TempVar:
  27. def __init__(self, x): self.expr = x
  28. #return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
  29. return UOp(UOps.DEFINE_VAR, dtypes.int, tuple(), TempVar(expr))
  30. class Node:
  31. @staticmethod
  32. def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
  33. @staticmethod
  34. def ands(ops): return functools.reduce(lambda x,y: x*y, ops)
  35. def __floordiv__(a,b,unk): return a//b
  36. def create_lt_node(v, n): return v.lt(n)
  37. def create_ge_node(v, n): return v.ge(n)
  38. def SumNode(x): return Node.sum(x)
  39. def MulNode(x, y): return x*y
  40. # *** leave tests the same
  41. @unittest.skip("not supported on uops yet")
  42. class TestSymbolicPickle(unittest.TestCase):
  43. def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
  44. def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
  45. def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
  46. class TestSymbolic(unittest.TestCase):
  47. def helper_test_variable(self, v, n, m, s):
  48. if isinstance(s, set):
  49. self.assertIn(render(v), s)
  50. else:
  51. self.assertEqual(render(v), s)
  52. #self.assertEqual(v.min, n)
  53. #self.assertEqual(v.max, m)
  54. def test_cmp_simple(self):
  55. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
  56. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "(7<a)", "(!(a<8))"})
  57. @unittest.expectedFailure
  58. def test_ge(self):
  59. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
  60. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
  61. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a*-1)<-7)")
  62. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a*-1)<-3)")
  63. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
  64. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
  65. @unittest.expectedFailure
  66. def test_lt(self):
  67. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
  68. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
  69. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
  70. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
  71. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "0")
  72. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "0")
  73. @unittest.expectedFailure
  74. def test_ge_divides(self):
  75. expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
  76. self.helper_test_variable(expr, 0, 1, "(idx<128)")
  77. @unittest.expectedFailure
  78. def test_ge_divides_and(self):
  79. expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
  80. create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
  81. self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
  82. expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
  83. create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
  84. self.helper_test_variable(expr//4, 0, 0, "0")
  85. def test_lt_factors(self):
  86. expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
  87. self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
  88. #def test_div_becomes_num(self):
  89. # assert isinstance(Variable("a", 2, 3)//2, NumNode)
  90. #def test_var_becomes_num(self):
  91. # assert isinstance(Variable("a", 2, 2), NumNode)
  92. @unittest.expectedFailure
  93. def test_equality(self):
  94. idx1 = Variable("idx1", 0, 3)
  95. idx2 = Variable("idx2", 0, 3)
  96. assert idx1 == idx1
  97. assert idx1 != idx2
  98. assert idx1*4 == idx1*4
  99. assert idx1*4 != idx1*3
  100. assert idx1*4 != idx1+4
  101. assert idx1*4 != idx2*4
  102. assert idx1+idx2 == idx1+idx2
  103. assert idx1+idx2 == idx2+idx1
  104. assert idx1+idx2 != idx2
  105. assert idx1*idx2 == idx2*idx1
  106. #def test_numnode_eq_int(self):
  107. # n1 = NumNode(1)
  108. # n2 = NumNode(2)
  109. # assert n1 == 1
  110. # assert n2 == 2
  111. # assert n1 != n2
  112. # assert hash(n1) == hash(1)
  113. # assert hash(n2) == hash(2)
  114. def test_factorize(self):
  115. a = Variable("a", 0, 8)
  116. self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
  117. def test_factorize_no_mul(self):
  118. a = Variable("a", 0, 8)
  119. self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
  120. def test_neg(self):
  121. self.helper_test_variable(-Variable("a", 0, 8), -8, 0, {"(a*-1)", "(-a)"})
  122. def test_add_1(self):
  123. self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, {"(1+a)", "(a+1)"})
  124. def test_add_num_1(self):
  125. self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, {"(1+a)", "(a+1)"})
  126. def test_sub_1(self):
  127. self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, {"(-1+a)", "(a+(-1))"})
  128. def test_sub_num_1(self):
  129. self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, {"(-1+a)", "(a+(-1))"})
  130. def test_mul_0(self):
  131. self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
  132. def test_mul_1(self):
  133. self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a")
  134. @unittest.expectedFailure
  135. def test_mul_neg_1(self):
  136. self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)")
  137. def test_mul_2(self):
  138. self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)")
  139. def test_div_1(self):
  140. self.helper_test_variable(Variable("a", 0, 8)//1, 0, 8, "a")
  141. def test_mod_1(self):
  142. self.helper_test_variable(Variable("a", 0, 8)%1, 0, 0, "0")
  143. def test_add_min_max(self):
  144. self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
  145. def test_div_min_max(self):
  146. self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
  147. @unittest.expectedFailure
  148. def test_div_neg_min_max(self):
  149. self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)")
  150. self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)")
  151. def test_sum_div_min_max(self):
  152. self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
  153. @unittest.expectedFailure
  154. def test_sum_div_factor(self):
  155. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
  156. @unittest.expectedFailure
  157. def test_sum_div_some_factor(self):
  158. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
  159. @unittest.expectedFailure
  160. def test_sum_div_some_partial_factor(self):
  161. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
  162. self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
  163. def test_sum_div_no_factor(self):
  164. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
  165. @unittest.expectedFailure
  166. def test_mod_factor(self):
  167. # NOTE: even though the mod max is 50, it can't know this without knowing about the mul
  168. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
  169. @unittest.expectedFailure
  170. def test_mod_to_sub(self):
  171. # This is mod reduction
  172. self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render())
  173. @unittest.expectedFailure
  174. def test_sum_div_const(self):
  175. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
  176. @unittest.expectedFailure
  177. def test_sum_div_const_big(self):
  178. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
  179. @unittest.expectedFailure
  180. def test_sum_lt_fold(self):
  181. self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]), 16), 0, 1, "(a<4)")
  182. self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1, "(((a*4)+b)<16)")
  183. self.helper_test_variable(create_lt_node(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]), (4 * 67)), 0, 1, "(a<23)")
  184. @unittest.expectedFailure
  185. def test_mod_mul(self):
  186. self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
  187. @unittest.expectedFailure
  188. def test_mod_mod(self):
  189. self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)")
  190. self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0")
  191. self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
  192. def test_mul_mul(self):
  193. self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
  194. @unittest.expectedFailure
  195. def test_mul_lt(self):
  196. self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
  197. self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
  198. self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a*-1)<-2)")
  199. self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a*-1)<-3)")
  200. def test_div_div(self):
  201. self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
  202. def test_distribute_mul(self):
  203. self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, {"((a*3)+(b*3))", "((a+b)*3)"})
  204. @unittest.expectedFailure
  205. def test_mod_mul_sum(self):
  206. self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
  207. def test_sum_0(self):
  208. self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
  209. @unittest.expectedFailure
  210. def test_mod_remove(self):
  211. self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
  212. def test_big_mod(self):
  213. # NOTE: we no longer support negative variables
  214. #self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
  215. #self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
  216. #self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
  217. self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
  218. #self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
  219. @unittest.expectedFailure
  220. def test_ge_remove(self):
  221. self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
  222. @unittest.expectedFailure
  223. def test_lt_remove(self):
  224. self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
  225. self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
  226. self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "1")
  227. def test_lt_sum_remove(self):
  228. self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
  229. def test_and_fold(self):
  230. self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
  231. def test_and_remove(self):
  232. self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
  233. @unittest.expectedFailure
  234. def test_mod_factor_negative(self):
  235. self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
  236. self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
  237. def test_sum_combine_num(self):
  238. self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, {"(6+a)", "(a+6)"})
  239. @unittest.expectedFailure
  240. def test_sum_num_hoisted_and_factors_cancel_out(self):
  241. self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
  242. @unittest.expectedFailure
  243. def test_div_factor(self):
  244. self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
  245. # TODO: this one should already work!
  246. def test_mul_div(self):
  247. self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
  248. @unittest.expectedFailure
  249. def test_mul_div_factor_mul(self):
  250. self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
  251. @unittest.expectedFailure
  252. def test_mul_div_factor_div(self):
  253. self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
  254. @unittest.expectedFailure
  255. def test_div_remove(self):
  256. self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
  257. @unittest.expectedFailure
  258. def test_div_numerator_negative(self):
  259. self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
  260. @unittest.expectedFailure
  261. def test_div_into_mod(self):
  262. self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
  263. @unittest.skip("not supported on uops yet")
  264. class TestSymbolicNumeric(unittest.TestCase):
  265. def helper_test_numeric(self, f):
  266. # TODO: why are the negative tests broken? (even if we did support negative variables)
  267. #MIN, MAX = -10, 10
  268. MIN, MAX = 0, 10
  269. # one number
  270. for i in range(MIN, MAX):
  271. v = f(NumNode(i))
  272. #print(i, f(i), v.min, v.max)
  273. self.assertEqual(v.min, v.max)
  274. self.assertEqual(v.min, f(i))
  275. for kmin in range(MIN, MAX):
  276. for kmax in range(MIN, MAX):
  277. if kmin > kmax: continue
  278. v = f(Variable("tmp", kmin, kmax))
  279. values = [f(rv) for rv in range(kmin, kmax+1)]
  280. # the min and max may not be exact
  281. self.assertLessEqual(v.min, min(values))
  282. self.assertGreaterEqual(v.max, max(values))
  283. def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
  284. def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
  285. def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
  286. def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
  287. def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
  288. def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
  289. def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
  290. def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
  291. def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
  292. class TestSymbolicVars(unittest.TestCase):
  293. def test_simple(self):
  294. z = NumNode(0)
  295. a = Variable("a", 0, 10)
  296. b = Variable("b", 0, 10)
  297. c = Variable("c", 0, 10)
  298. assert z.vars() == z.vars() == set()
  299. print(a.vars())
  300. assert a.vars() == a.vars() == {a}
  301. m = MulNode(a, 3)
  302. assert m.vars() == {a}
  303. s = SumNode([a, b, c])
  304. assert s.vars() == {a, b, c}
  305. @unittest.skip("TODO: fix me")
  306. def test_compound(self):
  307. a = Variable("a", 0, 10)
  308. b = Variable("b", 0, 10)
  309. c = Variable("c", 0, 10)
  310. assert (a + b * c).vars() == {a, b, c}
  311. assert (a % 3 + b // 5).vars() == {a, b}
  312. assert (a + b + c - a).vars() == {b, c}
  313. def test_dedup(self):
  314. a = Variable("a", 0, 10)
  315. assert (a * a).vars() == {a}
  316. assert (a//4 + a//6).vars() == {a}
  317. @unittest.skip("not supported on uops yet")
  318. class TestSymbolicMinMax(unittest.TestCase):
  319. def test_min_max_known(self):
  320. a = Variable("a", 1, 8)
  321. assert max(1, a) == max(a, 1) == a
  322. assert min(1, a) == min(a, 1) == 1
  323. """
  324. @unittest.skip("not supported on uops yet")
  325. class TestSymRender(unittest.TestCase):
  326. def test_sym_render(self):
  327. a = Variable("a", 1, 8)
  328. b = Variable("b", 1, 10)
  329. assert sym_render(a) == "a"
  330. assert sym_render(1) == "1"
  331. assert sym_render(a+1) == "(1+a)"
  332. assert sym_render(a*b) == "(a*b)"
  333. @unittest.skip("not supported on uops yet")
  334. class TestSymInfer(unittest.TestCase):
  335. def test_sym_infer(self):
  336. a = Variable("a", 0, 10)
  337. b = Variable("b", 0, 10)
  338. c = Variable("c", 0, 10)
  339. var_vals = {a: 2, b: 3, c: 4}
  340. assert sym_infer(5, var_vals) == 5
  341. assert sym_infer(a, var_vals) == 2
  342. assert sym_infer(b, var_vals) == 3
  343. assert sym_infer(a+b, var_vals) == 5
  344. assert sym_infer(a-b, var_vals) == -1
  345. assert sym_infer(a+b+c, var_vals) == 9
  346. assert sym_infer(a*b, var_vals) == 6
  347. assert sym_infer(a*b+c, var_vals) == 10
  348. @unittest.skip("not supported on uops yet")
  349. class TestSymbolicSymbolicOps(unittest.TestCase):
  350. def test_node_divmod_node(self):
  351. i = Variable("i", 1, 10)
  352. idx0 = Variable("idx0", 0, i*3-1)
  353. assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
  354. assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
  355. assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
  356. assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
  357. assert 127 // (Variable("i", 1, 10)*128) == 0
  358. assert 127 % (Variable("i", 1, 10)*128) == 127
  359. assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
  360. assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
  361. assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
  362. assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
  363. assert 0 // (Variable("i", 1, 10)*128) == 0
  364. assert 0 % (Variable("i", 1, 10)*128) == 0
  365. assert idx0 // (i*3) == 0
  366. assert idx0 % (i*3) == idx0
  367. assert i // i == 1
  368. assert i % i == 0
  369. assert 128 // NumNode(4) == 32
  370. assert 128 % NumNode(4) == 0
  371. assert NumNode(128) // NumNode(4) == 32
  372. assert NumNode(128) % NumNode(4) == 0
  373. def test_mulnode_divmod_node(self):
  374. i = Variable("i", 1, 10)
  375. idx0 = Variable("idx0", 0, 31)
  376. # assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
  377. # assert (idx0*(i*4+4)) % (i+1) == 0
  378. assert (idx0*i) % i == 0
  379. def test_sumnode_divmod_sumnode(self):
  380. i = Variable("i", 1, 10)
  381. # idx0 = Variable("idx0", 0, 7)
  382. # idx1 = Variable("idx1", 0, 3)
  383. # idx2 = Variable("idx2", 0, i)
  384. # assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
  385. # assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
  386. assert (i+1) // (i*128+128) == 0
  387. assert (i+1) % (i*128+128) == (i+1)
  388. # assert (i+1+idx2) // (i+1) == 1
  389. # assert (i+1+idx2) % (i+1) == idx2
  390. # assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
  391. # assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
  392. # assert (i*128+128)*2 // (i*128+128) == 2
  393. # assert (i*128+128)*2 % (i*128+128) == 0
  394. def test_sumnode_div_numnode_no_factoring(self):
  395. gid = Variable("gid", 0, 1023)
  396. lid = Variable("lid", 0, 3)
  397. expr_before_div = NumNode(-1019)-4*lid-gid
  398. unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
  399. factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
  400. self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
  401. self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
  402. def test_mod_node_max(self):
  403. i = Variable("i", 1, 128)
  404. gidx0 = Variable("gidx0", 0, i)
  405. mod = gidx0 % 8
  406. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
  407. mod = gidx0 % 2
  408. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
  409. gidx0 = Variable("gidx0", 0, i*8+7)
  410. mod = gidx0 % 8
  411. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
  412. mod = gidx0 % 2
  413. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
  414. def test_node_lt_node(self):
  415. a = Variable("a", 1, 5)
  416. b = Variable("b", 6, 9)
  417. c = Variable("c", 1, 10)
  418. d = Variable("d", 5, 10)
  419. # if the comparison output is always the same, it folds to num
  420. assert create_lt_node(a, b) == NumNode(1)
  421. assert create_lt_node(b, a) == NumNode(0)
  422. assert create_lt_node(d, a) == NumNode(0)
  423. assert create_lt_node(a, a) == NumNode(0)
  424. assert create_lt_node(a, a) == NumNode(0)
  425. # if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
  426. a_lt_c = create_lt_node(a, c)
  427. assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
  428. assert a_lt_c
  429. # same when comparing with a constant
  430. a_lt_3 = create_lt_node(a, 3)
  431. assert a_lt_3 and a_lt_3.min == 0 and a_lt_3.max == 1
  432. def test_sumnode_mulnode_lt(self):
  433. a = Variable("a", 1, 2)
  434. b = Variable("b", 1, 2)
  435. c = Variable("c", 1, 2)
  436. x = SumNode([MulNode(a, b), c])
  437. with self.assertRaises(AssertionError):
  438. create_lt_node(x, 3)
  439. def test_nested_variable_mod(self):
  440. i = Variable("i", 1, 5)
  441. idx0 = Variable("idx0", 0, i)
  442. with self.assertRaises(AssertionError):
  443. assert idx0 % 2 == idx0
  444. def test_num_node_mul_node(self):
  445. a = Variable("a", 1, 5)
  446. b = NumNode(2) * a
  447. assert b == a * 2
  448. assert isinstance(b, MulNode)
  449. b = NumNode(1) * a
  450. assert b == a
  451. assert isinstance(b, Variable)
  452. b = NumNode(0) * a
  453. assert b == 0
  454. assert isinstance(b, NumNode)
  455. def test_substitute(self):
  456. a = Variable("idx0", 1, 3)
  457. b = a + 1
  458. c = b.substitute({a: NumNode(1)})
  459. assert c == NumNode(2)
  460. """
  461. class TestSymbolicRealWorld(unittest.TestCase):
  462. @unittest.expectedFailure
  463. def test_resnet_half(self):
  464. gidx0 = Variable("gidx0", 0, 3)
  465. gidx1 = Variable("gidx1", 0, 127)
  466. gidx2 = Variable("gidx2", 0, 7)
  467. lidx3 = Variable("lidx3", 0, 7)
  468. lidx4 = Variable("lidx4", 0, 1)
  469. lidx5 = Variable("lidx5", 0, 15)
  470. idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
  471. print(idx.render())
  472. # NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
  473. assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
  474. if __name__ == '__main__':
  475. unittest.main()