test_transcendental.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import unittest
  2. from tinygrad import Tensor, Device, dtypes
  3. from tinygrad.tensor import _to_np_dtype
  4. from tinygrad.helpers import Context, getenv
  5. from test.test_schedule import check_schedule
  6. from test.test_dtype_alu import ht
  7. from test.helpers import is_dtype_supported
  8. import numpy as np
  9. from hypothesis import given, settings, strategies as strat
  10. settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
  11. settings.load_profile("my_profile")
  12. class TestTranscendentalMath(unittest.TestCase):
  13. @unittest.skipUnless(is_dtype_supported(dtypes.float64, Device.DEFAULT), f"no float64 on {Device.DEFAULT}")
  14. @unittest.skipIf(getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"), "crashed")
  15. @given(ht.float64, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
  16. def test_float64(self, x, op):
  17. if op[0] == Tensor.sin:
  18. # TODO: reduction does not work # 536870912.125 # 2914593.01171875 # 134217728.03125
  19. if abs(x) > 536870912: return
  20. with Context(TRANSCENDENTAL=2):
  21. np.testing.assert_allclose(op[0](Tensor([x], dtype=dtypes.float64)).numpy(),
  22. op[1](np.array([x], dtype=_to_np_dtype(dtypes.float64))),
  23. atol=3e-2, rtol=1e-5) # sin can have bigger atol for very big x
  24. @unittest.skipIf(getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"), "crashed")
  25. @given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
  26. def test_float32(self, x, op):
  27. with Context(TRANSCENDENTAL=2):
  28. np.testing.assert_allclose(op[0](Tensor([x], dtype=dtypes.float32)).numpy(),
  29. op[1](np.array([x], dtype=_to_np_dtype(dtypes.float32))),
  30. atol=2e-5, rtol=1e-5)
  31. @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
  32. @given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
  33. def test_float16(self, x, op):
  34. with Context(TRANSCENDENTAL=2):
  35. np.testing.assert_allclose(op[0](Tensor([x], dtype=dtypes.float16)).numpy(),
  36. op[1](np.array([x], dtype=_to_np_dtype(dtypes.float16))),
  37. atol=1e-2, rtol=4e-3) # exp can have bigger rtol
  38. class TestTranscendentalSchedule(unittest.TestCase):
  39. # w/ payne_hanek_reduction (fp32)
  40. def test_transcendental_sin_fusion(self):
  41. with Context(TRANSCENDENTAL=2):
  42. a = Tensor.empty(10)
  43. b = Tensor.empty(10)
  44. c = a.sin() + b.sin()
  45. c = c.sin()
  46. check_schedule(c, 1)
  47. def test_transcendental_log2_fusion(self):
  48. with Context(TRANSCENDENTAL=2):
  49. a = Tensor.empty(10)
  50. b = Tensor.empty(10)
  51. c = a.log2() + b.log2()
  52. c = c.log2()
  53. check_schedule(c, 1)
  54. def test_transcendental_exp2_fusion(self):
  55. with Context(TRANSCENDENTAL=2):
  56. a = Tensor.empty(10)
  57. b = Tensor.empty(10)
  58. c = a.exp2() + b.exp2()
  59. c = c.exp2()
  60. check_schedule(c, 1)
  61. if __name__ == '__main__':
  62. unittest.main()