test_dtype_alu.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import unittest
  2. from tinygrad import Tensor, dtypes, Device
  3. import operator
  4. import numpy as np
  5. from hypothesis import given, strategies as strat, settings
  6. from tinygrad.dtype import DType
  7. from tinygrad.helpers import CI, getenv
  8. from tinygrad.engine.schedule import create_schedule
  9. from tinygrad.engine.realize import run_schedule
  10. from tinygrad.ops import UnaryOps
  11. from tinygrad.tensor import _to_np_dtype
  12. from test.helpers import is_dtype_supported
  13. settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
  14. settings.load_profile("my_profile")
  15. print(settings.default)
  16. dtypes_float = (dtypes.float16, dtypes.float32, dtypes.float64)
  17. dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
  18. dtypes_bool = (dtypes.bool,)
  19. binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
  20. # TODO: LLVM comparing with nan is incorrect
  21. if Device.DEFAULT == "LLVM":
  22. binary_operations.remove(operator.lt)
  23. integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
  24. (Tensor.bitwise_or, np.bitwise_or)]
  25. unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
  26. (Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
  27. # TODO: enable this (this is a dtype issue)
  28. #binary_operations.append(operator.truediv)
  29. # TODO: enable mod on Tensor
  30. #binary_operations.append(operator.mod)
  31. # TODO: (a+b)/2 in tensor.py's maximum can overflow. This requires a new implementation of maximum that can be backpropagated
  32. #binary_operations += [(Tensor.maximum, np.maximum)]
  33. # TODO: CUDACPU segfaults on sin
  34. if getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"): unary_operations.remove((Tensor.sin, np.sin))
  35. class ht:
  36. float64 = strat.floats(width=64, allow_subnormal=False)
  37. float32 = strat.floats(width=32, allow_subnormal=False)
  38. float16 = strat.floats(width=16, allow_subnormal=False)
  39. uint8 = strat.integers(0, 255)
  40. uint16 = strat.integers(0, 65535)
  41. uint32 = strat.integers(0, 2**32-1)
  42. uint64 = strat.integers(0, 2**64-1)
  43. int8 = strat.integers(-128, 127)
  44. int16 = strat.integers(-32768, 32767)
  45. int32 = strat.integers(-2147483648, 2147483647)
  46. int64 = strat.integers(-9223372036854775808, 9223372036854775807)
  47. bool = strat.booleans()
  48. def universal_test(a, b, dtype, op):
  49. if not isinstance(op, tuple): op = (op, op)
  50. tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
  51. numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype)))
  52. if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
  53. else: np.testing.assert_equal(tensor_value, numpy_value)
  54. def universal_test_unary(a, dtype, op):
  55. if not isinstance(op, tuple): op = (op, op)
  56. out: Tensor = op[0](Tensor([a], dtype=dtype))
  57. sched = create_schedule([out.lazydata])
  58. ast = sched[-1].ast
  59. run_schedule(sched)
  60. tensor_value = out.numpy()
  61. numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)))
  62. if dtype in dtypes_float:
  63. np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
  64. else: np.testing.assert_equal(tensor_value, numpy_value)
  65. if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
  66. op = [x for x in ast.lazyops if x.op in UnaryOps][0]
  67. assert op.dtype == dtype
  68. def universal_test_cast(a, in_dtype, dtype):
  69. tensor_value = Tensor([a], dtype=in_dtype).cast(dtype)
  70. numpy_value = np.array([a]).astype(_to_np_dtype(dtype))
  71. np.testing.assert_equal(tensor_value.numpy(), numpy_value)
  72. def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
  73. if not isinstance(op1, tuple): op1 = (op1, op1)
  74. if not isinstance(op2, tuple): op2 = (op2, op2)
  75. at, bt, ct = Tensor([a], dtype=d1), Tensor([b], dtype=d1), Tensor([c], dtype=d2)
  76. an, bn, cn = np.array([a]).astype(_to_np_dtype(d1)), np.array([b]).astype(_to_np_dtype(d1)), np.array([c]).astype(_to_np_dtype(d2))
  77. tensor_value = op2[0](op1[0](at, bt).cast(d2), ct).numpy()
  78. numpy_value = op2[1](op1[1](an, bn).astype(_to_np_dtype(d2)), cn)
  79. np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7)
  80. class TestDTypeALU(unittest.TestCase):
  81. @unittest.skipUnless(is_dtype_supported(dtypes.float64, Device.DEFAULT), f"no float64 on {Device.DEFAULT}")
  82. @given(ht.float64, ht.float64, strat.sampled_from(binary_operations))
  83. def test_float64(self, a, b, op): universal_test(a, b, dtypes.float64, op)
  84. @given(ht.float32, ht.float32, strat.sampled_from(binary_operations))
  85. def test_float32(self, a, b, op): universal_test(a, b, dtypes.float32, op)
  86. @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
  87. @given(ht.float16, ht.float16, strat.sampled_from(binary_operations))
  88. def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
  89. @given(ht.float32, strat.sampled_from(unary_operations))
  90. def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
  91. @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
  92. @given(ht.float16, strat.sampled_from(unary_operations))
  93. def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
  94. @given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
  95. def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
  96. @unittest.skipUnless(is_dtype_supported(dtypes.uint16, Device.DEFAULT), f"no uint16 on {Device.DEFAULT}")
  97. @given(ht.uint16, ht.uint16, strat.sampled_from(integer_binary_operations))
  98. def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op)
  99. @unittest.skipUnless(is_dtype_supported(dtypes.uint32, Device.DEFAULT), f"no uint32 on {Device.DEFAULT}")
  100. @given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
  101. def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
  102. @unittest.skipUnless(is_dtype_supported(dtypes.uint64, Device.DEFAULT), f"no uint64 on {Device.DEFAULT}")
  103. @given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
  104. def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
  105. @given(ht.int8, ht.int8, strat.sampled_from(integer_binary_operations))
  106. def test_int8(self, a, b, op): universal_test(a, b, dtypes.int8, op)
  107. @given(ht.int16, ht.int16, strat.sampled_from(integer_binary_operations))
  108. def test_int16(self, a, b, op): universal_test(a, b, dtypes.int16, op)
  109. @given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
  110. def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
  111. @given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
  112. def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
  113. @given(ht.bool, ht.bool, strat.sampled_from(((operator.add, operator.add), (operator.mul, operator.mul))))
  114. def test_bool(self, a, b, op): universal_test(a, b, dtypes.bool, op)
  115. @given(ht.int32, ht.int32, ht.float32, strat.sampled_from(integer_binary_operations), strat.sampled_from(binary_operations))
  116. def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
  117. # Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
  118. skip_overflow = CI and (Device.DEFAULT in {"AMD", "NV"} or getenv("CUDACPU"))
  119. @given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
  120. strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
  121. ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
  122. @unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON")
  123. def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
  124. @unittest.skip("broken. TODO: fix it")
  125. @given(ht.float32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
  126. def test_float_cast(self, a, dtype): universal_test_cast(a, dtypes.float32, dtype)
  127. @unittest.skip("broken. TODO: fix it")
  128. @given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
  129. def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
  130. class TestFromFuzzer(unittest.TestCase):
  131. @given(strat.sampled_from(dtypes_float))
  132. def test_sin(self, dtype):
  133. if not is_dtype_supported(dtype): return
  134. if dtype == dtypes.float64:
  135. # crashes in CUDACPU
  136. if (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")): return
  137. def _test_value(n: float, unit: float=1.0):
  138. next_float = np.nextafter(1.0, 2.0, dtype=_to_np_dtype(dtype))
  139. ulp = next_float - 1.0
  140. ulp = unit * ulp
  141. np.testing.assert_allclose(Tensor([n], dtype=dtype).sin().numpy(), np.sin(np.array([n], dtype=_to_np_dtype(dtype))), atol=ulp, rtol=1e-5)
  142. _test_value(-35.0)
  143. _test_value(-25.0)
  144. _test_value(25.0)
  145. _test_value(30.0) # 30.0 == switch_over
  146. _test_value(35.0)
  147. _test_value(0.0)
  148. _test_value(np.pi / 2)
  149. # worst case of ulp 1.5
  150. _test_value(np.pi * 2, unit=1.5)
  151. @given(strat.sampled_from(dtypes_float))
  152. def test_log2(self, dtype):
  153. if not is_dtype_supported(dtype): return
  154. if dtype == dtypes.float64:
  155. # crashes in CUDACPU
  156. if (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")): return
  157. def _test_value(n: float, unit: float=1.0):
  158. next_float = np.nextafter(1.0, 2.0, dtype=_to_np_dtype(dtype))
  159. ulp = next_float - 1.0
  160. ulp = unit * ulp
  161. np.testing.assert_allclose(Tensor([n], dtype=dtype).log2().numpy(), np.log2(np.array([n], dtype=_to_np_dtype(dtype))), atol=ulp, rtol=1e-5)
  162. fmin = np.finfo(_to_np_dtype(dtype)).tiny
  163. for scale in [1.0, 1e10, 1e20, 1e30]:
  164. _test_value(fmin * scale)
  165. _test_value(-fmin * scale)
  166. _test_value(0)
  167. if __name__ == '__main__':
  168. unittest.main()