test_indexing.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579
  1. # test cases are modified from pytorch test_indexing.py https://github.com/pytorch/pytorch/blob/597d3fb86a2f3b8d6d8ee067e769624dcca31cdb/test/test_indexing.py
  2. import unittest, random, copy, warnings
  3. import numpy as np
  4. from tinygrad import Tensor, dtypes, Device, TinyJit
  5. from tinygrad.shape.shapetracker import ShapeTracker
  6. from tinygrad.shape.view import View
  7. from tinygrad.helpers import CI, all_same, prod
  8. random.seed(42)
  9. def numpy_testing_assert_equal_helper(a, b):
  10. if isinstance(a, Tensor): a = a.numpy()
  11. if isinstance(b, Tensor): b = b.numpy()
  12. np.testing.assert_equal(a, b)
  13. def consec(shape, start=1):
  14. return Tensor.arange(prod(shape)).reshape(shape)+start
  15. # creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
  16. def set_(reference: Tensor, shape, strides, offset):
  17. if reference.lazydata.base.realized is None: reference.realize()
  18. assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
  19. strided = Tensor(reference.lazydata._view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
  20. assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
  21. return strided
  22. def clone(original:Tensor): return copy.copy(original)
  23. def copy_(src:Tensor, other:Tensor) -> Tensor: return copy.copy(src)
  24. # this is fine for tested usecases since as geohotstan understands,
  25. # data_ptr is used to compare if operations needed between tensors is the same
  26. def data_ptr(tensor:Tensor): return tensor.lazydata
  27. # https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html
  28. # TODO this is setitem
  29. def index_put_(tensor:Tensor, indices, values, accumulate) -> Tensor:
  30. tensor[indices] = values
  31. # https://pytorch.org/docs/stable/generated/torch.argsort.html
  32. def argsort(tensor:Tensor) -> Tensor:
  33. pass
  34. # https://pytorch.org/docs/stable/generated/torch.all.html
  35. def all_(tensor:Tensor) -> Tensor:
  36. return tensor != 0
  37. # https://pytorch.org/docs/stable/generated/torch.diagonal.html
  38. def diagonal(tensor:Tensor) -> Tensor:
  39. assert tensor.ndim == 2 and all_same(tensor.shape), 'only support 2 ndim square tensors'
  40. return (Tensor.eye(tensor.shape[0]) * tensor).sum(0)
  41. # https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html
  42. def unravel_index(tensor, shape):
  43. pass
  44. # https://github.com/pytorch/pytorch/blob/79811e765c23242210ebdc623539d2103a166463/torch/testing/_creation.py#L38
  45. def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor:
  46. r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
  47. values uniformly drawn from ``[low, high)``.
  48. If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
  49. finite values then they are clamped to the lowest or highest representable finite value, respectively.
  50. If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
  51. which depend on :attr:`dtype`.
  52. +---------------------------+------------+----------+
  53. | ``dtype`` | ``low`` | ``high`` |
  54. +===========================+============+==========+
  55. | boolean type | ``0`` | ``2`` |
  56. +---------------------------+------------+----------+
  57. | unsigned integral type | ``0`` | ``10`` |
  58. +---------------------------+------------+----------+
  59. | signed integral types | ``-9`` | ``10`` |
  60. +---------------------------+------------+----------+
  61. | floating types | ``-9`` | ``9`` |
  62. +---------------------------+------------+----------+
  63. | complex types | ``-9`` | ``9`` |
  64. +---------------------------+------------+----------+
  65. """
  66. contiguous = not noncontiguous
  67. if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
  68. elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype)
  69. elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int
  70. elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
  71. else: raise NotImplementedError(f"{dtype} not implemented")
  72. class TestIndexing(unittest.TestCase):
  73. def test_index(self):
  74. reference = consec((3, 3, 3))
  75. numpy_testing_assert_equal_helper(reference[0], consec((3, 3)))
  76. numpy_testing_assert_equal_helper(reference[1], consec((3, 3), 10))
  77. numpy_testing_assert_equal_helper(reference[2], consec((3, 3), 19))
  78. numpy_testing_assert_equal_helper(reference[0, 1], consec((3,), 4))
  79. numpy_testing_assert_equal_helper(reference[0:2], consec((2, 3, 3)))
  80. numpy_testing_assert_equal_helper(reference[2, 2, 2], 27)
  81. numpy_testing_assert_equal_helper(reference[:], consec((3, 3, 3)))
  82. # indexing with Ellipsis
  83. numpy_testing_assert_equal_helper(reference[..., 2], np.array([[3., 6., 9.],[12., 15., 18.],[21., 24., 27.]]))
  84. numpy_testing_assert_equal_helper(reference[0, ..., 2], np.array([3., 6., 9.]))
  85. numpy_testing_assert_equal_helper(reference[..., 2], reference[:, :, 2])
  86. numpy_testing_assert_equal_helper(reference[0, ..., 2], reference[0, :, 2])
  87. numpy_testing_assert_equal_helper(reference[0, 2, ...], reference[0, 2])
  88. numpy_testing_assert_equal_helper(reference[..., 2, 2, 2], 27)
  89. numpy_testing_assert_equal_helper(reference[2, ..., 2, 2], 27)
  90. numpy_testing_assert_equal_helper(reference[2, 2, ..., 2], 27)
  91. numpy_testing_assert_equal_helper(reference[2, 2, 2, ...], 27)
  92. numpy_testing_assert_equal_helper(reference[...], reference)
  93. reference_5d = consec((3, 3, 3, 3, 3))
  94. numpy_testing_assert_equal_helper(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0])
  95. numpy_testing_assert_equal_helper(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0])
  96. numpy_testing_assert_equal_helper(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1])
  97. numpy_testing_assert_equal_helper(reference_5d[...], reference_5d)
  98. # None indexing
  99. numpy_testing_assert_equal_helper(reference[2, None], reference[2].unsqueeze(0))
  100. numpy_testing_assert_equal_helper(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0))
  101. numpy_testing_assert_equal_helper(reference[2:4, None], reference[2:4].unsqueeze(1))
  102. numpy_testing_assert_equal_helper(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
  103. numpy_testing_assert_equal_helper(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))
  104. # indexing 0-length slice
  105. numpy_testing_assert_equal_helper(np.empty((0, 3, 3)), reference[slice(0)])
  106. numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[slice(0), 2])
  107. numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[2, slice(0)])
  108. numpy_testing_assert_equal_helper(np.empty([]), reference[2, 1:1, 2])
  109. # indexing with step
  110. reference = consec((10, 10, 10))
  111. numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack(reference[1], reference[3], dim=0))
  112. numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack(reference[1], reference[3], reference[5], dim=0))
  113. numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack(reference[1], reference[5], dim=0))
  114. numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack(reference[2:4, 1], reference[2:4, 3], dim=1))
  115. numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack(reference[3, 1], reference[3, 3], reference[3, 5], dim=0))
  116. numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack(reference[2, 1], reference[2, 5], dim=0).unsqueeze(0))
  117. numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack(reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5], dim=1))
  118. lst = [list(range(i, i+10)) for i in range(0, 100, 10)]
  119. tensor = Tensor(lst)
  120. for _ in range(100):
  121. idx1_start = random.randrange(10)
  122. idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
  123. idx1_step = random.randrange(1, 8)
  124. idx1 = slice(idx1_start, idx1_end, idx1_step)
  125. if random.randrange(2) == 0:
  126. idx2_start = random.randrange(10)
  127. idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
  128. idx2_step = random.randrange(1, 8)
  129. idx2 = slice(idx2_start, idx2_end, idx2_step)
  130. lst_indexed = [l[idx2] for l in lst[idx1]]
  131. tensor_indexed = tensor[idx1, idx2]
  132. else:
  133. lst_indexed = lst[idx1]
  134. tensor_indexed = tensor[idx1]
  135. numpy_testing_assert_equal_helper(tensor_indexed, np.array(lst_indexed))
  136. self.assertRaises(ValueError, lambda: reference[1:9:0])
  137. # NOTE torch doesn't support this but numpy does so we should too. Torch raises ValueError
  138. # see test_slice_negative_strides in test_ops.py
  139. # self.assertRaises(ValueError, lambda: reference[1:9:-1])
  140. self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
  141. self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
  142. self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
  143. self.assertRaises(IndexError, lambda: reference[0.0])
  144. self.assertRaises(TypeError, lambda: reference[0.0:2.0])
  145. self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
  146. self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
  147. self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
  148. self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
  149. # TODO: delitem
  150. # def delitem(): del reference[0]
  151. # self.assertRaises(TypeError, delitem)
  152. # TODO: LLVM is quite fast, why are other compiled backends slow?
  153. @unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "GPU", "METAL", "NV", "AMD"], "slow")
  154. def test_advancedindex(self):
  155. # integer array indexing
  156. # pick a random valid indexer type
  157. def ri(indices):
  158. choice = random.randint(0, 2)
  159. if choice == 0: return Tensor(indices)
  160. if choice == 1: return list(indices)
  161. return tuple(indices)
  162. def validate_indexing(x):
  163. numpy_testing_assert_equal_helper(x[[0]], consec((1,)))
  164. numpy_testing_assert_equal_helper(x[ri([0]),], consec((1,)))
  165. numpy_testing_assert_equal_helper(x[ri([3]),], consec((1,), 4))
  166. numpy_testing_assert_equal_helper(x[[2, 3, 4]], consec((3,), 3))
  167. numpy_testing_assert_equal_helper(x[ri([2, 3, 4]),], consec((3,), 3))
  168. numpy_testing_assert_equal_helper(x[ri([0, 2, 4]),], np.array([1, 3, 5]))
  169. def validate_setting(x):
  170. x[[0]] = -2
  171. numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
  172. x[[0]] = -1
  173. numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
  174. x[[2, 3, 4]] = 4
  175. numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
  176. x[ri([2, 3, 4]), ] = 3
  177. numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
  178. x[ri([0, 2, 4]), ] = np.array([5, 4, 3])
  179. numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
  180. # Case 1: Purely Integer Array Indexing
  181. reference = consec((10,))
  182. validate_indexing(reference)
  183. # setting values
  184. # TODO: advanced setitem
  185. '''
  186. validate_setting(reference)
  187. '''
  188. # Tensor with stride != 1
  189. # strided is [1, 3, 5, 7]
  190. reference = consec((10,))
  191. strided = set_(reference, (4,), (2,), 0)
  192. numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
  193. numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
  194. numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
  195. numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
  196. numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
  197. numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
  198. np.array([[5, 3], [1, 7]]))
  199. # stride is [4, 8]
  200. strided = set_(reference, (2,), (4,), offset=4)
  201. numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
  202. numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
  203. numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
  204. numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
  205. numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
  206. numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
  207. np.array([[5, 9], [9, 5]]))
  208. # reference is 1 2
  209. # 3 4
  210. # 5 6
  211. reference = consec((3, 2))
  212. numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([1, 3, 5]))
  213. numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]))
  214. numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,)))
  215. numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6))
  216. numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
  217. numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
  218. numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
  219. rows = ri([[0, 0],
  220. [1, 2]])
  221. columns = [0],
  222. numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
  223. [3, 5]]))
  224. rows = ri([[0, 0],
  225. [1, 2]])
  226. columns = ri([1, 0])
  227. numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[2, 1],
  228. [4, 5]]))
  229. rows = ri([[0, 0],
  230. [1, 2]])
  231. columns = ri([[0, 1],
  232. [1, 0]])
  233. numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
  234. [4, 5]]))
  235. # TODO: advanced setitem
  236. '''
  237. # setting values
  238. reference[ri([0]), ri([1])] = -1
  239. numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
  240. reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
  241. numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
  242. np.array([-1, 2, -4]))
  243. reference[rows, columns] = np.array([[4, 6], [2, 3]])
  244. numpy_testing_assert_equal_helper(reference[rows, columns],
  245. np.array([[4, 6], [2, 3]]))
  246. '''
  247. # Verify still works with Transposed (i.e. non-contiguous) Tensors
  248. reference = Tensor([[0, 1, 2, 3],
  249. [4, 5, 6, 7],
  250. [8, 9, 10, 11]]).T
  251. # Transposed: [[0, 4, 8],
  252. # [1, 5, 9],
  253. # [2, 6, 10],
  254. # [3, 7, 11]]
  255. numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([0, 1, 2]))
  256. numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]))
  257. numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0]))
  258. numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6]))
  259. numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
  260. numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
  261. numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
  262. rows = ri([[0, 0],
  263. [1, 2]])
  264. columns = [0],
  265. numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
  266. rows = ri([[0, 0],
  267. [1, 2]])
  268. columns = ri([1, 0])
  269. numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[4, 0], [5, 2]]))
  270. rows = ri([[0, 0],
  271. [1, 3]])
  272. columns = ri([[0, 1],
  273. [1, 2]])
  274. numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))
  275. # TODO: advanced setitem
  276. '''
  277. # setting values
  278. reference[ri([0]), ri([1])] = -1
  279. numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
  280. np.array([-1]))
  281. reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
  282. numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
  283. np.array([-1, 2, -4]))
  284. reference[rows, columns] = np.array([[4, 6], [2, 3]])
  285. numpy_testing_assert_equal_helper(reference[rows, columns],
  286. np.array([[4, 6], [2, 3]]))
  287. '''
  288. # stride != 1
  289. # strided is [[1 3 5 7],
  290. # [9 11 13 15]]
  291. reference = Tensor.arange(0., 24).reshape(3, 8)
  292. strided = set_(reference, (2,4), (8,2), 1)
  293. numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])],
  294. np.array([1, 9]))
  295. numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])],
  296. np.array([3, 11]))
  297. numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])],
  298. np.array([1]))
  299. numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])],
  300. np.array([15]))
  301. numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]],
  302. np.array([1, 7]))
  303. numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
  304. np.array([9, 11, 11, 9, 15]))
  305. numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
  306. np.array([1, 3, 9, 9]))
  307. rows = ri([[0, 0],
  308. [1, 1]])
  309. columns = [0],
  310. numpy_testing_assert_equal_helper(strided[rows, columns],
  311. np.array([[1, 1], [9, 9]]))
  312. rows = ri([[0, 1],
  313. [1, 0]])
  314. columns = ri([1, 2])
  315. numpy_testing_assert_equal_helper(strided[rows, columns],
  316. np.array([[3, 13], [11, 5]]))
  317. rows = ri([[0, 0],
  318. [1, 1]])
  319. columns = ri([[0, 1],
  320. [1, 2]])
  321. numpy_testing_assert_equal_helper(strided[rows, columns],
  322. np.array([[1, 3], [11, 13]]))
  323. # setting values
  324. # strided is [[10, 11],
  325. # [17, 18]]
  326. reference = Tensor.arange(0., 24).reshape(3, 8)
  327. strided = set_(reference, (2,2), (7,1), 10)
  328. numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
  329. np.array([11]))
  330. # TODO advanced setitem
  331. '''
  332. strided[ri([0]), ri([1])] = -1
  333. numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
  334. Tensor([-1]))
  335. '''
  336. reference = Tensor.arange(0., 24).reshape(3, 8)
  337. strided = set_(reference, (2,2), (7,1), 10)
  338. numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
  339. np.array([11, 17]))
  340. # TODO advanced setitem
  341. '''
  342. strided[ri([0, 1]), ri([1, 0])] = Tensor([-1, 2])
  343. numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
  344. Tensor([-1, 2]))
  345. '''
  346. reference = Tensor.arange(0., 24).realize().reshape(3, 8)
  347. strided = set_(reference, (2,2), (7,1), 10)
  348. rows = ri([[0],
  349. [1]])
  350. columns = ri([[0, 1],
  351. [0, 1]])
  352. numpy_testing_assert_equal_helper(strided[rows, columns],
  353. np.array([[10, 11], [17, 18]]))
  354. # TODO advanced setitem
  355. '''
  356. strided[rows, columns] = Tensor([[4, 6], [2, 3]])
  357. numpy_testing_assert_equal_helper(strided[rows, columns],
  358. Tensor([[4, 6], [2, 3]]))
  359. '''
  360. # Tests using less than the number of dims, and ellipsis
  361. # reference is 1 2
  362. # 3 4
  363. # 5 6
  364. reference = consec((3, 2))
  365. numpy_testing_assert_equal_helper(reference[ri([0, 2]),], np.array([[1, 2], [5, 6]]))
  366. numpy_testing_assert_equal_helper(reference[ri([1]), ...], np.array([[3, 4]]))
  367. numpy_testing_assert_equal_helper(reference[..., ri([1])], np.array([[2], [4], [6]]))
  368. # verify too many indices fails
  369. with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])]
  370. # test invalid index fails
  371. reference = Tensor.empty(10)
  372. for err_idx in (10, -11):
  373. with self.assertRaises(IndexError):
  374. reference[err_idx]
  375. # NOTE cannot check for out of bounds with Tensor indexing
  376. # see tensor.py: __getitem__ (Tiny Things)
  377. '''
  378. with self.assertRaises(IndexError):
  379. reference[Tensor([err_idx], dtype=dtypes.int64)]
  380. with self.assertRaises(IndexError):
  381. reference[[err_idx]]
  382. '''
  383. def tensor_indices_to_np(tensor: Tensor, indices):
  384. npt = tensor.numpy()
  385. idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else
  386. i for i in indices)
  387. return npt, idxs
  388. def get_numpy(tensor, indices):
  389. npt, idxs = tensor_indices_to_np(tensor, indices)
  390. return Tensor(npt[idxs])
  391. def set_numpy(tensor:Tensor, indices, value):
  392. if not isinstance(value, int):
  393. value = value.numpy()
  394. npt, idxs = tensor_indices_to_np(tensor, indices)
  395. npt[idxs] = value
  396. return npt
  397. def assert_get_eq(tensor, indexer):
  398. numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
  399. def assert_set_eq(tensor: Tensor, indexer, val):
  400. pyt = clone(tensor)
  401. numt = clone(tensor)
  402. pyt[indexer] = val
  403. numt = set_numpy(numt, indexer, val)
  404. numpy_testing_assert_equal_helper(pyt, numt)
  405. # NOTE: torch initiates the gradients using g0cpu (rand as gradients)
  406. def assert_backward_eq(tensor: Tensor, indexer):
  407. cpu = clone(tensor.float())
  408. cpu.requires_grad = True
  409. outcpu = cpu[indexer].sum()
  410. outcpu.backward()
  411. dev = cpu.detach()
  412. dev.requires_grad = True
  413. outdev = dev[indexer].sum()
  414. outdev.backward()
  415. numpy_testing_assert_equal_helper(cpu.grad, dev.grad)
  416. def get_set_tensor(indexed: Tensor, indexer):
  417. set_size = indexed[indexer].shape
  418. set_count = indexed[indexer].numel()
  419. set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size).cast(dtypes.float64)
  420. return set_tensor
  421. # Tensor is 0 1 2 3 4
  422. # 5 6 7 8 9
  423. # 10 11 12 13 14
  424. # 15 16 17 18 19
  425. reference = Tensor.arange(0., 20).reshape(4, 5)
  426. indices_to_test = [
  427. # grab the second, fourth columns
  428. [slice(None), [1, 3]],
  429. # first, third rows,
  430. [[0, 2], slice(None)],
  431. # weird shape
  432. [slice(None), [[0, 1],
  433. [2, 3]]],
  434. # negatives
  435. [[-1], [0]],
  436. [[0, 2], [-1]],
  437. [slice(None), [-1]],
  438. ]
  439. # only test dupes on gets
  440. get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
  441. for indexer in get_indices_to_test:
  442. assert_get_eq(reference, indexer)
  443. assert_backward_eq(reference, indexer)
  444. # TODO advanced setitem
  445. '''
  446. for indexer in indices_to_test:
  447. assert_set_eq(reference, indexer, 44)
  448. assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
  449. '''
  450. reference = Tensor.arange(0., 160).reshape(4, 8, 5)
  451. indices_to_test = [
  452. [slice(None), slice(None), [0, 3, 4]],
  453. [slice(None), [2, 4, 5, 7], slice(None)],
  454. [[2, 3], slice(None), slice(None)],
  455. [slice(None), [0, 2, 3], [1, 3, 4]],
  456. [slice(None), [0], [1, 2, 4]],
  457. [slice(None), [0, 1, 3], [4]],
  458. [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
  459. [slice(None), [[0, 1], [2, 3]], [[0]]],
  460. [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
  461. [[0, 2, 3], [1, 3, 4], slice(None)],
  462. [[0], [1, 2, 4], slice(None)],
  463. [[0, 1, 3], [4], slice(None)],
  464. [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
  465. [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
  466. [[[0, 1], [2, 3]], [[0]], slice(None)],
  467. [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
  468. [[[2]], [[0, 3], [4, 1]], slice(None)],
  469. # non-contiguous indexing subspace
  470. [[0, 2, 3], slice(None), [1, 3, 4]],
  471. # less dim, ellipsis
  472. [[0, 2], ],
  473. [[0, 2], slice(None)],
  474. [[0, 2], Ellipsis],
  475. [[0, 2], slice(None), Ellipsis],
  476. [[0, 2], Ellipsis, slice(None)],
  477. [[0, 2], [1, 3]],
  478. [[0, 2], [1, 3], Ellipsis],
  479. [Ellipsis, [1, 3], [2, 3]],
  480. [Ellipsis, [2, 3, 4]],
  481. [Ellipsis, slice(None), [2, 3, 4]],
  482. [slice(None), Ellipsis, [2, 3, 4]],
  483. # ellipsis counts for nothing
  484. [Ellipsis, slice(None), slice(None), [0, 3, 4]],
  485. [slice(None), Ellipsis, slice(None), [0, 3, 4]],
  486. [slice(None), slice(None), Ellipsis, [0, 3, 4]],
  487. [slice(None), slice(None), [0, 3, 4], Ellipsis],
  488. [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
  489. [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
  490. [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
  491. ]
  492. for indexer in indices_to_test:
  493. assert_get_eq(reference, indexer)
  494. # TODO advanced setitem
  495. '''
  496. assert_set_eq(reference, indexer, 212)
  497. assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
  498. '''
  499. assert_backward_eq(reference, indexer)
  500. reference = Tensor.arange(0., 1296).reshape(3, 9, 8, 6)
  501. indices_to_test = [
  502. [slice(None), slice(None), slice(None), [0, 3, 4]],
  503. [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
  504. [slice(None), [2, 3], slice(None), slice(None)],
  505. [[1, 2], slice(None), slice(None), slice(None)],
  506. [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
  507. [slice(None), slice(None), [0], [1, 2, 4]],
  508. [slice(None), slice(None), [0, 1, 3], [4]],
  509. [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
  510. [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
  511. [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
  512. [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
  513. [slice(None), [0], [1, 2, 4], slice(None)],
  514. [slice(None), [0, 1, 3], [4], slice(None)],
  515. [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
  516. [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
  517. [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
  518. [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
  519. [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
  520. [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
  521. [[0], [1, 2, 4], slice(None), slice(None)],
  522. [[0, 1, 2], [4], slice(None), slice(None)],
  523. [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
  524. [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
  525. [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
  526. [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
  527. [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
  528. [slice(None), [2, 3, 4], [1, 3, 4], [4]],
  529. [slice(None), [0, 1, 3], [4], [1, 3, 4]],
  530. [slice(None), [6], [0, 2, 3], [1, 3, 4]],
  531. [slice(None), [2, 3, 5], [3], [4]],
  532. [slice(None), [0], [4], [1, 3, 4]],
  533. [slice(None), [6], [0, 2, 3], [1]],
  534. [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
  535. [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
  536. [[2, 0, 1], [1, 2, 3], [4], slice(None)],
  537. [[0, 1, 2], [4], [1, 3, 4], slice(None)],
  538. [[0], [0, 2, 3], [1, 3, 4], slice(None)],
  539. [[0, 2, 1], [3], [4], slice(None)],
  540. [[0], [4], [1, 3, 4], slice(None)],
  541. [[1], [0, 2, 3], [1], slice(None)],
  542. [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
  543. # less dim, ellipsis
  544. [Ellipsis, [0, 3, 4]],
  545. [Ellipsis, slice(None), [0, 3, 4]],
  546. [Ellipsis, slice(None), slice(None), [0, 3, 4]],
  547. [slice(None), Ellipsis, [0, 3, 4]],
  548. [slice(None), slice(None), Ellipsis, [0, 3, 4]],
  549. [slice(None), [0, 2, 3], [1, 3, 4]],
  550. [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
  551. [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
  552. [[0], [1, 2, 4]],
  553. [[0], [1, 2, 4], slice(None)],
  554. [[0], [1, 2, 4], Ellipsis],
  555. [[0], [1, 2, 4], Ellipsis, slice(None)],
  556. [[1], ],
  557. [[0, 2, 1], [3], [4]],
  558. [[0, 2, 1], [3], [4], slice(None)],
  559. [[0, 2, 1], [3], [4], Ellipsis],
  560. [Ellipsis, [0, 2, 1], [3], [4]],
  561. ]
  562. for indexer in indices_to_test:
  563. assert_get_eq(reference, indexer)
  564. # TODO advanced setitem
  565. '''
  566. assert_set_eq(reference, indexer, 1333)
  567. assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
  568. '''
  569. indices_to_test += [
  570. [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
  571. [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
  572. ]
  573. for indexer in indices_to_test:
  574. assert_get_eq(reference, indexer)
  575. # TODO advanced setitem
  576. '''
  577. assert_set_eq(reference, indexer, 1333)
  578. '''
  579. assert_backward_eq(reference, indexer)
  580. # TODO setitem backward
  581. '''
  582. def test_set_item_to_scalar_tensor(self):
  583. m = random.randint(1, 10)
  584. n = random.randint(1, 10)
  585. z = Tensor.randn([m, n])
  586. a = 1.0
  587. w = Tensor(a, requires_grad=True)
  588. z[:, 0] = w
  589. z.sum().backward()
  590. numpy_testing_assert_equal_helper(w.grad, m * a)
  591. '''
  592. def test_single_int(self):
  593. v = Tensor.randn(5, 7, 3)
  594. numpy_testing_assert_equal_helper(v[4].shape, (7, 3))
  595. def test_multiple_int(self):
  596. v = Tensor.randn(5, 7, 3)
  597. numpy_testing_assert_equal_helper(v[4].shape, (7, 3))
  598. numpy_testing_assert_equal_helper(v[4, :, 1].shape, (7,))
  599. def test_none(self):
  600. v = Tensor.randn(5, 7, 3)
  601. numpy_testing_assert_equal_helper(v[None].shape, (1, 5, 7, 3))
  602. numpy_testing_assert_equal_helper(v[:, None].shape, (5, 1, 7, 3))
  603. numpy_testing_assert_equal_helper(v[:, None, None].shape, (5, 1, 1, 7, 3))
  604. numpy_testing_assert_equal_helper(v[..., None].shape, (5, 7, 3, 1))
  605. def test_step(self):
  606. v = Tensor.arange(10)
  607. numpy_testing_assert_equal_helper(v[::1], v)
  608. numpy_testing_assert_equal_helper(v[::2], [0, 2, 4, 6, 8])
  609. numpy_testing_assert_equal_helper(v[::3], [0, 3, 6, 9])
  610. numpy_testing_assert_equal_helper(v[::11], [0])
  611. numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5])
  612. def test_step_assignment(self):
  613. v = Tensor.zeros(4, 4).contiguous()
  614. v[0, 1::2] = Tensor([3., 4.])
  615. numpy_testing_assert_equal_helper(v[0].numpy().tolist(), [0, 3, 0, 4])
  616. numpy_testing_assert_equal_helper(v[1:].sum(), 0)
  617. @unittest.skip("bool indexing not supported")
  618. def test_bool_indices(self):
  619. v = Tensor.randn(5, 7, 3)
  620. boolIndices = Tensor([True, False, True, True, False], dtype=dtypes.bool)
  621. numpy_testing_assert_equal_helper(v[boolIndices].shape, (3, 7, 3))
  622. numpy_testing_assert_equal_helper(v[boolIndices], Tensor.stack([v[0], v[2], v[3]]))
  623. v = Tensor([True, False, True], dtype=dtypes.bool)
  624. boolIndices = Tensor([True, False, False], dtype=dtypes.bool)
  625. uint8Indices = Tensor([1, 0, 0], dtype=dtypes.uint8)
  626. with warnings.catch_warnings(record=True) as w:
  627. numpy_testing_assert_equal_helper(v[boolIndices].shape, v[uint8Indices].shape)
  628. numpy_testing_assert_equal_helper(v[boolIndices], v[uint8Indices])
  629. numpy_testing_assert_equal_helper(v[boolIndices], Tensor([True]))
  630. numpy_testing_assert_equal_helper(len(w), 2)
  631. @unittest.skip("bool indexing not supported")
  632. def test_bool_indices_accumulate(self):
  633. mask = Tensor.zeros(size=(10, ), dtype=dtypes.bool)
  634. y = Tensor.ones(size=(10, 10))
  635. index_put_(y, (mask, ), y[mask], accumulate=True)
  636. numpy_testing_assert_equal_helper(y, Tensor.ones(size=(10, 10)))
  637. @unittest.skip("bool indexing not supported")
  638. def test_multiple_bool_indices(self):
  639. v = Tensor.randn(5, 7, 3)
  640. # note: these broadcast together and are transposed to the first dim
  641. mask1 = Tensor([1, 0, 1, 1, 0], dtype=dtypes.bool)
  642. mask2 = Tensor([1, 1, 1], dtype=dtypes.bool)
  643. numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7))
  644. @unittest.skip("bool indexing not supported")
  645. def test_byte_mask(self):
  646. v = Tensor.randn(5, 7, 3)
  647. mask = Tensor([1, 0, 1, 1, 0], dtype=dtypes.uint8)
  648. with warnings.catch_warnings(record=True) as w:
  649. numpy_testing_assert_equal_helper(v[mask].shape, (3, 7, 3))
  650. numpy_testing_assert_equal_helper(v[mask], Tensor.stack([v[0], v[2], v[3]]))
  651. numpy_testing_assert_equal_helper(len(w), 2)
  652. v = Tensor([1.])
  653. numpy_testing_assert_equal_helper(v[v == 0], Tensor([]))
  654. @unittest.skip("bool indexing not supported")
  655. def test_byte_mask_accumulate(self):
  656. mask = Tensor.zeros(size=(10, ), dtype=dtypes.uint8)
  657. y = Tensor.ones(size=(10, 10))
  658. with warnings.catch_warnings(record=True) as w:
  659. warnings.simplefilter("always")
  660. index_put_(y, (mask, ), y[mask], accumulate=True)
  661. numpy_testing_assert_equal_helper(y, Tensor.ones(size=(10, 10)))
  662. numpy_testing_assert_equal_helper(len(w), 2)
  663. # TODO setitem
  664. # NOTE: tinygrad doesn't support idx.max that big
  665. '''
  666. def test_index_put_accumulate_large_tensor(self):
  667. # This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
  668. N = (1 << 31) + 5
  669. dt = dtypes.int8
  670. a = Tensor.ones(N, dtype=dt).contiguous()
  671. indices = Tensor([-2, 0, -2, -1, 0, -1, 1], dtype=dtypes.int64)
  672. values = Tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt)
  673. index_put_(a, (indices, ), values, accumulate=True)
  674. numpy_testing_assert_equal_helper(a[0], 11)
  675. numpy_testing_assert_equal_helper(a[1], 12)
  676. numpy_testing_assert_equal_helper(a[2], 1)
  677. numpy_testing_assert_equal_helper(a[-3], 1)
  678. numpy_testing_assert_equal_helper(a[-2], 13)
  679. numpy_testing_assert_equal_helper(a[-1], 14)
  680. a = Tensor.ones((2, N), dtype=dt).contiguous()
  681. indices0 = np.array([0, -1, 0, 1], dtype=dtypes.int64)
  682. indices1 = np.array([-2, -1, 0, 1], dtype=dtypes.int64)
  683. values = np.array([12, 13, 10, 11], dtype=dt)
  684. index_put_(a, (indices0, indices1), values, accumulate=True)
  685. numpy_testing_assert_equal_helper(a[0, 0], 11)
  686. numpy_testing_assert_equal_helper(a[0, 1], 1)
  687. numpy_testing_assert_equal_helper(a[1, 0], 1)
  688. numpy_testing_assert_equal_helper(a[1, 1], 12)
  689. numpy_testing_assert_equal_helper(a[:, 2], Tensor.ones(2, dtype=dtypes.int8))
  690. numpy_testing_assert_equal_helper(a[:, -3], Tensor.ones(2, dtype=dtypes.int8))
  691. numpy_testing_assert_equal_helper(a[0, -2], 13)
  692. numpy_testing_assert_equal_helper(a[1, -2], 1)
  693. numpy_testing_assert_equal_helper(a[-1, -1], 14)
  694. numpy_testing_assert_equal_helper(a[0, -1], 1)
  695. '''
  696. # TODO fancy setitem
  697. '''
  698. def test_index_put_accumulate_duplicate_indices(self):
  699. for i in range(1, 512):
  700. # generate indices by random walk, this will create indices with
  701. # lots of duplicates interleaved with each other
  702. delta = Tensor.uniform(low=-1, high=1, dtype=dtypes.double)
  703. indices = delta.cumsum(0).cast(dtypes.int64)
  704. # input = torch.randn(indices.abs().max() + 1)
  705. input = Tensor.randn(indices.abs().max().item() + 1)
  706. # values = torch.randn(indices.size(0))
  707. values = Tensor.randn(indices.shape(0))
  708. output = index_put_(input, (indices,), values, accumulate=True)
  709. input_list = input.numpy().tolist()
  710. indices_list = indices.numpy().tolist()
  711. values_list = values.numpy().tolist()
  712. for i, v in zip(indices_list, values_list):
  713. input_list[i] += v
  714. numpy_testing_assert_equal_helper(output, input_list)
  715. '''
  716. def test_index_ind_dtype(self):
  717. x = Tensor.randn(4, 4)
  718. # ind_long = torch.randint(4, (4,), dtype=torch.long)
  719. # TODO should we spend an extra line to allow for randint other dtypes?
  720. # copied from randint
  721. ind_long = (Tensor.rand((4,),)*(4-0)+0).cast(dtypes.int64)
  722. # ind_int = ind_long.int()
  723. ind_int = (ind_long).cast(dtypes.int32)
  724. ref = x[ind_long, ind_long]
  725. res = x[ind_int, ind_int]
  726. numpy_testing_assert_equal_helper(ref, res)
  727. ref = x[ind_long, :]
  728. res = x[ind_int, :]
  729. numpy_testing_assert_equal_helper(ref, res)
  730. ref = x[:, ind_long]
  731. res = x[:, ind_int]
  732. numpy_testing_assert_equal_helper(ref, res)
  733. # no repeating indices for index_put
  734. # TODO fancy setitem
  735. '''
  736. src = Tensor.randn(4)
  737. ind_long = Tensor.arange(4, dtype=dtypes.int64)
  738. ind_int = ind_long.cast(dtypes.int32)
  739. for accum in (True, False):
  740. inp_ref = clone(x)
  741. inp_res = clone(x)
  742. index_put_(inp_ref, (ind_long, ind_long), src, accum)
  743. index_put_(inp_res, (ind_int, ind_int), src, accum)
  744. numpy_testing_assert_equal_helper(inp_ref, inp_res)
  745. '''
  746. # TODO empty setitem
  747. '''
  748. def test_index_put_accumulate_empty(self):
  749. # Regression test for https://github.com/pytorch/pytorch/issues/94667
  750. input = Tensor.rand([], dtype=dtypes.float32)
  751. with self.assertRaises(RuntimeError):
  752. index_put_(input, [], np.array([1.0]), True)
  753. '''
  754. @unittest.skip("bool indexing not supported")
  755. def test_multiple_byte_mask(self):
  756. v = Tensor.randn(5, 7, 3)
  757. # note: these broadcast together and are transposed to the first dim
  758. mask1 = Tensor([1, 0, 1, 1, 0], dtype=dtypes.uint8)
  759. mask2 = Tensor([1, 1, 1], dtype=dtypes.uint8)
  760. with warnings.catch_warnings(record=True) as w:
  761. warnings.simplefilter("always")
  762. numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7))
  763. numpy_testing_assert_equal_helper(len(w), 2)
  764. @unittest.skip("bool indexing not supported")
  765. def test_byte_mask2d(self):
  766. v = Tensor.randn(5, 7, 3)
  767. c = Tensor.randn(5, 7)
  768. num_ones = (c > 0).sum()
  769. r = v[c > 0]
  770. numpy_testing_assert_equal_helper(r.shape, (num_ones, 3))
  771. @unittest.skip("bool indexing not supported")
  772. def test_jit_indexing(self):
  773. def fn1(x):
  774. x[x < 50] = 1.0
  775. return x
  776. def fn2(x):
  777. x[0:50] = 1.0
  778. return x
  779. scripted_fn1 = TinyJit(fn1)
  780. scripted_fn2 = TinyJit(fn2)
  781. data = Tensor.arange(100, dtype=dtypes.float)
  782. out = scripted_fn1(clone(data))
  783. ref = Tensor(np.concatenate((np.ones(50), np.arange(50, 100))), dtype=dtypes.float)
  784. numpy_testing_assert_equal_helper(out, ref)
  785. out = scripted_fn2(clone(data))
  786. numpy_testing_assert_equal_helper(out, ref)
  787. def test_int_indices(self):
  788. v = Tensor.randn(5, 7, 3)
  789. numpy_testing_assert_equal_helper(v[[0, 4, 2]].shape, (3, 7, 3))
  790. numpy_testing_assert_equal_helper(v[:, [0, 4, 2]].shape, (5, 3, 3))
  791. numpy_testing_assert_equal_helper(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
  792. # TODO fancy setitem
  793. '''
  794. def test_index_put_src_datatype(self, dtype):
  795. src = Tensor.ones(3, 2, 4, dtype=dtype)
  796. vals = Tensor.ones(3, 2, 4, dtype=dtype)
  797. indices = (np.array([0, 2, 1]),)
  798. res = index_put_(src, indices, vals, accumulate=True)
  799. numpy_testing_assert_equal_helper(res.shape, src.shape)
  800. '''
  801. def test_index_src_datatype(self):
  802. src = Tensor.ones(3, 2, 4)
  803. # test index
  804. res = src[[0, 2, 1], :, :]
  805. numpy_testing_assert_equal_helper(res.shape, src.shape)
  806. # test index_put, no accum
  807. # TODO fancy setitem
  808. '''
  809. src[[0, 2, 1], :, :] = res
  810. numpy_testing_assert_equal_helper(res.shape, src.shape)
  811. '''
  812. def test_int_indices2d(self):
  813. # From the NumPy indexing example
  814. x = Tensor.arange(0, 12).reshape(4, 3)
  815. rows = Tensor([[0, 0], [3, 3]])
  816. columns = Tensor([[0, 2], [0, 2]])
  817. numpy_testing_assert_equal_helper(x[rows, columns].numpy().tolist(), [[0, 2], [9, 11]])
  818. def test_int_indices_broadcast(self):
  819. # From the NumPy indexing example
  820. x = Tensor.arange(0, 12).reshape(4, 3)
  821. rows = Tensor([0, 3])
  822. columns = Tensor([0, 2])
  823. result = x[rows[:, None], columns]
  824. numpy_testing_assert_equal_helper(result.numpy().tolist(), [[0, 2], [9, 11]])
  825. # TODO jax supports empty tensor indexing
  826. @unittest.skip("empty tensor indexing not supported")
  827. def test_empty_index(self):
  828. x = Tensor.arange(0, 12).reshape(4, 3)
  829. idx = Tensor([], dtype=dtypes.int64)
  830. numpy_testing_assert_equal_helper(x[idx].numel(), 0)
  831. # TODO empty setitem
  832. '''
  833. # empty assignment should have no effect but not throw an exception
  834. y = clone(x)
  835. y[idx] = -1
  836. numpy_testing_assert_equal_helper(x, y)
  837. mask = Tensor.zeros(4, 3).cast(dtypes.bool)
  838. y[mask] = -1
  839. numpy_testing_assert_equal_helper(x, y)
  840. '''
  841. # TODO jax supports empty tensor indexing
  842. @unittest.skip("empty tensor indexing not supported")
  843. def test_empty_ndim_index(self):
  844. x = Tensor.randn(5)
  845. numpy_testing_assert_equal_helper(Tensor.empty(0, 2), x[Tensor.empty(0, 2, dtype=dtypes.int64)])
  846. x = Tensor.randn(2, 3, 4, 5)
  847. numpy_testing_assert_equal_helper(Tensor.empty(2, 0, 6, 4, 5),
  848. x[:, Tensor.empty(0, 6, dtype=dtypes.int64)])
  849. x = Tensor.empty(10, 0)
  850. numpy_testing_assert_equal_helper(x[[1, 2]].shape, (2, 0))
  851. numpy_testing_assert_equal_helper(x[[], []].shape, (0,))
  852. with self.assertRaises(IndexError):
  853. x[:, [0, 1]]
  854. def test_empty_slice(self):
  855. x = Tensor.randn(2, 3, 4, 5)
  856. y = x[:, :, :, 1]
  857. z = y[:, 1:1, :]
  858. numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
  859. # this isn't technically necessary, but matches NumPy stride calculations.
  860. # NOTE: this is empty and shouldn't have strides
  861. #numpy_testing_assert_equal_helper((60, 20, 5), z.lazydata.st.real_strides())
  862. # NOTE tinygrad's int slicing implementation makes this not contiguous
  863. # self.assertTrue(z.lazydata.st.contiguous)
  864. @unittest.skip("bool indexing not supported")
  865. def test_index_getitem_copy_bools_slices(self):
  866. true = Tensor(1, dtype=dtypes.uint8)
  867. false = Tensor(0, dtype=dtypes.uint8)
  868. tensors = [Tensor.randn(2, 3), Tensor(3.)]
  869. for a in tensors:
  870. self.assertNotEqual(data_ptr(a), data_ptr(a[True]))
  871. numpy_testing_assert_equal_helper(Tensor.empty(0, *a.shape), a[False])
  872. self.assertNotEqual(data_ptr(a), data_ptr(a[true]))
  873. numpy_testing_assert_equal_helper(Tensor.empty(0, *a.shape), a[false])
  874. self.assertEqual(data_ptr(a), data_ptr(a[None]))
  875. self.assertEqual(data_ptr(a), data_ptr(a[...]))
  876. @unittest.skip("bool indexing not supported")
  877. def test_index_setitem_bools_slices(self):
  878. true = Tensor(1, dtype=dtypes.uint8)
  879. false = Tensor(0, dtype=dtypes.uint8)
  880. tensors = [Tensor.randn(2, 3), Tensor(3)]
  881. for a in tensors:
  882. # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
  883. # (some of these ops already prefix a 1 to the size)
  884. neg_ones = Tensor.ones_like(a) * -1
  885. neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
  886. a[True] = neg_ones_expanded
  887. numpy_testing_assert_equal_helper(a, neg_ones)
  888. a[False] = 5
  889. numpy_testing_assert_equal_helper(a, neg_ones)
  890. a[true] = neg_ones_expanded * 2
  891. numpy_testing_assert_equal_helper(a, neg_ones * 2)
  892. a[false] = 5
  893. numpy_testing_assert_equal_helper(a, neg_ones * 2)
  894. a[None] = neg_ones_expanded * 3
  895. numpy_testing_assert_equal_helper(a, neg_ones * 3)
  896. a[...] = neg_ones_expanded * 4
  897. numpy_testing_assert_equal_helper(a, neg_ones * 4)
  898. if a.dim() == 0:
  899. with self.assertRaises(IndexError):
  900. a[:] = neg_ones_expanded * 5
  901. @unittest.skip("bool indexing not supported")
  902. def test_index_scalar_with_bool_mask(self):
  903. a = Tensor(1)
  904. uintMask = Tensor(True, dtype=dtypes.uint8)
  905. boolMask = Tensor(True, dtype=dtypes.bool)
  906. numpy_testing_assert_equal_helper(a[uintMask], a[boolMask])
  907. numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype)
  908. a = Tensor(True, dtype=dtypes.bool)
  909. numpy_testing_assert_equal_helper(a[uintMask], a[boolMask])
  910. numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype)
  911. @unittest.skip("bool indexing not supported")
  912. def test_setitem_expansion_error(self):
  913. true = Tensor(True)
  914. a = Tensor.randn(2, 3)
  915. # check prefix with non-1s doesn't work
  916. # a_expanded = a.expand(torch.Size([5, 1]) + a.size())
  917. a_expanded = a.expand((5, 1) + a.shape)
  918. # NumPy: ValueError
  919. with self.assertRaises(RuntimeError):
  920. a[True] = a_expanded
  921. with self.assertRaises(RuntimeError):
  922. a[true] = a_expanded
  923. def test_getitem_scalars_simple(self):
  924. src = Tensor([[[1.,2.],[3.,4.]], [[1,1],[1,1]]])
  925. a = src[0].mul(src[1])
  926. self.assertEqual(a[0,1].item(), 2)
  927. def test_getitem_scalars(self):
  928. zero = Tensor(0, dtype=dtypes.int64)
  929. one = Tensor(1, dtype=dtypes.int64)
  930. # non-scalar indexed with scalars
  931. a = Tensor.randn(2, 3)
  932. numpy_testing_assert_equal_helper(a[0], a[zero])
  933. numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
  934. numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
  935. numpy_testing_assert_equal_helper(a[0, one], a[zero, 1])
  936. # indexing by a scalar should slice (not copy)
  937. self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
  938. self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
  939. self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
  940. # scalar indexed with scalar
  941. r = Tensor.randn()
  942. with self.assertRaises(IndexError):
  943. r[:]
  944. with self.assertRaises(IndexError):
  945. r[zero]
  946. numpy_testing_assert_equal_helper(r, r[...])
  947. # TODO fancy setitem
  948. '''
  949. def test_setitem_scalars(self):
  950. zero = Tensor(0, dtype=dtypes.int64)
  951. # non-scalar indexed with scalars
  952. a = Tensor.randn(2, 3).contiguous()
  953. a_set_with_number = clone(a).contiguous()
  954. a_set_with_scalar = clone(a).contiguous()
  955. b = Tensor.randn(3)
  956. a_set_with_number[0] = b
  957. a_set_with_scalar[zero] = b
  958. numpy_testing_assert_equal_helper(a_set_with_number, a_set_with_scalar)
  959. a[1, zero] = 7.7
  960. # TODO: weird inaccuracy Max relative difference: 2.47707621e-08
  961. # numpy_testing_assert_equal_helper(7.7, a[1, 0])
  962. np.testing.assert_allclose(7.7, a[1, 0].numpy(), rtol=1e-7)
  963. # scalar indexed with scalars
  964. r = Tensor.randn().contiguous()
  965. with self.assertRaises(IndexError):
  966. r[:] = 8.8
  967. with self.assertRaises(IndexError):
  968. r[zero] = 8.8
  969. r[...] = 9.9
  970. # TODO: weird inaccuracy Max relative difference: 3.85322971e-08
  971. # numpy_testing_assert_equal_helper(9.9, r)
  972. np.testing.assert_allclose(9.9, r, rtol=1e-7)
  973. '''
  974. def test_basic_advanced_combined(self):
  975. # From the NumPy indexing example
  976. x = Tensor.arange(0, 12).reshape(4, 3)
  977. numpy_testing_assert_equal_helper(x[1:2, 1:3], x[1:2, [1, 2]])
  978. numpy_testing_assert_equal_helper(x[1:2, 1:3].numpy().tolist(), [[4, 5]])
  979. # Check that it is a copy
  980. unmodified = clone(x)
  981. x[1:2, [1, 2]].zeros_like()
  982. numpy_testing_assert_equal_helper(x, unmodified)
  983. # But assignment should modify the original
  984. # TODO fancy setitem
  985. '''
  986. unmodified = clone(x)
  987. x[1:2, [1, 2]] = 0
  988. self.assertNotEqual(x, unmodified)
  989. '''
  990. def test_int_assignment(self):
  991. x = Tensor.arange(0, 4).reshape(2, 2)
  992. x[1] = 5
  993. numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 5]])
  994. x = Tensor.arange(0, 4).reshape(2, 2)
  995. x[1] = Tensor.arange(5, 7)
  996. numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 6]])
  997. # TODO fancy setitem
  998. '''
  999. def test_byte_tensor_assignment(self):
  1000. x = Tensor.arange(0., 16).reshape(4, 4)
  1001. b = Tensor([True, False, True, False], dtype=dtypes.uint8)
  1002. value = Tensor([3., 4., 5., 6.])
  1003. with warnings.catch_warnings(record=True) as w:
  1004. x[b] = value
  1005. numpy_testing_assert_equal_helper(len(w), 1)
  1006. numpy_testing_assert_equal_helper(x[0], value)
  1007. numpy_testing_assert_equal_helper(x[1], Tensor.arange(4., 8))
  1008. numpy_testing_assert_equal_helper(x[2], value)
  1009. numpy_testing_assert_equal_helper(x[3], Tensor.arange(12., 16))
  1010. '''
  1011. @unittest.skip("Tensor unpacking not supported")
  1012. def test_variable_slicing(self):
  1013. x = Tensor.arange(0, 16).reshape(4, 4)
  1014. indices = Tensor([0, 1], dtype=dtypes.int32)
  1015. i, j = indices
  1016. numpy_testing_assert_equal_helper(x[i:j], x[0:1])
  1017. def test_ellipsis_tensor(self):
  1018. x = Tensor.arange(0, 9).reshape(3, 3)
  1019. idx = Tensor([0, 2])
  1020. numpy_testing_assert_equal_helper(x[..., idx].numpy().tolist(), [[0, 2],
  1021. [3, 5],
  1022. [6, 8]])
  1023. numpy_testing_assert_equal_helper(x[idx, ...].numpy().tolist(), [[0, 1, 2],
  1024. [6, 7, 8]])
  1025. # TODO unravel_index
  1026. '''
  1027. def test_unravel_index_errors(self):
  1028. with self.assertRaises(TypeError):
  1029. unravel_index(
  1030. Tensor(0.5),
  1031. (2, 2))
  1032. with self.assertRaises(TypeError):
  1033. unravel_index(
  1034. Tensor([]),
  1035. (10, 3, 5))
  1036. with self.assertRaises(TypeError):
  1037. unravel_index(
  1038. Tensor([1], dtype=dtypes.int64),
  1039. Tensor([1, 2, 3]))
  1040. with self.assertRaises(TypeError):
  1041. unravel_index(
  1042. Tensor([1], dtype=dtypes.int64),
  1043. (1, 2, 2.0))
  1044. with self.assertRaises(ValueError):
  1045. unravel_index(
  1046. Tensor(0),
  1047. (2, -3))
  1048. '''
  1049. def test_invalid_index(self):
  1050. x = Tensor.arange(0, 16).reshape(4, 4)
  1051. self.assertRaises(TypeError, lambda: x["0":"1"])
  1052. def test_out_of_bound_index(self):
  1053. x = Tensor.arange(0, 100).reshape(2, 5, 10)
  1054. self.assertRaises(IndexError, lambda: x[0, 5])
  1055. self.assertRaises(IndexError, lambda: x[4, 5])
  1056. self.assertRaises(IndexError, lambda: x[0, 1, 15])
  1057. self.assertRaises(IndexError, lambda: x[:, :, 12])
  1058. def test_zero_dim_index(self):
  1059. x = Tensor(10)
  1060. numpy_testing_assert_equal_helper(x, x.item())
  1061. def runner():
  1062. print(x[0])
  1063. return x[0]
  1064. self.assertRaises(IndexError, runner)
  1065. # TODO fancy setitem
  1066. '''
  1067. def test_cpu_indices(self):
  1068. idx = Tensor([0, 1])
  1069. b = Tensor.zeros(2)
  1070. x = Tensor.ones(10).contiguous()
  1071. x[idx] = b # index_put_
  1072. ref = Tensor.ones(10).contiguous()
  1073. ref[:2] = 0
  1074. numpy_testing_assert_equal_helper(x, ref)
  1075. out = x[idx] # index
  1076. numpy_testing_assert_equal_helper(out, Tensor.zeros(2))
  1077. '''
  1078. def test_take_along_dim(self):
  1079. def _test_against_numpy(t: Tensor, indices: Tensor, dim):
  1080. actual = t.gather(dim, indices)
  1081. t_np = t.numpy()
  1082. indices_np = indices.numpy()
  1083. expected = np.take_along_axis(t_np, indices_np, axis=dim)
  1084. numpy_testing_assert_equal_helper(actual, expected)
  1085. # TODO argsort
  1086. '''
  1087. for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]:
  1088. for noncontiguous in [True, False]:
  1089. for dtype in (dtypes.float32, dtypes.int64):
  1090. t = make_tensor(shape, dtype=dtype, noncontiguous=noncontiguous)
  1091. for dim in list(range(t.ndim)) + [None]:
  1092. if dim is None:
  1093. indices = argsort(t.reshape(-1))
  1094. else:
  1095. indices = argsort(t, dim=dim)
  1096. _test_against_numpy(t, indices, dim)
  1097. '''
  1098. # test broadcasting
  1099. t = Tensor.ones((3, 4, 1))
  1100. indices = Tensor.ones((1, 2, 5), dtype=dtypes.int64)
  1101. _test_against_numpy(t, indices, 1)
  1102. # test empty indices
  1103. t = Tensor.ones((3, 4, 5))
  1104. indices = Tensor.ones((3, 0, 5), dtype=dtypes.int64)
  1105. _test_against_numpy(t, indices, 1)
  1106. # TODO argsort
  1107. '''
  1108. def test_take_along_dim_invalid(self):
  1109. for dtype in (dtypes.int64, dtypes.float32):
  1110. shape = (2, 3, 1, 4)
  1111. dim = 0
  1112. t = make_tensor(shape, dtype=dtype)
  1113. indices = argsort(t, dim=dim)
  1114. # dim of `t` and `indices` does not match
  1115. with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"):
  1116. t.gather(0, indices[0])
  1117. # invalid `indices` dtype
  1118. with self.assertRaises(RuntimeError):
  1119. t.gather(0, indices.cast(dtypes.bool))
  1120. with self.assertRaises(RuntimeError):
  1121. t.gather(0, indices.cast(dtypes.float32))
  1122. with self.assertRaises(RuntimeError):
  1123. t.gather(0, indices.cast(dtypes.int32))
  1124. # invalid axis
  1125. with self.assertRaises(IndexError):
  1126. t.gather(-7, indices)
  1127. with self.assertRaises(IndexError):
  1128. t.gather(7, indices)
  1129. '''
  1130. class TestNumpy(unittest.TestCase):
  1131. def test_index_no_floats(self):
  1132. a = Tensor([[[5.]]])
  1133. self.assertRaises(IndexError, lambda: a[0.0])
  1134. self.assertRaises(IndexError, lambda: a[0, 0.0])
  1135. self.assertRaises(IndexError, lambda: a[0.0, 0])
  1136. self.assertRaises(IndexError, lambda: a[0.0, :])
  1137. self.assertRaises(IndexError, lambda: a[:, 0.0])
  1138. self.assertRaises(IndexError, lambda: a[:, 0.0, :])
  1139. self.assertRaises(IndexError, lambda: a[0.0, :, :])
  1140. self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
  1141. self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
  1142. self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
  1143. self.assertRaises(IndexError, lambda: a[-1.4])
  1144. self.assertRaises(IndexError, lambda: a[0, -1.4])
  1145. self.assertRaises(IndexError, lambda: a[-1.4, 0])
  1146. self.assertRaises(IndexError, lambda: a[-1.4, :])
  1147. self.assertRaises(IndexError, lambda: a[:, -1.4])
  1148. self.assertRaises(IndexError, lambda: a[:, -1.4, :])
  1149. self.assertRaises(IndexError, lambda: a[-1.4, :, :])
  1150. self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
  1151. self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
  1152. self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
  1153. self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
  1154. self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
  1155. def test_none_index(self):
  1156. # `None` index adds newaxis
  1157. a = Tensor([1, 2, 3])
  1158. numpy_testing_assert_equal_helper(a[None].ndim, a.ndim+1)
  1159. def test_empty_tuple_index(self):
  1160. # Empty tuple index creates a view
  1161. a = Tensor([1, 2, 3])
  1162. numpy_testing_assert_equal_helper(a[()], a)
  1163. self.assertEqual(data_ptr(a[()]), data_ptr(a))
  1164. # TODO jax supports empty tensor indexing
  1165. @unittest.skip("empty tensor indexing not supported")
  1166. def test_empty_fancy_index(self):
  1167. # Empty list index creates an empty array
  1168. a = Tensor([1, 2, 3])
  1169. numpy_testing_assert_equal_helper(a[[]], np.array([]))
  1170. b = Tensor([]).cast(dtypes.int64)
  1171. numpy_testing_assert_equal_helper(a[[]], np.array([]))
  1172. b = Tensor([]).float()
  1173. self.assertRaises(IndexError, lambda: a[b])
  1174. def test_ellipsis_index(self):
  1175. a = Tensor([[1, 2, 3],
  1176. [4, 5, 6],
  1177. [7, 8, 9]])
  1178. self.assertIsNot(a[...], a)
  1179. numpy_testing_assert_equal_helper(a[...], a)
  1180. # `a[...]` was `a` in numpy <1.9.
  1181. numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
  1182. # Slicing with ellipsis can skip an
  1183. # arbitrary number of dimensions
  1184. numpy_testing_assert_equal_helper(a[0, ...], a[0])
  1185. numpy_testing_assert_equal_helper(a[0, ...], a[0, :])
  1186. numpy_testing_assert_equal_helper(a[..., 0], a[:, 0])
  1187. # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
  1188. # we don't have separate 0-dim arrays and scalars.
  1189. numpy_testing_assert_equal_helper(a[0, ..., 1], np.array(2))
  1190. # Assignment with `(Ellipsis,)` on 0-d arrays
  1191. b = np.array(1)
  1192. b[(Ellipsis,)] = 2
  1193. numpy_testing_assert_equal_helper(b, 2)
  1194. def test_single_int_index(self):
  1195. # Single integer index selects one row
  1196. a = Tensor([[1, 2, 3],
  1197. [4, 5, 6],
  1198. [7, 8, 9]])
  1199. numpy_testing_assert_equal_helper(a[0], [1, 2, 3])
  1200. numpy_testing_assert_equal_helper(a[-1], [7, 8, 9])
  1201. self.assertRaises(IndexError, a.__getitem__, 1 << 30)
  1202. self.assertRaises(IndexError, a.__getitem__, 1 << 64)
  1203. @unittest.skip("bool indexing not supported")
  1204. def test_single_bool_index(self):
  1205. # Single boolean index
  1206. a = Tensor([[1, 2, 3],
  1207. [4, 5, 6],
  1208. [7, 8, 9]])
  1209. numpy_testing_assert_equal_helper(a[True], a[None])
  1210. numpy_testing_assert_equal_helper(a[False], a[None][0:0])
  1211. @unittest.skip("bool indexing not supported")
  1212. def test_boolean_shape_mismatch(self):
  1213. arr = Tensor.ones((5, 4, 3))
  1214. index = Tensor([True])
  1215. self.assertRaises(IndexError, lambda: arr[index])
  1216. index = Tensor([False] * 6)
  1217. self.assertRaises(IndexError, lambda: arr[index])
  1218. index = Tensor.zeros(4, 4, dtype=dtypes.uint8)
  1219. self.assertRaises(IndexError, lambda: arr[index])
  1220. self.assertRaises(IndexError, lambda: arr[(slice(None), index)])
  1221. @unittest.skip("bool indexing not supported")
  1222. def test_boolean_indexing_onedim(self):
  1223. # Indexing a 2-dimensional array with
  1224. # boolean array of length one
  1225. a = Tensor([[0., 0., 0.]])
  1226. b = Tensor([True])
  1227. numpy_testing_assert_equal_helper(a[b], a)
  1228. # boolean assignment
  1229. a[b] = 1.
  1230. numpy_testing_assert_equal_helper(a, Tensor([[1., 1., 1.]]))
  1231. @unittest.skip("bool indexing not supported")
  1232. def test_boolean_assignment_value_mismatch(self):
  1233. # A boolean assignment should fail when the shape of the values
  1234. # cannot be broadcast to the subscription. (see also gh-3458)
  1235. a = Tensor.arange(0, 4)
  1236. def f(a, v):
  1237. a[a > -1] = Tensor(v)
  1238. self.assertRaises(Exception, f, a, [])
  1239. self.assertRaises(Exception, f, a, [1, 2, 3])
  1240. self.assertRaises(Exception, f, a[:1], [1, 2, 3])
  1241. @unittest.skip("bool indexing not supported")
  1242. def test_boolean_indexing_twodim(self):
  1243. # Indexing a 2-dimensional array with
  1244. # 2-dimensional boolean array
  1245. a = Tensor([[1, 2, 3],
  1246. [4, 5, 6],
  1247. [7, 8, 9]])
  1248. b = Tensor([[True, False, True],
  1249. [False, True, False],
  1250. [True, False, True]])
  1251. numpy_testing_assert_equal_helper(a[b], Tensor([1, 3, 5, 7, 9]))
  1252. numpy_testing_assert_equal_helper(a[b[1]], Tensor([[4, 5, 6]]))
  1253. numpy_testing_assert_equal_helper(a[b[0]], a[b[2]])
  1254. # boolean assignment
  1255. a[b] = 0
  1256. numpy_testing_assert_equal_helper(a, Tensor([[0, 2, 0],
  1257. [4, 0, 6],
  1258. [0, 8, 0]]))
  1259. @unittest.skip("bool indexing not supported")
  1260. def test_boolean_indexing_weirdness(self):
  1261. # Weird boolean indexing things
  1262. a = Tensor.ones((2, 3, 4))
  1263. numpy_testing_assert_equal_helper((0, 2, 3, 4), a[False, True, ...].shape)
  1264. numpy_testing_assert_equal_helper(Tensor.ones(1, 2), a[True, [0, 1], True, True, [1], [[2]]])
  1265. self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
  1266. @unittest.skip("bool indexing not supported")
  1267. def test_boolean_indexing_weirdness_tensors(self):
  1268. # Weird boolean indexing things
  1269. false = Tensor(False)
  1270. true = Tensor(True)
  1271. a = Tensor.ones((2, 3, 4))
  1272. numpy_testing_assert_equal_helper((0, 2, 3, 4), a[False, True, ...].shape)
  1273. numpy_testing_assert_equal_helper(Tensor.ones(1, 2), a[true, [0, 1], true, true, [1], [[2]]])
  1274. self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
  1275. @unittest.skip("bool indexing not supported")
  1276. def test_boolean_indexing_alldims(self):
  1277. true = Tensor(True)
  1278. a = Tensor.ones((2, 3))
  1279. numpy_testing_assert_equal_helper((1, 2, 3), a[True, True].shape)
  1280. numpy_testing_assert_equal_helper((1, 2, 3), a[true, true].shape)
  1281. @unittest.skip("bool indexing not supported")
  1282. def test_boolean_list_indexing(self):
  1283. # Indexing a 2-dimensional array with
  1284. # boolean lists
  1285. a = Tensor([[1, 2, 3],
  1286. [4, 5, 6],
  1287. [7, 8, 9]])
  1288. b = [True, False, False]
  1289. c = [True, True, False]
  1290. numpy_testing_assert_equal_helper(a[b], Tensor([[1, 2, 3]]))
  1291. numpy_testing_assert_equal_helper(a[b, b], Tensor([1]))
  1292. numpy_testing_assert_equal_helper(a[c], Tensor([[1, 2, 3], [4, 5, 6]]))
  1293. numpy_testing_assert_equal_helper(a[c, c], Tensor([1, 5]))
  1294. def test_everything_returns_views(self):
  1295. # Before `...` would return a itself.
  1296. a = Tensor([5])
  1297. self.assertIsNot(a, a[()])
  1298. self.assertIsNot(a, a[...])
  1299. self.assertIsNot(a, a[:])
  1300. def test_broaderrors_indexing(self):
  1301. a = Tensor.zeros(5, 5)
  1302. self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
  1303. # TODO: fancy setitem
  1304. '''
  1305. self.assertRaises(IndexError, a.contiguous().__setitem__, ([0, 1], [0, 1, 2]), 0)
  1306. '''
  1307. # TODO out of bound getitem does not raise error
  1308. '''
  1309. def test_trivial_fancy_out_of_bounds(self):
  1310. a = Tensor.zeros(5)
  1311. ind = Tensor.ones(20, dtype=dtypes.int64)
  1312. ind[-1] = 10
  1313. self.assertRaises(IndexError, a.__getitem__, ind)
  1314. self.assertRaises(IndexError, a.__setitem__, ind, 0)
  1315. ind = Tensor.ones(20, dtype=dtypes.int64)
  1316. ind[0] = 11
  1317. self.assertRaises(IndexError, a.__getitem__, ind)
  1318. self.assertRaises(IndexError, a.__setitem__, ind, 0)
  1319. '''
  1320. # TODO fancy setitem
  1321. '''
  1322. def test_index_is_larger(self):
  1323. # Simple case of fancy index broadcasting of the index.
  1324. a = Tensor.zeros((5, 5))
  1325. a[[[0], [1], [2]], [0, 1, 2]] = Tensor([2., 3., 4.])
  1326. self.assertTrue((a[:3, :3] == all_(Tensor([2., 3., 4.]))))
  1327. '''
  1328. # TODO fancy setitem
  1329. '''
  1330. def test_broadcast_subspace(self):
  1331. a = Tensor.zeros((100, 100))
  1332. v = Tensor.arange(0., 100)[:, None]
  1333. b = Tensor.arange(99, -1, -1).cast(dtypes.int64)
  1334. a[b] = v
  1335. expected = b.float().unsqueeze(1).expand(100, 100)
  1336. numpy_testing_assert_equal_helper(a, expected)
  1337. '''
  1338. # TODO fancy setitem
  1339. '''
  1340. def test_truncate_leading_1s(self):
  1341. col_max = Tensor.randn(1, 4)
  1342. kernel = col_max.T * col_max # [4, 4] tensor
  1343. kernel2 = clone(kernel)
  1344. # Set the diagonal
  1345. # len(torch.tensor) is just tensor.shape[0]
  1346. kernel[range(kernel.shape[0]), range(kernel.shape[0])] = col_max.square()
  1347. kernel2 = diagonal(kernel2)
  1348. # torch.diagonal(kernel2).copy_(torch.square(col_max.view(4)))
  1349. kernel2 = copy_(kernel2, col_max.reshape(4).square())
  1350. numpy_testing_assert_equal_helper(kernel, kernel2)
  1351. '''
  1352. if __name__ == '__main__':
  1353. unittest.main()