test_view.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #!/usr/bin/env python
  2. import unittest
  3. from tinygrad.shape.view import View
  4. class TestView(unittest.TestCase):
  5. def test_canonicalize_empty_mask(self):
  6. v = View.create(shape=(2,2,2), strides=(4,2,1), mask=((0,2),(0,2),(0,2)))
  7. assert v.mask is None
  8. v = View.create(shape=(4,3,2), strides=(1,4,10), mask=((0,4),(0,3),(0,2)))
  9. assert v.mask is None
  10. def test_minify_zero_strided_dims(self):
  11. target = View.create(shape=(2,2), strides=(30,2), offset=7, mask=None)
  12. v = View.create(shape=(2,1,2), strides=(30,0,2), offset=7, mask=None)
  13. assert v.minify() == target
  14. v = View.create(shape=(1,2,2), strides=(0,30,2), offset=7, mask=None)
  15. assert v.minify() == target
  16. v = View.create(shape=(2,2,1), strides=(30,2,0), offset=7, mask=None)
  17. assert v.minify() == target
  18. v = View.create(shape=(2,1,1,2), strides=(30,0,0,2), offset=7, mask=None)
  19. assert v.minify() == target
  20. v = View.create(shape=(1,1,2,2), strides=(0,0,30,2), offset=7, mask=None)
  21. assert v.minify() == target
  22. v = View.create(shape=(2,2,1,1), strides=(30,2,0,0), offset=7, mask=None)
  23. assert v.minify() == target
  24. v = View.create(shape=(1,2,2,1), strides=(0,30,2,0), offset=7, mask=None)
  25. assert v.minify() == target
  26. v = View.create(shape=(1,2,1,2), strides=(0,30,0,2), offset=7, mask=None)
  27. assert v.minify() == target
  28. def test_empty_mask_contiguous(self):
  29. v1 = View.create(shape=(2,2,2), strides=(4,2,1), mask=None)
  30. v2 = View.create(shape=(2,2,2), strides=(4,2,1), mask=((0,2),(0,2),(0,2)))
  31. assert v1.contiguous == v2.contiguous
  32. v1 = View.create(shape=(1,1,1,4), strides=(0,0,0,1), offset=0, mask=None)
  33. v2 = View.create(shape=(1,1,1,4), strides=(0,0,0,1), offset=0, mask=((0,1),(0,1),(0,1),(0,4)))
  34. assert v1.contiguous == v2.contiguous
  35. v = View.create(shape=(2,3,4), mask=((0,2),(0,3),(0,4)))
  36. assert v.contiguous
  37. if __name__ == '__main__':
  38. unittest.main()