test_uop_graph.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import unittest
  2. from test.helpers import TestUOps
  3. from tinygrad import dtypes, Variable
  4. from tinygrad.dtype import PtrDType
  5. from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
  6. from tinygrad.codegen.uops import UOps, UOp
  7. from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, graph_rewrite
  8. from tinygrad.engine.graph import print_tree # noqa: F401 # pylint: disable=unused-import
  9. simple_pm = PatternMatcher([
  10. (UOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
  11. (UOp.cvar('x') + UOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
  12. (UOp.cvar('x') * UOp.cvar('y') * UOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
  13. ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x + UOp.const(x.dtype, c1.arg+c2.arg)),
  14. ])
  15. class TestGraphRewrite(unittest.TestCase):
  16. def test_dedup(self):
  17. v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
  18. v2 = UOp(UOps.DEFINE_VAR, dtypes.float)
  19. nout = graph_rewrite(v1+v2, PatternMatcher([]))
  20. self.assertIs(nout.src[0], nout.src[1])
  21. def test_simple(self):
  22. c1 = UOp.const(dtypes.float, 1.0)
  23. c2 = UOp.const(dtypes.float, 2.0)
  24. nout = graph_rewrite(c1+c2, simple_pm)
  25. self.assertEqual(nout.op, UOps.CONST)
  26. self.assertEqual(nout.arg, 3.0)
  27. def test_depth_2_late(self):
  28. c1 = UOp.const(dtypes.float, 1.0)
  29. c2 = UOp.const(dtypes.float, 2.0)
  30. c3 = UOp.const(dtypes.float, 3.0)
  31. nout = graph_rewrite(c1*c2*(c3+c3), simple_pm)
  32. self.assertEqual(nout.op, UOps.CONST)
  33. self.assertEqual(nout.arg, 12.0)
  34. def test_double(self):
  35. c1 = UOp.const(dtypes.float, 1.0)
  36. c2 = UOp.const(dtypes.float, 2.0)
  37. c3 = UOp.const(dtypes.float, 3.0)
  38. nout = graph_rewrite(c1+c2+c3, simple_pm)
  39. self.assertEqual(nout.op, UOps.CONST)
  40. self.assertEqual(nout.arg, 6.0)
  41. def test_triple(self):
  42. c1 = UOp.const(dtypes.float, 1.0)
  43. c2 = UOp.const(dtypes.float, 2.0)
  44. c3 = UOp.const(dtypes.float, 3.0)
  45. c4 = UOp.const(dtypes.float, 4.0)
  46. nout = graph_rewrite(c1+c2+c3+c4, simple_pm)
  47. self.assertEqual(nout.op, UOps.CONST)
  48. self.assertEqual(nout.arg, 10.0)
  49. def test_diamond(self):
  50. c1 = UOp.const(dtypes.float, 1.0)
  51. c2 = UOp.const(dtypes.float, 2.0)
  52. c3 = UOp.const(dtypes.float, 3.0)
  53. nout = graph_rewrite((c1+c2)+(c1+c3), simple_pm)
  54. self.assertEqual(nout.op, UOps.CONST)
  55. self.assertEqual(nout.arg, 7.0)
  56. def test_magic_4(self):
  57. c1 = UOp.const(dtypes.int, 4.0)
  58. nout = graph_rewrite(c1, simple_pm)
  59. self.assertEqual(nout.op, UOps.CONST)
  60. self.assertEqual(nout.arg, 3.0)
  61. def test_depth_2_fold(self):
  62. v = UOp(UOps.DEFINE_VAR, dtypes.float)
  63. c1 = UOp.const(dtypes.float, 1.0)
  64. c2 = UOp.const(dtypes.float, 2.0)
  65. nout = graph_rewrite(v+c1+c2, simple_pm)
  66. self.assertEqual(nout.op, UOps.ALU)
  67. self.assertEqual(nout.src[0].op, UOps.DEFINE_VAR)
  68. self.assertEqual(nout.src[1].op, UOps.CONST)
  69. self.assertEqual(nout.src[1].arg, 3.0)
  70. class TestUOpGraph(TestUOps):
  71. def test_add_constant_fold(self):
  72. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  73. c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
  74. out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
  75. g = UOpGraph([out])
  76. self.assertEqual(len(g.uops), 1)
  77. out = g.uops[-1]
  78. self.assertEqual(out.op, UOps.CONST)
  79. self.assertEqual(out.arg, 3.0)
  80. def test_where_same_fold(self):
  81. v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
  82. c0 = UOp(UOps.CONST, dtypes.int, arg=0)
  83. vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
  84. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  85. out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
  86. g = UOpGraph([out])
  87. self.assertEqual(len(g.uops), 1)
  88. out = g.uops[-1]
  89. self.assertEqual(out.op, UOps.CONST)
  90. self.assertEqual(out.arg, 1.0)
  91. def test_where_const_fold(self):
  92. bf = UOp(UOps.CONST, dtypes.bool, arg=False)
  93. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  94. c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
  95. out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
  96. g = UOpGraph([out])
  97. self.assertEqual(len(g.uops), 1)
  98. out = g.uops[-1]
  99. self.assertEqual(out.op, UOps.CONST)
  100. self.assertEqual(out.arg, 2.0)
  101. def test_const_cast(self):
  102. bf = UOp(UOps.CONST, dtypes.bool, arg=False)
  103. out = UOp(UOps.CAST, dtypes.int, (bf,))
  104. g = UOpGraph([out])
  105. self.assertEqual(len(g.uops), 1)
  106. out = g.uops[-1]
  107. self.assertEqual(out.op, UOps.CONST)
  108. self.assertEqual(out.arg, 0)
  109. def test_const_vectorize_fold(self):
  110. c0 = UOp(UOps.CONST, dtypes.half, arg=0.0)
  111. out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0))
  112. g = UOpGraph([out])
  113. self.assertEqual(len(g.uops), 1)
  114. out = g.uops[-1]
  115. self.assertEqual(out.op, UOps.CONST)
  116. self.assertEqual(out.arg, 0.0)
  117. def test_noop_vectorize_fold(self):
  118. d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
  119. idx = UOp.const(dtypes.int, 0)
  120. ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
  121. vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
  122. x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
  123. alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
  124. out = UOp(UOps.STORE, None, (d0, idx, alu))
  125. g = UOpGraph([out])
  126. self.assertEqual(len([x for x in g.uops if x.op is UOps.VECTORIZE]), 0)
  127. def test_gep_vec_fold(self):
  128. d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
  129. d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, False))
  130. d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (2, False))
  131. idx = UOp.const(dtypes.int, 0)
  132. def _test_vec(geps):
  133. vec = UOp(UOps.VECTORIZE, dtypes.float.vec(4), geps)
  134. out = UOp(UOps.STORE, None, (d0, idx, vec))
  135. return UOpGraph([out]).uops[-1].src[-1]
  136. # possible
  137. val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
  138. xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4))
  139. self.assert_equiv_uops(_test_vec(xyzw), val)
  140. # unaligned
  141. val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
  142. wzyx = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in reversed(range(4)))
  143. self.assertIs(_test_vec(wzyx).op, UOps.VECTORIZE)
  144. # different_size
  145. val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
  146. xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2))
  147. self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE)
  148. # different vals
  149. val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
  150. val2 = UOp(UOps.LOAD, dtypes.float.vec(2), (d2, idx))
  151. xy1 = tuple(UOp(UOps.GEP, dtypes.float, (val1, ), i) for i in range(2))
  152. xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), i) for i in range(2))
  153. self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE)
  154. def test_cast_alu_fold(self):
  155. d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=(0, True))
  156. d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(1, False))
  157. idx = UOp.const(dtypes.int, 0)
  158. ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
  159. alu = ld.lt(1).cast(dtypes.bool)
  160. out = UOp(UOps.STORE, None, (d0, idx, alu))
  161. g = UOpGraph([out])
  162. self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0)
  163. def test_double_cast_fold(self):
  164. d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
  165. d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(1, False))
  166. idx = UOp.const(dtypes.int, 0)
  167. ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
  168. alu = ld.cast(dtypes.float).cast(dtypes.float)
  169. out = UOp(UOps.STORE, None, (d0, idx, alu))
  170. g = UOpGraph([out])
  171. self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1)
  172. def test_depth_2_const_fold(self):
  173. v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
  174. c2 = UOp(UOps.CONST, dtypes.int, arg=2)
  175. c4 = UOp(UOps.CONST, dtypes.int, arg=4)
  176. vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
  177. out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
  178. g = UOpGraph([out])
  179. self.assertEqual(len(g.uops), 3)
  180. out = g.uops[-1]
  181. self.assertEqual(out.op, UOps.ALU)
  182. self.assertEqual(out.arg, BinaryOps.ADD)
  183. self.assertEqual(out.src[1].op, UOps.CONST)
  184. self.assertEqual(out.src[1].arg, 6)
  185. def test_fold_gated_load(self):
  186. glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
  187. glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (1, False))
  188. glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (2, False))
  189. idx = UOp.const(dtypes.int, 0)
  190. ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2)))
  191. ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3)))
  192. uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld0+ld1))])
  193. ld0, ld1 = uops[-1].src[2].src
  194. # ld0 becomes the invalid value
  195. self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2))
  196. # the gate and invalid value are deleted from ld1
  197. self.assert_equiv_uops(ld1, UOp.load(glbl2, idx, dtype=dtypes.int))
  198. def test_fold_gated_load_local(self):
  199. glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
  200. smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
  201. lidx = UOp(UOps.SPECIAL, dtypes.int, (), (0, "lidx1", 16))
  202. st = UOp(UOps.STORE, None, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
  203. barrier = UOp(UOps.BARRIER, None, (st, ))
  204. ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier))
  205. ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3), barrier))
  206. uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld0+ld1))])
  207. ld0, ld1 = uops[-1].src[2].src
  208. # ld0 becomes the invalid value
  209. self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2))
  210. # the gate and invalid value are deleted from ld1
  211. self.assert_equiv_uops(ld1, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
  212. def test_fold_gated_store(self):
  213. glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
  214. idx0 = UOp.const(dtypes.int, 0)
  215. idx1 = UOp.const(dtypes.int, 0)
  216. val = UOp.const(dtypes.int, 42)
  217. st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
  218. st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
  219. uops = UOpGraph([st0, st1])
  220. # only the second store happens
  221. self.assertEqual(len(uops.uops), 4)
  222. self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
  223. def test_asserts_bad_gate(self):
  224. glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
  225. idx = UOp.const(dtypes.int, 0)
  226. bad_gate = UOp.const(dtypes.int, 1)
  227. uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
  228. with self.assertRaises(AssertionError): uops.linearize()
  229. def test_switched_range_order(self):
  230. glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
  231. c0 = UOp.const(dtypes.int, 0)
  232. c2 = UOp.const(dtypes.int, 2)
  233. cf = UOp.const(dtypes.float, 0.0)
  234. r1 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 0, False))
  235. r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
  236. alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
  237. store = UOp(UOps.STORE, None, (glbl, alu, cf))
  238. uops = UOpGraph([store]).uops
  239. ranges = [x for x in uops if x.op is UOps.RANGE]
  240. endranges = [x for x in uops if x.op is UOps.ENDRANGE]
  241. # ranges are closed in the right order
  242. self.assertEqual(endranges[-1].src[0], ranges[0])
  243. def expander_rewrite(sink):
  244. from tinygrad.codegen.uopgraph import expander, constant_folder
  245. together = PatternMatcher(expander.patterns + constant_folder.patterns)
  246. return graph_rewrite(sink, together)
  247. #out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
  248. #out.linearize()
  249. #return out.uops[-1]
  250. class TestExpander(unittest.TestCase):
  251. def test_expand_add_broadcast(self):
  252. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
  253. sink = expander_rewrite(e1+3)
  254. assert sink.op is UOps.EXPAND and len(sink.src) == 4
  255. self.assertListEqual([x.arg for x in sink.src], [3,4,5,6])
  256. def test_contract_simple(self):
  257. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
  258. con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
  259. sink = expander_rewrite(con)
  260. assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
  261. self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
  262. def test_contract_axis_1(self):
  263. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
  264. con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
  265. sink = expander_rewrite(con)
  266. assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),)
  267. assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
  268. self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
  269. self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
  270. def test_contract_axis_2(self):
  271. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
  272. con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,))
  273. sink = expander_rewrite(con)
  274. assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),)
  275. assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
  276. self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
  277. self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
  278. def test_contract_mid(self):
  279. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
  280. con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
  281. sink = expander_rewrite(con)
  282. assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2))
  283. assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2
  284. self.assertListEqual([x.arg for x in sink.src[0].src], [0,2])
  285. self.assertListEqual([x.arg for x in sink.src[1].src], [1,3])
  286. self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
  287. self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
  288. def test_expand_same_axis(self):
  289. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
  290. e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
  291. sink = expander_rewrite(e1+e2)
  292. assert sink.op is UOps.EXPAND and len(sink.src) == 4
  293. self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
  294. def test_expand_different_axis(self, flip=False):
  295. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
  296. e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
  297. sink = expander_rewrite((e2+e1) if flip else (e1+e2))
  298. assert sink.op is UOps.EXPAND and len(sink.src) == 16
  299. assert sink.arg == ((1, 4), (2, 4))
  300. self.assertListEqual([x.arg for x in sink.src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
  301. def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
  302. def test_reduce_known_axis(self):
  303. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
  304. sink = UOp(UOps.REDUCE, dtypes.int, (3*e1,e1), ReduceOps.SUM)
  305. sink = expander_rewrite(sink)
  306. assert sink.op is UOps.CONST
  307. self.assertEqual(sink.arg, 3*(0+1+2+3))
  308. def test_reduce_const(self):
  309. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
  310. sink = UOp(UOps.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), ReduceOps.SUM)
  311. sink = expander_rewrite(sink)
  312. assert sink.op is UOps.CONST
  313. self.assertEqual(sink.arg, 3*4)
  314. def test_double_expand(self):
  315. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
  316. e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
  317. e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((1,2),))
  318. sink = expander_rewrite(e)
  319. assert sink.op is UOps.EXPAND and len(sink.src) == 8
  320. assert sink.arg == ((1, 2), (2, 4))
  321. self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
  322. def test_double_expand_reverse(self):
  323. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
  324. e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
  325. e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
  326. sink = expander_rewrite(e)
  327. assert sink.op is UOps.EXPAND and len(sink.src) == 8
  328. assert sink.arg == ((1, 4), (2, 2))
  329. self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
  330. def test_double_expand_middle(self):
  331. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
  332. e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
  333. e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
  334. sink = expander_rewrite(e)
  335. assert sink.op is UOps.EXPAND and len(sink.src) == 8
  336. assert sink.arg == ((1, 2), (2, 2), (3, 2))
  337. self.assertListEqual([x.arg for x in sink.src], [0, 1, 4, 5, 2, 3, 6, 7])
  338. # does this need to work?
  339. @unittest.expectedFailure
  340. @unittest.skip
  341. def test_reduce_different_axis(self):
  342. e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
  343. e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
  344. sink = UOp(UOps.REDUCE, dtypes.int, (e1,e2), ReduceOps.SUM)
  345. sink = expander_rewrite(sink)
  346. print_tree(sink)
  347. if __name__ == '__main__':
  348. unittest.main(verbosity=2)