test_conv_shapetracker.py 879 B

12345678910111213141516171819202122
  1. #!/usr/bin/env python
  2. import unittest
  3. from tinygrad.tensor import Tensor
  4. from tinygrad.ops import MetaOps, BufferOps
  5. from tinygrad.nn import Conv2d
  6. from tinygrad.engine.schedule import create_schedule
  7. class TestConvShapetracker(unittest.TestCase):
  8. def test_conv_3x3_one_view(self):
  9. conv = Conv2d(16, 32, (3, 3))
  10. seen = set()
  11. # first run to init the weights, they are saved in seen
  12. create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
  13. # run it again to get the kernels
  14. sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is MetaOps.KERNEL]
  15. assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
  16. for st in [x.arg.st for x in sched[0].ast.lazyops if x.op is BufferOps.LOAD]:
  17. assert len(st.views) == 1
  18. if __name__ == '__main__':
  19. unittest.main()