test_print_tree.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #%%
  2. import unittest
  3. from tinygrad.engine.graph import print_tree
  4. from tinygrad import Tensor, dtypes
  5. from tinygrad.codegen.uops import UOp
  6. import sys, io
  7. class TestPrintTree(unittest.TestCase):
  8. def _capture_print(self, fn):
  9. capturedOutput = io.StringIO()
  10. sys.stdout = capturedOutput
  11. fn()
  12. sys.stdout = sys.__stdout__
  13. return capturedOutput.getvalue()
  14. def test_print_uop(self):
  15. x = Tensor.arange(10).schedule()[-1].ast.src[0]
  16. output = self._capture_print(lambda: print_tree(x))
  17. assert output == '\
  18. 0 ━┳ BufferOps.STORE MemBuffer(idx=0, dtype=dtypes.int, \
  19. st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))\n\
  20. 1 ┗━┳ BinaryOps.ADD None\n\
  21. 2 ┣━┳ ReduceOps.SUM (1,)\n\
  22. 3 ┃ ┗━━ BufferOps.CONST ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTrac\
  23. ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19))\
  24. , contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))\n\
  25. 4 ┗━━ BufferOps.CONST ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10,\
  26. 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)))\n'
  27. x = UOp.var("x", dtypes.int)
  28. x = (x + x) - UOp.const(dtypes.int, 2)
  29. output = self._capture_print(lambda: print_tree(x))
  30. assert output == '\
  31. 0 ━┳ UOps.ALU BinaryOps.ADD\n\
  32. 1 ┣━┳ UOps.ALU BinaryOps.ADD\n\
  33. 2 ┃ ┣━━ UOps.VAR x\n\
  34. 3 ┃ ┗━━ UOps.VAR x\n\
  35. 4 ┗━┳ UOps.ALU UnaryOps.NEG\n\
  36. 5 ┗━━ UOps.CONST 2\n'
  37. """
  38. x = UPat(UOp.alu(BinaryOps.ADD, UOp.var("x", dtypes.int), UOp.var("x", dtypes.int)))
  39. assert self._capture_print(lambda: print_tree(x)) == '\
  40. 0 ━━ UOps.ALU : dtypes.int [<UOps.VAR: 2>, <UOps.VAR: 2>] BinaryOps.ADD None\n'
  41. x = UPat.compile(UOp.store(UOp.var("buf"), UOp.var("idx"),
  42. UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store)
  43. assert self._capture_print(lambda: print_tree(x)) == '\
  44. 0 ━┳ UOps.STORE None\n\
  45. 1 ┣━━ None None\n\
  46. 2 ┣━━ None None\n\
  47. 3 ┗━┳ UOps.CAST None\n\
  48. 4 ┣━┳ UOps.GEP 0\n\
  49. 5 ┃ ┗━━ None None\n\
  50. 6 ┣━┳ UOps.GEP 1\n\
  51. 7 ┃ ┗━━ None None\n\
  52. 8 ┣━┳ UOps.GEP 2\n\
  53. 9 ┃ ┗━━ None None\n\
  54. 10 ┗━┳ UOps.GEP 3\n\
  55. 11 ┗━━ None None\n'
  56. """
  57. if __name__ == "__main__":
  58. unittest.main()