test_tensor_variable.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import unittest
  2. import numpy as np
  3. from tinygrad import Tensor, Variable
  4. class TestTensorVariable(unittest.TestCase):
  5. def test_add_tvar(self):
  6. vv = Variable("a", 0, 10)
  7. vv.bind(1)
  8. ret = (Tensor(vv) + 3).item()
  9. assert ret == 4
  10. def test_inner_tvar_node(self):
  11. vv = Variable("w", 0, 10)
  12. vv.bind(2)
  13. ret = Tensor.from_node(vv * 4).item()
  14. assert ret == 8
  15. def test_inner_tvar_mul(self):
  16. vv = Variable("w", 0, 10)
  17. vv.bind(2)
  18. assert (Tensor(3) * vv).item() == 6
  19. def test_inner_tvar_mul_node(self):
  20. vv = Variable("w", 0, 10)
  21. vv.bind(2)
  22. assert (Tensor(3) * (vv * 4)).item() == 24
  23. def test_symbolic_mean(self):
  24. vv = Variable("a", 1, 10).bind(2)
  25. t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
  26. ret = t.mean().item()
  27. assert ret == 1
  28. def test_symbolic_mean_2d(self):
  29. vv = Variable("a", 1, 10).bind(2)
  30. vv2 = Variable("b", 1, 10).bind(2)
  31. t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
  32. ret = t.mean().item()
  33. assert ret == 1
  34. def test_symbolic_mean_2d_axis_1(self):
  35. vv = Variable("a", 1, 10).bind(2)
  36. vv2 = Variable("b", 1, 10).bind(2)
  37. t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
  38. ret = t.mean(axis=1).reshape(2, 1).numpy()
  39. assert np.all(ret == 1)
  40. @unittest.expectedFailure
  41. def test_symbolic_mean_2d_add(self):
  42. add_term = Variable("c", 0, 10)
  43. add_term.bind(1)
  44. vv = Variable("a", 1, 10)
  45. vv.bind(1)
  46. vv2 = Variable("b", 1, 10)
  47. vv2.bind(1)
  48. t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term)
  49. ret = t.mean().item()
  50. assert ret == 1
  51. def test_symbolic_var(self):
  52. vv = Variable("a", 1, 10).bind(2)
  53. t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
  54. ret = t.var().item()
  55. assert ret == 0
  56. @unittest.skip("symbolic arange isn't supported")
  57. def test_symbolic_arange(self):
  58. vv = Variable("a", 1, 10)
  59. vv.bind(2)
  60. ret = Tensor.arange(0, vv)
  61. ret.realize()
  62. if __name__ == '__main__':
  63. unittest.main()