test_search.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import unittest
  2. from tinygrad.codegen.kernel import Opt, OptOps
  3. from tinygrad.codegen.kernel import Kernel
  4. from tinygrad.engine.schedule import create_schedule
  5. from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
  6. from tinygrad.device import Device, Buffer
  7. from tinygrad.ops import LazyOp, MetaOps, BufferOps, ReduceOps, BinaryOps, MemBuffer, ConstBuffer
  8. from tinygrad.tensor import Tensor
  9. from tinygrad.dtype import dtypes
  10. from tinygrad.helpers import Context, GlobalCounters
  11. from tinygrad.engine.realize import capturing
  12. from tinygrad.shape.shapetracker import ShapeTracker
  13. from tinygrad.shape.view import View
  14. class TestTimeLinearizer(unittest.TestCase):
  15. def test_reasonable_time(self):
  16. si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.KERNEL][0]
  17. out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
  18. memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast.lazyops if x.op is BufferOps.LOAD}
  19. rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
  20. tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
  21. assert tm > 0 and tm != float('inf')
  22. def test_bufs_from_lin(self):
  23. si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.KERNEL][0]
  24. rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
  25. assert len(rawbufs) == len(lin.membufs)
  26. assert all(r is not None for r in rawbufs)
  27. assert all(isinstance(r, Buffer) for r in rawbufs)
  28. assert all(r.size > 0 for r in rawbufs)
  29. def test_kernel_count(self):
  30. """
  31. Ensure that the kernel count is not incremented by time_linearizer when clearing l2
  32. """
  33. # ast of Tensor.zeros(16).contiguous().realize()
  34. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
  35. lin = Kernel(ast)
  36. bufs = bufs_from_lin(lin)
  37. kernel_count = GlobalCounters.kernel_count
  38. time_linearizer(lin, bufs, allow_test_size=False, cnt=2, disable_cache=True, clear_l2=True)
  39. assert GlobalCounters.kernel_count == kernel_count, "kernel count was incremented by time_linearizer"
  40. class TestBEAM(unittest.TestCase):
  41. def test_dynamic_beam(self):
  42. # TODO: make this infra globally usable
  43. class Capture:
  44. def __init__(self): self.captured = []
  45. def add(self, x): self.captured.append(x)
  46. capturing.append(Capture())
  47. kernel_count = GlobalCounters.kernel_count
  48. with Context(BEAM=1): Tensor.zeros(16).contiguous().realize()
  49. assert GlobalCounters.kernel_count == kernel_count + 1
  50. k_beam_1 = capturing[0].captured
  51. capturing.clear()
  52. capturing.append(Capture())
  53. kernel_count = GlobalCounters.kernel_count
  54. with Context(BEAM=0): Tensor.zeros(16).contiguous().realize()
  55. assert GlobalCounters.kernel_count == kernel_count + 1
  56. k_beam_0 = capturing[0].captured
  57. capturing.clear()
  58. self.assertNotEqual(k_beam_0[-1].prg.p.src, k_beam_1[-1].prg.p.src)
  59. def test_get_kernel_actions(self):
  60. from test.test_linearizer import helper_realized_ast
  61. a = Tensor.rand(4, 3)
  62. b = Tensor.rand(3)
  63. realized_ast, _ = helper_realized_ast(a @ b)
  64. from tinygrad.engine.search import get_kernel_actions
  65. lins = get_kernel_actions(Kernel(realized_ast), False).values()
  66. # ensure amt=0 are not duplicated
  67. if Opt(OptOps.UPCAST, 0, 0) in actions:
  68. assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, amt=4)]) == 0, "did not de-dup UPCAST"
  69. if Opt(OptOps.LOCAL, 0, 0) in actions:
  70. assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, amt=4)]) == 0, "did not de-dup LOCAL"
  71. if Opt(OptOps.UNROLL, 0, 0) in actions:
  72. assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, amt=3)]) == 0, "did not de-dup UNROLL"
  73. if Opt(OptOps.GROUP, 0, 0) in actions:
  74. assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, amt=3)]) == 0, "did not de-dup GROUP"
  75. if Opt(OptOps.GROUPTOP, 0, 0) in actions:
  76. assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP"
  77. def test_filter_global_buffer(self):
  78. # taken from https://github.com/tinygrad/tinygrad/issues/4612
  79. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
  80. lin = Kernel(ast)
  81. bufs = bufs_from_lin(lin)
  82. best_lin = beam_search(lin, bufs, 3)
  83. assert best_lin
  84. # need disable_cache to trigger.
  85. tm = time_linearizer(best_lin, bufs, allow_test_size=False, cnt=2, disable_cache=True)
  86. assert tm
  87. if __name__ == '__main__':
  88. unittest.main()