test_symbolic_shapetracker.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import unittest
  2. from tinygrad.shape.shapetracker import ShapeTracker, View
  3. from tinygrad.shape.symbolic import Variable, NumNode
  4. from tinygrad.tensor import Tensor
  5. class TestSymbolic(unittest.TestCase):
  6. def test_symbolic_st(self):
  7. x = Variable("x", 1, 100)
  8. st = ShapeTracker.from_shape((x, 3))
  9. assert st.shape == (x, 3)
  10. assert st.real_strides() == (3, 1)
  11. def test_expr_idxs(self):
  12. x = Variable("x", 1, 100)
  13. st = ShapeTracker.from_shape((x, 3))
  14. idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
  15. e1, e2 = st.expr_idxs(idxs)
  16. assert e1.render() == "((x*3)+y)"
  17. assert e2.render() == "1"
  18. st = st.permute((1, 0))
  19. e1, e2 = st.expr_idxs(idxs)
  20. assert e1.render() == "((y*3)+x)"
  21. assert e2.render() == "1"
  22. def test_cat_dim0_strides(self):
  23. i = Variable("i", 1, 5).bind(3)
  24. j = Variable("j", 1, 5).bind(3)
  25. k = Variable("k", 1, 5).bind(3)
  26. t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
  27. st = t.lazydata.st
  28. assert st.shape == (i+j+k, 4)
  29. assert st.real_strides() == (4, 1)
  30. t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
  31. st = t.lazydata.st
  32. assert st.shape == (2*i+3, 3)
  33. assert st.real_strides() == (3, 1)
  34. def test_cat_dim1_strides(self):
  35. i = Variable("i", 1, 5).bind(4)
  36. j = Variable("j", 1, 5).bind(4)
  37. k = Variable("k", 1, 5).bind(4)
  38. t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
  39. st = t.lazydata.st
  40. assert st.shape == (3, i+j+k)
  41. assert st.real_strides() == (i+j+k, 1)
  42. class TestSymbolicVarVals(unittest.TestCase):
  43. def test_var_vals_empty(self):
  44. assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {}
  45. def test_var_vals_shape(self):
  46. x = Variable("x", 1, 100).bind(3)
  47. assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3}
  48. def test_var_vals_offset(self):
  49. x = Variable("x", 1, 100).bind(3)
  50. st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
  51. assert st.views[-1].offset == x * 3
  52. assert st.var_vals == {Variable("x", 1, 100): 3}
  53. def test_var_vals_mask(self):
  54. x = Variable("x", 1, 100).bind(3)
  55. view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4)))
  56. st = ShapeTracker(views=(view,))
  57. assert st.var_vals == {Variable("x", 1, 100): 3}
  58. def test_var_vals_complex(self):
  59. x = Variable("x", 1, 100).bind(3)
  60. y = Variable("y", 1, 100).bind(4)
  61. z = Variable("z", 1, 100).bind(5)
  62. st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
  63. assert st.views[-1].offset == y * z
  64. assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
  65. def test_shrink_reshape(self):
  66. x = Variable("x", 1, 100).bind(3)
  67. st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5)))
  68. st = st.reshape((3*4*3,))
  69. assert st.var_vals == {Variable("x", 1, 100): 3}
  70. class TestShapeTrackerUnbind(unittest.TestCase):
  71. def test_view_unbind(self):
  72. v = Variable("v", 1, 100)
  73. bv = Variable("v", 1, 100).bind(3)
  74. unbound_view, var_val = View.create(shape=(bv, 4)).unbind()
  75. assert unbound_view == View.create(shape=(v, 4))
  76. assert var_val == {v: 3}
  77. def test_reshape_unbind(self):
  78. v = Variable("v", 1, 100)
  79. bv = Variable("v", 1, 100).bind(3)
  80. t = Tensor.rand(3, 4).reshape(bv, 4)
  81. unbound_st, var_val = t.lazydata.st.unbind()
  82. assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
  83. assert var_val == {v: 3}
  84. def test_shrink_unbind(self):
  85. v = Variable("v", 1, 100)
  86. bv = Variable("v", 1, 100).bind(2)
  87. t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
  88. unbound_st, var_val = t.lazydata.st.unbind()
  89. assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
  90. assert var_val == {v: 2}
  91. class TestSymbolicReshapeFromContiguous(unittest.TestCase):
  92. def test_reshape_into_symbols_simple(self):
  93. for i in range(1, 6):
  94. vi = Variable("i", 1, 5).bind(i)
  95. t = Tensor.rand(i, 4).reshape(vi, 4)
  96. assert t.shape == (vi, 4)
  97. t = Tensor.rand(i, 6).reshape(vi, 2, 3)
  98. assert t.shape == (vi, 2, 3)
  99. def test_reshape_symbols_reshape_ints(self):
  100. for i in range(1, 6):
  101. vi = Variable("i", 1, 5).bind(i)
  102. t = Tensor.rand(i, 4).reshape(vi, 4)
  103. assert t.shape == (vi, 4)
  104. t = t.reshape(i, 4)
  105. assert t.shape == (i, 4)
  106. def test_reshape_into_symbols_bad_shape(self):
  107. vi = Variable("i", 1, 10).bind(4)
  108. # TODO: this never actually worked, it relied on lazy
  109. #with self.assertRaises(ValueError):
  110. # Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
  111. with self.assertRaises(AssertionError):
  112. Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
  113. def test_two_symbol_reshape(self):
  114. for i in range(1, 6):
  115. for j in range(1, 6):
  116. vi = Variable("i", 1, 5).bind(i)
  117. vj = Variable("j", 1, 5).bind(j)
  118. t = Tensor.rand(i, j).reshape(vi, vj)
  119. assert t.shape == (vi, vj)
  120. # NOTE: this is currently not allowed
  121. # t = t.reshape(1, vi*vj)
  122. # assert t.shape == (1, vi*vj)
  123. t = t.reshape(vj, vi)
  124. assert t.shape == (vj, vi)
  125. def test_symbolic_mask(self):
  126. # taken from gpt2 single kvcache
  127. # these two caused problems in gpt2 if reshape merged views
  128. view = View(shape=(1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64), strides=(0, 0, 64, 1), offset=NumNode(1024), mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (NumNode(1)+Variable('start_pos', 1, 128).bind(2))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
  129. new_shape = (1, 1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64)
  130. assert view.reshape(new_shape) is None
  131. view = View(shape=(2, 1, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (NumNode(1)+Variable('start_pos', 1, 128))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
  132. new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
  133. assert view.reshape(new_shape) is None
  134. class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
  135. def test_reshape_from_const(self):
  136. vi = Variable("i", 1, 5).bind(4)
  137. t = Tensor.ones(3, 4).reshape(3, vi)
  138. assert t.shape == (3, vi)
  139. assert not t.lazydata.st.contiguous
  140. assert len(t.lazydata.st.views) == 1
  141. def test_reshape_not_allowed(self):
  142. vi = Variable("i", 1, 5).bind(4)
  143. with self.assertRaises(ValueError):
  144. # different shape length # TODO: cases where contractions matched might be fine
  145. Tensor.ones(3, 4, 1).reshape(3, vi)
  146. with self.assertRaises(ValueError):
  147. # size matched, but dimensions do not match
  148. Tensor.ones(4, 3).reshape(3, vi)
  149. def test_reshape_from_padded(self):
  150. vi = Variable("i", 1, 5).bind(4)
  151. t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
  152. st = t.lazydata.st
  153. assert len(st.views) == 1
  154. view = st.views[0]
  155. assert view.shape == (4, 3, 2)
  156. t = t.reshape(vi, 3, 2)
  157. st2 = t.lazydata.st
  158. assert len(st2.views) == 1
  159. view2 = st2.views[0]
  160. # check only shape changed. strides, offset, mask, contiguous remained the same
  161. assert view2.shape == (vi, 3, 2)
  162. assert view.strides == view2.strides == (0, 4, 1)
  163. assert view.offset == view2.offset == 1
  164. assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
  165. assert not view.contiguous and not view2.contiguous
  166. class TestSymbolicExpand(unittest.TestCase):
  167. def test_expand_into_symbols(self):
  168. vi = Variable("i", 1, 5).bind(3)
  169. vj = Variable("j", 1, 5).bind(3)
  170. a = Tensor([[1], [2], [3]]).expand((3, vi))
  171. assert a.shape == (3, vi)
  172. a = a.reshape(3, vi, 1).expand((3, vi, vj))
  173. assert a.shape == (3, vi, vj)
  174. def test_plus_expands_constant(self):
  175. for i in range(1, 6):
  176. vi = Variable("i", 1, 5).bind(i)
  177. a = Tensor.rand(3, i).reshape(3, vi)
  178. a = a + 1
  179. assert a.shape == (3, vi)
  180. class TestSymbolicShrink(unittest.TestCase):
  181. def test_shrink_symbols(self):
  182. vi = Variable("i", 1, 5)
  183. t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
  184. assert t.shape == (2, 1)
  185. class TestSymbolicPad(unittest.TestCase):
  186. def test_pad(self):
  187. v = Variable("v", 1, 100).bind(5)
  188. t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9)
  189. assert t.shape == (9,)
  190. st = t.lazydata.st
  191. print(st)
  192. # TODO: fix this, required for symbolic arange
  193. with self.assertRaises(RuntimeError):
  194. st.expr_idxs()
  195. class TestSymbolicShapeExpr(unittest.TestCase):
  196. def test_symbolic_expr_idxs(self):
  197. # taken from symbolic shape llama
  198. i = Variable("i", 1, 120)
  199. gidx0 = Variable("gidx0", 0, i)
  200. lidx1 = Variable("lidx1", 0, 7)
  201. idx = (gidx0, lidx1, NumNode(1))
  202. shape = (i+1, 8, 4)
  203. strides = (1, (i*4)+4, i+1)
  204. st = ShapeTracker((View.create(shape, strides), ))
  205. idx, _valid = st.expr_idxs(idx)
  206. assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
  207. if __name__ == '__main__':
  208. unittest.main()