test_masked_st.py 912 B

1234567891011121314151617181920212223242526272829303132
  1. import unittest
  2. from tinygrad.tensor import Tensor
  3. class TestMaskedShapeTracker(unittest.TestCase):
  4. def test_mul_masked(self):
  5. a = Tensor([1,1,1,1,1])
  6. b = Tensor([1,1]).pad(((0,3),))
  7. c = a*b
  8. assert c.shape == a.shape
  9. #assert c.lazydata.st.views[0].mask is not None
  10. ret = c.data()
  11. assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
  12. def test_mul_both_masked(self):
  13. a = Tensor([1,1]).pad(((0,3),))
  14. b = Tensor([1,1]).pad(((0,3),))
  15. c = a*b
  16. assert c.shape == a.shape
  17. #assert c.lazydata.st.views[0].mask is not None
  18. ret = c.data()
  19. assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
  20. def test_add_masked(self):
  21. a = Tensor([1,1]).pad(((0,2),))
  22. b = Tensor([1,1]).pad(((0,2),))
  23. c = a+b
  24. #assert c.lazydata.st.views[0].mask is not None
  25. ret = c.data()
  26. assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]
  27. if __name__ == '__main__':
  28. unittest.main()