test_fusion_op.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import unittest
  2. import time
  3. import numpy as np
  4. from tinygrad import Tensor, dtypes
  5. from tinygrad.engine.schedule import create_schedule
  6. from tinygrad.engine.realize import lower_schedule_item, run_schedule
  7. class TestFusionOp(unittest.TestCase):
  8. def test_contiguous_add(self):
  9. def test(contig=False):
  10. bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4)
  11. x = bt.permute(1,0)
  12. if contig: x = x.contiguous()
  13. return (x.permute(1,0) + bt).data()
  14. assert test() == test(True)
  15. def test_expand_fuse(self):
  16. bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
  17. out = (bt*2).expand(10,10).sum(1)
  18. sched = create_schedule([out.lazydata], None)
  19. run_schedule(sched)
  20. outd = out.tolist()
  21. assert all(x == 20.0 for x in outd)
  22. def test_recursive_add(self):
  23. st = time.perf_counter()
  24. a = Tensor([1,2,3,4])
  25. for _ in range(24): a = a + a
  26. sched = create_schedule([a.lazydata], None)
  27. ei = lower_schedule_item(sched[-1])
  28. self.assertLess(time.perf_counter()-st, 1.0)
  29. assert len(ei.prg.p.src.splitlines()) < 250
  30. def test_recursive_add_cmp(self):
  31. st = time.perf_counter()
  32. a = Tensor([1,2,3,4])
  33. for _ in range(24): a = a + a
  34. sched1 = create_schedule([a.lazydata], None)
  35. b = Tensor([1,2,3,4])
  36. for _ in range(24): b = b + b
  37. sched2 = create_schedule([b.lazydata], None)
  38. c = Tensor([1,2,3,4])
  39. for _ in range(23): c = c + c
  40. sched3 = create_schedule([c.lazydata], None)
  41. assert sched1[-1].ast == sched2[-1].ast
  42. assert sched1[-1].ast != sched3[-1].ast
  43. self.assertLess(time.perf_counter()-st, 1.0)
  44. if __name__ == '__main__':
  45. unittest.main(verbosity=2)