test_subbuffer.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import unittest
  2. from tinygrad import Device, dtypes, Tensor
  3. from tinygrad.helpers import getenv
  4. from tinygrad.device import Buffer
  5. from tinygrad.lazy import view_supported_devices
  6. @unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
  7. class TestSubBuffer(unittest.TestCase):
  8. def setUp(self):
  9. self.buf = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated()
  10. self.buf.copyin(memoryview(bytearray(range(10))))
  11. def test_subbuffer(self):
  12. vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated()
  13. tst = vbuf.as_buffer().tolist()
  14. assert tst == [3, 4]
  15. def test_subbuffer_cast(self):
  16. # NOTE: bitcast depends on endianness
  17. vbuf = self.buf.view(2, dtypes.uint16, offset=3).ensure_allocated()
  18. tst = vbuf.as_buffer().cast("H").tolist()
  19. assert tst == [3|(4<<8), 5|(6<<8)]
  20. def test_subbuffer_double(self):
  21. vbuf = self.buf.view(4, dtypes.uint8, offset=3).ensure_allocated()
  22. vvbuf = vbuf.view(2, dtypes.uint8, offset=1).ensure_allocated()
  23. tst = vvbuf.as_buffer().tolist()
  24. assert tst == [4, 5]
  25. def test_subbuffer_len(self):
  26. vbuf = self.buf.view(5, dtypes.uint8, 2).ensure_allocated()
  27. mv = vbuf.as_buffer()
  28. assert len(mv) == 5
  29. mv = vbuf.as_buffer(allow_zero_copy=True)
  30. assert len(mv) == 5
  31. def test_subbuffer_used(self):
  32. t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
  33. # TODO: why does it needs contiguous
  34. vt = t[2:4].contiguous().realize()
  35. out = (vt + 100).tolist()
  36. assert out == [102, 103]
  37. @unittest.skipIf(Device.DEFAULT not in {"CUDA", "NV", "AMD"} or getenv("CUDACPU"), "only NV, AMD, CUDA but not CUDACPU")
  38. def test_subbuffer_transfer(self):
  39. t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
  40. vt = t[2:5].contiguous().realize()
  41. out = vt.to(f"{Device.DEFAULT}:1").realize().tolist()
  42. assert out == [2, 3, 4]
  43. if __name__ == '__main__':
  44. unittest.main()