test_arange.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import unittest
  2. import numpy as np
  3. from tinygrad import Tensor, GlobalCounters, dtypes
  4. from tinygrad.helpers import Context, getenv
  5. from tinygrad.engine.realize import run_schedule
  6. class TestArange(unittest.TestCase):
  7. def _get_flops(self, N):
  8. GlobalCounters.reset()
  9. with Context(NOOPT=1):
  10. Tensor.arange(N).realize()
  11. return GlobalCounters.global_ops
  12. def test_complexity(self):
  13. # add 1 to avoid divide by 0. arange is 0 flops now!
  14. f1 = self._get_flops(256) + 1
  15. f2 = self._get_flops(2560) + 1
  16. print(f"{f1=}, {f2=}")
  17. assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
  18. class TestIndexing(unittest.TestCase):
  19. def test_arange_2_reduce(self):
  20. needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
  21. needle[1337] = 1
  22. needle.realize()
  23. with Context(NOOPT=1):
  24. GlobalCounters.reset()
  25. # TODO: it should work without these reshapes
  26. out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
  27. sched = out.schedule()
  28. assert len(sched) == 1
  29. run_schedule(sched)
  30. assert out.item() == 1337, f"expected 1337, got {out.item()}"
  31. @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
  32. def test_manual_index(self):
  33. dataset = Tensor.rand(16384, 256).realize()
  34. idxs = Tensor([0,3,5,6]).realize()
  35. real_index = dataset.numpy()[idxs.numpy()]
  36. print("*** indexing ***")
  37. with Context(NOOPT=1):
  38. GlobalCounters.reset()
  39. rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
  40. idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
  41. reshape_dataset = dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)
  42. full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
  43. X = full.sum(axis=(2,3))
  44. sched = X.schedule()
  45. assert len(sched) == 1
  46. run_schedule(sched)
  47. assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
  48. np.testing.assert_allclose(real_index, X.numpy())
  49. def test_index(self):
  50. dataset = Tensor.rand(16384, 256).realize()
  51. idxs = Tensor([0,3,5,6]).realize()
  52. real_index = dataset.numpy()[idxs.numpy()]
  53. print("*** indexing ***")
  54. with Context(NOOPT=1):
  55. GlobalCounters.reset()
  56. X = dataset[idxs]
  57. assert X.shape == (4,256)
  58. sched = X.schedule()
  59. # TODO: enable these asserts when the scheduler can handle this
  60. #assert len(sched) == 1, f"{len(sched)} != 1"
  61. run_schedule(sched)
  62. #assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
  63. np.testing.assert_allclose(real_index, X.numpy())
  64. # TODO: AssertionError: ReduceOps late fusion must be contiguous
  65. @unittest.expectedFailure
  66. def test_index_fused(self):
  67. dataset = Tensor.rand(16384, 256).realize()
  68. idxs = Tensor([0,3,5,6]).realize()
  69. real_index = dataset.numpy()[idxs.numpy()]
  70. print("*** indexing ***")
  71. with Context(NOOPT=1):
  72. GlobalCounters.reset()
  73. X = dataset[idxs]
  74. assert X.shape == (4,256)
  75. sched = X.schedule()
  76. assert len(sched) == 1, f"{len(sched)} != 1"
  77. run_schedule(sched)
  78. assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
  79. np.testing.assert_allclose(real_index, X.numpy())
  80. if __name__ == "__main__":
  81. unittest.main()