test_sample.py 658 B

1234567891011121314151617181920
  1. import unittest
  2. import numpy as np
  3. from tinygrad.tensor import Tensor
  4. from tinygrad.shape.symbolic import Variable
  5. class TestSample(unittest.TestCase):
  6. def test_sample(self):
  7. X = Tensor.rand(10000, 50).realize()
  8. BS = 16
  9. idxs = np.random.randint(0, X.shape[0], size=(BS))
  10. # this uncovered a bug with arg sort order
  11. batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
  12. x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
  13. print(idxs)
  14. ret = x.numpy()
  15. base = X.numpy()[idxs]
  16. np.testing.assert_equal(ret, base)
  17. if __name__ == '__main__':
  18. unittest.main()