test_shapetracker_math.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import unittest
  2. from typing import List
  3. from tinygrad.helpers import prod
  4. from tinygrad.shape.view import View
  5. from tinygrad.shape.shapetracker import ShapeTracker
  6. from tinygrad.shape.symbolic import Variable, sym_infer
  7. class MultiShapeTracker:
  8. def __init__(self, sts:List[ShapeTracker]): self.sts = sts
  9. @property
  10. def shape(self): return self.sts[0].shape
  11. def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts]
  12. def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
  13. def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
  14. def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
  15. def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
  16. def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
  17. def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
  18. if st1.shape != st2.shape: return False
  19. if st1 == st2: return True
  20. idx = Variable("idx", 0, prod(st1.shape)-1)
  21. st1_idx, st1_valid = st1.reshape((st1.size,)).expr_idxs([idx])
  22. st2_idx, st2_valid = st2.reshape((st2.size,)).expr_idxs([idx])
  23. for i in range(idx.min, idx.max + 1):
  24. st1_off = sym_infer(st1_idx, {idx: i})
  25. st2_off = sym_infer(st2_idx, {idx: i})
  26. st1_v = sym_infer(st1_valid, {idx: i})
  27. st2_v = sym_infer(st2_valid, {idx: i})
  28. if st1_v != st2_v or (st1_off != st2_off and st1_v):
  29. print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}")
  30. print(st1)
  31. print(st2)
  32. return False
  33. return True
  34. class TestShapeTrackerBasics(unittest.TestCase):
  35. def test_pad_shrink_removes_mask(self):
  36. a = ShapeTracker.from_shape((10, 10))
  37. a = a.pad(((0,2), (0,2)))
  38. a = a.shrink(((0,10), (0,10)))
  39. assert len(a.views) == 1 and a.views[-1].mask is None
  40. def test_pad_shrink_leaves_mask(self):
  41. a = ShapeTracker.from_shape((10, 10))
  42. a = a.pad(((0,2), (0,2)))
  43. a = a.shrink(((0,10), (0,11)))
  44. assert len(a.views) == 1 and a.views[-1].mask is not None
  45. def test_reshape_makes_same(self):
  46. a = ShapeTracker.from_shape((2, 5))
  47. x = a.pad( ((2, 0), (0, 0)) )
  48. x = x.reshape( (2, 2, 5) )
  49. x1 = x.reshape( (4, 5) )
  50. x1 = x1.reshape( (2, 2, 5) )
  51. assert x == x1.simplify()
  52. def test_simplify_is_correct(self):
  53. multiv = ShapeTracker(views=(View(shape=(15, 3), strides=(9, 1), offset=6, mask=None, contiguous=False),
  54. View(shape=(4, 3), strides=(12, 4), offset=0, mask=None, contiguous=False)))
  55. assert st_equal(multiv, multiv.simplify())
  56. class TestShapeTrackerAdd(unittest.TestCase):
  57. def test_simple_add_reshape(self):
  58. a = ShapeTracker.from_shape((10, 10))
  59. a = a.reshape((100,))
  60. b = ShapeTracker.from_shape((100,))
  61. assert a+b == b
  62. def test_simple_add_permute(self):
  63. a = ShapeTracker.from_shape((10, 10))
  64. a = a.permute((1,0))
  65. b = ShapeTracker.from_shape((10, 10))
  66. b = b.permute((1,0))
  67. assert a+b == ShapeTracker.from_shape((10, 10))
  68. def test_plus_real1(self):
  69. st = MultiShapeTracker([ShapeTracker.from_shape((15, 9))])
  70. st.shrink( ((0, 15), (6, 9)) )
  71. backup = st.sts[0]
  72. st.sts.append(ShapeTracker.from_shape(backup.shape))
  73. st.reshape( (45,) )
  74. st.stride( (4,) )
  75. st.reshape( (4, 3) )
  76. assert st_equal(backup + st.sts[1], st.sts[0])
  77. def test_off_by_one(self):
  78. st1 = ShapeTracker(views=(View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True),
  79. View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
  80. st2 = ShapeTracker(views=(View(shape=(4,), strides=(1,), offset=0, mask=None, contiguous=True),
  81. View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
  82. assert not (st_equal(st1, st2))
  83. class TestShapeTrackerAddVariable(unittest.TestCase):
  84. def test_self_add(self):
  85. j = Variable("j", 0, 20).bind(10)
  86. a = ShapeTracker.from_shape((10,10))
  87. x = a.reshape((10, j))
  88. out = x + x
  89. assert out == x
  90. def test_self_add_reshape(self):
  91. j = Variable("j", 0, 20).bind(10)
  92. a = ShapeTracker.from_shape((10,10))
  93. x = a.reshape((10, j))
  94. out = x.reshape((5, 2, j)) + x
  95. assert out == x
  96. def test_merge_symbolic_views(self):
  97. var_i = Variable('i', 1, 10)
  98. var_j = Variable('i', 1, 10)
  99. vm1 = View(shape=(var_i, var_j, 3), strides=(3, 0, 1), offset=0, mask=None, contiguous=False)
  100. vm2 = View(shape=(var_i, var_j, 3), strides=(var_j*3, 3, 1), offset=0, mask=None, contiguous=True)
  101. ShapeTracker((vm1,)) + ShapeTracker((vm2,))
  102. @unittest.skip("two vars not supported")
  103. def test_merge_symbolic_views_2(self):
  104. var_i = Variable('i', 1, 10)
  105. var_j = Variable('j', 1, 10)
  106. vm1 = View(shape=(var_i, var_j), strides=(0, 0), offset=0, mask=None, contiguous=False)
  107. vm2 = View(shape=(var_i, var_j), strides=(var_j, 1), offset=0, mask=None, contiguous=True)
  108. ret = (ShapeTracker((vm1,)) + ShapeTracker((vm2,))).reshape((var_i, var_j, 1))
  109. ret_2 = ShapeTracker((vm1,)) + ShapeTracker((vm2,)).reshape((var_i, var_j, 1))
  110. assert ret == ret_2
  111. class TestShapeTrackerInvert(unittest.TestCase):
  112. def test_invert_reshape(self):
  113. a = ShapeTracker.from_shape((10, 10))
  114. x = a.reshape((5, 20))
  115. ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape)
  116. assert ap == a, f"{ap} != {a}"
  117. def test_invert_permute(self):
  118. a = ShapeTracker.from_shape((5, 20))
  119. x = a.permute((1,0))
  120. ap = x + x.invert(a.shape)
  121. assert ap == a, f"{ap} != {a}"
  122. def test_invert_permute_3(self):
  123. a = ShapeTracker.from_shape((8, 4, 5))
  124. x = a.permute((1,2,0))
  125. ap = x + x.invert(a.shape)
  126. assert ap == a, f"{ap} != {a}"
  127. def test_invert_real1(self):
  128. a = ShapeTracker.from_shape((3, 6, 10))
  129. x = a.reshape( (3, 3, 2, 10) )
  130. x = x.permute( (2, 1, 3, 0) )
  131. ap = x + x.invert(a.shape)
  132. assert ap == a, f"{ap} != {a}"
  133. def test_cant_invert_expand(self):
  134. a = ShapeTracker.from_shape((10, 1))
  135. x = a.expand((10,10))
  136. assert x.invert(a.shape) is None
  137. def test_cant_invert_shrink(self):
  138. a = ShapeTracker.from_shape((10, 10))
  139. x = a.shrink(((0,10),(2,8)))
  140. assert x.invert(a.shape) is None
  141. def test_can_invert_flip(self):
  142. a = ShapeTracker.from_shape((20, 10))
  143. x = a.stride((-1,1))
  144. ap = x + x.invert(a.shape)
  145. assert st_equal(ap, a)
  146. def test_can_invert_flip_permute(self):
  147. a = ShapeTracker.from_shape((20, 10))
  148. x = a.permute((1,0))
  149. x = x.stride((-1,1))
  150. ap = x + x.invert(a.shape)
  151. assert st_equal(ap, a)
  152. def test_cant_invert_stride(self):
  153. a = ShapeTracker.from_shape((10, 10))
  154. x = a.stride((2,2))
  155. assert x.invert(a.shape) is None
  156. def test_invert_failure(self):
  157. a = ShapeTracker.from_shape((2, 5))
  158. x = a.pad( ((2, 0), (0, 0)) )
  159. x = x.reshape( (2, 2, 5) )
  160. x = x.reshape( (4, 5) )
  161. ap = x + x.invert(a.shape)
  162. assert st_equal(ap, a)
  163. if __name__ == '__main__':
  164. unittest.main()