test_tensor.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. import numpy as np
  2. import torch
  3. import unittest, copy, mmap, random, math
  4. from tinygrad import Tensor, Device, dtypes
  5. from tinygrad.engine.schedule import create_schedule
  6. from tinygrad.helpers import getenv, temp, CI, _METADATA
  7. from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
  8. from hypothesis import given, settings, strategies as strat
  9. from test.helpers import is_dtype_supported
  10. settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
  11. settings.load_profile("my_profile")
  12. x_init = np.random.randn(1,3).astype(np.float32)
  13. U_init = np.random.randn(3,3).astype(np.float32)
  14. V_init = np.random.randn(3,3).astype(np.float32)
  15. W_init = np.random.randn(3,3).astype(np.float32)
  16. m_init = np.random.randn(1,3).astype(np.float32)
  17. class TestTinygrad(unittest.TestCase):
  18. def test_zerodim_initialization(self):
  19. self.assertEqual(Tensor(55).shape, ())
  20. self.assertEqual(Tensor(3.14).shape, ())
  21. def test_plus_equals(self):
  22. a = Tensor.randn(10,10)
  23. b = Tensor.randn(10,10)
  24. c = a + b
  25. val1 = c.numpy()
  26. a += b
  27. val2 = a.numpy()
  28. np.testing.assert_allclose(val1, val2)
  29. def test_backward_pass(self):
  30. def test_tinygrad():
  31. x = Tensor(x_init, requires_grad=True)
  32. W = Tensor(W_init, requires_grad=True)
  33. m = Tensor(m_init)
  34. out = x.dot(W).relu()
  35. out = out.log_softmax()
  36. out = out.mul(m).add(m).sum()
  37. out.backward()
  38. return out.numpy(), x.grad.numpy(), W.grad.numpy()
  39. def test_pytorch():
  40. x = torch.tensor(x_init, requires_grad=True)
  41. W = torch.tensor(W_init, requires_grad=True)
  42. m = torch.tensor(m_init)
  43. out = x.matmul(W).relu()
  44. out = torch.nn.functional.log_softmax(out, dim=1)
  45. out = out.mul(m).add(m).sum()
  46. out.backward()
  47. return out.detach().numpy(), x.grad, W.grad
  48. for x,y in zip(test_tinygrad(), test_pytorch()):
  49. np.testing.assert_allclose(x, y, atol=1e-5)
  50. @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs which breaks webgpu") #TODO: remove after #1461
  51. def test_backward_pass_diamond_model(self):
  52. def test_tinygrad():
  53. u = Tensor(U_init, requires_grad=True)
  54. v = Tensor(V_init, requires_grad=True)
  55. w = Tensor(W_init, requires_grad=True)
  56. x = u.mul(v).relu()
  57. y = u.mul(w).relu()
  58. out = x.add(y).mul(y).relu()
  59. out = out.log_softmax()
  60. out = out.sum()
  61. out.backward()
  62. return out.numpy(), u.grad.numpy(), v.grad.numpy(), w.grad.numpy()
  63. def test_pytorch():
  64. u = torch.tensor(U_init, requires_grad=True)
  65. v = torch.tensor(V_init, requires_grad=True)
  66. w = torch.tensor(W_init, requires_grad=True)
  67. x = u.mul(v).relu()
  68. y = u.mul(w).relu()
  69. out = x.add(y).mul(y).relu()
  70. out = torch.nn.functional.log_softmax(out, dim=1)
  71. out = out.sum()
  72. out.backward()
  73. return out.detach().numpy(), u.grad, v.grad, w.grad
  74. for x,y in zip(test_tinygrad(), test_pytorch()):
  75. np.testing.assert_allclose(x, y, atol=1e-5)
  76. def test_nograd(self):
  77. x = Tensor(x_init, requires_grad=False)
  78. m = Tensor(m_init, requires_grad=False)
  79. W = Tensor(W_init, requires_grad=True)
  80. tmp = x.mul(m)
  81. mm = tmp.matmul(W)
  82. out = mm.relu()
  83. out = out.sum()
  84. out.backward()
  85. assert x.grad is None
  86. assert m.grad is None
  87. assert tmp.grad is None
  88. assert mm.grad is not None
  89. assert W.grad is not None
  90. def test_dropout(self):
  91. with Tensor.train():
  92. n, rate = 1_000_000, 0.1
  93. w = Tensor.ones(n).dropout(rate)
  94. non_zeros = np.count_nonzero(w.numpy())
  95. expected = n * (1 - rate)
  96. np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
  97. def test_jacobian(self):
  98. W = np.random.RandomState(42069).random((10, 5)).astype(np.float32)
  99. x = np.random.RandomState(69420).random((1, 10)).astype(np.float32)
  100. torch_x = torch.tensor(x, requires_grad=True)
  101. torch_W = torch.tensor(W, requires_grad=True)
  102. def torch_func(x): return torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
  103. PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
  104. tiny_x = Tensor(x, requires_grad=True)
  105. tiny_W = Tensor(W, requires_grad=True)
  106. def tiny_func(x): return x.dot(tiny_W).relu().log_softmax()
  107. J = jacobian(tiny_func, tiny_x)
  108. NJ = numerical_jacobian(tiny_func, tiny_x)
  109. np.testing.assert_allclose(PJ, J, atol = 1e-5)
  110. np.testing.assert_allclose(PJ, NJ, atol = 1e-3)
  111. def test_gradcheck(self):
  112. W = np.random.RandomState(1337).random((10, 5)).astype(np.float32)
  113. x = np.random.RandomState(7331).random((1, 10)).astype(np.float32)
  114. tiny_x = Tensor(x, requires_grad=True)
  115. tiny_W = Tensor(W, requires_grad=True)
  116. def tiny_func(x): return x.dot(tiny_W).relu().log_softmax()
  117. self.assertTrue(gradcheck(tiny_func, tiny_x, eps = 1e-3))
  118. # coarse approx. since a "big" eps and the non-linearities of the model
  119. self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 1e-5))
  120. def test_random_fns_are_deterministic_with_seed(self):
  121. for random_fn in [Tensor.randn, Tensor.normal, Tensor.uniform, Tensor.scaled_uniform, Tensor.glorot_uniform, Tensor.kaiming_normal]:
  122. with self.subTest(msg=f"Tensor.{random_fn.__name__}"):
  123. Tensor.manual_seed(1337)
  124. a = random_fn(10,10).realize()
  125. Tensor.manual_seed(1337)
  126. b = random_fn(10,10).realize()
  127. np.testing.assert_allclose(a.numpy(), b.numpy())
  128. def test_randn_isnt_inf_on_zero(self):
  129. # simulate failure case of rand handing a zero to randn
  130. original_rand, Tensor.rand = Tensor.rand, Tensor.zeros
  131. try: self.assertNotIn(np.inf, Tensor.randn(16).numpy())
  132. except: raise
  133. finally: Tensor.rand = original_rand
  134. def test_zeros_like_has_same_dtype_and_shape(self):
  135. for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]:
  136. a = Tensor([1, 2, 3], dtype=datatype)
  137. b = Tensor.zeros_like(a)
  138. assert a.dtype == b.dtype, f"dtype mismatch {a.dtype=} != {b.dtype}"
  139. assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
  140. a = Tensor([1, 2, 3])
  141. b = Tensor.zeros_like(a, dtype=dtypes.int8)
  142. assert a.dtype == dtypes.default_int and b.dtype == dtypes.int8, "a.dtype should be int and b.dtype should be char"
  143. assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
  144. def test_ones_like_has_same_dtype_and_shape(self):
  145. for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]:
  146. a = Tensor([1, 2, 3], dtype=datatype)
  147. b = Tensor.ones_like(a)
  148. assert a.dtype == b.dtype, f"dtype mismatch {a.dtype=} != {b.dtype}"
  149. assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
  150. a = Tensor([1, 2, 3])
  151. b = Tensor.ones_like(a, dtype=dtypes.int8)
  152. assert a.dtype == dtypes.default_int and b.dtype == dtypes.int8, "a.dtype should be int and b.dtype should be char"
  153. assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
  154. def test_ndim(self):
  155. assert Tensor(1).ndim == 0
  156. assert Tensor.randn(1).ndim == 1
  157. assert Tensor.randn(2,2,2).ndim == 3
  158. assert Tensor.randn(1,1,1,1,1,1).ndim == 6
  159. def test_argfix(self):
  160. for f in [Tensor.zeros, Tensor.ones, Tensor.rand, Tensor.randn, Tensor.empty]:
  161. self.assertEqual(f().shape, ())
  162. self.assertEqual(f(1).shape, (1,))
  163. self.assertEqual(f(10,20,40).shape, (10,20,40))
  164. self.assertEqual(f([]).shape, ())
  165. self.assertEqual(f([1]).shape, (1,))
  166. self.assertEqual(f([10,20,40]).shape, (10,20,40))
  167. self.assertEqual(f(()).shape, ())
  168. self.assertEqual(f((1,)).shape, (1,))
  169. self.assertEqual(f((10,20,40)).shape, (10,20,40))
  170. with self.assertRaises(ValueError): f((2, 2), 2, 2)
  171. with self.assertRaises(ValueError): f((2, 2), (2, 2))
  172. with self.assertRaises(ValueError): f((128, 128), 0.0, 0.01)
  173. def test_numel(self):
  174. assert Tensor.randn(10, 10).numel() == 100
  175. assert Tensor.randn(1,2,5).numel() == 10
  176. assert Tensor.randn(1,1,1,1,1,1).numel() == 1
  177. assert Tensor([]).numel() == 0
  178. assert Tensor.randn(1,0,2,5).numel() == 0
  179. assert Tensor(3).numel() == 1
  180. def test_len(self):
  181. assert len(torch.zeros(7)) == len(Tensor.zeros(7))
  182. assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20))
  183. assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20,30))
  184. assert len(torch.zeros(1).flatten()) == len(Tensor.zeros(1).flatten())
  185. with self.assertRaises(TypeError): len(Tensor(3))
  186. def test_size(self):
  187. t1, t2 = torch.zeros(10,20), Tensor.zeros(10,20)
  188. assert t1.size() == t2.size()
  189. assert t1.size(0) == t2.size(0)
  190. assert t1.size(1) == t2.size(1)
  191. assert t1.size(-1) == t2.size(-1)
  192. assert t1.size(-2) == t2.size(-2)
  193. with self.assertRaises(IndexError): t2.size(2)
  194. def test_tolist(self):
  195. # NOTE: float16 Tensor.tolist() requires python 3.12
  196. for arr in [[1,2,3], [1.5,2,3], [[1,2,3], [4,5,6]], 3]:
  197. assert Tensor(arr).tolist() == torch.tensor(arr).tolist() == arr
  198. def test_element_size(self):
  199. for _, dtype in dtypes.fields().items():
  200. assert dtype.itemsize == Tensor.randn(3, dtype=dtype).element_size(), f"Tensor.element_size() not matching Tensor.dtype.itemsize for {dtype}"
  201. def test_deepwalk_ctx_check(self):
  202. layer = Tensor.uniform(1, 1, requires_grad=True)
  203. x = Tensor.randn(1, 1, 1)
  204. x.dot(layer).mean().backward()
  205. x = Tensor.randn(1, 1, 1)
  206. x.dot(layer).mean().backward()
  207. def test_zerosized_tensors(self):
  208. np.testing.assert_equal(Tensor([]).numpy(), np.array([]))
  209. np.testing.assert_equal(Tensor(None).numpy(), np.array([]))
  210. def test_tensor_ndarray_dtype(self):
  211. arr = np.array([1]) # where dtype is implicitly int64
  212. assert Tensor(arr).dtype == dtypes.int64
  213. assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
  214. assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
  215. def test_tensor_list_dtype(self):
  216. for arr in ([1], [[[1]]], [[1,1],[1,1]], [[[1,1],[1,1]],[[1,1],[1,1]]]):
  217. assert Tensor(arr).dtype == dtypes.default_int
  218. assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32
  219. assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64
  220. for arr in ([True], [[[False]]], [[True,False],[True,False]], [[[False,True],[False,False]],[[True,True],[False,True]]]):
  221. assert Tensor(arr).dtype == dtypes.bool
  222. assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32
  223. assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64
  224. # empty tensor defaults
  225. for arr in ([], [[[]]], [[],[]]):
  226. t = Tensor(arr)
  227. assert t.dtype == dtypes.default_float
  228. np.testing.assert_allclose(t.numpy(), np.array(arr))
  229. # mixture of bool and int
  230. for arr in ([True, 3], [[True],[3]], [[[True]], [[3]]], [[True, 3], [3, True]]):
  231. t = Tensor(arr)
  232. assert t.dtype == dtypes.default_int
  233. np.testing.assert_allclose(t.numpy(), np.array(arr))
  234. # mixture of bool, int and float
  235. for arr in ([[True,True],[3.,True]], [[0,1],[3.,4]], [[[0],[1]],[[3.],[4]]], [[[True],[1]],[[3.],[4]]]):
  236. t = Tensor(arr)
  237. assert t.dtype == dtypes.default_float
  238. np.testing.assert_allclose(t.numpy(), np.array(arr))
  239. def test_tensor_list_shapes(self):
  240. self.assertEqual(Tensor([[[]]]).shape, (1,1,0))
  241. self.assertEqual(Tensor([[],[]]).shape, (2,0))
  242. self.assertEqual(Tensor([[[[]],[[]]], [[[]],[[]]], [[[]],[[]]]]).shape, (3,2,1,0))
  243. def test_tensor_list_errors(self):
  244. # inhomogeneous shape
  245. with self.assertRaises(ValueError): Tensor([[],[[]]])
  246. with self.assertRaises(ValueError): Tensor([[1],[]])
  247. with self.assertRaises(ValueError): Tensor([[1],[1],1])
  248. with self.assertRaises(ValueError): Tensor([[[1,1,1],[1,1]]])
  249. with self.assertRaises(ValueError): Tensor([[1,1,1],[[1,1,1]]])
  250. def test_tensor_mixed_list_tuple(self):
  251. def _list_or_tuple(): return list if random.random() < 0.5 else tuple
  252. def _generate_data(depth):
  253. if depth == 0: return _list_or_tuple()()
  254. if depth == 1: return _list_or_tuple()([random.random(), random.random()])
  255. return _list_or_tuple()([_generate_data(depth-1), _generate_data(depth-1)])
  256. for depth in range(7):
  257. for _ in range(20):
  258. data = _generate_data(depth)
  259. np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))
  260. def test_tensor_list_special_values(self):
  261. if is_dtype_supported(dtypes.float16):
  262. data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]
  263. data = data + [-x for x in data]
  264. np.testing.assert_allclose(Tensor(data, dtype=dtypes.float16).numpy(), np.array(data).astype(np.float16))
  265. # uint32
  266. data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
  267. data = data + [-x for x in data]
  268. np.testing.assert_allclose(Tensor(data, dtype=dtypes.uint32).numpy(), np.array(data).astype(np.uint32))
  269. # int32
  270. data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
  271. data = data + [-x for x in data]
  272. np.testing.assert_allclose(Tensor(data, dtype=dtypes.int32).numpy(), np.array(data).astype(np.int32))
  273. def test_tensor_list_ndarray(self):
  274. data = [np.array([1, 2, 3]), np.array([1, 2, 3]), np.array([1, 2, 3])]
  275. np.testing.assert_equal(Tensor(data).numpy(), np.array(data))
  276. data = [np.array([1.0, 2.0, 3.0]), np.array([1, 2, 3]), np.array([1, 2, 3])]
  277. np.testing.assert_equal(Tensor(data).numpy(), np.array(data))
  278. data = [np.array(1.0), np.array(2.0), np.array(3.0)]
  279. np.testing.assert_equal(Tensor(data).numpy(), np.array(data))
  280. def test_tensor_bytes(self):
  281. data = b"abc123"
  282. t = Tensor(data)
  283. assert t.dtype == dtypes.uint8
  284. assert t.shape == (6,)
  285. np.testing.assert_equal(t.numpy(), list(data))
  286. def test_tensor_copy(self):
  287. x = copy.deepcopy(Tensor.ones((3,3,3)))
  288. np.testing.assert_allclose(x.numpy(), np.ones((3,3,3)))
  289. def test_copy_from_disk(self):
  290. t = Tensor.randn(30).to(f"disk:{temp('test_copy_from_disk')}")
  291. a = t[10:20]
  292. dev = a.to(Device.DEFAULT)
  293. np.testing.assert_allclose(a.numpy(), dev.numpy())
  294. # Regression test for https://github.com/tinygrad/tinygrad/issues/1751
  295. def test_copy_from_numpy_unaligned(self):
  296. # 2**15 is the minimum for repro
  297. arr = np.random.randn(2**15).astype(np.float32)
  298. fn = temp('test_copy_from_numpy_unaligned')
  299. with open(fn, 'wb') as f: f.write(b't' + arr.tobytes())
  300. with open(fn, "a+b") as f: memview = memoryview(mmap.mmap(f.fileno(), arr.nbytes + 1))
  301. ua_arr = np.frombuffer(memview[1:], dtype=arr.dtype, count=arr.shape[0])
  302. np.testing.assert_allclose(arr, ua_arr)
  303. assert not ua_arr.flags.aligned
  304. # force device copy - to() is opt'd away - Tensor(dev)/1 is ignored
  305. np.testing.assert_allclose(ua_arr, (Tensor(ua_arr)/Tensor(1)).numpy())
  306. def test_item_to_tensor_to_item(self):
  307. for a in [0, 1, 2, 3, -1, -100, 100, -101.1, 2.345, 100.1, True, False]:
  308. item = Tensor(a).item()
  309. assert type(item) is type(a), a
  310. np.testing.assert_allclose(item, a), a
  311. buffered_item = Tensor([a]).item()
  312. assert type(buffered_item) is type(a), a
  313. np.testing.assert_allclose(buffered_item, a), a
  314. reshaped_item = Tensor([a]).reshape((1, 1, 1, 1, 1)).item()
  315. assert type(reshaped_item) is type(a), a
  316. np.testing.assert_allclose(reshaped_item, a), a
  317. def test_no_bool(self):
  318. with self.assertRaises(TypeError):
  319. if Tensor(3):
  320. print("hi")
  321. with self.assertRaises(TypeError):
  322. _a = Tensor([3]) in [Tensor([3]), Tensor([4]), Tensor([5])]
  323. def test_repr_with_grad(self):
  324. a = Tensor([1], requires_grad=True)
  325. b = Tensor([1])
  326. c = (a + b).mean().backward()
  327. print(a)
  328. print(c)
  329. @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI")
  330. class TestMoveTensor(unittest.TestCase):
  331. d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
  332. @given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
  333. strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
  334. def test_to_preserves(self, src, dest, dtype, requires_grad):
  335. s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
  336. if requires_grad: s.sum().backward()
  337. t = s.to(dest)
  338. np.testing.assert_equal(s.numpy(), t.numpy())
  339. assert s.dtype == t.dtype
  340. assert s.requires_grad == t.requires_grad
  341. if requires_grad:
  342. np.testing.assert_equal(s.grad.numpy(), t.grad.numpy())
  343. @given(strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
  344. def test_shard_preserves(self, dtype, requires_grad):
  345. s = Tensor([1, 2, 3], dtype=dtype, requires_grad=requires_grad)
  346. t = s.shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"))
  347. np.testing.assert_equal(s.numpy(), t.numpy())
  348. assert s.dtype == t.dtype
  349. assert s.requires_grad == t.requires_grad
  350. @given(strat.sampled_from([d0, d1]))
  351. def test_same_dev(self, dev):
  352. x = Tensor([1,2,3], device=dev)
  353. y = x.to(dev)
  354. assert x is y
  355. def test_to_grad(self):
  356. x = Tensor.eye(3, requires_grad=True, device=self.d0)
  357. y = Tensor([[2.0,0,-2.0]], requires_grad=True, device=self.d0)
  358. z = y.matmul(x).to(self.d1).sum()
  359. z.backward()
  360. np.testing.assert_equal(x.grad.numpy(), [[2,2,2],[0,0,0],[-2,-2,-2]])
  361. class TestZeroShapeTensor(unittest.TestCase):
  362. def test_shape_stride(self):
  363. t = Tensor.empty(3, 2, 0)
  364. assert t.shape == (3, 2, 0)
  365. # numpy has stride 0, 0, 0; torch has stride 2, 1, 1
  366. assert t.lazydata.st.real_strides() == (0, 0, 0)
  367. t = Tensor.empty(3, 0, 2)
  368. assert t.shape == (3, 0, 2)
  369. # numpy has stride 0, 0, 0; torch has stride 2, 2, 1
  370. assert t.lazydata.st.real_strides() == (0, 0, 0)
  371. t = Tensor.empty(0, 0, 0)
  372. assert t.shape == (0, 0, 0)
  373. # numpy has stride 0, 0, 0; torch has stride 1, 1, 1
  374. assert t.lazydata.st.real_strides() == (0, 0, 0)
  375. def test_rand(self):
  376. t = Tensor.rand(3, 2, 0)
  377. assert t.shape == (3, 2, 0)
  378. np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0)))
  379. t = Tensor.rand(0)
  380. assert t.shape == (0,)
  381. np.testing.assert_equal(t.numpy(), np.zeros((0,)))
  382. t = Tensor.rand(0, 0, 0)
  383. assert t.shape == (0, 0, 0)
  384. np.testing.assert_equal(t.numpy(), np.zeros((0, 0, 0)))
  385. def test_full(self):
  386. t = Tensor.zeros(3, 2, 0)
  387. assert t.shape == (3, 2, 0)
  388. np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0)))
  389. t = Tensor.full((3, 2, 0), 12)
  390. assert t.shape == (3, 2, 0)
  391. np.testing.assert_equal(t.numpy(), np.full((3, 2, 0), 12))
  392. def test_reshape(self):
  393. t = Tensor.zeros(3, 2, 0)
  394. a = t.reshape(7, 0)
  395. assert a.shape == (7, 0)
  396. np.testing.assert_equal(a.numpy(), np.zeros((7, 0)))
  397. a = t.reshape(0)
  398. assert a.shape == (0,)
  399. np.testing.assert_equal(a.numpy(), np.zeros((0,)))
  400. with self.assertRaises(AssertionError):
  401. # cannot reshape from size 0 to size 1
  402. a = t.reshape(())
  403. def test_expand(self):
  404. t = Tensor.full((1, 2, 0), 12).expand((6, 2, 0))
  405. assert t.shape == (6, 2, 0)
  406. np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12))
  407. def test_pad(self):
  408. t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), value=1)
  409. assert t.shape == (3, 2, 2)
  410. np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2)))
  411. t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), value=1)
  412. assert t.shape == (3, 4, 0)
  413. np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0)))
  414. t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), value=1)
  415. assert t.shape == (5, 2, 0)
  416. np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0)))
  417. def test_shrink_into_zero(self):
  418. t = Tensor.rand(3, 4).realize()
  419. assert t.shrink((None, (2, 2))).realize().shape == (3, 0)
  420. assert t.shrink(((2, 2), None)).realize().shape == (0, 4)
  421. assert t.shrink(((2, 2), (2, 2))).realize().shape == (0, 0)
  422. def test_cat(self):
  423. a = Tensor.rand(3, 2, 2)
  424. b = Tensor.rand(3, 2, 0)
  425. t = a.cat(b, dim=2)
  426. assert t.shape == (3, 2, 2)
  427. np.testing.assert_equal(t.numpy(), a.numpy())
  428. t = b.cat(a, dim=2)
  429. assert t.shape == (3, 2, 2)
  430. np.testing.assert_equal(t.numpy(), a.numpy())
  431. t = b.cat(b, dim=0)
  432. assert t.shape == (6, 2, 0)
  433. np.testing.assert_equal(t.numpy(), np.zeros((6, 2, 0)))
  434. t = b.cat(b, dim=1)
  435. assert t.shape == (3, 4, 0)
  436. np.testing.assert_equal(t.numpy(), np.zeros((3, 4, 0)))
  437. t = b.cat(b, dim=2)
  438. assert t.shape == (3, 2, 0)
  439. np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0)))
  440. def test_elementwise(self):
  441. a = Tensor.rand(3, 2, 0)
  442. a_exp = a.exp()
  443. assert a_exp.shape == (3, 2, 0)
  444. np.testing.assert_equal(a_exp.numpy(), np.exp(a.numpy()))
  445. b = Tensor.rand(3, 2, 0)
  446. assert b.shape == (3, 2, 0)
  447. ab = a * b
  448. assert ab.shape == (3, 2, 0)
  449. np.testing.assert_equal(ab.numpy(), a.numpy() * b.numpy())
  450. mask = (Tensor.rand(3, 2, 0) > 0.5)
  451. assert mask.shape == (3, 2, 0)
  452. c = mask.where(a, b)
  453. assert c.shape == (3, 2, 0)
  454. np.testing.assert_equal(c.numpy(), np.where(mask.numpy(), a.numpy(), b.numpy()))
  455. def test_reduce_over_non_zero(self):
  456. a = Tensor.ones(3, 2, 0).sum(axis=1)
  457. assert a.shape == (3, 0)
  458. np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=1))
  459. def test_reduce_over_zero(self):
  460. a = Tensor.ones(3, 2, 0).sum(axis=2)
  461. assert a.shape == (3, 2)
  462. np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2))
  463. a = Tensor.ones(3, 2, 0).sum(axis=2, keepdim=True)
  464. assert a.shape == (3, 2, 1)
  465. np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2, keepdims=True))
  466. def test_reduce_default(self):
  467. np.testing.assert_equal(Tensor([]).max().numpy(), -float("inf"))
  468. np.testing.assert_equal(Tensor([]).min().numpy(), float("inf"))
  469. np.testing.assert_equal(Tensor([]).sum().numpy(), 0)
  470. np.testing.assert_equal(Tensor([]).mean().numpy(), float("nan"))
  471. class TestTensorCreationDevice(unittest.TestCase):
  472. # test auxiliary tensors are created on the same device
  473. def test_one_hot(self):
  474. y = Tensor([1, 2, 3]).to("CLANG")
  475. x = y.one_hot(10)
  476. x.realize()
  477. class TestTrainMode(unittest.TestCase):
  478. def test_train_mode(self):
  479. assert not Tensor.training
  480. @Tensor.train()
  481. def f():
  482. assert Tensor.training
  483. f()
  484. assert not Tensor.training
  485. class TestInferenceMode(unittest.TestCase):
  486. def test_inference_mode(self):
  487. x = Tensor(x_init, requires_grad=True)
  488. m = Tensor(m_init, requires_grad=True)
  489. W = Tensor(W_init, requires_grad=True)
  490. with Tensor.inference_mode():
  491. tmp = x.mul(m)
  492. mm = tmp.matmul(W)
  493. out = mm.relu()
  494. out = out.sum()
  495. out.backward()
  496. assert x.grad is None
  497. assert m.grad is None
  498. assert tmp.grad is None
  499. assert mm.grad is None
  500. assert W.grad is None
  501. assert W.requires_grad
  502. def test_no_grad_mode_context_manager(self):
  503. x = Tensor(x_init, requires_grad=True)
  504. m = Tensor(m_init, requires_grad=True)
  505. W = Tensor(W_init, requires_grad=True)
  506. @Tensor.inference_mode()
  507. def f(x, m, W):
  508. tmp = x.mul(m)
  509. mm = tmp.matmul(W)
  510. out = mm.relu()
  511. out = out.sum()
  512. out.backward()
  513. assert x.grad is None
  514. assert m.grad is None
  515. assert tmp.grad is None
  516. assert mm.grad is None
  517. assert W.grad is None
  518. f(x, m, W)
  519. class TestTensorMetadata(unittest.TestCase):
  520. def test_matmul(self):
  521. _METADATA.set(None)
  522. x = Tensor.rand(3, requires_grad=True)
  523. W = Tensor.rand(3, 3, requires_grad=True)
  524. out = x.matmul(W)
  525. assert out.lazydata.metadata.name == "matmul"
  526. s = create_schedule([out.lazydata])
  527. assert len(s[-1].metadata) == 1
  528. assert s[-1].metadata[0].name == "matmul"
  529. def test_relu(self):
  530. _METADATA.set(None)
  531. x = Tensor.rand(3, requires_grad=True)
  532. out = x.relu()
  533. assert out.lazydata.metadata.name == "relu"
  534. s = create_schedule([out.lazydata])
  535. assert len(s[-1].metadata) == 1
  536. assert s[-1].metadata[0].name == "relu"
  537. def test_complex(self):
  538. _METADATA.set(None)
  539. x = Tensor.rand(3, requires_grad=True)
  540. y = Tensor.rand(3, requires_grad=True)
  541. out = x.relu() * y.sigmoid()
  542. assert out.lazydata.metadata.name == "__mul__"
  543. assert out.lazydata.srcs[0].metadata.name == "relu"
  544. assert out.lazydata.srcs[1].metadata.name == "sigmoid"
  545. s = create_schedule([out.lazydata])
  546. assert len(s[-1].metadata) == 3
  547. assert s[-1].metadata[0].name == "relu"
  548. assert s[-1].metadata[1].name == "sigmoid"
  549. assert s[-1].metadata[2].name == "__mul__"
  550. def test_complex_backward(self):
  551. _METADATA.set(None)
  552. x = Tensor.rand(3, requires_grad=True)
  553. y = Tensor.rand(3, requires_grad=True)
  554. out = (x.relu() * y.sigmoid()).sum()
  555. assert out.lazydata.metadata.name == "sum"
  556. out.backward()
  557. assert x.grad.lazydata.metadata.name == "relu"
  558. assert x.grad.lazydata.metadata.backward
  559. assert y.grad.lazydata.metadata.name == "sigmoid"
  560. assert y.grad.lazydata.metadata.backward
  561. s = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])
  562. assert len(s[-1].metadata) == 3
  563. assert s[-1].metadata[0].name == "sigmoid"
  564. assert s[-1].metadata[1].name == "sigmoid"
  565. assert s[-1].metadata[1].backward
  566. assert s[-1].metadata[2].name == "relu"
  567. if __name__ == '__main__':
  568. unittest.main()