| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579 |
- # test cases are modified from pytorch test_indexing.py https://github.com/pytorch/pytorch/blob/597d3fb86a2f3b8d6d8ee067e769624dcca31cdb/test/test_indexing.py
- import unittest, random, copy, warnings
- import numpy as np
- from tinygrad import Tensor, dtypes, Device, TinyJit
- from tinygrad.shape.shapetracker import ShapeTracker
- from tinygrad.shape.view import View
- from tinygrad.helpers import CI, all_same, prod
- random.seed(42)
- def numpy_testing_assert_equal_helper(a, b):
- if isinstance(a, Tensor): a = a.numpy()
- if isinstance(b, Tensor): b = b.numpy()
- np.testing.assert_equal(a, b)
- def consec(shape, start=1):
- return Tensor.arange(prod(shape)).reshape(shape)+start
- # creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
- def set_(reference: Tensor, shape, strides, offset):
- if reference.lazydata.base.realized is None: reference.realize()
- assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
- strided = Tensor(reference.lazydata._view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
- assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
- return strided
- def clone(original:Tensor): return copy.copy(original)
- def copy_(src:Tensor, other:Tensor) -> Tensor: return copy.copy(src)
- # this is fine for tested usecases since as geohotstan understands,
- # data_ptr is used to compare if operations needed between tensors is the same
- def data_ptr(tensor:Tensor): return tensor.lazydata
- # https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html
- # TODO this is setitem
- def index_put_(tensor:Tensor, indices, values, accumulate) -> Tensor:
- tensor[indices] = values
- # https://pytorch.org/docs/stable/generated/torch.argsort.html
- def argsort(tensor:Tensor) -> Tensor:
- pass
- # https://pytorch.org/docs/stable/generated/torch.all.html
- def all_(tensor:Tensor) -> Tensor:
- return tensor != 0
- # https://pytorch.org/docs/stable/generated/torch.diagonal.html
- def diagonal(tensor:Tensor) -> Tensor:
- assert tensor.ndim == 2 and all_same(tensor.shape), 'only support 2 ndim square tensors'
- return (Tensor.eye(tensor.shape[0]) * tensor).sum(0)
- # https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html
- def unravel_index(tensor, shape):
- pass
- # https://github.com/pytorch/pytorch/blob/79811e765c23242210ebdc623539d2103a166463/torch/testing/_creation.py#L38
- def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor:
- r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
- values uniformly drawn from ``[low, high)``.
- If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
- finite values then they are clamped to the lowest or highest representable finite value, respectively.
- If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
- which depend on :attr:`dtype`.
- +---------------------------+------------+----------+
- | ``dtype`` | ``low`` | ``high`` |
- +===========================+============+==========+
- | boolean type | ``0`` | ``2`` |
- +---------------------------+------------+----------+
- | unsigned integral type | ``0`` | ``10`` |
- +---------------------------+------------+----------+
- | signed integral types | ``-9`` | ``10`` |
- +---------------------------+------------+----------+
- | floating types | ``-9`` | ``9`` |
- +---------------------------+------------+----------+
- | complex types | ``-9`` | ``9`` |
- +---------------------------+------------+----------+
- """
- contiguous = not noncontiguous
- if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
- elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype)
- elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int
- elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
- else: raise NotImplementedError(f"{dtype} not implemented")
- class TestIndexing(unittest.TestCase):
- def test_index(self):
- reference = consec((3, 3, 3))
- numpy_testing_assert_equal_helper(reference[0], consec((3, 3)))
- numpy_testing_assert_equal_helper(reference[1], consec((3, 3), 10))
- numpy_testing_assert_equal_helper(reference[2], consec((3, 3), 19))
- numpy_testing_assert_equal_helper(reference[0, 1], consec((3,), 4))
- numpy_testing_assert_equal_helper(reference[0:2], consec((2, 3, 3)))
- numpy_testing_assert_equal_helper(reference[2, 2, 2], 27)
- numpy_testing_assert_equal_helper(reference[:], consec((3, 3, 3)))
- # indexing with Ellipsis
- numpy_testing_assert_equal_helper(reference[..., 2], np.array([[3., 6., 9.],[12., 15., 18.],[21., 24., 27.]]))
- numpy_testing_assert_equal_helper(reference[0, ..., 2], np.array([3., 6., 9.]))
- numpy_testing_assert_equal_helper(reference[..., 2], reference[:, :, 2])
- numpy_testing_assert_equal_helper(reference[0, ..., 2], reference[0, :, 2])
- numpy_testing_assert_equal_helper(reference[0, 2, ...], reference[0, 2])
- numpy_testing_assert_equal_helper(reference[..., 2, 2, 2], 27)
- numpy_testing_assert_equal_helper(reference[2, ..., 2, 2], 27)
- numpy_testing_assert_equal_helper(reference[2, 2, ..., 2], 27)
- numpy_testing_assert_equal_helper(reference[2, 2, 2, ...], 27)
- numpy_testing_assert_equal_helper(reference[...], reference)
- reference_5d = consec((3, 3, 3, 3, 3))
- numpy_testing_assert_equal_helper(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0])
- numpy_testing_assert_equal_helper(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0])
- numpy_testing_assert_equal_helper(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1])
- numpy_testing_assert_equal_helper(reference_5d[...], reference_5d)
- # None indexing
- numpy_testing_assert_equal_helper(reference[2, None], reference[2].unsqueeze(0))
- numpy_testing_assert_equal_helper(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0))
- numpy_testing_assert_equal_helper(reference[2:4, None], reference[2:4].unsqueeze(1))
- numpy_testing_assert_equal_helper(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
- numpy_testing_assert_equal_helper(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))
- # indexing 0-length slice
- numpy_testing_assert_equal_helper(np.empty((0, 3, 3)), reference[slice(0)])
- numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[slice(0), 2])
- numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[2, slice(0)])
- numpy_testing_assert_equal_helper(np.empty([]), reference[2, 1:1, 2])
- # indexing with step
- reference = consec((10, 10, 10))
- numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack(reference[1], reference[3], dim=0))
- numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack(reference[1], reference[3], reference[5], dim=0))
- numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack(reference[1], reference[5], dim=0))
- numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack(reference[2:4, 1], reference[2:4, 3], dim=1))
- numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack(reference[3, 1], reference[3, 3], reference[3, 5], dim=0))
- numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack(reference[2, 1], reference[2, 5], dim=0).unsqueeze(0))
- numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack(reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5], dim=1))
- lst = [list(range(i, i+10)) for i in range(0, 100, 10)]
- tensor = Tensor(lst)
- for _ in range(100):
- idx1_start = random.randrange(10)
- idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
- idx1_step = random.randrange(1, 8)
- idx1 = slice(idx1_start, idx1_end, idx1_step)
- if random.randrange(2) == 0:
- idx2_start = random.randrange(10)
- idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
- idx2_step = random.randrange(1, 8)
- idx2 = slice(idx2_start, idx2_end, idx2_step)
- lst_indexed = [l[idx2] for l in lst[idx1]]
- tensor_indexed = tensor[idx1, idx2]
- else:
- lst_indexed = lst[idx1]
- tensor_indexed = tensor[idx1]
- numpy_testing_assert_equal_helper(tensor_indexed, np.array(lst_indexed))
- self.assertRaises(ValueError, lambda: reference[1:9:0])
- # NOTE torch doesn't support this but numpy does so we should too. Torch raises ValueError
- # see test_slice_negative_strides in test_ops.py
- # self.assertRaises(ValueError, lambda: reference[1:9:-1])
- self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
- self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
- self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
- self.assertRaises(IndexError, lambda: reference[0.0])
- self.assertRaises(TypeError, lambda: reference[0.0:2.0])
- self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
- self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
- self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
- self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
- # TODO: delitem
- # def delitem(): del reference[0]
- # self.assertRaises(TypeError, delitem)
- # TODO: LLVM is quite fast, why are other compiled backends slow?
- @unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "GPU", "METAL", "NV", "AMD"], "slow")
- def test_advancedindex(self):
- # integer array indexing
- # pick a random valid indexer type
- def ri(indices):
- choice = random.randint(0, 2)
- if choice == 0: return Tensor(indices)
- if choice == 1: return list(indices)
- return tuple(indices)
- def validate_indexing(x):
- numpy_testing_assert_equal_helper(x[[0]], consec((1,)))
- numpy_testing_assert_equal_helper(x[ri([0]),], consec((1,)))
- numpy_testing_assert_equal_helper(x[ri([3]),], consec((1,), 4))
- numpy_testing_assert_equal_helper(x[[2, 3, 4]], consec((3,), 3))
- numpy_testing_assert_equal_helper(x[ri([2, 3, 4]),], consec((3,), 3))
- numpy_testing_assert_equal_helper(x[ri([0, 2, 4]),], np.array([1, 3, 5]))
- def validate_setting(x):
- x[[0]] = -2
- numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
- x[[0]] = -1
- numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
- x[[2, 3, 4]] = 4
- numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
- x[ri([2, 3, 4]), ] = 3
- numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
- x[ri([0, 2, 4]), ] = np.array([5, 4, 3])
- numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
- # Case 1: Purely Integer Array Indexing
- reference = consec((10,))
- validate_indexing(reference)
- # setting values
- # TODO: advanced setitem
- '''
- validate_setting(reference)
- '''
- # Tensor with stride != 1
- # strided is [1, 3, 5, 7]
- reference = consec((10,))
- strided = set_(reference, (4,), (2,), 0)
- numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
- numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
- numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
- numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
- numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
- numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
- np.array([[5, 3], [1, 7]]))
- # stride is [4, 8]
- strided = set_(reference, (2,), (4,), offset=4)
- numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
- numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
- numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
- numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
- numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
- numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
- np.array([[5, 9], [9, 5]]))
- # reference is 1 2
- # 3 4
- # 5 6
- reference = consec((3, 2))
- numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([1, 3, 5]))
- numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]))
- numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,)))
- numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6))
- numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
- numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
- numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
- rows = ri([[0, 0],
- [1, 2]])
- columns = [0],
- numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
- [3, 5]]))
- rows = ri([[0, 0],
- [1, 2]])
- columns = ri([1, 0])
- numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[2, 1],
- [4, 5]]))
- rows = ri([[0, 0],
- [1, 2]])
- columns = ri([[0, 1],
- [1, 0]])
- numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
- [4, 5]]))
- # TODO: advanced setitem
- '''
- # setting values
- reference[ri([0]), ri([1])] = -1
- numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
- reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
- numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
- np.array([-1, 2, -4]))
- reference[rows, columns] = np.array([[4, 6], [2, 3]])
- numpy_testing_assert_equal_helper(reference[rows, columns],
- np.array([[4, 6], [2, 3]]))
- '''
- # Verify still works with Transposed (i.e. non-contiguous) Tensors
- reference = Tensor([[0, 1, 2, 3],
- [4, 5, 6, 7],
- [8, 9, 10, 11]]).T
- # Transposed: [[0, 4, 8],
- # [1, 5, 9],
- # [2, 6, 10],
- # [3, 7, 11]]
- numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([0, 1, 2]))
- numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]))
- numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0]))
- numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6]))
- numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
- numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
- numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
- rows = ri([[0, 0],
- [1, 2]])
- columns = [0],
- numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
- rows = ri([[0, 0],
- [1, 2]])
- columns = ri([1, 0])
- numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[4, 0], [5, 2]]))
- rows = ri([[0, 0],
- [1, 3]])
- columns = ri([[0, 1],
- [1, 2]])
- numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))
- # TODO: advanced setitem
- '''
- # setting values
- reference[ri([0]), ri([1])] = -1
- numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
- np.array([-1]))
- reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
- numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
- np.array([-1, 2, -4]))
- reference[rows, columns] = np.array([[4, 6], [2, 3]])
- numpy_testing_assert_equal_helper(reference[rows, columns],
- np.array([[4, 6], [2, 3]]))
- '''
- # stride != 1
- # strided is [[1 3 5 7],
- # [9 11 13 15]]
- reference = Tensor.arange(0., 24).reshape(3, 8)
- strided = set_(reference, (2,4), (8,2), 1)
- numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])],
- np.array([1, 9]))
- numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])],
- np.array([3, 11]))
- numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])],
- np.array([1]))
- numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])],
- np.array([15]))
- numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]],
- np.array([1, 7]))
- numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
- np.array([9, 11, 11, 9, 15]))
- numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
- np.array([1, 3, 9, 9]))
- rows = ri([[0, 0],
- [1, 1]])
- columns = [0],
- numpy_testing_assert_equal_helper(strided[rows, columns],
- np.array([[1, 1], [9, 9]]))
- rows = ri([[0, 1],
- [1, 0]])
- columns = ri([1, 2])
- numpy_testing_assert_equal_helper(strided[rows, columns],
- np.array([[3, 13], [11, 5]]))
- rows = ri([[0, 0],
- [1, 1]])
- columns = ri([[0, 1],
- [1, 2]])
- numpy_testing_assert_equal_helper(strided[rows, columns],
- np.array([[1, 3], [11, 13]]))
- # setting values
- # strided is [[10, 11],
- # [17, 18]]
- reference = Tensor.arange(0., 24).reshape(3, 8)
- strided = set_(reference, (2,2), (7,1), 10)
- numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
- np.array([11]))
- # TODO advanced setitem
- '''
- strided[ri([0]), ri([1])] = -1
- numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
- Tensor([-1]))
- '''
- reference = Tensor.arange(0., 24).reshape(3, 8)
- strided = set_(reference, (2,2), (7,1), 10)
- numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
- np.array([11, 17]))
- # TODO advanced setitem
- '''
- strided[ri([0, 1]), ri([1, 0])] = Tensor([-1, 2])
- numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
- Tensor([-1, 2]))
- '''
- reference = Tensor.arange(0., 24).realize().reshape(3, 8)
- strided = set_(reference, (2,2), (7,1), 10)
- rows = ri([[0],
- [1]])
- columns = ri([[0, 1],
- [0, 1]])
- numpy_testing_assert_equal_helper(strided[rows, columns],
- np.array([[10, 11], [17, 18]]))
- # TODO advanced setitem
- '''
- strided[rows, columns] = Tensor([[4, 6], [2, 3]])
- numpy_testing_assert_equal_helper(strided[rows, columns],
- Tensor([[4, 6], [2, 3]]))
- '''
- # Tests using less than the number of dims, and ellipsis
- # reference is 1 2
- # 3 4
- # 5 6
- reference = consec((3, 2))
- numpy_testing_assert_equal_helper(reference[ri([0, 2]),], np.array([[1, 2], [5, 6]]))
- numpy_testing_assert_equal_helper(reference[ri([1]), ...], np.array([[3, 4]]))
- numpy_testing_assert_equal_helper(reference[..., ri([1])], np.array([[2], [4], [6]]))
- # verify too many indices fails
- with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])]
- # test invalid index fails
- reference = Tensor.empty(10)
- for err_idx in (10, -11):
- with self.assertRaises(IndexError):
- reference[err_idx]
- # NOTE cannot check for out of bounds with Tensor indexing
- # see tensor.py: __getitem__ (Tiny Things)
- '''
- with self.assertRaises(IndexError):
- reference[Tensor([err_idx], dtype=dtypes.int64)]
- with self.assertRaises(IndexError):
- reference[[err_idx]]
- '''
- def tensor_indices_to_np(tensor: Tensor, indices):
- npt = tensor.numpy()
- idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else
- i for i in indices)
- return npt, idxs
- def get_numpy(tensor, indices):
- npt, idxs = tensor_indices_to_np(tensor, indices)
- return Tensor(npt[idxs])
- def set_numpy(tensor:Tensor, indices, value):
- if not isinstance(value, int):
- value = value.numpy()
- npt, idxs = tensor_indices_to_np(tensor, indices)
- npt[idxs] = value
- return npt
- def assert_get_eq(tensor, indexer):
- numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
- def assert_set_eq(tensor: Tensor, indexer, val):
- pyt = clone(tensor)
- numt = clone(tensor)
- pyt[indexer] = val
- numt = set_numpy(numt, indexer, val)
- numpy_testing_assert_equal_helper(pyt, numt)
- # NOTE: torch initiates the gradients using g0cpu (rand as gradients)
- def assert_backward_eq(tensor: Tensor, indexer):
- cpu = clone(tensor.float())
- cpu.requires_grad = True
- outcpu = cpu[indexer].sum()
- outcpu.backward()
- dev = cpu.detach()
- dev.requires_grad = True
- outdev = dev[indexer].sum()
- outdev.backward()
- numpy_testing_assert_equal_helper(cpu.grad, dev.grad)
- def get_set_tensor(indexed: Tensor, indexer):
- set_size = indexed[indexer].shape
- set_count = indexed[indexer].numel()
- set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size).cast(dtypes.float64)
- return set_tensor
- # Tensor is 0 1 2 3 4
- # 5 6 7 8 9
- # 10 11 12 13 14
- # 15 16 17 18 19
- reference = Tensor.arange(0., 20).reshape(4, 5)
- indices_to_test = [
- # grab the second, fourth columns
- [slice(None), [1, 3]],
- # first, third rows,
- [[0, 2], slice(None)],
- # weird shape
- [slice(None), [[0, 1],
- [2, 3]]],
- # negatives
- [[-1], [0]],
- [[0, 2], [-1]],
- [slice(None), [-1]],
- ]
- # only test dupes on gets
- get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
- for indexer in get_indices_to_test:
- assert_get_eq(reference, indexer)
- assert_backward_eq(reference, indexer)
- # TODO advanced setitem
- '''
- for indexer in indices_to_test:
- assert_set_eq(reference, indexer, 44)
- assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
- '''
- reference = Tensor.arange(0., 160).reshape(4, 8, 5)
- indices_to_test = [
- [slice(None), slice(None), [0, 3, 4]],
- [slice(None), [2, 4, 5, 7], slice(None)],
- [[2, 3], slice(None), slice(None)],
- [slice(None), [0, 2, 3], [1, 3, 4]],
- [slice(None), [0], [1, 2, 4]],
- [slice(None), [0, 1, 3], [4]],
- [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
- [slice(None), [[0, 1], [2, 3]], [[0]]],
- [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
- [[0, 2, 3], [1, 3, 4], slice(None)],
- [[0], [1, 2, 4], slice(None)],
- [[0, 1, 3], [4], slice(None)],
- [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
- [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
- [[[0, 1], [2, 3]], [[0]], slice(None)],
- [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
- [[[2]], [[0, 3], [4, 1]], slice(None)],
- # non-contiguous indexing subspace
- [[0, 2, 3], slice(None), [1, 3, 4]],
- # less dim, ellipsis
- [[0, 2], ],
- [[0, 2], slice(None)],
- [[0, 2], Ellipsis],
- [[0, 2], slice(None), Ellipsis],
- [[0, 2], Ellipsis, slice(None)],
- [[0, 2], [1, 3]],
- [[0, 2], [1, 3], Ellipsis],
- [Ellipsis, [1, 3], [2, 3]],
- [Ellipsis, [2, 3, 4]],
- [Ellipsis, slice(None), [2, 3, 4]],
- [slice(None), Ellipsis, [2, 3, 4]],
- # ellipsis counts for nothing
- [Ellipsis, slice(None), slice(None), [0, 3, 4]],
- [slice(None), Ellipsis, slice(None), [0, 3, 4]],
- [slice(None), slice(None), Ellipsis, [0, 3, 4]],
- [slice(None), slice(None), [0, 3, 4], Ellipsis],
- [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
- [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
- [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
- ]
- for indexer in indices_to_test:
- assert_get_eq(reference, indexer)
- # TODO advanced setitem
- '''
- assert_set_eq(reference, indexer, 212)
- assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
- '''
- assert_backward_eq(reference, indexer)
- reference = Tensor.arange(0., 1296).reshape(3, 9, 8, 6)
- indices_to_test = [
- [slice(None), slice(None), slice(None), [0, 3, 4]],
- [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
- [slice(None), [2, 3], slice(None), slice(None)],
- [[1, 2], slice(None), slice(None), slice(None)],
- [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
- [slice(None), slice(None), [0], [1, 2, 4]],
- [slice(None), slice(None), [0, 1, 3], [4]],
- [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
- [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
- [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
- [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
- [slice(None), [0], [1, 2, 4], slice(None)],
- [slice(None), [0, 1, 3], [4], slice(None)],
- [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
- [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
- [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
- [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
- [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
- [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
- [[0], [1, 2, 4], slice(None), slice(None)],
- [[0, 1, 2], [4], slice(None), slice(None)],
- [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
- [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
- [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
- [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
- [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
- [slice(None), [2, 3, 4], [1, 3, 4], [4]],
- [slice(None), [0, 1, 3], [4], [1, 3, 4]],
- [slice(None), [6], [0, 2, 3], [1, 3, 4]],
- [slice(None), [2, 3, 5], [3], [4]],
- [slice(None), [0], [4], [1, 3, 4]],
- [slice(None), [6], [0, 2, 3], [1]],
- [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
- [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
- [[2, 0, 1], [1, 2, 3], [4], slice(None)],
- [[0, 1, 2], [4], [1, 3, 4], slice(None)],
- [[0], [0, 2, 3], [1, 3, 4], slice(None)],
- [[0, 2, 1], [3], [4], slice(None)],
- [[0], [4], [1, 3, 4], slice(None)],
- [[1], [0, 2, 3], [1], slice(None)],
- [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
- # less dim, ellipsis
- [Ellipsis, [0, 3, 4]],
- [Ellipsis, slice(None), [0, 3, 4]],
- [Ellipsis, slice(None), slice(None), [0, 3, 4]],
- [slice(None), Ellipsis, [0, 3, 4]],
- [slice(None), slice(None), Ellipsis, [0, 3, 4]],
- [slice(None), [0, 2, 3], [1, 3, 4]],
- [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
- [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
- [[0], [1, 2, 4]],
- [[0], [1, 2, 4], slice(None)],
- [[0], [1, 2, 4], Ellipsis],
- [[0], [1, 2, 4], Ellipsis, slice(None)],
- [[1], ],
- [[0, 2, 1], [3], [4]],
- [[0, 2, 1], [3], [4], slice(None)],
- [[0, 2, 1], [3], [4], Ellipsis],
- [Ellipsis, [0, 2, 1], [3], [4]],
- ]
- for indexer in indices_to_test:
- assert_get_eq(reference, indexer)
- # TODO advanced setitem
- '''
- assert_set_eq(reference, indexer, 1333)
- assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
- '''
- indices_to_test += [
- [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
- [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
- ]
- for indexer in indices_to_test:
- assert_get_eq(reference, indexer)
- # TODO advanced setitem
- '''
- assert_set_eq(reference, indexer, 1333)
- '''
- assert_backward_eq(reference, indexer)
- # TODO setitem backward
- '''
- def test_set_item_to_scalar_tensor(self):
- m = random.randint(1, 10)
- n = random.randint(1, 10)
- z = Tensor.randn([m, n])
- a = 1.0
- w = Tensor(a, requires_grad=True)
- z[:, 0] = w
- z.sum().backward()
- numpy_testing_assert_equal_helper(w.grad, m * a)
- '''
- def test_single_int(self):
- v = Tensor.randn(5, 7, 3)
- numpy_testing_assert_equal_helper(v[4].shape, (7, 3))
- def test_multiple_int(self):
- v = Tensor.randn(5, 7, 3)
- numpy_testing_assert_equal_helper(v[4].shape, (7, 3))
- numpy_testing_assert_equal_helper(v[4, :, 1].shape, (7,))
- def test_none(self):
- v = Tensor.randn(5, 7, 3)
- numpy_testing_assert_equal_helper(v[None].shape, (1, 5, 7, 3))
- numpy_testing_assert_equal_helper(v[:, None].shape, (5, 1, 7, 3))
- numpy_testing_assert_equal_helper(v[:, None, None].shape, (5, 1, 1, 7, 3))
- numpy_testing_assert_equal_helper(v[..., None].shape, (5, 7, 3, 1))
- def test_step(self):
- v = Tensor.arange(10)
- numpy_testing_assert_equal_helper(v[::1], v)
- numpy_testing_assert_equal_helper(v[::2], [0, 2, 4, 6, 8])
- numpy_testing_assert_equal_helper(v[::3], [0, 3, 6, 9])
- numpy_testing_assert_equal_helper(v[::11], [0])
- numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5])
- def test_step_assignment(self):
- v = Tensor.zeros(4, 4).contiguous()
- v[0, 1::2] = Tensor([3., 4.])
- numpy_testing_assert_equal_helper(v[0].numpy().tolist(), [0, 3, 0, 4])
- numpy_testing_assert_equal_helper(v[1:].sum(), 0)
- @unittest.skip("bool indexing not supported")
- def test_bool_indices(self):
- v = Tensor.randn(5, 7, 3)
- boolIndices = Tensor([True, False, True, True, False], dtype=dtypes.bool)
- numpy_testing_assert_equal_helper(v[boolIndices].shape, (3, 7, 3))
- numpy_testing_assert_equal_helper(v[boolIndices], Tensor.stack([v[0], v[2], v[3]]))
- v = Tensor([True, False, True], dtype=dtypes.bool)
- boolIndices = Tensor([True, False, False], dtype=dtypes.bool)
- uint8Indices = Tensor([1, 0, 0], dtype=dtypes.uint8)
- with warnings.catch_warnings(record=True) as w:
- numpy_testing_assert_equal_helper(v[boolIndices].shape, v[uint8Indices].shape)
- numpy_testing_assert_equal_helper(v[boolIndices], v[uint8Indices])
- numpy_testing_assert_equal_helper(v[boolIndices], Tensor([True]))
- numpy_testing_assert_equal_helper(len(w), 2)
- @unittest.skip("bool indexing not supported")
- def test_bool_indices_accumulate(self):
- mask = Tensor.zeros(size=(10, ), dtype=dtypes.bool)
- y = Tensor.ones(size=(10, 10))
- index_put_(y, (mask, ), y[mask], accumulate=True)
- numpy_testing_assert_equal_helper(y, Tensor.ones(size=(10, 10)))
- @unittest.skip("bool indexing not supported")
- def test_multiple_bool_indices(self):
- v = Tensor.randn(5, 7, 3)
- # note: these broadcast together and are transposed to the first dim
- mask1 = Tensor([1, 0, 1, 1, 0], dtype=dtypes.bool)
- mask2 = Tensor([1, 1, 1], dtype=dtypes.bool)
- numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7))
- @unittest.skip("bool indexing not supported")
- def test_byte_mask(self):
- v = Tensor.randn(5, 7, 3)
- mask = Tensor([1, 0, 1, 1, 0], dtype=dtypes.uint8)
- with warnings.catch_warnings(record=True) as w:
- numpy_testing_assert_equal_helper(v[mask].shape, (3, 7, 3))
- numpy_testing_assert_equal_helper(v[mask], Tensor.stack([v[0], v[2], v[3]]))
- numpy_testing_assert_equal_helper(len(w), 2)
- v = Tensor([1.])
- numpy_testing_assert_equal_helper(v[v == 0], Tensor([]))
- @unittest.skip("bool indexing not supported")
- def test_byte_mask_accumulate(self):
- mask = Tensor.zeros(size=(10, ), dtype=dtypes.uint8)
- y = Tensor.ones(size=(10, 10))
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- index_put_(y, (mask, ), y[mask], accumulate=True)
- numpy_testing_assert_equal_helper(y, Tensor.ones(size=(10, 10)))
- numpy_testing_assert_equal_helper(len(w), 2)
- # TODO setitem
- # NOTE: tinygrad doesn't support idx.max that big
- '''
- def test_index_put_accumulate_large_tensor(self):
- # This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
- N = (1 << 31) + 5
- dt = dtypes.int8
- a = Tensor.ones(N, dtype=dt).contiguous()
- indices = Tensor([-2, 0, -2, -1, 0, -1, 1], dtype=dtypes.int64)
- values = Tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt)
- index_put_(a, (indices, ), values, accumulate=True)
- numpy_testing_assert_equal_helper(a[0], 11)
- numpy_testing_assert_equal_helper(a[1], 12)
- numpy_testing_assert_equal_helper(a[2], 1)
- numpy_testing_assert_equal_helper(a[-3], 1)
- numpy_testing_assert_equal_helper(a[-2], 13)
- numpy_testing_assert_equal_helper(a[-1], 14)
- a = Tensor.ones((2, N), dtype=dt).contiguous()
- indices0 = np.array([0, -1, 0, 1], dtype=dtypes.int64)
- indices1 = np.array([-2, -1, 0, 1], dtype=dtypes.int64)
- values = np.array([12, 13, 10, 11], dtype=dt)
- index_put_(a, (indices0, indices1), values, accumulate=True)
- numpy_testing_assert_equal_helper(a[0, 0], 11)
- numpy_testing_assert_equal_helper(a[0, 1], 1)
- numpy_testing_assert_equal_helper(a[1, 0], 1)
- numpy_testing_assert_equal_helper(a[1, 1], 12)
- numpy_testing_assert_equal_helper(a[:, 2], Tensor.ones(2, dtype=dtypes.int8))
- numpy_testing_assert_equal_helper(a[:, -3], Tensor.ones(2, dtype=dtypes.int8))
- numpy_testing_assert_equal_helper(a[0, -2], 13)
- numpy_testing_assert_equal_helper(a[1, -2], 1)
- numpy_testing_assert_equal_helper(a[-1, -1], 14)
- numpy_testing_assert_equal_helper(a[0, -1], 1)
- '''
- # TODO fancy setitem
- '''
- def test_index_put_accumulate_duplicate_indices(self):
- for i in range(1, 512):
- # generate indices by random walk, this will create indices with
- # lots of duplicates interleaved with each other
- delta = Tensor.uniform(low=-1, high=1, dtype=dtypes.double)
- indices = delta.cumsum(0).cast(dtypes.int64)
- # input = torch.randn(indices.abs().max() + 1)
- input = Tensor.randn(indices.abs().max().item() + 1)
- # values = torch.randn(indices.size(0))
- values = Tensor.randn(indices.shape(0))
- output = index_put_(input, (indices,), values, accumulate=True)
- input_list = input.numpy().tolist()
- indices_list = indices.numpy().tolist()
- values_list = values.numpy().tolist()
- for i, v in zip(indices_list, values_list):
- input_list[i] += v
- numpy_testing_assert_equal_helper(output, input_list)
- '''
- def test_index_ind_dtype(self):
- x = Tensor.randn(4, 4)
- # ind_long = torch.randint(4, (4,), dtype=torch.long)
- # TODO should we spend an extra line to allow for randint other dtypes?
- # copied from randint
- ind_long = (Tensor.rand((4,),)*(4-0)+0).cast(dtypes.int64)
- # ind_int = ind_long.int()
- ind_int = (ind_long).cast(dtypes.int32)
- ref = x[ind_long, ind_long]
- res = x[ind_int, ind_int]
- numpy_testing_assert_equal_helper(ref, res)
- ref = x[ind_long, :]
- res = x[ind_int, :]
- numpy_testing_assert_equal_helper(ref, res)
- ref = x[:, ind_long]
- res = x[:, ind_int]
- numpy_testing_assert_equal_helper(ref, res)
- # no repeating indices for index_put
- # TODO fancy setitem
- '''
- src = Tensor.randn(4)
- ind_long = Tensor.arange(4, dtype=dtypes.int64)
- ind_int = ind_long.cast(dtypes.int32)
- for accum in (True, False):
- inp_ref = clone(x)
- inp_res = clone(x)
- index_put_(inp_ref, (ind_long, ind_long), src, accum)
- index_put_(inp_res, (ind_int, ind_int), src, accum)
- numpy_testing_assert_equal_helper(inp_ref, inp_res)
- '''
- # TODO empty setitem
- '''
- def test_index_put_accumulate_empty(self):
- # Regression test for https://github.com/pytorch/pytorch/issues/94667
- input = Tensor.rand([], dtype=dtypes.float32)
- with self.assertRaises(RuntimeError):
- index_put_(input, [], np.array([1.0]), True)
- '''
- @unittest.skip("bool indexing not supported")
- def test_multiple_byte_mask(self):
- v = Tensor.randn(5, 7, 3)
- # note: these broadcast together and are transposed to the first dim
- mask1 = Tensor([1, 0, 1, 1, 0], dtype=dtypes.uint8)
- mask2 = Tensor([1, 1, 1], dtype=dtypes.uint8)
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7))
- numpy_testing_assert_equal_helper(len(w), 2)
- @unittest.skip("bool indexing not supported")
- def test_byte_mask2d(self):
- v = Tensor.randn(5, 7, 3)
- c = Tensor.randn(5, 7)
- num_ones = (c > 0).sum()
- r = v[c > 0]
- numpy_testing_assert_equal_helper(r.shape, (num_ones, 3))
- @unittest.skip("bool indexing not supported")
- def test_jit_indexing(self):
- def fn1(x):
- x[x < 50] = 1.0
- return x
- def fn2(x):
- x[0:50] = 1.0
- return x
- scripted_fn1 = TinyJit(fn1)
- scripted_fn2 = TinyJit(fn2)
- data = Tensor.arange(100, dtype=dtypes.float)
- out = scripted_fn1(clone(data))
- ref = Tensor(np.concatenate((np.ones(50), np.arange(50, 100))), dtype=dtypes.float)
- numpy_testing_assert_equal_helper(out, ref)
- out = scripted_fn2(clone(data))
- numpy_testing_assert_equal_helper(out, ref)
- def test_int_indices(self):
- v = Tensor.randn(5, 7, 3)
- numpy_testing_assert_equal_helper(v[[0, 4, 2]].shape, (3, 7, 3))
- numpy_testing_assert_equal_helper(v[:, [0, 4, 2]].shape, (5, 3, 3))
- numpy_testing_assert_equal_helper(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
- # TODO fancy setitem
- '''
- def test_index_put_src_datatype(self, dtype):
- src = Tensor.ones(3, 2, 4, dtype=dtype)
- vals = Tensor.ones(3, 2, 4, dtype=dtype)
- indices = (np.array([0, 2, 1]),)
- res = index_put_(src, indices, vals, accumulate=True)
- numpy_testing_assert_equal_helper(res.shape, src.shape)
- '''
- def test_index_src_datatype(self):
- src = Tensor.ones(3, 2, 4)
- # test index
- res = src[[0, 2, 1], :, :]
- numpy_testing_assert_equal_helper(res.shape, src.shape)
- # test index_put, no accum
- # TODO fancy setitem
- '''
- src[[0, 2, 1], :, :] = res
- numpy_testing_assert_equal_helper(res.shape, src.shape)
- '''
- def test_int_indices2d(self):
- # From the NumPy indexing example
- x = Tensor.arange(0, 12).reshape(4, 3)
- rows = Tensor([[0, 0], [3, 3]])
- columns = Tensor([[0, 2], [0, 2]])
- numpy_testing_assert_equal_helper(x[rows, columns].numpy().tolist(), [[0, 2], [9, 11]])
- def test_int_indices_broadcast(self):
- # From the NumPy indexing example
- x = Tensor.arange(0, 12).reshape(4, 3)
- rows = Tensor([0, 3])
- columns = Tensor([0, 2])
- result = x[rows[:, None], columns]
- numpy_testing_assert_equal_helper(result.numpy().tolist(), [[0, 2], [9, 11]])
- # TODO jax supports empty tensor indexing
- @unittest.skip("empty tensor indexing not supported")
- def test_empty_index(self):
- x = Tensor.arange(0, 12).reshape(4, 3)
- idx = Tensor([], dtype=dtypes.int64)
- numpy_testing_assert_equal_helper(x[idx].numel(), 0)
- # TODO empty setitem
- '''
- # empty assignment should have no effect but not throw an exception
- y = clone(x)
- y[idx] = -1
- numpy_testing_assert_equal_helper(x, y)
- mask = Tensor.zeros(4, 3).cast(dtypes.bool)
- y[mask] = -1
- numpy_testing_assert_equal_helper(x, y)
- '''
- # TODO jax supports empty tensor indexing
- @unittest.skip("empty tensor indexing not supported")
- def test_empty_ndim_index(self):
- x = Tensor.randn(5)
- numpy_testing_assert_equal_helper(Tensor.empty(0, 2), x[Tensor.empty(0, 2, dtype=dtypes.int64)])
- x = Tensor.randn(2, 3, 4, 5)
- numpy_testing_assert_equal_helper(Tensor.empty(2, 0, 6, 4, 5),
- x[:, Tensor.empty(0, 6, dtype=dtypes.int64)])
- x = Tensor.empty(10, 0)
- numpy_testing_assert_equal_helper(x[[1, 2]].shape, (2, 0))
- numpy_testing_assert_equal_helper(x[[], []].shape, (0,))
- with self.assertRaises(IndexError):
- x[:, [0, 1]]
- def test_empty_slice(self):
- x = Tensor.randn(2, 3, 4, 5)
- y = x[:, :, :, 1]
- z = y[:, 1:1, :]
- numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
- # this isn't technically necessary, but matches NumPy stride calculations.
- # NOTE: this is empty and shouldn't have strides
- #numpy_testing_assert_equal_helper((60, 20, 5), z.lazydata.st.real_strides())
- # NOTE tinygrad's int slicing implementation makes this not contiguous
- # self.assertTrue(z.lazydata.st.contiguous)
- @unittest.skip("bool indexing not supported")
- def test_index_getitem_copy_bools_slices(self):
- true = Tensor(1, dtype=dtypes.uint8)
- false = Tensor(0, dtype=dtypes.uint8)
- tensors = [Tensor.randn(2, 3), Tensor(3.)]
- for a in tensors:
- self.assertNotEqual(data_ptr(a), data_ptr(a[True]))
- numpy_testing_assert_equal_helper(Tensor.empty(0, *a.shape), a[False])
- self.assertNotEqual(data_ptr(a), data_ptr(a[true]))
- numpy_testing_assert_equal_helper(Tensor.empty(0, *a.shape), a[false])
- self.assertEqual(data_ptr(a), data_ptr(a[None]))
- self.assertEqual(data_ptr(a), data_ptr(a[...]))
- @unittest.skip("bool indexing not supported")
- def test_index_setitem_bools_slices(self):
- true = Tensor(1, dtype=dtypes.uint8)
- false = Tensor(0, dtype=dtypes.uint8)
- tensors = [Tensor.randn(2, 3), Tensor(3)]
- for a in tensors:
- # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
- # (some of these ops already prefix a 1 to the size)
- neg_ones = Tensor.ones_like(a) * -1
- neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
- a[True] = neg_ones_expanded
- numpy_testing_assert_equal_helper(a, neg_ones)
- a[False] = 5
- numpy_testing_assert_equal_helper(a, neg_ones)
- a[true] = neg_ones_expanded * 2
- numpy_testing_assert_equal_helper(a, neg_ones * 2)
- a[false] = 5
- numpy_testing_assert_equal_helper(a, neg_ones * 2)
- a[None] = neg_ones_expanded * 3
- numpy_testing_assert_equal_helper(a, neg_ones * 3)
- a[...] = neg_ones_expanded * 4
- numpy_testing_assert_equal_helper(a, neg_ones * 4)
- if a.dim() == 0:
- with self.assertRaises(IndexError):
- a[:] = neg_ones_expanded * 5
- @unittest.skip("bool indexing not supported")
- def test_index_scalar_with_bool_mask(self):
- a = Tensor(1)
- uintMask = Tensor(True, dtype=dtypes.uint8)
- boolMask = Tensor(True, dtype=dtypes.bool)
- numpy_testing_assert_equal_helper(a[uintMask], a[boolMask])
- numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype)
- a = Tensor(True, dtype=dtypes.bool)
- numpy_testing_assert_equal_helper(a[uintMask], a[boolMask])
- numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype)
- @unittest.skip("bool indexing not supported")
- def test_setitem_expansion_error(self):
- true = Tensor(True)
- a = Tensor.randn(2, 3)
- # check prefix with non-1s doesn't work
- # a_expanded = a.expand(torch.Size([5, 1]) + a.size())
- a_expanded = a.expand((5, 1) + a.shape)
- # NumPy: ValueError
- with self.assertRaises(RuntimeError):
- a[True] = a_expanded
- with self.assertRaises(RuntimeError):
- a[true] = a_expanded
- def test_getitem_scalars_simple(self):
- src = Tensor([[[1.,2.],[3.,4.]], [[1,1],[1,1]]])
- a = src[0].mul(src[1])
- self.assertEqual(a[0,1].item(), 2)
- def test_getitem_scalars(self):
- zero = Tensor(0, dtype=dtypes.int64)
- one = Tensor(1, dtype=dtypes.int64)
- # non-scalar indexed with scalars
- a = Tensor.randn(2, 3)
- numpy_testing_assert_equal_helper(a[0], a[zero])
- numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
- numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
- numpy_testing_assert_equal_helper(a[0, one], a[zero, 1])
- # indexing by a scalar should slice (not copy)
- self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
- self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
- self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
- # scalar indexed with scalar
- r = Tensor.randn()
- with self.assertRaises(IndexError):
- r[:]
- with self.assertRaises(IndexError):
- r[zero]
- numpy_testing_assert_equal_helper(r, r[...])
- # TODO fancy setitem
- '''
- def test_setitem_scalars(self):
- zero = Tensor(0, dtype=dtypes.int64)
- # non-scalar indexed with scalars
- a = Tensor.randn(2, 3).contiguous()
- a_set_with_number = clone(a).contiguous()
- a_set_with_scalar = clone(a).contiguous()
- b = Tensor.randn(3)
- a_set_with_number[0] = b
- a_set_with_scalar[zero] = b
- numpy_testing_assert_equal_helper(a_set_with_number, a_set_with_scalar)
- a[1, zero] = 7.7
- # TODO: weird inaccuracy Max relative difference: 2.47707621e-08
- # numpy_testing_assert_equal_helper(7.7, a[1, 0])
- np.testing.assert_allclose(7.7, a[1, 0].numpy(), rtol=1e-7)
- # scalar indexed with scalars
- r = Tensor.randn().contiguous()
- with self.assertRaises(IndexError):
- r[:] = 8.8
- with self.assertRaises(IndexError):
- r[zero] = 8.8
- r[...] = 9.9
- # TODO: weird inaccuracy Max relative difference: 3.85322971e-08
- # numpy_testing_assert_equal_helper(9.9, r)
- np.testing.assert_allclose(9.9, r, rtol=1e-7)
- '''
- def test_basic_advanced_combined(self):
- # From the NumPy indexing example
- x = Tensor.arange(0, 12).reshape(4, 3)
- numpy_testing_assert_equal_helper(x[1:2, 1:3], x[1:2, [1, 2]])
- numpy_testing_assert_equal_helper(x[1:2, 1:3].numpy().tolist(), [[4, 5]])
- # Check that it is a copy
- unmodified = clone(x)
- x[1:2, [1, 2]].zeros_like()
- numpy_testing_assert_equal_helper(x, unmodified)
- # But assignment should modify the original
- # TODO fancy setitem
- '''
- unmodified = clone(x)
- x[1:2, [1, 2]] = 0
- self.assertNotEqual(x, unmodified)
- '''
- def test_int_assignment(self):
- x = Tensor.arange(0, 4).reshape(2, 2)
- x[1] = 5
- numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 5]])
- x = Tensor.arange(0, 4).reshape(2, 2)
- x[1] = Tensor.arange(5, 7)
- numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 6]])
- # TODO fancy setitem
- '''
- def test_byte_tensor_assignment(self):
- x = Tensor.arange(0., 16).reshape(4, 4)
- b = Tensor([True, False, True, False], dtype=dtypes.uint8)
- value = Tensor([3., 4., 5., 6.])
- with warnings.catch_warnings(record=True) as w:
- x[b] = value
- numpy_testing_assert_equal_helper(len(w), 1)
- numpy_testing_assert_equal_helper(x[0], value)
- numpy_testing_assert_equal_helper(x[1], Tensor.arange(4., 8))
- numpy_testing_assert_equal_helper(x[2], value)
- numpy_testing_assert_equal_helper(x[3], Tensor.arange(12., 16))
- '''
- @unittest.skip("Tensor unpacking not supported")
- def test_variable_slicing(self):
- x = Tensor.arange(0, 16).reshape(4, 4)
- indices = Tensor([0, 1], dtype=dtypes.int32)
- i, j = indices
- numpy_testing_assert_equal_helper(x[i:j], x[0:1])
- def test_ellipsis_tensor(self):
- x = Tensor.arange(0, 9).reshape(3, 3)
- idx = Tensor([0, 2])
- numpy_testing_assert_equal_helper(x[..., idx].numpy().tolist(), [[0, 2],
- [3, 5],
- [6, 8]])
- numpy_testing_assert_equal_helper(x[idx, ...].numpy().tolist(), [[0, 1, 2],
- [6, 7, 8]])
- # TODO unravel_index
- '''
- def test_unravel_index_errors(self):
- with self.assertRaises(TypeError):
- unravel_index(
- Tensor(0.5),
- (2, 2))
- with self.assertRaises(TypeError):
- unravel_index(
- Tensor([]),
- (10, 3, 5))
- with self.assertRaises(TypeError):
- unravel_index(
- Tensor([1], dtype=dtypes.int64),
- Tensor([1, 2, 3]))
- with self.assertRaises(TypeError):
- unravel_index(
- Tensor([1], dtype=dtypes.int64),
- (1, 2, 2.0))
- with self.assertRaises(ValueError):
- unravel_index(
- Tensor(0),
- (2, -3))
- '''
- def test_invalid_index(self):
- x = Tensor.arange(0, 16).reshape(4, 4)
- self.assertRaises(TypeError, lambda: x["0":"1"])
- def test_out_of_bound_index(self):
- x = Tensor.arange(0, 100).reshape(2, 5, 10)
- self.assertRaises(IndexError, lambda: x[0, 5])
- self.assertRaises(IndexError, lambda: x[4, 5])
- self.assertRaises(IndexError, lambda: x[0, 1, 15])
- self.assertRaises(IndexError, lambda: x[:, :, 12])
- def test_zero_dim_index(self):
- x = Tensor(10)
- numpy_testing_assert_equal_helper(x, x.item())
- def runner():
- print(x[0])
- return x[0]
- self.assertRaises(IndexError, runner)
- # TODO fancy setitem
- '''
- def test_cpu_indices(self):
- idx = Tensor([0, 1])
- b = Tensor.zeros(2)
- x = Tensor.ones(10).contiguous()
- x[idx] = b # index_put_
- ref = Tensor.ones(10).contiguous()
- ref[:2] = 0
- numpy_testing_assert_equal_helper(x, ref)
- out = x[idx] # index
- numpy_testing_assert_equal_helper(out, Tensor.zeros(2))
- '''
- def test_take_along_dim(self):
- def _test_against_numpy(t: Tensor, indices: Tensor, dim):
- actual = t.gather(dim, indices)
- t_np = t.numpy()
- indices_np = indices.numpy()
- expected = np.take_along_axis(t_np, indices_np, axis=dim)
- numpy_testing_assert_equal_helper(actual, expected)
- # TODO argsort
- '''
- for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]:
- for noncontiguous in [True, False]:
- for dtype in (dtypes.float32, dtypes.int64):
- t = make_tensor(shape, dtype=dtype, noncontiguous=noncontiguous)
- for dim in list(range(t.ndim)) + [None]:
- if dim is None:
- indices = argsort(t.reshape(-1))
- else:
- indices = argsort(t, dim=dim)
- _test_against_numpy(t, indices, dim)
- '''
- # test broadcasting
- t = Tensor.ones((3, 4, 1))
- indices = Tensor.ones((1, 2, 5), dtype=dtypes.int64)
- _test_against_numpy(t, indices, 1)
- # test empty indices
- t = Tensor.ones((3, 4, 5))
- indices = Tensor.ones((3, 0, 5), dtype=dtypes.int64)
- _test_against_numpy(t, indices, 1)
- # TODO argsort
- '''
- def test_take_along_dim_invalid(self):
- for dtype in (dtypes.int64, dtypes.float32):
- shape = (2, 3, 1, 4)
- dim = 0
- t = make_tensor(shape, dtype=dtype)
- indices = argsort(t, dim=dim)
- # dim of `t` and `indices` does not match
- with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"):
- t.gather(0, indices[0])
- # invalid `indices` dtype
- with self.assertRaises(RuntimeError):
- t.gather(0, indices.cast(dtypes.bool))
- with self.assertRaises(RuntimeError):
- t.gather(0, indices.cast(dtypes.float32))
- with self.assertRaises(RuntimeError):
- t.gather(0, indices.cast(dtypes.int32))
- # invalid axis
- with self.assertRaises(IndexError):
- t.gather(-7, indices)
- with self.assertRaises(IndexError):
- t.gather(7, indices)
- '''
- class TestNumpy(unittest.TestCase):
- def test_index_no_floats(self):
- a = Tensor([[[5.]]])
- self.assertRaises(IndexError, lambda: a[0.0])
- self.assertRaises(IndexError, lambda: a[0, 0.0])
- self.assertRaises(IndexError, lambda: a[0.0, 0])
- self.assertRaises(IndexError, lambda: a[0.0, :])
- self.assertRaises(IndexError, lambda: a[:, 0.0])
- self.assertRaises(IndexError, lambda: a[:, 0.0, :])
- self.assertRaises(IndexError, lambda: a[0.0, :, :])
- self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
- self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
- self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
- self.assertRaises(IndexError, lambda: a[-1.4])
- self.assertRaises(IndexError, lambda: a[0, -1.4])
- self.assertRaises(IndexError, lambda: a[-1.4, 0])
- self.assertRaises(IndexError, lambda: a[-1.4, :])
- self.assertRaises(IndexError, lambda: a[:, -1.4])
- self.assertRaises(IndexError, lambda: a[:, -1.4, :])
- self.assertRaises(IndexError, lambda: a[-1.4, :, :])
- self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
- self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
- self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
- self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
- self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
- def test_none_index(self):
- # `None` index adds newaxis
- a = Tensor([1, 2, 3])
- numpy_testing_assert_equal_helper(a[None].ndim, a.ndim+1)
- def test_empty_tuple_index(self):
- # Empty tuple index creates a view
- a = Tensor([1, 2, 3])
- numpy_testing_assert_equal_helper(a[()], a)
- self.assertEqual(data_ptr(a[()]), data_ptr(a))
- # TODO jax supports empty tensor indexing
- @unittest.skip("empty tensor indexing not supported")
- def test_empty_fancy_index(self):
- # Empty list index creates an empty array
- a = Tensor([1, 2, 3])
- numpy_testing_assert_equal_helper(a[[]], np.array([]))
- b = Tensor([]).cast(dtypes.int64)
- numpy_testing_assert_equal_helper(a[[]], np.array([]))
- b = Tensor([]).float()
- self.assertRaises(IndexError, lambda: a[b])
- def test_ellipsis_index(self):
- a = Tensor([[1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]])
- self.assertIsNot(a[...], a)
- numpy_testing_assert_equal_helper(a[...], a)
- # `a[...]` was `a` in numpy <1.9.
- numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
- # Slicing with ellipsis can skip an
- # arbitrary number of dimensions
- numpy_testing_assert_equal_helper(a[0, ...], a[0])
- numpy_testing_assert_equal_helper(a[0, ...], a[0, :])
- numpy_testing_assert_equal_helper(a[..., 0], a[:, 0])
- # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
- # we don't have separate 0-dim arrays and scalars.
- numpy_testing_assert_equal_helper(a[0, ..., 1], np.array(2))
- # Assignment with `(Ellipsis,)` on 0-d arrays
- b = np.array(1)
- b[(Ellipsis,)] = 2
- numpy_testing_assert_equal_helper(b, 2)
- def test_single_int_index(self):
- # Single integer index selects one row
- a = Tensor([[1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]])
- numpy_testing_assert_equal_helper(a[0], [1, 2, 3])
- numpy_testing_assert_equal_helper(a[-1], [7, 8, 9])
- self.assertRaises(IndexError, a.__getitem__, 1 << 30)
- self.assertRaises(IndexError, a.__getitem__, 1 << 64)
- @unittest.skip("bool indexing not supported")
- def test_single_bool_index(self):
- # Single boolean index
- a = Tensor([[1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]])
- numpy_testing_assert_equal_helper(a[True], a[None])
- numpy_testing_assert_equal_helper(a[False], a[None][0:0])
- @unittest.skip("bool indexing not supported")
- def test_boolean_shape_mismatch(self):
- arr = Tensor.ones((5, 4, 3))
- index = Tensor([True])
- self.assertRaises(IndexError, lambda: arr[index])
- index = Tensor([False] * 6)
- self.assertRaises(IndexError, lambda: arr[index])
- index = Tensor.zeros(4, 4, dtype=dtypes.uint8)
- self.assertRaises(IndexError, lambda: arr[index])
- self.assertRaises(IndexError, lambda: arr[(slice(None), index)])
- @unittest.skip("bool indexing not supported")
- def test_boolean_indexing_onedim(self):
- # Indexing a 2-dimensional array with
- # boolean array of length one
- a = Tensor([[0., 0., 0.]])
- b = Tensor([True])
- numpy_testing_assert_equal_helper(a[b], a)
- # boolean assignment
- a[b] = 1.
- numpy_testing_assert_equal_helper(a, Tensor([[1., 1., 1.]]))
- @unittest.skip("bool indexing not supported")
- def test_boolean_assignment_value_mismatch(self):
- # A boolean assignment should fail when the shape of the values
- # cannot be broadcast to the subscription. (see also gh-3458)
- a = Tensor.arange(0, 4)
- def f(a, v):
- a[a > -1] = Tensor(v)
- self.assertRaises(Exception, f, a, [])
- self.assertRaises(Exception, f, a, [1, 2, 3])
- self.assertRaises(Exception, f, a[:1], [1, 2, 3])
- @unittest.skip("bool indexing not supported")
- def test_boolean_indexing_twodim(self):
- # Indexing a 2-dimensional array with
- # 2-dimensional boolean array
- a = Tensor([[1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]])
- b = Tensor([[True, False, True],
- [False, True, False],
- [True, False, True]])
- numpy_testing_assert_equal_helper(a[b], Tensor([1, 3, 5, 7, 9]))
- numpy_testing_assert_equal_helper(a[b[1]], Tensor([[4, 5, 6]]))
- numpy_testing_assert_equal_helper(a[b[0]], a[b[2]])
- # boolean assignment
- a[b] = 0
- numpy_testing_assert_equal_helper(a, Tensor([[0, 2, 0],
- [4, 0, 6],
- [0, 8, 0]]))
- @unittest.skip("bool indexing not supported")
- def test_boolean_indexing_weirdness(self):
- # Weird boolean indexing things
- a = Tensor.ones((2, 3, 4))
- numpy_testing_assert_equal_helper((0, 2, 3, 4), a[False, True, ...].shape)
- numpy_testing_assert_equal_helper(Tensor.ones(1, 2), a[True, [0, 1], True, True, [1], [[2]]])
- self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
- @unittest.skip("bool indexing not supported")
- def test_boolean_indexing_weirdness_tensors(self):
- # Weird boolean indexing things
- false = Tensor(False)
- true = Tensor(True)
- a = Tensor.ones((2, 3, 4))
- numpy_testing_assert_equal_helper((0, 2, 3, 4), a[False, True, ...].shape)
- numpy_testing_assert_equal_helper(Tensor.ones(1, 2), a[true, [0, 1], true, true, [1], [[2]]])
- self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
- @unittest.skip("bool indexing not supported")
- def test_boolean_indexing_alldims(self):
- true = Tensor(True)
- a = Tensor.ones((2, 3))
- numpy_testing_assert_equal_helper((1, 2, 3), a[True, True].shape)
- numpy_testing_assert_equal_helper((1, 2, 3), a[true, true].shape)
- @unittest.skip("bool indexing not supported")
- def test_boolean_list_indexing(self):
- # Indexing a 2-dimensional array with
- # boolean lists
- a = Tensor([[1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]])
- b = [True, False, False]
- c = [True, True, False]
- numpy_testing_assert_equal_helper(a[b], Tensor([[1, 2, 3]]))
- numpy_testing_assert_equal_helper(a[b, b], Tensor([1]))
- numpy_testing_assert_equal_helper(a[c], Tensor([[1, 2, 3], [4, 5, 6]]))
- numpy_testing_assert_equal_helper(a[c, c], Tensor([1, 5]))
- def test_everything_returns_views(self):
- # Before `...` would return a itself.
- a = Tensor([5])
- self.assertIsNot(a, a[()])
- self.assertIsNot(a, a[...])
- self.assertIsNot(a, a[:])
- def test_broaderrors_indexing(self):
- a = Tensor.zeros(5, 5)
- self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
- # TODO: fancy setitem
- '''
- self.assertRaises(IndexError, a.contiguous().__setitem__, ([0, 1], [0, 1, 2]), 0)
- '''
- # TODO out of bound getitem does not raise error
- '''
- def test_trivial_fancy_out_of_bounds(self):
- a = Tensor.zeros(5)
- ind = Tensor.ones(20, dtype=dtypes.int64)
- ind[-1] = 10
- self.assertRaises(IndexError, a.__getitem__, ind)
- self.assertRaises(IndexError, a.__setitem__, ind, 0)
- ind = Tensor.ones(20, dtype=dtypes.int64)
- ind[0] = 11
- self.assertRaises(IndexError, a.__getitem__, ind)
- self.assertRaises(IndexError, a.__setitem__, ind, 0)
- '''
- # TODO fancy setitem
- '''
- def test_index_is_larger(self):
- # Simple case of fancy index broadcasting of the index.
- a = Tensor.zeros((5, 5))
- a[[[0], [1], [2]], [0, 1, 2]] = Tensor([2., 3., 4.])
- self.assertTrue((a[:3, :3] == all_(Tensor([2., 3., 4.]))))
- '''
- # TODO fancy setitem
- '''
- def test_broadcast_subspace(self):
- a = Tensor.zeros((100, 100))
- v = Tensor.arange(0., 100)[:, None]
- b = Tensor.arange(99, -1, -1).cast(dtypes.int64)
- a[b] = v
- expected = b.float().unsqueeze(1).expand(100, 100)
- numpy_testing_assert_equal_helper(a, expected)
- '''
- # TODO fancy setitem
- '''
- def test_truncate_leading_1s(self):
- col_max = Tensor.randn(1, 4)
- kernel = col_max.T * col_max # [4, 4] tensor
- kernel2 = clone(kernel)
- # Set the diagonal
- # len(torch.tensor) is just tensor.shape[0]
- kernel[range(kernel.shape[0]), range(kernel.shape[0])] = col_max.square()
- kernel2 = diagonal(kernel2)
- # torch.diagonal(kernel2).copy_(torch.square(col_max.view(4)))
- kernel2 = copy_(kernel2, col_max.reshape(4).square())
- numpy_testing_assert_equal_helper(kernel, kernel2)
- '''
- if __name__ == '__main__':
- unittest.main()
|