test_randomness.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import unittest, math
  2. from functools import partial
  3. import numpy as np
  4. import torch
  5. from tinygrad import nn, dtypes, Tensor, Device, TinyJit
  6. from tinygrad.helpers import THREEFRY, getenv, CI
  7. from test.helpers import is_dtype_supported
  8. from hypothesis import given, settings, strategies as strat
  9. settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
  10. settings.load_profile("my_profile")
  11. # https://gist.github.com/devries/11405101
  12. def ksprob(a):
  13. fac, total, termbf = 2.0, 0.0, 0.0
  14. a2 = -2.0 * a * a
  15. for j in range(1, 101):
  16. term = fac * math.exp(a2 * j * j)
  17. total += term
  18. if math.fabs(term) <= 0.001 * termbf or math.fabs(term) <= 1e-8 * total:
  19. return total
  20. fac = -fac
  21. termbf = math.fabs(term)
  22. return 1.0
  23. def kstest(l1, l2):
  24. n1, n2 = len(l1), len(l2)
  25. l1.sort()
  26. l2.sort()
  27. j1, j2, d, fn1, fn2 = 0, 0, 0.0, 0.0, 0.0
  28. while j1 < n1 and j2 < n2:
  29. d1, d2 = l1[j1], l2[j2]
  30. if d1 <= d2:
  31. fn1 = (float(j1) + 1.0) / float(n1)
  32. j1 += 1
  33. if d2 <= d1:
  34. fn2 = (float(j2) + 1.0) / float(n2)
  35. j2 += 1
  36. dtemp = math.fabs(fn2 - fn1)
  37. if dtemp > d:
  38. d = dtemp
  39. ne = float(n1 * n2) / float(n1 + n2)
  40. nesq = math.sqrt(ne)
  41. prob = ksprob((nesq + 0.12 + 0.11 / nesq) * d)
  42. return prob
  43. def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.04):
  44. Tensor.manual_seed(1337)
  45. torch.manual_seed(1337)
  46. np.random.seed(1337)
  47. assert not (torch_func is None and numpy_func is None), "no function to compare with"
  48. x1 = tiny_func(*shape).numpy().flatten()
  49. x2 = tiny_func(shape).numpy().flatten()
  50. if numpy_func is not None: y = numpy_func(shape).flatten()
  51. if torch_func is not None: z = torch_func(shape).numpy().flatten()
  52. return (numpy_func is None or (kstest(x1, y) >= alpha and kstest(x2, y) >= alpha)) and \
  53. (torch_func is None or (kstest(x1, z) >= alpha and kstest(x2, z) >= alpha))
  54. def normal_test(func, shape=(20, 23), alpha=0.05): return equal_distribution(func, numpy_func=lambda x: np.random.randn(*x), shape=shape, alpha=alpha)
  55. class TestRandomness(unittest.TestCase):
  56. def test_rand(self):
  57. self.assertFalse(normal_test(Tensor.rand))
  58. self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x)))
  59. @unittest.skipIf(THREEFRY.value, "broken with threefry")
  60. def test_rand_half(self):
  61. N = 128
  62. x = Tensor.rand((2, N, N), dtype=dtypes.half)
  63. assert x.dtype == dtypes.half
  64. x = x.numpy()
  65. ones = np.take(x, np.where(x == 1))
  66. zeros = np.take(x, np.where(x == 0))
  67. self.assertTrue(ones.size == 0)
  68. self.assertTrue(zeros.size > 0)
  69. equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
  70. @unittest.skipIf(not THREEFRY.value, "not using threefry")
  71. def test_threefly_against_reference(self):
  72. Tensor.manual_seed(1337)
  73. # generated using
  74. # (jax.extend.random.threefry_2x32((np.uint32(1337), np.uint32(0x0)), np.arange(20, dtype=np.uint32)) >> 8).astype(float) / np.float32(2**24)
  75. jr = np.array([0.30984968, 0.42723763, 0.92448753, 0.27268296, 0.48820806, 0.29587173, 0.3213513, 0.05805135, 0.4954177, 0.23303074,
  76. 0.62478125, 0.51861334, 0.24712527, 0.12718695, 0.5236074, 0.50704265, 0.9166272, 0.6918763, 0.6530086, 0.34640658])
  77. r = Tensor.rand(20).numpy()
  78. np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
  79. @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
  80. def test_rand_bfloat16(self):
  81. N = 128
  82. x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
  83. assert x.dtype == dtypes.bfloat16
  84. # TODO: fix this property for bfloat16 random
  85. # x = x.numpy()
  86. # ones = np.take(x, np.where(x == 1))
  87. # zeros = np.take(x, np.where(x == 0))
  88. # self.assertTrue(ones.size == 0)
  89. # self.assertTrue(zeros.size > 0)
  90. equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
  91. def test_randn(self):
  92. self.assertTrue(normal_test(Tensor.randn))
  93. self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)))
  94. @given(strat.sampled_from([dtypes.float, dtypes.float16, dtypes.bfloat16]))
  95. @unittest.skipIf(Device.DEFAULT in ["HSA", "AMD"], "bfloat16 local buffer broken in HSA")
  96. def test_randn_finite(self, default_float):
  97. if not is_dtype_supported(default_float): return
  98. old_default_float = dtypes.default_float
  99. # low precision can result in inf from randn
  100. dtypes.default_float = default_float
  101. t = Tensor.randn(1024, 1024)
  102. mx = t.max().numpy().item()
  103. mn = t.min().numpy().item()
  104. print(f"testing with {default_float=}")
  105. assert math.isfinite(mx), mx
  106. assert math.isfinite(mn), mn
  107. dtypes.default_float = old_default_float
  108. def test_randint(self):
  109. self.assertFalse(normal_test(Tensor.randint))
  110. self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
  111. self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG")
  112. # check types of args
  113. with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3)
  114. with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
  115. with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32)
  116. def test_normal(self):
  117. self.assertTrue(normal_test(Tensor.normal))
  118. self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1),
  119. lambda x: np.random.normal(loc=0, scale=1, size=x)))
  120. def test_uniform(self):
  121. self.assertFalse(normal_test(Tensor.uniform))
  122. self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x)), lambda x: np.random.uniform(size=x)))
  123. self.assertTrue(equal_distribution(partial(Tensor.uniform, low=-100, high=100, dtype=dtypes.int32),
  124. numpy_func=lambda x: np.random.randint(low=-100, high=100, size=x)))
  125. def test_scaled_uniform(self):
  126. self.assertFalse(normal_test(Tensor.scaled_uniform))
  127. self.assertTrue(equal_distribution(Tensor.scaled_uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1) / math.sqrt(math.prod(x)),
  128. lambda x: np.random.uniform(-1, 1, size=x) / math.sqrt(math.prod(x))))
  129. def test_glorot_uniform(self):
  130. self.assertFalse(normal_test(Tensor.glorot_uniform))
  131. self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)),
  132. lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
  133. def test_kaiming_uniform(self):
  134. for shape in [(128, 64, 3, 3), (20, 24)]:
  135. self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
  136. def test_kaiming_normal(self):
  137. for shape in [(128, 64, 3, 3), (20, 24)]:
  138. self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
  139. def test_multinomial(self):
  140. self.assertRaises(AssertionError, lambda: Tensor(2).multinomial(1, replacement=False))
  141. self.assertRaises(AssertionError, lambda: Tensor([1, 9]).multinomial(0, replacement=False))
  142. def _check_with_torch(w, num_samples, replacement):
  143. tiny_res = Tensor(w).multinomial(num_samples, replacement=replacement)
  144. torch_res = torch.tensor(w).multinomial(num_samples, replacement=replacement)
  145. self.assertEqual(tiny_res.shape, torch_res.shape)
  146. if torch_res.ndim == 1:
  147. tiny_res = tiny_res.unsqueeze(0)
  148. torch_res = torch_res.unsqueeze(0)
  149. for i in range(torch_res.shape[0]):
  150. self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i]))
  151. _check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True)
  152. _check_with_torch(w=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row
  153. _check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True)
  154. # no-replacement isn't supported, unless taking only one sample
  155. w = [0.1, 0.9]
  156. self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False))
  157. @TinyJit
  158. def sample_one(): return Tensor(w).multinomial(1, replacement=False).realize()
  159. # TODO: fix mockgpu issue
  160. if not (CI and Device.DEFAULT == "AMD"):
  161. tiny_samples = [sample_one().item() for _ in range(1000)]
  162. torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)]
  163. self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
  164. def test_multinomial_counterexample(self):
  165. tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
  166. torch_res = torch.tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
  167. self.assertTrue(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
  168. torch_res = torch.tensor([0.2, 0.7, 0.1]).multinomial(2000, replacement=True)
  169. self.assertFalse(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
  170. def test_conv2d_init(self):
  171. params = (128, 256, (3,3))
  172. assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach())
  173. assert equal_distribution(lambda *_: nn.Conv2d(*params).bias, lambda _: torch.nn.Conv2d(*params).bias.detach())
  174. def test_linear_init(self):
  175. params = (64, 64)
  176. assert equal_distribution(lambda *_: nn.Linear(*params).weight, lambda _: torch.nn.Linear(*params).weight.detach())
  177. assert equal_distribution(lambda *_: nn.Linear(*params).bias, lambda _: torch.nn.Linear(*params).bias.detach())
  178. def test_bn_init(self):
  179. params = (64,)
  180. assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).weight, lambda _: torch.nn.BatchNorm2d(*params).weight.detach())
  181. assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).bias, lambda _: torch.nn.BatchNorm2d(*params).bias.detach())
  182. if __name__ == "__main__":
  183. unittest.main()