test_symbolic.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. #!/usr/bin/env python
  2. import unittest, pickle
  3. from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
  4. class TestSymbolicPickle(unittest.TestCase):
  5. def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
  6. def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
  7. def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
  8. class TestSymbolic(unittest.TestCase):
  9. def helper_test_variable(self, v, n, m, s):
  10. self.assertEqual(v.render(), s)
  11. self.assertEqual(v.min, n)
  12. self.assertEqual(v.max, m)
  13. def test_ge(self):
  14. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
  15. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
  16. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a*-1)<-7)")
  17. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a*-1)<-3)")
  18. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
  19. self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
  20. def test_lt(self):
  21. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
  22. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
  23. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
  24. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
  25. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "0")
  26. self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "0")
  27. def test_ge_divides(self):
  28. expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
  29. self.helper_test_variable(expr, 0, 1, "(idx<128)")
  30. def test_ge_divides_and(self):
  31. expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
  32. create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
  33. self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
  34. expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
  35. create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
  36. self.helper_test_variable(expr//4, 0, 0, "0")
  37. def test_lt_factors(self):
  38. expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
  39. self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
  40. def test_div_becomes_num(self):
  41. assert isinstance(Variable("a", 2, 3)//2, NumNode)
  42. def test_var_becomes_num(self):
  43. assert isinstance(Variable("a", 2, 2), NumNode)
  44. def test_equality(self):
  45. idx1 = Variable("idx1", 0, 3)
  46. idx2 = Variable("idx2", 0, 3)
  47. assert idx1 == idx1
  48. assert idx1 != idx2
  49. assert idx1*4 == idx1*4
  50. assert idx1*4 != idx1*3
  51. assert idx1*4 != idx1+4
  52. assert idx1*4 != idx2*4
  53. assert idx1+idx2 == idx1+idx2
  54. assert idx1+idx2 == idx2+idx1
  55. assert idx1+idx2 != idx2
  56. assert idx1*idx2 == idx2*idx1
  57. def test_numnode_eq_int(self):
  58. n1 = NumNode(1)
  59. n2 = NumNode(2)
  60. assert n1 == 1
  61. assert n2 == 2
  62. assert n1 != n2
  63. assert hash(n1) == hash(1)
  64. assert hash(n2) == hash(2)
  65. def test_factorize(self):
  66. a = Variable("a", 0, 8)
  67. self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
  68. def test_factorize_no_mul(self):
  69. a = Variable("a", 0, 8)
  70. self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
  71. def test_neg(self):
  72. self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
  73. def test_add_1(self):
  74. self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)")
  75. def test_add_num_1(self):
  76. self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(1+a)")
  77. def test_sub_1(self):
  78. self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)")
  79. def test_sub_num_1(self):
  80. self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(-1+a)")
  81. def test_mul_0(self):
  82. self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
  83. def test_mul_1(self):
  84. self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a")
  85. def test_mul_neg_1(self):
  86. self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)")
  87. def test_mul_2(self):
  88. self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)")
  89. def test_div_1(self):
  90. self.helper_test_variable(Variable("a", 0, 8)//1, 0, 8, "a")
  91. def test_mod_1(self):
  92. self.helper_test_variable(Variable("a", 0, 8)%1, 0, 0, "0")
  93. def test_add_min_max(self):
  94. self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
  95. def test_div_min_max(self):
  96. self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
  97. def test_div_neg_min_max(self):
  98. self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)")
  99. self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)")
  100. def test_sum_div_min_max(self):
  101. self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
  102. def test_sum_div_factor(self):
  103. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
  104. def test_sum_div_some_factor(self):
  105. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
  106. def test_sum_div_some_partial_factor(self):
  107. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
  108. self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
  109. def test_sum_div_no_factor(self):
  110. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
  111. def test_mod_factor(self):
  112. # NOTE: even though the mod max is 50, it can't know this without knowing about the mul
  113. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
  114. def test_mod_to_sub(self):
  115. # This is mod reduction
  116. self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render())
  117. def test_sum_div_const(self):
  118. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
  119. def test_sum_div_const_big(self):
  120. self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
  121. def test_sum_lt_fold(self):
  122. self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]), 16), 0, 1, "(a<4)")
  123. self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1, "(((a*4)+b)<16)")
  124. self.helper_test_variable(create_lt_node(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]), (4 * 67)), 0, 1, "(a<23)")
  125. def test_mod_mul(self):
  126. self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
  127. def test_mod_mod(self):
  128. self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)")
  129. self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0")
  130. self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
  131. def test_mul_mul(self):
  132. self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
  133. def test_mul_lt(self):
  134. self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
  135. self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
  136. self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a*-1)<-2)")
  137. self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a*-1)<-3)")
  138. def test_div_div(self):
  139. self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
  140. def test_distribute_mul(self):
  141. self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
  142. def test_mod_mul_sum(self):
  143. self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
  144. def test_sum_0(self):
  145. self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
  146. def test_mod_remove(self):
  147. self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
  148. def test_big_mod(self):
  149. # NOTE: we no longer support negative variables
  150. #self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
  151. #self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
  152. #self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
  153. self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
  154. #self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
  155. def test_ge_remove(self):
  156. self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
  157. def test_lt_remove(self):
  158. self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
  159. self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
  160. self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "1")
  161. def test_lt_sum_remove(self):
  162. self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
  163. def test_and_fold(self):
  164. self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
  165. def test_and_remove(self):
  166. self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
  167. def test_mod_factor_negative(self):
  168. self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
  169. self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
  170. def test_sum_combine_num(self):
  171. self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)")
  172. def test_sum_num_hoisted_and_factors_cancel_out(self):
  173. self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
  174. def test_div_factor(self):
  175. self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
  176. def test_mul_div(self):
  177. self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
  178. def test_mul_div_factor_mul(self):
  179. self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
  180. def test_mul_div_factor_div(self):
  181. self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
  182. def test_div_remove(self):
  183. self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
  184. def test_div_numerator_negative(self):
  185. self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
  186. def test_div_into_mod(self):
  187. self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
  188. class TestSymbolicNumeric(unittest.TestCase):
  189. def helper_test_numeric(self, f):
  190. # TODO: why are the negative tests broken? (even if we did support negative variables)
  191. #MIN, MAX = -10, 10
  192. MIN, MAX = 0, 10
  193. # one number
  194. for i in range(MIN, MAX):
  195. v = f(NumNode(i))
  196. #print(i, f(i), v.min, v.max)
  197. self.assertEqual(v.min, v.max)
  198. self.assertEqual(v.min, f(i))
  199. for kmin in range(MIN, MAX):
  200. for kmax in range(MIN, MAX):
  201. if kmin > kmax: continue
  202. v = f(Variable("tmp", kmin, kmax))
  203. values = [f(rv) for rv in range(kmin, kmax+1)]
  204. # the min and max may not be exact
  205. self.assertLessEqual(v.min, min(values))
  206. self.assertGreaterEqual(v.max, max(values))
  207. def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
  208. def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
  209. def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
  210. def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
  211. def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
  212. def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
  213. def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
  214. def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
  215. def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
  216. class TestSymbolicVars(unittest.TestCase):
  217. def test_simple(self):
  218. z = NumNode(0)
  219. a = Variable("a", 0, 10)
  220. b = Variable("b", 0, 10)
  221. c = Variable("c", 0, 10)
  222. assert z.vars() == z.vars() == set()
  223. assert a.vars() == a.vars() == {a}
  224. m = MulNode(a, 3)
  225. assert m.vars() == {a}
  226. s = SumNode([a, b, c])
  227. assert s.vars() == {a, b, c}
  228. def test_compound(self):
  229. a = Variable("a", 0, 10)
  230. b = Variable("b", 0, 10)
  231. c = Variable("c", 0, 10)
  232. assert (a + b * c).vars() == {a, b, c}
  233. assert (a % 3 + b // 5).vars() == {a, b}
  234. assert (a + b + c - a).vars() == {b, c}
  235. def test_dedup(self):
  236. a = Variable("a", 0, 10)
  237. assert (a * a).vars() == {a}
  238. assert (a//4 + a//6).vars() == {a}
  239. class TestSymbolicMinMax(unittest.TestCase):
  240. def test_min_max_known(self):
  241. a = Variable("a", 1, 8)
  242. assert max(1, a) == max(a, 1) == a
  243. assert min(1, a) == min(a, 1) == 1
  244. class TestSymRender(unittest.TestCase):
  245. def test_sym_render(self):
  246. a = Variable("a", 1, 8)
  247. b = Variable("b", 1, 10)
  248. assert sym_render(a) == "a"
  249. assert sym_render(1) == "1"
  250. assert sym_render(a+1) == "(1+a)"
  251. assert sym_render(a*b) == "(a*b)"
  252. class TestSymInfer(unittest.TestCase):
  253. def test_sym_infer(self):
  254. a = Variable("a", 0, 10)
  255. b = Variable("b", 0, 10)
  256. c = Variable("c", 0, 10)
  257. var_vals = {a: 2, b: 3, c: 4}
  258. assert sym_infer(5, var_vals) == 5
  259. assert sym_infer(a, var_vals) == 2
  260. assert sym_infer(b, var_vals) == 3
  261. assert sym_infer(a+b, var_vals) == 5
  262. assert sym_infer(a-b, var_vals) == -1
  263. assert sym_infer(a+b+c, var_vals) == 9
  264. assert sym_infer(a*b, var_vals) == 6
  265. assert sym_infer(a*b+c, var_vals) == 10
  266. class TestSymbolicSymbolicOps(unittest.TestCase):
  267. def test_node_divmod_node(self):
  268. i = Variable("i", 1, 10)
  269. idx0 = Variable("idx0", 0, i*3-1)
  270. assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
  271. assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
  272. assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
  273. assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
  274. assert 127 // (Variable("i", 1, 10)*128) == 0
  275. assert 127 % (Variable("i", 1, 10)*128) == 127
  276. assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
  277. assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
  278. assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
  279. assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
  280. assert 0 // (Variable("i", 1, 10)*128) == 0
  281. assert 0 % (Variable("i", 1, 10)*128) == 0
  282. assert idx0 // (i*3) == 0
  283. assert idx0 % (i*3) == idx0
  284. assert i // i == 1
  285. assert i % i == 0
  286. assert 128 // NumNode(4) == 32
  287. assert 128 % NumNode(4) == 0
  288. assert NumNode(128) // NumNode(4) == 32
  289. assert NumNode(128) % NumNode(4) == 0
  290. def test_mulnode_divmod_node(self):
  291. i = Variable("i", 1, 10)
  292. idx0 = Variable("idx0", 0, 31)
  293. # assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
  294. # assert (idx0*(i*4+4)) % (i+1) == 0
  295. assert (idx0*i) % i == 0
  296. def test_sumnode_divmod_sumnode(self):
  297. i = Variable("i", 1, 10)
  298. # idx0 = Variable("idx0", 0, 7)
  299. # idx1 = Variable("idx1", 0, 3)
  300. # idx2 = Variable("idx2", 0, i)
  301. # assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
  302. # assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
  303. assert (i+1) // (i*128+128) == 0
  304. assert (i+1) % (i*128+128) == (i+1)
  305. # assert (i+1+idx2) // (i+1) == 1
  306. # assert (i+1+idx2) % (i+1) == idx2
  307. # assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
  308. # assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
  309. # assert (i*128+128)*2 // (i*128+128) == 2
  310. # assert (i*128+128)*2 % (i*128+128) == 0
  311. def test_sumnode_div_numnode_no_factoring(self):
  312. gid = Variable("gid", 0, 1023)
  313. lid = Variable("lid", 0, 3)
  314. expr_before_div = NumNode(-1019)-4*lid-gid
  315. unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
  316. factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
  317. self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
  318. self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
  319. def test_mod_node_max(self):
  320. i = Variable("i", 1, 128)
  321. gidx0 = Variable("gidx0", 0, i)
  322. mod = gidx0 % 8
  323. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
  324. mod = gidx0 % 2
  325. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
  326. gidx0 = Variable("gidx0", 0, i*8+7)
  327. mod = gidx0 % 8
  328. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
  329. mod = gidx0 % 2
  330. assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
  331. def test_node_lt_node(self):
  332. a = Variable("a", 1, 5)
  333. b = Variable("b", 6, 9)
  334. c = Variable("c", 1, 10)
  335. d = Variable("d", 5, 10)
  336. # if the comparison output is always the same, it folds to num
  337. assert create_lt_node(a, b) == NumNode(1)
  338. assert create_lt_node(b, a) == NumNode(0)
  339. assert create_lt_node(d, a) == NumNode(0)
  340. assert create_lt_node(a, a) == NumNode(0)
  341. assert create_lt_node(a, a) == NumNode(0)
  342. # if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
  343. a_lt_c = create_lt_node(a, c)
  344. assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
  345. assert a_lt_c
  346. # same when comparing with a constant
  347. a_lt_3 = create_lt_node(a, 3)
  348. assert a_lt_3 and a_lt_3.min == 0 and a_lt_3.max == 1
  349. def test_sumnode_mulnode_lt(self):
  350. a = Variable("a", 1, 2)
  351. b = Variable("b", 1, 2)
  352. c = Variable("c", 1, 2)
  353. x = SumNode([MulNode(a, b), c])
  354. with self.assertRaises(AssertionError):
  355. create_lt_node(x, 3)
  356. def test_nested_variable_mod(self):
  357. i = Variable("i", 1, 5)
  358. idx0 = Variable("idx0", 0, i)
  359. with self.assertRaises(AssertionError):
  360. assert idx0 % 2 == idx0
  361. def test_num_node_mul_node(self):
  362. a = Variable("a", 1, 5)
  363. b = NumNode(2) * a
  364. assert b == a * 2
  365. assert isinstance(b, MulNode)
  366. b = NumNode(1) * a
  367. assert b == a
  368. assert isinstance(b, Variable)
  369. b = NumNode(0) * a
  370. assert b == 0
  371. assert isinstance(b, NumNode)
  372. def test_substitute(self):
  373. a = Variable("idx0", 1, 3)
  374. b = a + 1
  375. c = b.substitute({a: NumNode(1)})
  376. assert c == NumNode(2)
  377. class TestSymbolicRealWorld(unittest.TestCase):
  378. def test_resnet_half(self):
  379. gidx0 = Variable("gidx0", 0, 3)
  380. gidx1 = Variable("gidx1", 0, 127)
  381. gidx2 = Variable("gidx2", 0, 7)
  382. lidx3 = Variable("lidx3", 0, 7)
  383. lidx4 = Variable("lidx4", 0, 1)
  384. lidx5 = Variable("lidx5", 0, 15)
  385. idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
  386. print(idx.render())
  387. # NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
  388. assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
  389. if __name__ == '__main__':
  390. unittest.main()