1
0

test_pattern_matcher.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import unittest
  2. from test.helpers import TestUOps
  3. from tinygrad.dtype import dtypes
  4. from tinygrad.ops import BinaryOps, TernaryOps
  5. from tinygrad.codegen.uops import UOps, UOp
  6. from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, UPat, _match
  7. class TestPatternMatcher(TestUOps):
  8. def test_simple_match(self):
  9. matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
  10. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  11. c2 = UOp(UOps.CONST, dtypes.int, arg=1)
  12. self.assertEqual(matcher.rewrite(c1), c1)
  13. self.assertEqual(matcher.rewrite(c2), None)
  14. def test_uop(self):
  15. matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
  16. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  17. c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
  18. self.assertEqual(matcher.rewrite(c1), c1)
  19. self.assertEqual(matcher.rewrite(c2), None)
  20. def test_uop_set(self):
  21. matcher = PatternMatcher([(UPat({UOps.CONST, UOps.CAST}, name="x"), lambda x: x)])
  22. c1 = UOp(UOps.CONST, dtypes.bool, arg=False)
  23. c2 = UOp(UOps.CAST, dtypes.int, (c1,))
  24. c3 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  25. c4 = UOp(UOps.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
  26. self.assertEqual(matcher.rewrite(c1), c1)
  27. self.assertEqual(matcher.rewrite(c2), c2)
  28. self.assertEqual(matcher.rewrite(c4), None)
  29. def test_arg(self):
  30. matcher = PatternMatcher([
  31. (UPat(UOps.CONST, 0, name="x"), lambda x: x),
  32. (UPat(UOps.CONST, False, name="x"), lambda x: x),
  33. (UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x),
  34. ])
  35. c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
  36. c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
  37. c3 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
  38. c4 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
  39. c5 = UOp(UOps.CONST, dtypes.int, arg=-1)
  40. self.assertEqual(matcher.rewrite(c1), c1)
  41. self.assertEqual(matcher.rewrite(c2), c2)
  42. self.assertEqual(matcher.rewrite(c3), c3)
  43. self.assertEqual(matcher.rewrite(c4), None)
  44. self.assertEqual(matcher.rewrite(c5), None)
  45. @unittest.skip("this is not supported any more")
  46. def test_arg_set(self):
  47. matcher = PatternMatcher([(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, {-1, 1}), UPat(UOps.CONST, 2)), name="x"), lambda x: x)])
  48. y1 = UOp(UOps.CONST, dtypes.int, arg=1)
  49. y2 = UOp(UOps.CONST, dtypes.int, arg=2)
  50. y3 = UOp(UOps.CONST, dtypes.int, arg=-1)
  51. c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
  52. c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
  53. c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
  54. self.assertEqual(matcher.rewrite(c1), c1)
  55. self.assertEqual(matcher.rewrite(c2), None)
  56. self.assertEqual(matcher.rewrite(c3), c3)
  57. def test_dup_name(self):
  58. matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
  59. y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  60. y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  61. c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
  62. c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
  63. self.assertEqual(matcher.rewrite(c1), c1)
  64. self.assertEqual(matcher.rewrite(c2), None)
  65. def test_dtype(self):
  66. matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
  67. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  68. c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
  69. self.assertEqual(matcher.rewrite(c1), c1)
  70. self.assertEqual(matcher.rewrite(c2), None)
  71. def test_dtype_set(self):
  72. matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=set([dtypes.float32, dtypes.float64])), lambda x: x)])
  73. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  74. c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
  75. c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0)
  76. c4 = UOp(UOps.CONST, dtypes.int, arg=1)
  77. self.assertEqual(matcher.rewrite(c1), c1)
  78. self.assertEqual(matcher.rewrite(c2), c2)
  79. self.assertEqual(matcher.rewrite(c3), None)
  80. self.assertEqual(matcher.rewrite(c4), None)
  81. def test_vin_one(self):
  82. matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
  83. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  84. c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
  85. c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
  86. self.assertEqual(matcher.rewrite(c3), c3)
  87. self.assertEqual(matcher.rewrite(c2), None)
  88. matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
  89. c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
  90. c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
  91. self.assertEqual(matcher.rewrite(c3), None)
  92. self.assertEqual(matcher.rewrite(c4), c4)
  93. self.assertEqual(matcher.rewrite(c5), None)
  94. def test_vin_permutations(self):
  95. matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
  96. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  97. c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
  98. c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
  99. c4 = UOp(UOps.ALU, dtypes.float, (c3,c2), BinaryOps.ADD)
  100. c5 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
  101. c6 = UOp(UOps.ALU, dtypes.float, (c3,c4), BinaryOps.ADD)
  102. self.assertEqual(matcher.rewrite(c3), None)
  103. self.assertEqual(matcher.rewrite(c4), c4)
  104. self.assertEqual(matcher.rewrite(c5), c5)
  105. self.assertEqual(matcher.rewrite(c6), None)
  106. def test_vin_repeat(self):
  107. matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)])
  108. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  109. c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
  110. c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
  111. c4 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
  112. self.assertEqual(matcher.rewrite(c3), c3)
  113. self.assertEqual(matcher.rewrite(c4), None)
  114. def test_allow_len(self):
  115. matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
  116. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  117. c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
  118. c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
  119. #c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG)
  120. c5 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
  121. c6 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
  122. #self.assertEqual(matcher.rewrite(c4), c4)
  123. self.assertEqual(matcher.rewrite(c5), None)
  124. self.assertEqual(matcher.rewrite(c6), c6)
  125. def test_deep_src_permutations(self):
  126. c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
  127. c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
  128. u1 = (c1 + c2) + c1
  129. u2 = (c2 + c1) + c1
  130. pat = UPat(UOps.ALU, src = (UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')))
  131. assert _match(u1, pat, {})
  132. assert _match(u2, pat, {})
  133. @unittest.skip("no longer supported")
  134. def test_rewrite_graph_folds(self):
  135. uops = UOpGraph()
  136. UOp(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
  137. matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float),
  138. lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
  139. matcher.rewrite_graph(uops)
  140. # TODO: fix this. it's 2 now
  141. # self.assertEqual(len(uops.uops), 1)
  142. self.assertEqual(len(uops.uops), 2)
  143. self.assert_equiv_uops(UOp(UOps.CONST, dtypes.int, arg=4), uops.uops[-1])
  144. @unittest.skip("no longer supported")
  145. def test_rewrite_graph_adds(self):
  146. uops = UOpGraph()
  147. UOp(UOps.CONST, dtypes.int, arg=2, simplify=False)
  148. matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
  149. lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
  150. matcher.rewrite_graph(uops)
  151. uops.remove_childless(set(x for x in uops if x.op in {UOps.STORE}))
  152. self.assertEqual(len(uops.uops), 3)
  153. e1 = UOp(UOps.CONST, dtypes.int, arg=2)
  154. e2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
  155. e3 = UOp(UOps.STORE, dtypes.int, (e2,e1))
  156. self.assert_equiv_uops(e1, uops.uops[0])
  157. self.assert_equiv_uops(e2, uops.uops[1])
  158. self.assert_equiv_uops(e3, uops.uops[2])
  159. if __name__ == '__main__':
  160. unittest.main(verbosity=2)