test_dtype.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. import unittest, operator, subprocess, math
  2. import numpy as np
  3. import torch
  4. from typing import Any, List
  5. from tinygrad.helpers import getenv, DEBUG, CI
  6. from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
  7. from tinygrad import Device, Tensor, dtypes
  8. from tinygrad.tensor import _to_np_dtype
  9. from hypothesis import given, settings, strategies as strat
  10. from test.helpers import is_dtype_supported, rand_for_dtype
  11. settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
  12. settings.load_profile("my_profile")
  13. core_dtypes = list(DTYPES_DICT.values())
  14. if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
  15. dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
  16. dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
  17. def get_available_cast_dtypes(dtype: DType) -> List[DType]:
  18. if not is_dtype_supported(dtype): return []
  19. return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes
  20. def _test_to_np(a:Tensor, np_dtype, target):
  21. if DEBUG >= 2: print(a)
  22. na = a.numpy()
  23. if DEBUG >= 2: print(na, na.dtype, a.lazydata.base.realized)
  24. try:
  25. assert na.dtype == np_dtype
  26. np.testing.assert_allclose(na, target)
  27. except AssertionError as e:
  28. raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e
  29. def _assert_eq(tensor:Tensor, target_dtype:DType, target):
  30. if DEBUG >= 2: print(tensor.numpy())
  31. try:
  32. assert tensor.dtype == target_dtype
  33. np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, 1e-7))
  34. except AssertionError as e:
  35. raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
  36. def _test_op(fxn, target_dtype:DType, target):
  37. _assert_eq(fxn(), target_dtype, target)
  38. def _test_cast(a:Tensor, target_dtype:DType):
  39. if a.is_floating_point() and dtypes.is_unsigned(target_dtype):
  40. # converting negative float to unsigned integer is undefined
  41. a = a.abs()
  42. if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON":
  43. # TODO: struct.pack cannot pack value > 65504 (max of half) into e format
  44. a = (a > 65504).where(65504, a)
  45. if CI and Device.DEFAULT == "CLANG" and (target_dtype, a.dtype) in [(dtypes.double, dtypes.half), (dtypes.half, dtypes.double)]:
  46. # TODO: cast between double and half are broken https://github.com/tinygrad/tinygrad/issues/4084
  47. return
  48. _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
  49. def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
  50. if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
  51. _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
  52. class TestDType(unittest.TestCase):
  53. DTYPE: Any = None
  54. DATA: Any = None
  55. @classmethod
  56. def setUpClass(cls):
  57. if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
  58. cls.DATA = rand_for_dtype(cls.DTYPE, 10)
  59. def setUp(self):
  60. if self.DTYPE is None: raise unittest.SkipTest("base class")
  61. def test_to_np(self):
  62. _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
  63. def test_casts_to(self): list(map(
  64. lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
  65. get_available_cast_dtypes(self.DTYPE)
  66. ))
  67. def test_casts_from(self): list(map(
  68. lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
  69. get_available_cast_dtypes(self.DTYPE)
  70. ))
  71. def test_same_size_ops(self):
  72. list(map(
  73. lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
  74. get_available_cast_dtypes(self.DTYPE)
  75. ))
  76. def test_upcast_ops(self):
  77. list(map(
  78. lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
  79. get_available_cast_dtypes(self.DTYPE)
  80. ))
  81. def test_upcast_to_ops(self):
  82. list(map(
  83. lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
  84. get_available_cast_dtypes(self.DTYPE)
  85. ))
  86. def test_bitcast(self):
  87. if Device.DEFAULT == "WEBGL": raise unittest.SkipTest("no bitcast in WebGL GLSL")
  88. if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
  89. list(map(
  90. lambda dtype:
  91. _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
  92. get_available_cast_dtypes(self.DTYPE)
  93. ))
  94. def test_dtypes_fields(self):
  95. fields = dtypes.fields()
  96. self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
  97. self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
  98. def test_resulting_and_init_dtypes_match(self):
  99. dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
  100. data = [1., 2., 0., 0.5, -1.5, 5.25]
  101. for dt in dtypes:
  102. arr = np.asarray(data).astype(dt)
  103. tin = Tensor(arr).numpy()
  104. tor = torch.as_tensor(arr).detach().numpy()
  105. assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
  106. np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
  107. def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
  108. target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
  109. if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return
  110. if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
  111. _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
  112. _assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2])
  113. _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
  114. _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
  115. _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
  116. @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
  117. class TestBFloat16(unittest.TestCase):
  118. def test_bf16_creation_numpy(self):
  119. data = [-1, 1, 2]
  120. t = Tensor(data, dtype=dtypes.bfloat16)
  121. assert t.dtype == dtypes.bfloat16
  122. tnp = t.numpy()
  123. assert tnp.dtype == np.float32
  124. np.testing.assert_allclose(tnp, np.array(data))
  125. def test_bf16_ones(self):
  126. t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
  127. assert t.dtype == dtypes.bfloat16
  128. np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
  129. def test_bf16_eye(self):
  130. t = Tensor.eye(3, dtype=dtypes.bfloat16)
  131. assert t.dtype == dtypes.bfloat16
  132. np.testing.assert_allclose(t.numpy(), np.eye(3))
  133. @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
  134. class TestBFloat16DType(unittest.TestCase):
  135. def test_bf16_to_float(self):
  136. _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
  137. def test_float_to_bf16(self):
  138. _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
  139. def test_bf16(self):
  140. t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
  141. t.realize()
  142. back = t.cast(dtypes.float32)
  143. assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
  144. @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
  145. class TestBFloat16DTypeCast(unittest.TestCase):
  146. def test_f16_to_bf16_conversion(self):
  147. original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
  148. converted_tensor = original_tensor.cast(dtypes.bfloat16)
  149. self.assertEqual(converted_tensor.dtype, dtypes.bfloat16)
  150. back_to_float32 = converted_tensor.cast(dtypes.float32)
  151. original_to_float32 = original_tensor.cast(dtypes.float32)
  152. np.testing.assert_allclose(back_to_float32.numpy(), original_to_float32.numpy(), rtol=1e-2, atol=1e-3)
  153. def test_f16_to_bf16_edge_cases(self):
  154. edge_cases = Tensor([0.0, -0.0, float('inf'), float('-inf'), float('nan')], dtype=dtypes.float16)
  155. converted = edge_cases.cast(dtypes.bfloat16).cast(dtypes.float32)
  156. np.testing.assert_equal(converted.numpy(), edge_cases.cast(dtypes.float32).numpy())
  157. def test_f16_to_bf16_range_precision(self):
  158. large_value = Tensor([65504.0], dtype=dtypes.float16) # Max representable in float16
  159. small_value = Tensor([6.1035e-5], dtype=dtypes.float16) # Smallest positive normal float16
  160. large_converted = large_value.cast(dtypes.bfloat16).cast(dtypes.float32)
  161. small_converted = small_value.cast(dtypes.bfloat16).cast(dtypes.float32)
  162. np.testing.assert_allclose(large_converted.numpy(), large_value.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
  163. np.testing.assert_equal(small_converted.numpy(), small_value.cast(dtypes.float32).numpy())
  164. def test_f16_to_bf16_randomized(self):
  165. np.random.seed(42) # For reproducibility
  166. random_values = Tensor(np.random.uniform(-65504, 65504, 1000), dtype=dtypes.float16)
  167. converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32)
  168. np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
  169. class TestHalfDType(TestDType): DTYPE = dtypes.half
  170. class TestFloatDType(TestDType):
  171. DTYPE = dtypes.float
  172. def test_float_to_uint(self):
  173. _test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32,
  174. [0, 0, 1])
  175. class TestDoubleDType(TestDType):
  176. DTYPE = dtypes.double
  177. @unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or getenv("PTX"), "conversion not supported on CUDACPU and PTX") # TODO: why not?
  178. def test_float64_increased_precision(self):
  179. for func in [
  180. lambda t: t.exp(),
  181. lambda t: t.exp2(),
  182. lambda t: t.log(),
  183. lambda t: t.log2(),
  184. lambda t: t.sqrt(),
  185. lambda t: t.rsqrt(),
  186. lambda t: t.sin(),
  187. lambda t: t.cos(),
  188. lambda t: t.tan(),
  189. lambda t: t.sigmoid(),
  190. ]:
  191. a = [2, 3, 4]
  192. np.testing.assert_allclose(func(Tensor(a, dtype=self.DTYPE)).numpy(), func(torch.tensor(a, dtype=torch.float64)), rtol=1e-12, atol=1e-12)
  193. def test_float64_to_float32_cast_inf(self):
  194. _test_op(lambda: Tensor([3.4e40, 3.4e38, 1, 0], dtype=dtypes.float64).cast(dtypes.float32),
  195. dtypes.float32, [float('inf'), 3.4e38, 1, 0])
  196. class TestInt8DType(TestDType):
  197. DTYPE = dtypes.int8
  198. @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
  199. def test_int8_to_uint8_negative(self):
  200. _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
  201. def test_int8_to_uint16_negative(self):
  202. _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4])
  203. class TestUint8DType(TestDType):
  204. DTYPE = dtypes.uint8
  205. @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
  206. def test_uint8_to_int8_overflow(self):
  207. _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
  208. @unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL")
  209. class TestBitCast(unittest.TestCase):
  210. def test_shape_change_bitcast(self):
  211. with self.assertRaises(RuntimeError):
  212. _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
  213. def test_bitcast_float_to_int32(self):
  214. a = Tensor([1.,2,3])
  215. b = a.bitcast(dtypes.int32)
  216. assert b.numpy()[0] == 0x3f800000
  217. def test_bitcast_upcasted(self):
  218. a = Tensor.zeros(100, 4, dtype=dtypes.int32).contiguous() + 0x3f800000
  219. b = a.bitcast(dtypes.float32)
  220. assert b.numpy()[0,0] == 1.
  221. class TestInt16DType(TestDType): DTYPE = dtypes.int16
  222. class TestUint16DType(TestDType):
  223. DTYPE = dtypes.uint16
  224. def test_uint16_to_int8_overflow(self):
  225. _test_op(lambda: Tensor([2**16-1, 2**16-2, 1, 0], dtype=dtypes.uint16).cast(dtypes.int8), dtypes.int8, [-1, -2, 1, 0])
  226. class TestInt32DType(TestDType): DTYPE = dtypes.int32
  227. class TestUint32DType(TestDType): DTYPE = dtypes.uint32
  228. class TestInt64DType(TestDType): DTYPE = dtypes.int64
  229. class TestUint64DType(TestDType): DTYPE = dtypes.uint64
  230. class TestBoolDType(TestDType): DTYPE = dtypes.bool
  231. class TestImageDType(unittest.TestCase):
  232. def test_image_scalar(self):
  233. assert dtypes.imagef((10,10)).scalar() == dtypes.float32
  234. assert dtypes.imageh((10,10)).scalar() == dtypes.float32
  235. def test_image_vec(self):
  236. assert dtypes.imagef((10,10)).vec(4) == dtypes.float32.vec(4)
  237. assert dtypes.imageh((10,10)).vec(4) == dtypes.float32.vec(4)
  238. class TestEqStrDType(unittest.TestCase):
  239. def test_image_ne(self):
  240. if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
  241. assert dtypes.float == dtypes.float32, "float doesn't match?"
  242. assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
  243. assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
  244. assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
  245. assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
  246. def test_ptr_ne(self):
  247. if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
  248. # TODO: is this the wrong behavior?
  249. assert PtrDType(dtypes.float32) == dtypes.float32
  250. assert not (PtrDType(dtypes.float32) != dtypes.float32)
  251. assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
  252. assert not (PtrDType(dtypes.float32) != PtrDType(dtypes.float32))
  253. #assert PtrDType(dtypes.float32) != dtypes.float32
  254. def test_strs(self):
  255. if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
  256. self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
  257. self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
  258. class TestHelpers(unittest.TestCase):
  259. signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
  260. uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
  261. floats = (dtypes.float16, dtypes.float32, dtypes.float64)
  262. @given(strat.sampled_from(signed_ints+uints), strat.integers(min_value=1, max_value=8))
  263. def test_is_int(self, dtype, amt):
  264. assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
  265. assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
  266. @given(strat.sampled_from(uints), strat.integers(min_value=1, max_value=8))
  267. def test_is_unsigned_uints(self, dtype, amt):
  268. assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
  269. @given(strat.sampled_from(signed_ints), strat.integers(min_value=1, max_value=8))
  270. def test_is_unsigned_signed_ints(self, dtype, amt):
  271. assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
  272. @given(strat.sampled_from(floats), strat.integers(min_value=1, max_value=8))
  273. def test_is_float(self, dtype, amt):
  274. assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
  275. assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
  276. assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
  277. def test_bf16_is_float(self):
  278. assert dtypes.is_float(dtypes.bfloat16)
  279. @given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8))
  280. def test_scalar(self, dtype, amt):
  281. assert dtype.vec(amt).scalar() == dtype
  282. def test_from_py(self):
  283. assert dtypes.from_py(True) == dtypes.bool
  284. assert dtypes.from_py(2) == dtypes.default_int
  285. assert dtypes.from_py(3.0) == dtypes.default_float
  286. assert dtypes.from_py([]) == dtypes.default_float
  287. assert dtypes.from_py(()) == dtypes.default_float
  288. assert dtypes.from_py([True]) == dtypes.bool
  289. assert dtypes.from_py([True, 2]) == dtypes.default_int
  290. assert dtypes.from_py([True, 3.0]) == dtypes.default_float
  291. assert dtypes.from_py([2, 3.0]) == dtypes.default_float
  292. assert dtypes.from_py([True, 2, 3.0]) == dtypes.default_float
  293. with self.assertRaises(RuntimeError): dtypes.from_py(None)
  294. with self.assertRaises(RuntimeError): dtypes.from_py([None])
  295. with self.assertRaises(RuntimeError): dtypes.from_py({})
  296. with self.assertRaises(RuntimeError): dtypes.from_py(set())
  297. def test_dtype_range(self):
  298. for dt in core_dtypes:
  299. if dtypes.is_float(dt):
  300. np.testing.assert_equal(dtypes.min(dt), -math.inf)
  301. np.testing.assert_equal(dtypes.max(dt), math.inf)
  302. elif dtypes.is_int(dt):
  303. info = np.iinfo(_to_np_dtype(dt))
  304. np.testing.assert_equal(dtypes.min(dt), info.min)
  305. np.testing.assert_equal(dtypes.max(dt), info.max)
  306. else:
  307. assert dt == dtypes.bool, dt
  308. np.testing.assert_equal(dtypes.min(dt), False)
  309. np.testing.assert_equal(dtypes.max(dt), True)
  310. class TestTypeSpec(unittest.TestCase):
  311. def setUp(self):
  312. self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
  313. def tearDown(self):
  314. dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
  315. def test_set_dtype_default(self):
  316. for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
  317. dtypes.default_int = default_int
  318. assert dtypes.default_int == default_int
  319. for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
  320. dtypes.default_float = default_float
  321. assert dtypes.default_float == default_float
  322. def test_env_set_default_float(self):
  323. # check default
  324. subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
  325. shell=True, check=True)
  326. # check change
  327. subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
  328. shell=True, check=True)
  329. # check invalid
  330. with self.assertRaises(subprocess.CalledProcessError):
  331. subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
  332. shell=True, check=True)
  333. with self.assertRaises(subprocess.CalledProcessError):
  334. subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
  335. shell=True, check=True)
  336. @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
  337. def test_creation(self, default_int, default_float):
  338. dtypes.default_int, dtypes.default_float = default_int, default_float
  339. _assert_eq(Tensor(True), dtypes.bool, True)
  340. _assert_eq(Tensor(None), dtypes.default_float, [])
  341. _assert_eq(Tensor(2), dtypes.default_int, 2)
  342. _assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
  343. _assert_eq(Tensor([]), dtypes.default_float, [])
  344. _assert_eq(Tensor([1]), dtypes.default_int, [1])
  345. _assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
  346. _assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
  347. _assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
  348. _assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
  349. if is_dtype_supported(dtypes.float16):
  350. _assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
  351. @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
  352. def test_full(self, default_int, default_float):
  353. dtypes.default_int, dtypes.default_float = default_int, default_float
  354. _assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
  355. _assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
  356. if is_dtype_supported(dtypes.float16):
  357. _assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
  358. _assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
  359. _assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
  360. if is_dtype_supported(dtypes.float16):
  361. _assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
  362. _assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
  363. _assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
  364. _assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
  365. _assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
  366. _assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
  367. if is_dtype_supported(dtypes.float16):
  368. _assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
  369. _assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
  370. @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
  371. def test_reduce_0d_default(self, default_int, default_float):
  372. dtypes.default_int, dtypes.default_float = default_int, default_float
  373. _assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
  374. # TODO: what should this one be?
  375. # _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
  376. _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
  377. @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
  378. def test_arange(self, default_int, default_float):
  379. dtypes.default_int, dtypes.default_float = default_int, default_float
  380. _assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
  381. _assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
  382. _assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
  383. _assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
  384. _assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
  385. if is_dtype_supported(dtypes.float16):
  386. _assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
  387. _assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
  388. _assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
  389. @given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
  390. def test_bool_ops(self, dtype, op):
  391. assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
  392. @given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
  393. def test_functions_return_index(self, dtype, default_int, default_float):
  394. dtypes.default_int, dtypes.default_float = default_int, default_float
  395. assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
  396. assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
  397. assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
  398. @given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
  399. def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
  400. X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype)
  401. indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
  402. assert X_data[indices].dtype == X_data.dtype
  403. @given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
  404. def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
  405. X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
  406. indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
  407. assert X_data.gather(0, indices).dtype == X_data.dtype
  408. assert X_data.gather(1, indices).dtype == X_data.dtype
  409. @given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
  410. def test_attention_returns_same_dtype(self, data_dtype, default_float):
  411. dtypes.default_float = default_float
  412. query = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
  413. key = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
  414. value = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
  415. mask = (Tensor.rand(32, 8, 128, 128) < 0.5)
  416. assert query.scaled_dot_product_attention(key, value, is_causal=True).dtype == data_dtype
  417. assert query.scaled_dot_product_attention(key, value, is_causal=True, dropout_p=0.3).dtype == data_dtype
  418. assert query.scaled_dot_product_attention(key, value, is_causal=False).dtype == data_dtype
  419. assert query.scaled_dot_product_attention(key, value, attn_mask=mask).dtype == data_dtype
  420. class TestTypePromotion(unittest.TestCase):
  421. @given(strat.sampled_from(core_dtypes))
  422. def test_self_promo_to_self(self, dtype):
  423. assert least_upper_dtype(dtype) == dtype
  424. assert least_upper_dtype(dtype, dtype) == dtype
  425. assert least_upper_dtype(dtype, dtype, dtype) == dtype
  426. @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
  427. def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
  428. result = least_upper_dtype(dtype1, dtype2)
  429. assert result >= dtype1 and result >= dtype2
  430. def test_dtype_promo(self):
  431. assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8
  432. assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16
  433. assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16
  434. assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32
  435. assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32
  436. assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
  437. assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
  438. # similar to jax but we don't use weak type
  439. assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float16
  440. assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
  441. assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
  442. assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32
  443. assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64
  444. assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
  445. assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
  446. @given(strat.sampled_from(dtype_floats))
  447. def test_float_to_float(self, dt):
  448. assert least_upper_float(dt) == dt
  449. class TestAutoCastType(unittest.TestCase):
  450. def setUp(self):
  451. self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
  452. def tearDown(self):
  453. dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
  454. @given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
  455. def test_int_to_float_unary_func(self, dtype):
  456. for func in [
  457. lambda t: t.exp(),
  458. lambda t: t.exp2(),
  459. lambda t: t.log(),
  460. lambda t: t.log2(),
  461. lambda t: t.sqrt(),
  462. lambda t: t.rsqrt(),
  463. lambda t: t.sin(),
  464. lambda t: t.cos(),
  465. lambda t: t.tan(),
  466. lambda t: t.sigmoid(),
  467. ]:
  468. a = [2, 3, 4]
  469. # float16 can have larger precision errors
  470. np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
  471. @given(strat.sampled_from(core_dtypes))
  472. def test_broadcast_scalar(self, dt):
  473. assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
  474. assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
  475. if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
  476. assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt
  477. def test_sum(self):
  478. assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
  479. assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
  480. assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
  481. assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
  482. assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
  483. assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
  484. assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
  485. assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
  486. assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
  487. assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
  488. #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
  489. assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
  490. assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
  491. @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
  492. def test_sum_acc_dtype(self):
  493. t = Tensor([40000, 40000], dtype=dtypes.float16)
  494. # default float16 sum returns in float16, overflowed in this case
  495. assert t.sum().dtype == dtypes.float16
  496. assert math.isinf(t.sum().numpy().item())
  497. # specifiying acc_dtype and it's not downcasted
  498. assert t.sum(acc_dtype=dtypes.float32).dtype == dtypes.float32
  499. np.testing.assert_allclose(t.sum(acc_dtype=dtypes.float32).numpy(), 80000)
  500. def test_mean(self):
  501. assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
  502. assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
  503. assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
  504. assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
  505. assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
  506. assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
  507. assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
  508. assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
  509. assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
  510. assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
  511. #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
  512. assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
  513. assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
  514. def test_cumsum(self):
  515. assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
  516. assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
  517. assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
  518. assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
  519. assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
  520. assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
  521. assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
  522. assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
  523. assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
  524. assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
  525. #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
  526. assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
  527. assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
  528. @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
  529. def test_matmul(self, dt1, dt2, acc_dt):
  530. t1 = Tensor([0, 1], dtype=dt1)
  531. t2 = Tensor([0, 1], dtype=dt2)
  532. assert (t1 @ t2).dtype == least_upper_dtype(dt1, dt2)
  533. # if acc_dtype is specified, return in acc_dtype
  534. assert (t1.matmul(t2, acc_dtype=acc_dt).dtype == acc_dt)
  535. @staticmethod
  536. def check_where_alternate_input_other(input_, other, data_type):
  537. assert (Tensor([True, False]).where(input_, other)).dtype == data_type
  538. assert (Tensor([True, False]).where(other, input_)).dtype == data_type
  539. @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
  540. def test_where_no_scalar(self, dt1, dt2):
  541. self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
  542. @given(strat.sampled_from(core_dtypes))
  543. def test_where_one_scalar(self, dt):
  544. t = Tensor(2, dtype=dt)
  545. self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
  546. self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
  547. self.check_where_alternate_input_other(t, True, dt)
  548. def test_where_two_scalars(self):
  549. self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
  550. self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
  551. self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
  552. self.check_where_alternate_input_other(3, 2, dtypes.default_int)
  553. self.check_where_alternate_input_other(3, True, dtypes.default_int)
  554. self.check_where_alternate_input_other(False, True, dtypes.bool)
  555. @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
  556. def test_maximum(self, dt1, dt2):
  557. assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
  558. @given(strat.sampled_from(core_dtypes))
  559. def test_maximum_const(self, dt):
  560. assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
  561. assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
  562. assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
  563. def test_div(self):
  564. assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
  565. assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
  566. assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
  567. assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
  568. def test_div_const(self):
  569. assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
  570. assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
  571. assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
  572. assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
  573. def test_gradient_dtype(self):
  574. old_default_float = dtypes.default_float
  575. for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
  576. if not is_dtype_supported(default_dtype): continue
  577. dtypes.default_float = default_dtype
  578. for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
  579. if not is_dtype_supported(dtype): continue
  580. if DEBUG >= 2:
  581. print(f"testing {default_dtype=}, {dtype=}")
  582. a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
  583. b = (a * 5).sum()
  584. b.backward() # if there is dtype mismatch, lazy should assert
  585. assert a.grad.dtype == a.dtype
  586. np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
  587. dtypes.default_float = old_default_float
  588. @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
  589. def test_backward_sum_acc_dtype(self):
  590. # test acc of sum in the backward is upcasted to float
  591. t = Tensor([5, -5], dtype=dtypes.half, requires_grad=True)
  592. t.reshape(2, 1).expand(2, 10001).max().backward()
  593. np.testing.assert_allclose(t.grad.numpy(), [1, 0])
  594. @unittest.skipIf(Device.DEFAULT=="PYTHON", "very slow")
  595. @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
  596. def test_mean_half_precision_underflow(self):
  597. N = 10000
  598. x = 0.001
  599. t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
  600. np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
  601. @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
  602. def test_mean_half_precision_overflow(self):
  603. N = 256
  604. t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
  605. np.testing.assert_allclose(t.mean().numpy(), 60000)
  606. t.square().mean().backward()
  607. np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
  608. class TestImplicitFunctionTypeChange(unittest.TestCase):
  609. def test_functions(self):
  610. result = []
  611. for func in [
  612. lambda t: t.exp(),
  613. lambda t: t.exp2(),
  614. lambda t: t.log(),
  615. lambda t: t.log2(),
  616. lambda t: t.sqrt(),
  617. lambda t: t.sin(),
  618. ]:
  619. t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0]))
  620. result.append(t.numpy().sum())
  621. assert all(result)
  622. class TestTensorMethod(unittest.TestCase):
  623. @given(strat.sampled_from(core_dtypes))
  624. def test_abs_diff(self, dt):
  625. if dt == dtypes.bool or not is_dtype_supported(dt): return
  626. a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt)
  627. ret = (a - b).abs()
  628. np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy()))
  629. if __name__ == '__main__':
  630. unittest.main()