test_beam_search.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import unittest
  2. import numpy as np
  3. from tinygrad.helpers import BEAM, Timing, CI
  4. from tinygrad.shape.symbolic import Variable
  5. from tinygrad.tensor import Tensor
  6. from tinygrad.nn import Conv2d
  7. def rand(*shape):
  8. return Tensor(np.random.rand(*shape).astype(np.float32))
  9. class TestBeamSearch(unittest.TestCase):
  10. def setUp(self):
  11. self.old_beam = BEAM.value
  12. BEAM.value = 2
  13. def tearDown(self):
  14. BEAM.value = self.old_beam
  15. def test_variable_ast_beam(self):
  16. a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
  17. a = (a+1).realize()
  18. def test_big_prime_number(self):
  19. a = rand(367, 367)
  20. b = rand(367, 367)
  21. c = (a@b).realize()
  22. np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
  23. def test_big_prime_number_max(self):
  24. a = -rand(367, 367)
  25. b = rand(367, 367)
  26. # if incorrectly padded 0, the max would be 0 instead of a negative number
  27. c = (a*b).max(1)
  28. np.testing.assert_allclose(c.numpy(), (a.numpy() * b.numpy()).max(1), atol=1e-4, rtol=1e-4)
  29. def test_big_prime_number_sum(self):
  30. a = rand(367, 367)
  31. b = rand(367, 367)
  32. # if incorrectly padded 0, the sum would be inf
  33. c = (a/b).sum(1).realize()
  34. np.testing.assert_allclose(c.numpy(), (a.numpy() / b.numpy()).sum(1), atol=1e-4, rtol=1e-4)
  35. def test_variable_big_prime_number(self):
  36. v = Variable("v", 1, 400).bind(367)
  37. a = rand(367, 367)
  38. b = rand(367, 367)
  39. c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
  40. np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
  41. def test_variable_shrink_prime_number(self):
  42. v = Variable("v", 1, 400).bind(367)
  43. a = rand(400, 367)
  44. b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
  45. np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
  46. def test_no_mutate_rawbuffers(self):
  47. a = rand(3, 3).realize()
  48. desired = a.numpy() + 1
  49. a.assign(a+1)
  50. actual = a.numpy()
  51. np.testing.assert_allclose(actual, desired)
  52. @unittest.skipIf(CI, "flaky. CL_OUT_OF_RESOURCES")
  53. def test_conv_beam(self):
  54. c = Conv2d(3, 16, (3,3))
  55. x = rand(1,3,32,32)
  56. with Timing():
  57. c(x).realize()
  58. def test_large_ast(self):
  59. a = Tensor.rand(3, 3)
  60. for _ in range(5):
  61. for _ in range(4):
  62. a = (a + a) * a
  63. a.realize()
  64. if __name__ == '__main__':
  65. unittest.main()