test_jit.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. #!/usr/bin/env python
  2. import unittest, functools
  3. import numpy as np
  4. from hypothesis import given, settings, strategies as strat
  5. from test.helpers import assert_jit_cache_len
  6. from tinygrad.tensor import Tensor
  7. from tinygrad.engine.jit import TinyJit
  8. from tinygrad.device import Device
  9. from tinygrad.helpers import CI, Context
  10. from tinygrad.dtype import dtypes
  11. from extra.models.unet import ResBlock
  12. def _simple_test(add, extract=lambda x: x, N=10):
  13. for _ in range(5):
  14. a = Tensor.randn(N, N)
  15. b = Tensor.randn(N, N)
  16. c = add(a, b)
  17. np.testing.assert_allclose(extract(c).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
  18. assert_jit_cache_len(add, 1)
  19. class TestJit(unittest.TestCase):
  20. @settings(deadline=2e4)
  21. @unittest.skipUnless(Device.DEFAULT in ["LLVM", "CLANG"], f"no support on {Device.DEFAULT}")
  22. @given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
  23. def test_approx_jit_timeout(self, op):
  24. with Context(TRANSCENDENTAL=2):
  25. model = [ResBlock(16, 24, 16) for _ in range(4)]
  26. @TinyJit
  27. def fw_approx(t, t2):
  28. for l in model: t = l(t, t2)
  29. return op(t).realize()
  30. fw_approx(Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24))
  31. def test_simple_jit(self):
  32. @TinyJit
  33. def add(a, b): return (a+b).realize()
  34. _simple_test(add)
  35. def test_simple_jit_reset(self):
  36. @TinyJit
  37. def add(a, b): return (a+b).realize()
  38. _simple_test(add)
  39. add.reset()
  40. _simple_test(add, N=20)
  41. def test_simple_jit_norealize(self):
  42. @TinyJit
  43. def add(a, b): return (a+b)
  44. _simple_test(add)
  45. def test_simple_jit_norealize_list(self):
  46. @TinyJit
  47. def add(a, b): return [a+b]
  48. _simple_test(add, extract=lambda x: x[0])
  49. def test_simple_jit_norealize_dict(self):
  50. @TinyJit
  51. def add(a, b): return {"billy": a+b}
  52. _simple_test(add, extract=lambda x: x["billy"])
  53. def test_jit_multiple_outputs(self):
  54. @TinyJit
  55. def f(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
  56. for _ in range(5):
  57. a = Tensor.randn(10, 10)
  58. b = Tensor.randn(10, 10)
  59. c, d, e = f(a, b)
  60. np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
  61. np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
  62. np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
  63. assert_jit_cache_len(f, 3)
  64. def test_nothing_jitted(self):
  65. @TinyJit
  66. def add(a, b): return None
  67. with self.assertRaises(AssertionError):
  68. for _ in range(5):
  69. a = Tensor.randn(10, 10)
  70. b = Tensor.randn(10, 10)
  71. add(a, b)
  72. def test_jit_zero_does_not_jit(self):
  73. @TinyJit
  74. def add(a, b): return (a+b).realize()
  75. with Context(JIT=0):
  76. for i in range(5):
  77. a = Tensor([i])
  78. b = Tensor([i])
  79. c = add(a, b)
  80. np.testing.assert_allclose(c.numpy(), 2*i)
  81. assert_jit_cache_len(add, 0)
  82. def test_jit_not_capturing(self):
  83. @TinyJit
  84. def add(a, b):
  85. Tensor.zeros(4, 4).contiguous().realize() # no-op kernel is captured
  86. return (a+b).realize()
  87. for i in range(5):
  88. a = Tensor([i])
  89. b = Tensor([i])
  90. c = add(a, b)
  91. np.testing.assert_allclose(c.numpy(), 2*i)
  92. assert_jit_cache_len(add, 2)
  93. @TinyJit
  94. def add2(a, b):
  95. with Context(CAPTURING=0): # not captured
  96. Tensor.zeros(4, 4).contiguous().realize()
  97. return (a+b).realize()
  98. for i in range(5):
  99. a = Tensor([i])
  100. b = Tensor([i])
  101. c = add2(a, b)
  102. np.testing.assert_allclose(c.numpy(), 2*i)
  103. assert_jit_cache_len(add2, 1)
  104. def test_jit_shape_mismatch(self):
  105. @TinyJit
  106. def add(a, b): return (a+b).realize()
  107. for _ in range(5):
  108. a = Tensor.randn(10, 10)
  109. b = Tensor.randn(10, 10)
  110. add(a, b)
  111. bad = Tensor.randn(20, 20)
  112. with self.assertRaises(AssertionError):
  113. add(a, bad)
  114. def test_jit_shape_views_mismatch(self):
  115. @TinyJit
  116. def add(a): return (a+1).realize()
  117. with self.assertRaises(AssertionError):
  118. for i in range(1,5):
  119. # a has an offset that the kernel doesn't know about
  120. a = Tensor.randn(10, 10).realize()[:, i:i+2]
  121. add(a)
  122. def test_jit_duplicate_fail(self):
  123. # the jit doesn't support duplicate arguments
  124. @TinyJit
  125. def add(a, b): return (a+b).realize()
  126. a = Tensor.randn(10, 10)
  127. with self.assertRaises(AssertionError):
  128. add(a, a)
  129. def test_kwargs_jit(self):
  130. @TinyJit
  131. def add_kwargs(first, second): return (first+second).realize()
  132. for _ in range(5):
  133. a = Tensor.randn(10, 10)
  134. b = Tensor.randn(10, 10)
  135. c = add_kwargs(first=a, second=b)
  136. np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
  137. assert_jit_cache_len(add_kwargs, 1)
  138. def test_reorder_kwargs_jit(self):
  139. @TinyJit
  140. def add_kwargs(first, second): return (first/second).realize()
  141. for _ in range(2):
  142. a = Tensor.randn(10, 10)
  143. b = Tensor.randn(10, 10)
  144. c = add_kwargs(second=b, first=a)
  145. np.testing.assert_allclose(c.numpy(), a.numpy()/b.numpy(), atol=1e-4, rtol=1e-5)
  146. for _ in range(2):
  147. a = Tensor.randn(10, 10)
  148. b = Tensor.randn(10, 10)
  149. c = add_kwargs(first=a, second=b)
  150. np.testing.assert_allclose(c.numpy(), a.numpy()/b.numpy(), atol=1e-4, rtol=1e-5)
  151. assert_jit_cache_len(add_kwargs, 1)
  152. def test_array_jit(self):
  153. @TinyJit
  154. def add_array(a, arr): return (a+arr[0]).realize()
  155. for i in range(5):
  156. a = Tensor.randn(10, 10)
  157. b = Tensor.randn(10, 10)
  158. a.realize(), b.realize()
  159. c = add_array(a, [b])
  160. if i >= 2:
  161. # should fail once jitted since jit can't handle arrays
  162. np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
  163. else:
  164. np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
  165. assert_jit_cache_len(add_array, 1)
  166. def test_jit_copyin(self):
  167. @TinyJit
  168. def f(a):
  169. return a + Tensor([1,2,3])
  170. for _ in range(5):
  171. b = Tensor.randn(3)
  172. c = f(b)
  173. np.testing.assert_allclose(c.numpy(), b.numpy()+[1,2,3], atol=1e-4, rtol=1e-5)
  174. def test_method_jit(self):
  175. class Fun:
  176. def __init__(self):
  177. self.a = Tensor.randn(10, 10)
  178. @TinyJit
  179. def __call__(self, b:Tensor) -> Tensor:
  180. return (self.a+b).realize()
  181. fun = Fun()
  182. for _ in range(5):
  183. b = Tensor.randn(10, 10)
  184. c = fun(b)
  185. np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
  186. assert_jit_cache_len(fun.__call__.func.__self__, 1)
  187. def test_jit_size1_input(self):
  188. @TinyJit
  189. def f(a, b): return (a+b).realize()
  190. a = Tensor([1, 2, 3])
  191. for i in range(5):
  192. np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5)
  193. assert_jit_cache_len(f, 1)
  194. def test_jit_output_non_tensor_fail(self):
  195. @TinyJit
  196. def f(a, b, i): return (a+b).realize(), i
  197. output1, output2 = [], []
  198. expect1, expect2 = [], []
  199. for i in range(5):
  200. a = Tensor.randn(10, 10)
  201. b = Tensor.randn(10, 10)
  202. o1, o2 = f(a, b, i)
  203. output1.append(o1.numpy().copy())
  204. output2.append(o2)
  205. expect1.append(a.numpy().copy()+b.numpy().copy())
  206. expect2.append(i)
  207. np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
  208. # the jit only works with Tensor outputs
  209. assert output2 != expect2
  210. assert_jit_cache_len(f, 1)
  211. def test_jit_random_regen(self):
  212. def f(a, b):
  213. rn = Tensor.randn(*a.shape)
  214. return ((a+b)*rn).realize()
  215. a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
  216. b = Tensor.randn(10, 10).realize()
  217. Tensor.manual_seed(1234)
  218. jf = TinyJit(f)
  219. res = set()
  220. for _ in range(5):
  221. o1 = jf(a, b)
  222. res.add(o1.numpy()[0][0])
  223. assert len(res) == 5, "All values should be different, rand works in jit."
  224. Tensor.manual_seed(1234)
  225. jf2 = TinyJit(f)
  226. res2 = set()
  227. for _ in range(5):
  228. o1 = jf2(a, b)
  229. res2.add(o1.numpy()[0][0])
  230. assert len(res2) == 5, "All values should be different, rand works in jit."
  231. assert res == res2, "Jit rand is not reproducible with the same seed"
  232. Tensor.manual_seed(3421)
  233. jf3 = TinyJit(f)
  234. res3 = set()
  235. for _ in range(5):
  236. o1 = jf3(a, b)
  237. res3.add(o1.numpy()[0][0])
  238. assert len(res3) == 5, "All values should be different, rand works in jit."
  239. assert res3 != res2, "Jit rand is diff with diff seeds"
  240. def test_jit_realization_and_sampling(self):
  241. w = Tensor.eye(5)
  242. @TinyJit
  243. def foo (x): return w.dot(x).realize()
  244. arg = [
  245. Tensor([1,2,3,4,5]),
  246. Tensor([1,3,3,4,6]),
  247. Tensor([1,2,5,4,7]),
  248. Tensor([0,2,3,1,0]),
  249. ]
  250. Y = [foo(e).numpy() for e in arg]
  251. foo(Tensor([7,7,7,7,7]))
  252. want = [[1., 2., 3., 4., 5.],
  253. [1., 3., 3., 4., 6.],
  254. [1., 2., 5., 4., 7.],
  255. [0., 2., 3., 1., 0.]]
  256. np.testing.assert_allclose(want, Y)
  257. def test_jit_buffer_behavior(self):
  258. @TinyJit
  259. def foo(x) -> Tensor: return x.sum().realize()
  260. result_1 = foo(Tensor([1] * 2))
  261. result_2 = foo(Tensor([2] * 2))
  262. result_3 = foo(Tensor([3] * 2))
  263. # expect the buffer to share underlying buffer
  264. np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5)
  265. np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5)
  266. np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5)
  267. @unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails")
  268. def test_jit_batch_split(self):
  269. if Device[Device.DEFAULT].graph is None: raise unittest.SkipTest("only test graphs")
  270. # Create long jit with 83 kernels.
  271. def f(a, b, c, d, e):
  272. for _ in range(80):
  273. a = (a+b).realize()
  274. y = (a*c).realize()
  275. z = (y*d).realize()
  276. w = (z*e)
  277. return w.realize()
  278. a = Tensor.randn(10, 10).realize()
  279. b = Tensor.randn(10, 10).realize()
  280. c = Tensor.randn(10, 10).realize()
  281. d = Tensor.randn(10, 10).realize()
  282. e = Tensor.randn(10, 10).realize()
  283. jf = TinyJit(f)
  284. prev = None
  285. for _ in range(5):
  286. o = jf(a, b, c, d, e).numpy()
  287. if prev is not None: np.testing.assert_allclose(o, prev, atol=1e-4, rtol=1e-5)
  288. prev = o
  289. graph_t = Device[Device.DEFAULT].graph.func if isinstance(Device[Device.DEFAULT].graph, functools.partial) else Device[Device.DEFAULT].graph
  290. # Checking that 2 graphs are inited.
  291. assert isinstance(jf.jit_cache[0].prg, graph_t)
  292. assert isinstance(jf.jit_cache[1].prg, graph_t)
  293. def test_jit_const_inputs(self):
  294. @TinyJit
  295. def g(x,y,z): return (x+y+z).realize()
  296. for i in range(5):
  297. np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
  298. @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI")
  299. def test_jitted_transfers(self):
  300. d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
  301. def f(a, b):
  302. x = a.to(d1)
  303. y = b.to(d1)
  304. return x.realize(), y.realize()
  305. jf = TinyJit(f)
  306. for _ in range(5):
  307. a = Tensor.randn(10, 10, device=d0).realize()
  308. b = Tensor.randn(10, 10, device=d0).realize()
  309. xc, yc = jf(a, b)
  310. np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
  311. np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
  312. @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU/CUDA/METAL in CI, fine to run on AMD/NV")
  313. def test_jitted_view(self):
  314. d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
  315. def f(a):
  316. x1 = a.sum(axis=(1,))
  317. x = (x1 + 5).bitcast(dtypes.int32)
  318. y = x.to(d1)
  319. return y.realize()
  320. jf = TinyJit(f)
  321. for _ in range(5):
  322. a = Tensor.randn(10, 1000, device=d0).realize()
  323. xc = jf(a)
  324. np.testing.assert_allclose((a.numpy().sum(axis=(1,)) + 5).view(np.int32), xc.numpy(), atol=1e-4, rtol=1e-5)
  325. @unittest.skip("Pending multioutput implementation #3607")
  326. class TestMultioutputJit(unittest.TestCase):
  327. def _test(self, f):
  328. for _ in range(5):
  329. a, b = Tensor.randn(10, 10), Tensor.randn(10, 10)
  330. out0, out1, out2 = f(a, b)
  331. np.testing.assert_allclose(out0.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
  332. np.testing.assert_allclose(out1.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
  333. np.testing.assert_allclose(out2.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
  334. def test_jit_multioutput_realize(self):
  335. @TinyJit
  336. def fxn(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
  337. self._test(fxn)
  338. assert_jit_cache_len(fxn, 3)
  339. def test_jit_multioutput_norealize(self):
  340. @TinyJit
  341. def fxn(a, b): return a+b, a-b, a*b
  342. self._test(fxn)
  343. assert_jit_cache_len(fxn, 1)
  344. def test_jit_multioutput_mix(self):
  345. @TinyJit
  346. def fxn(a, b): return a+b, a-b, (a*b).realize()
  347. self._test(fxn)
  348. assert_jit_cache_len(fxn, 2)
  349. class TestJitInsideJit(unittest.TestCase):
  350. def test_jit_jit_error(self):
  351. @TinyJit
  352. def f(t): return t + 1
  353. @TinyJit
  354. def g(t): return f(t) * 3
  355. # NOTE: first does not raise
  356. g(Tensor([1])).realize()
  357. with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
  358. g(Tensor([1])).realize()
  359. if __name__ == '__main__':
  360. unittest.main()