test_fuzz_shape_ops.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from __future__ import annotations
  2. import unittest
  3. from math import prod
  4. from hypothesis import assume, given, settings, strategies as st
  5. from hypothesis.extra import numpy as stn
  6. import numpy as np
  7. import torch
  8. from tinygrad import Tensor, Device
  9. from tinygrad.helpers import CI, getenv
  10. settings.register_profile(__file__, settings.default,
  11. max_examples=100 if CI else 250, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
  12. # torch wraparound for large numbers
  13. st_int32 = st.integers(-2147483648, 2147483647)
  14. @st.composite
  15. def st_shape(draw) -> tuple[int, ...]:
  16. s = draw(stn.array_shapes(min_dims=0, max_dims=6,
  17. min_side=0, max_side=512))
  18. assume(prod(s) <= 1024 ** 2)
  19. assume(prod([d for d in s if d]) <= 1024 ** 4)
  20. return s
  21. def tensors_for_shape(s:tuple[int, ...]) -> tuple[torch.tensor, Tensor]:
  22. x = np.arange(prod(s)).reshape(s)
  23. return torch.from_numpy(x), Tensor(x)
  24. def apply(tor, ten, tor_fn, ten_fn=None):
  25. ok = True
  26. try: tor = tor_fn(tor)
  27. except: tor, ok = None, not ok # noqa: E722
  28. try: ten = ten_fn(ten) if ten_fn is not None else tor_fn(ten)
  29. except: ten, ok = None, not ok # noqa: E722
  30. return tor, ten, ok
  31. @unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow")
  32. class TestShapeOps(unittest.TestCase):
  33. @settings.get_profile(__file__)
  34. @given(st_shape(), st_int32, st.one_of(st_int32, st.lists(st_int32)))
  35. def test_split(self, s:tuple[int, ...], dim:int, sizes:int|list[int]):
  36. tor, ten = tensors_for_shape(s)
  37. tor, ten, ok = apply(tor, ten, lambda t: t.split(sizes, dim))
  38. assert ok
  39. if tor is None and ten is None: return
  40. assert len(tor) == len(ten)
  41. assert all([np.array_equal(tor.numpy(), ten.numpy()) for (tor, ten) in zip(tor, ten)])
  42. @settings.get_profile(__file__)
  43. @given(st_shape(), st_int32, st_int32)
  44. def test_chunk(self, s:tuple[int, ...], dim:int, num:int):
  45. # chunking on a 0 dim is cloning and leads to OOM if done unbounded.
  46. assume((0 <= (actual_dim := len(s)-dim if dim < 0 else dim) < len(s) and s[actual_dim] > 0) or
  47. (num < 32))
  48. tor, ten = tensors_for_shape(s)
  49. tor, ten, ok = apply(tor, ten, lambda t: t.chunk(num, dim))
  50. assert ok
  51. if tor is None and ten is None: return
  52. assert len(tor) == len(ten)
  53. assert all([np.array_equal(tor.numpy(), ten.numpy()) for (tor, ten) in zip(tor, ten)])
  54. @settings.get_profile(__file__)
  55. @given(st_shape(), st_int32)
  56. def test_squeeze(self, s:tuple[int, ...], dim:int):
  57. tor, ten = tensors_for_shape(s)
  58. tor, ten, ok = apply(tor, ten, lambda t: t.squeeze(dim))
  59. assert ok
  60. if tor is None and ten is None: return
  61. assert np.array_equal(tor.numpy(), ten.numpy())
  62. @settings.get_profile(__file__)
  63. @given(st_shape(), st_int32)
  64. def test_unsqueeze(self, s:tuple[int, ...], dim:int):
  65. tor, ten = tensors_for_shape(s)
  66. tor, ten, ok = apply(tor, ten, lambda t: t.unsqueeze(dim))
  67. assert ok
  68. if tor is None and ten is None: return
  69. assert np.array_equal(tor.numpy(), ten.numpy())
  70. if __name__ == '__main__':
  71. unittest.main()