test_setitem.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import unittest
  2. from tinygrad import Tensor, TinyJit, Variable, dtypes
  3. import numpy as np
  4. class TestSetitem(unittest.TestCase):
  5. def test_simple_setitem(self):
  6. cases = (
  7. ((6,6), (slice(2,4), slice(3,5)), Tensor.ones(2,2)),
  8. ((6,6), (slice(2,4), slice(3,5)), Tensor([1.,2.])),
  9. ((6,6), (slice(2,4), slice(3,5)), 1.0),
  10. ((6,6), (3, 4), 1.0),
  11. ((6,6), (3, None, 4, None), 1.0),
  12. ((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4)),
  13. ((4,4,4,4), (Ellipsis, slice(1,3)), 4),
  14. ((4,4,4,4), (2, slice(1,3), None, 1), 4),
  15. ((4,4,4,4), (slice(1,3), slice(None), slice(0,4,2)), 4),
  16. ((4,4,4,4), (slice(1,3), slice(None), slice(None), slice(0,3)), 4),
  17. ((6,6), (slice(1,5,2), slice(0,5,3)), 1.0),
  18. ((6,6), (slice(5,1,-2), slice(5,0,-3)), 1.0),
  19. )
  20. for shp, slc, val in cases:
  21. t = Tensor.zeros(shp).contiguous()
  22. t[slc] = val
  23. n = np.zeros(shp)
  24. n[slc] = val.numpy() if isinstance(val, Tensor) else val
  25. np.testing.assert_allclose(t.numpy(), n)
  26. def test_setitem_into_unrealized(self):
  27. t = Tensor.arange(4).reshape(2, 2)
  28. t[1] = 5
  29. np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
  30. def test_setitem_dtype(self):
  31. for dt in (dtypes.int, dtypes.float, dtypes.bool):
  32. for v in (5., 5, True):
  33. t = Tensor.ones(6,6, dtype=dt).contiguous()
  34. t[1] = v
  35. assert t.dtype == dt
  36. def test_setitem_into_noncontiguous(self):
  37. t = Tensor.ones(4)
  38. assert not t.lazydata.st.contiguous
  39. with self.assertRaises(AssertionError): t[1] = 5
  40. def test_setitem_inplace_operator(self):
  41. t = Tensor.arange(4).reshape(2, 2).contiguous()
  42. t[1] += 2
  43. np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 5]])
  44. t = Tensor.arange(4).reshape(2, 2).contiguous()
  45. t[1] -= 1
  46. np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 2]])
  47. t = Tensor.arange(4).reshape(2, 2).contiguous()
  48. t[1] *= 2
  49. np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 6]])
  50. # NOTE: have to manually cast setitem target to least_upper_float for div
  51. t = Tensor.arange(4, dtype=dtypes.float).reshape(2, 2).contiguous()
  52. t[1] /= 2
  53. np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 1.5]])
  54. t = Tensor.arange(4).reshape(2, 2).contiguous()
  55. t[1] **= 2
  56. np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 9]])
  57. t = Tensor.arange(4).reshape(2, 2).contiguous()
  58. t[1] ^= 5
  59. np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
  60. @unittest.expectedFailure
  61. def test_setitem_consecutive_inplace_operator(self):
  62. t = Tensor.arange(4).reshape(2, 2).contiguous()
  63. t[1] += 2
  64. t = t.contiguous()
  65. # TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),))
  66. t[1] -= 1
  67. np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])
  68. # TODO: implement fancy setitem
  69. @unittest.expectedFailure
  70. def test_fancy_setitem(self):
  71. t = Tensor.zeros(6,6).contiguous()
  72. t[[1,2], [3,2]] = 3
  73. n = np.zeros((6,6))
  74. n[[1,2], [3,2]] = 3
  75. np.testing.assert_allclose(t.numpy(), n)
  76. def test_simple_jit_setitem(self):
  77. @TinyJit
  78. def f(t:Tensor, a:Tensor):
  79. t[2:4, 3:5] = a
  80. for i in range(1, 6):
  81. t = Tensor.zeros(6, 6).contiguous().realize()
  82. a = Tensor.full((2, 2), fill_value=i, dtype=dtypes.float).contiguous()
  83. f(t, a)
  84. n = np.zeros((6, 6))
  85. n[2:4, 3:5] = np.full((2, 2), i)
  86. np.testing.assert_allclose(t.numpy(), n)
  87. def test_jit_setitem_variable_offset(self):
  88. @TinyJit
  89. def f(t:Tensor, a:Tensor, v:Variable):
  90. t.shrink(((v,v+1), None)).assign(a).realize()
  91. t = Tensor.zeros(6, 6).contiguous().realize()
  92. n = np.zeros((6, 6))
  93. for i in range(6):
  94. v = Variable("v", 0, 6).bind(i)
  95. a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous()
  96. n[i, :] = i+1
  97. f(t, a, v)
  98. np.testing.assert_allclose(t.numpy(), n)
  99. np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
  100. class TestWithGrad(unittest.TestCase):
  101. def test_no_requires_grad_works(self):
  102. z = Tensor.rand(8, 8)
  103. x = Tensor.rand(8)
  104. z[:3] = x
  105. def test_set_into_requires_grad(self):
  106. z = Tensor.rand(8, 8, requires_grad=True)
  107. x = Tensor.rand(8)
  108. with self.assertRaises(NotImplementedError):
  109. z[:3] = x
  110. def test_set_with_requires_grad(self):
  111. z = Tensor.rand(8, 8)
  112. x = Tensor.rand(8, requires_grad=True)
  113. with self.assertRaises(NotImplementedError):
  114. z[:3] = x
  115. if __name__ == '__main__':
  116. unittest.main()