test_verify_lazyop.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. import unittest
  3. from tinygrad.codegen.kernel import Kernel
  4. #from tinygrad.codegen.kernel import Kernel
  5. from tinygrad.engine.graph import print_tree
  6. from tinygrad.helpers import DEBUG
  7. from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, MetaOps, verify_lazyop
  8. from tinygrad.shape.shapetracker import ShapeTracker
  9. from tinygrad import dtypes
  10. from tinygrad.shape.view import View
  11. class InvalidLazyOpException(Exception): pass
  12. def lower(*ast:LazyOp):
  13. sink_ast = LazyOp(MetaOps.KERNEL, ast)
  14. if DEBUG >= 3:
  15. for op in ast: print_tree(op)
  16. try: verify_lazyop(sink_ast)
  17. except AssertionError: raise InvalidLazyOpException()
  18. k = Kernel(sink_ast)
  19. k.linearize()
  20. if DEBUG >= 6: k.uops.print()
  21. if DEBUG >= 4: print(k.to_program().src)
  22. return k
  23. class TestVerifyLazyOp(unittest.TestCase):
  24. def test_tiny_add(self):
  25. dtype = dtypes.int
  26. st = ShapeTracker.from_shape((32, 1))
  27. a = LazyOp(BufferOps.LOAD, arg=MemBuffer(1, dtype, st))
  28. b = LazyOp(BufferOps.LOAD, arg=MemBuffer(2, dtype, st))
  29. out = LazyOp(BufferOps.STORE, (a+b, ), arg=MemBuffer(0, dtype, st))
  30. lower(out)
  31. def test_exactly_one_full_shape(self):
  32. a = LazyOp(BufferOps.LOAD, arg=MemBuffer(1, dtypes.int, ShapeTracker.from_shape((32, 1))))
  33. b = LazyOp(BufferOps.LOAD, arg=MemBuffer(2, dtypes.int, ShapeTracker.from_shape((32, 1))))
  34. out0 = LazyOp(BufferOps.STORE, (a+b, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 1))))
  35. c = LazyOp(BufferOps.LOAD, arg=MemBuffer(3, dtypes.int, ShapeTracker.from_shape((32, 32))))
  36. d = LazyOp(BufferOps.LOAD, arg=MemBuffer(4, dtypes.int, ShapeTracker.from_shape((32, 32))))
  37. out1 = LazyOp(BufferOps.STORE, (c+d, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 32))))
  38. with self.assertRaises(InvalidLazyOpException): lower(out0, out1)
  39. def test_no_implicit_broadcasting(self):
  40. t = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((4, 32))))
  41. b = t + LazyOp(ReduceOps.MAX, (t, ), (1, ))
  42. out = LazyOp(BufferOps.STORE, (b, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((4, 32))))
  43. with self.assertRaises(InvalidLazyOpException): lower(out)
  44. def test_shrink_ok(self):
  45. a = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),))))
  46. b = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),))))
  47. out = LazyOp(BufferOps.STORE, (a+b, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((32, 32))))
  48. lower(out)
  49. def test_reduce_store(self):
  50. a = LazyOp(BufferOps.LOAD, arg=MemBuffer(1, dtypes.int, ShapeTracker.from_shape((32, 1))))
  51. r = LazyOp(ReduceOps.SUM, (a, ), (0, ))
  52. out = LazyOp(BufferOps.STORE, (r, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 1))))
  53. with self.assertRaises(InvalidLazyOpException): lower(out)
  54. def test_reduce_add_store(self):
  55. a = LazyOp(BufferOps.LOAD, arg=MemBuffer(1, dtypes.int, ShapeTracker.from_shape((32, 1))))
  56. r = LazyOp(ReduceOps.SUM, (a, ), (0, ))
  57. out = LazyOp(BufferOps.STORE, (r+a, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 1))))
  58. with self.assertRaises(InvalidLazyOpException): lower(out)
  59. def test_multi_reduce_simple(self):
  60. early_st = ShapeTracker.from_shape((32, 32)).reshape((32, 1, 32)).expand((32, 32, 32))
  61. early_x = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=early_st))
  62. r0 = LazyOp(op=ReduceOps.SUM, src=(early_x, ), arg=(1, ))
  63. late_st = ShapeTracker.from_shape((32, 32)).reshape((32, 1, 32))
  64. late_x = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=late_st))
  65. r1 = LazyOp(op=ReduceOps.SUM, src=(late_x + r0, ), arg=(0, 1, 2))
  66. out = LazyOp(op=BufferOps.STORE, src=(r1, ), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1, 1, 1))))
  67. lower(out)
  68. if __name__ == '__main__':
  69. unittest.main()