test_symbolic_jit.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import unittest
  2. from test.helpers import assert_jit_cache_len
  3. from tinygrad.engine.jit import TinyJit
  4. from tinygrad.shape.symbolic import Variable
  5. from tinygrad.tensor import Tensor
  6. import numpy as np
  7. class TestSymbolicJit(unittest.TestCase):
  8. def test_plus1(self):
  9. def f(a): return (a+1).realize()
  10. jf = TinyJit(f)
  11. for i in range(1, 5):
  12. vi = Variable("i", 1, 10).bind(i)
  13. a = Tensor.rand(3, i)
  14. symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
  15. expected = f(a).numpy()
  16. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  17. assert_jit_cache_len(jf, 1)
  18. def test_add(self):
  19. def f(a, b): return (a+b).realize()
  20. jf = TinyJit(f)
  21. for i in range(1, 5):
  22. vi = Variable("i", 1, 10).bind(i)
  23. a = Tensor.rand(3, i)
  24. b = Tensor.rand(3, i)
  25. symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
  26. expected = f(a, b).numpy()
  27. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  28. assert_jit_cache_len(jf, 1)
  29. def test_matmul(self):
  30. def f(a, b): return (a@b).realize()
  31. jf = TinyJit(f)
  32. for i in range(1, 5):
  33. vi = Variable("i", 1, 10).bind(i)
  34. a = Tensor.rand(3, i)
  35. b = Tensor.rand(i, 5)
  36. symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
  37. expected = f(a, b).numpy()
  38. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  39. assert_jit_cache_len(jf, 1)
  40. def test_mixed_with_no_symbol_kernel(self):
  41. def f(a, b):
  42. s = (a@b).realize()
  43. s = (s+s).realize() # this one does not have symbols in input
  44. return s
  45. jf = TinyJit(f)
  46. for i in range(1, 5):
  47. vi = Variable("i", 1, 10).bind(i)
  48. a = Tensor.rand(3, i)
  49. b = Tensor.rand(i, 5)
  50. symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
  51. expected = f(a, b).numpy()
  52. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  53. assert_jit_cache_len(jf, 2)
  54. def test_attention(self):
  55. def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
  56. jf = TinyJit(f)
  57. for i in range(1, 5):
  58. vi = Variable("i", 1, 10).bind(i)
  59. q = Tensor.rand(2, 1, 4, 8)
  60. k = Tensor.rand(2, i, 4, 8)
  61. v = Tensor.rand(2, i, 4, 8)
  62. symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
  63. expected = f(q, k, v).numpy()
  64. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  65. assert_jit_cache_len(jf, 5)
  66. def test_cat_dim0(self):
  67. def f(a, b): return a.cat(b, dim=0).realize()
  68. jf = TinyJit(f)
  69. for i in range(1, 5):
  70. vi = Variable("i", 1, 10).bind(i)
  71. a = Tensor.rand(i, 3)
  72. b = Tensor.rand(2, 3)
  73. symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
  74. expected = f(a, b).numpy()
  75. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  76. assert_jit_cache_len(jf, 1)
  77. def test_cat_dim1(self):
  78. def f(a, b): return a.cat(b, dim=1).realize()
  79. jf = TinyJit(f)
  80. for i in range(1, 5):
  81. vi = Variable("i", 1, 10).bind(i)
  82. a = Tensor.rand(3, i)
  83. b = Tensor.rand(3, 2)
  84. symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
  85. expected = f(a, b).numpy()
  86. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  87. assert_jit_cache_len(jf, 1)
  88. def test_cat_dim0_two_vars(self):
  89. def f(a, b): return a.cat(b, dim=0).realize()
  90. jf = TinyJit(f)
  91. for i in range(1, 5):
  92. for j in range(1, 5):
  93. vi = Variable("i", 1, 10).bind(i)
  94. vj = Variable("j", 1, 10).bind(j)
  95. a = Tensor.rand(i, 3)
  96. b = Tensor.rand(j, 3)
  97. symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
  98. expected = f(a, b).numpy()
  99. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  100. assert_jit_cache_len(jf, 1)
  101. def test_cat_dim1_two_vars(self):
  102. def f(a, b): return a.cat(b, dim=1).realize()
  103. jf = TinyJit(f)
  104. for i in range(1, 5):
  105. for j in range(1, 5):
  106. vi = Variable("i", 1, 10).bind(i)
  107. vj = Variable("j", 1, 10).bind(j)
  108. a = Tensor.rand(3, i)
  109. b = Tensor.rand(3, j)
  110. symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
  111. expected = f(a, b).numpy()
  112. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  113. assert_jit_cache_len(jf, 1)
  114. def test_two_vars_plus1_ij(self):
  115. def f(a, b): return (a@b+1).realize()
  116. jf = TinyJit(f)
  117. for i in range(1, 5):
  118. for j in range(1, 5):
  119. vi = Variable("i", 1, 10).bind(i)
  120. vj = Variable("j", 1, 10).bind(j)
  121. a = Tensor.rand(i, 3)
  122. b = Tensor.rand(3, j)
  123. symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
  124. expected = f(a, b).numpy()
  125. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  126. assert_jit_cache_len(jf, 1)
  127. def test_two_vars_plus1_ji(self):
  128. def f(a, b): return (a@b+1).realize()
  129. jf = TinyJit(f)
  130. for i in range(1, 5):
  131. for j in range(1, 5):
  132. vi = Variable("i", 1, 10).bind(i)
  133. vj = Variable("j", 1, 10).bind(j)
  134. a = Tensor.rand(j, 3)
  135. b = Tensor.rand(3, i)
  136. symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
  137. expected = f(a, b).numpy()
  138. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  139. assert_jit_cache_len(jf, 1)
  140. def test_jit_symbolic_shape_mismatch(self):
  141. @TinyJit
  142. def add(a, b): return (a+b).realize()
  143. for i in range(1, 5):
  144. vi = Variable("i", 1, 10).bind(i)
  145. a = Tensor.rand(3, i).reshape(3, vi)
  146. b = Tensor.rand(3, i).reshape(3, vi)
  147. add(a, b)
  148. vi2 = Variable("i", 1, 10).bind(7)
  149. a = Tensor.rand(3, 7).reshape(3, vi2)
  150. bad = Tensor.rand(4, 7).reshape(4, vi2)
  151. with self.assertRaises(AssertionError):
  152. add(a, bad)
  153. def test_shrink(self):
  154. # shrink is a movement, so we pair it with a simple function to test the JIT interaction
  155. def f(a): return (a+1).realize()
  156. jf = TinyJit(f)
  157. for i in range(1, 5):
  158. vi = Variable("i", 1, 10).bind(i)
  159. a = Tensor.rand(7, 11)
  160. symbolic = a.shrink(((3,5),(vi,vi+2)))
  161. symbolic = jf(symbolic).numpy()
  162. expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
  163. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  164. assert_jit_cache_len(jf, 1)
  165. def test_ones_sum(self):
  166. def f(a): return a.sum().realize()
  167. jf = TinyJit(f)
  168. for i in range(1, 5):
  169. vi = Variable("i", 1, 10).bind(i)
  170. t = Tensor.ones(i)
  171. symbolic = jf(t.reshape(vi)).item()
  172. expected = f(t).item()
  173. np.testing.assert_equal(symbolic, expected)
  174. def test_mean(self):
  175. def f(a): return a.mean().realize()
  176. def f0(a): return a.mean(0).realize()
  177. def f1(a): return a.mean(1).realize()
  178. jf = TinyJit(f)
  179. jf0 = TinyJit(f0)
  180. jf1 = TinyJit(f1)
  181. for i in range(1, 5):
  182. vi = Variable("i", 1, 10).bind(i)
  183. # aixs = None
  184. a = Tensor.rand(i, 3)
  185. symbolic = jf(a.reshape(vi, 3)).numpy()
  186. expected = a.mean().numpy()
  187. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  188. # aixs = 0
  189. a = Tensor.rand(i, 3)
  190. symbolic = jf0(a.reshape(vi, 3)).numpy()
  191. expected = a.mean(0).numpy()
  192. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  193. # aixs = 1
  194. a = Tensor.rand(i, 3)
  195. symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy()
  196. expected = a.mean(1).numpy()
  197. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  198. def test_mean_2d(self):
  199. def f(a): return a.mean().realize()
  200. def f0(a): return a.mean(0).realize()
  201. def f1(a): return a.mean(1).realize()
  202. jf = TinyJit(f)
  203. jf0 = TinyJit(f0)
  204. jf1 = TinyJit(f1)
  205. for i in range(1, 5):
  206. for j in range(1, 5):
  207. vi = Variable("i", 1, 10).bind(i)
  208. vj = Variable("j", 1, 10).bind(j)
  209. # aixs = None
  210. a = Tensor.rand(i, j)
  211. symbolic = jf(a.reshape(vi, vj)).numpy()
  212. expected = a.mean().numpy()
  213. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  214. # aixs = 0
  215. a = Tensor.rand(i, j)
  216. symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy()
  217. expected = a.mean(0).numpy()
  218. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  219. # aixs = 1
  220. a = Tensor.rand(i, j)
  221. symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy()
  222. expected = a.mean(1).numpy()
  223. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  224. def test_var(self):
  225. def f(a): return a.var().realize()
  226. def f0(a): return a.var(0).realize()
  227. def f1(a): return a.var(1).realize()
  228. jf = TinyJit(f)
  229. jf0 = TinyJit(f0)
  230. jf1 = TinyJit(f1)
  231. for i in range(1, 5):
  232. vi = Variable("i", 1, 10).bind(i)
  233. # aixs = None
  234. a = Tensor.rand(i, 3)
  235. symbolic = jf(a.reshape(vi, 3)).numpy()
  236. expected = a.var().numpy()
  237. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  238. # aixs = 0
  239. a = Tensor.rand(i, 3)
  240. symbolic = jf0(a.reshape(vi, 3)).numpy()
  241. expected = a.var(0).numpy()
  242. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  243. # aixs = 1
  244. a = Tensor.rand(i, 3)
  245. symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy()
  246. expected = a.var(1).numpy()
  247. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  248. def test_var_2d(self):
  249. def f(a): return a.var().realize()
  250. def f0(a): return a.var(0).realize()
  251. def f1(a): return a.var(1).realize()
  252. jf = TinyJit(f)
  253. jf0 = TinyJit(f0)
  254. jf1 = TinyJit(f1)
  255. for i in range(1, 5):
  256. for j in range(1, 5):
  257. vi = Variable("i", 1, 10).bind(i)
  258. vj = Variable("j", 1, 10).bind(j)
  259. # aixs = None
  260. a = Tensor.rand(i, j)
  261. symbolic = jf(a.reshape(vi, vj)).numpy()
  262. expected = a.var().numpy()
  263. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  264. # aixs = 0
  265. a = Tensor.rand(i, j)
  266. symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy()
  267. expected = a.var(0).numpy()
  268. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  269. # aixs = 1
  270. a = Tensor.rand(i, j)
  271. symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy()
  272. expected = a.var(1).numpy()
  273. np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
  274. if __name__ == '__main__':
  275. unittest.main()