test_shapetracker.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811
  1. #!/usr/bin/env python
  2. import unittest
  3. import numpy as np
  4. from tinygrad.helpers import prod, DEBUG
  5. from tinygrad.shape.shapetracker import ShapeTracker, View
  6. from tinygrad.shape.symbolic import Variable, NumNode
  7. from itertools import product
  8. def shapetracker_getitem(st, val):
  9. _locals = {"idx0": val, "valid": 1}
  10. idx, valid = st.reshape((st.size,)).expr_idxs()
  11. exec(f"valid={valid.render()};idx0={idx.render()}", None, _locals)
  12. return _locals["idx0"] if _locals["valid"] else -1
  13. class CheckingShapeTracker:
  14. def __init__(self, shape):
  15. self.st = ShapeTracker.from_shape(shape)
  16. self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
  17. @property
  18. def shape(self):
  19. return self.t.shape
  20. def simplify(self):
  21. self.st = self.st.simplify()
  22. return self
  23. def reshape(self, new_shape):
  24. self.st = self.st.reshape(new_shape)
  25. self.t = self.t.reshape(new_shape)
  26. return self
  27. def permute(self, axis):
  28. self.st = self.st.permute(axis)
  29. self.t = np.transpose(self.t, axis)
  30. return self
  31. def expand(self, new_shape):
  32. self.st = self.st.expand(new_shape)
  33. self.t = np.broadcast_to(self.t, new_shape)
  34. return self
  35. def flip(self, axis):
  36. self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
  37. self.t = np.flip(self.t, axis)
  38. return self
  39. def shrink(self, arg):
  40. self.st = self.st.shrink(arg)
  41. self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
  42. return self
  43. def pad(self, arg):
  44. self.st = self.st.pad(arg)
  45. self.t = np.pad(self.t, arg, constant_values=-1)
  46. return self
  47. def stride(self, arg):
  48. self.st = self.st.stride(arg)
  49. self.t = self.t[tuple([slice(None, None, x) for x in arg])]
  50. return self
  51. def __getitem__(self, val):
  52. return self.t.flatten()[val]
  53. @property
  54. def views(self): return self.st.views
  55. @property
  56. def contiguous(self): return self.st.contiguous
  57. def assert_same(self):
  58. x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))]
  59. y = [self[i] for i in range(prod(self.shape))]
  60. idx, valid = self.st.expr_idxs()
  61. if DEBUG >= 1: print(x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st)
  62. assert self.st.shape == self.shape
  63. assert x == y, f"mismatch shapetracker:{x} real:{y}"
  64. @unittest.skip("don't create shapetrackers with views")
  65. class TestRealIssues(unittest.TestCase):
  66. def test_reshape_doesnt_multiview(self):
  67. self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),))
  68. self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
  69. assert len(self.st.views) == 1
  70. def test_reshape_stable_diffusion(self):
  71. # regression test for https://github.com/tinygrad/tinygrad/pull/2616
  72. st = ShapeTracker((View((2, 1920, 32, 32), (1310720, 1024, 32, 1), 0, ((0, 2), (0, 1280), (0, 32), (0, 32)), False),))
  73. st = st.reshape((2, 32, 240, 256))
  74. assert len(st.views) == 2
  75. def test_reshape_trailing_invalid_ones(self):
  76. st = ShapeTracker((View(shape=(1, 1, 5), strides=(0, 0, 1), offset=-5, mask=((1, 1), (0, 1), (0, 5)), contiguous=False),))
  77. st = st.reshape((5,))
  78. assert len(st.views) == 1
  79. assert st.views[0].mask == ((0,0),)
  80. class TestRealDoesntSimplify(unittest.TestCase):
  81. def tearDown(self):
  82. st = self.st.real_strides()
  83. print(st)
  84. self.st = self.st.simplify()
  85. assert len(self.st.views) != 1
  86. assert None in st
  87. def test_1(self):
  88. self.st = ShapeTracker((
  89. View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
  90. View.create((8, 6, 11), (66, 11, 1), 0, None)))
  91. assert self.st.real_strides() == (33, None, 1)
  92. def test_2(self):
  93. self.st = ShapeTracker((
  94. View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
  95. View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
  96. assert self.st.real_strides() == (None, 18, -3, -1)
  97. class TestRealStrides(unittest.TestCase):
  98. def test_1(self):
  99. self.st = ShapeTracker((
  100. View.create((2048,), (1,), 0, ((0, 512),)),
  101. View.create((16, 32, 4), (128, 4, 1), 0, None)))
  102. st = self.st.real_strides()
  103. print(self.st, st)
  104. assert st == (None, 4, 1)
  105. class TestRealSimplifies(unittest.TestCase):
  106. def tearDown(self):
  107. st = self.st.real_strides()
  108. self.st = self.st.simplify()
  109. assert len(self.st.views) == 1
  110. print(self.st.views[-1].strides, st)
  111. assert self.st.views[-1].strides == st
  112. def test_1(self):
  113. self.st = ShapeTracker((
  114. View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
  115. View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)))
  116. def test_2(self):
  117. self.st = ShapeTracker((
  118. View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
  119. View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)))
  120. class TestViewMinify(unittest.TestCase):
  121. def test_minifies(self):
  122. assert len(View.create((10,10)).minify().shape) == 1
  123. assert len(View.create((10,10)).permute((1,0)).minify().shape) == 2
  124. assert len(View.create((10,10,10,10)).permute((1,0,2,3)).minify().shape) == 3
  125. class TestIndexExpressions2d(unittest.TestCase):
  126. def setUp(self):
  127. shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
  128. offsets = [0, 1, 15, 28, 10000]
  129. self.sts = [ShapeTracker.from_shape((prod(base_shape)+offset,)).shrink(((offset, offset+prod(base_shape)),)).\
  130. reshape(base_shape) for base_shape in shapes for offset in offsets]
  131. self.offset = [NumNode(offset) for base_shape in shapes for offset in offsets]
  132. self.shapes = [shape for shape in shapes for offset in offsets]
  133. self.idxs_exprs = []
  134. def tearDown(self):
  135. for st, offset, shape, idxs_expr in zip(self.sts, self.offset, self.shapes, self.idxs_exprs):
  136. numel = prod(shape)
  137. assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0]
  138. self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel)
  139. idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)]
  140. idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)]
  141. idx2s = [(0,0), (0, min(1, st.shape[2]-1)), (0, st.shape[2]-1), (min(3, st.shape[2]-1), min(6, st.shape[2]-1)),
  142. (st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s]
  143. for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s):
  144. idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None]
  145. assert idxs_expr(idxs) == st.expr_idxs(idxs)[0]
  146. self.check_bounds(idxs_expr(idxs), offset, numel)
  147. def default_idx(self, shape):
  148. return Variable("idx", 0, prod(shape)-1)
  149. def default_idxs(self, shape):
  150. return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)]
  151. def check_bounds(self, expr, offset, numel):
  152. assert expr.min >= offset
  153. assert expr.max <= offset + numel - 1
  154. def test_noop(self):
  155. for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
  156. self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
  157. def test_permute(self):
  158. new_st = []
  159. for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
  160. st = st.permute((1, 0))
  161. self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
  162. new_st.append(st)
  163. self.sts = new_st
  164. def test_reshape(self):
  165. new_st = []
  166. for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
  167. st = st.reshape((base_shape[0], 1, base_shape[1]))
  168. self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
  169. new_st.append(st)
  170. self.sts = new_st
  171. def test_reshape_expand(self):
  172. new_st = []
  173. for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
  174. st = st.reshape((base_shape[0], 1, base_shape[1]))
  175. st = st.expand((base_shape[0], base_shape[1], base_shape[1]))
  176. self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
  177. new_st.append(st)
  178. self.sts = new_st
  179. def test_permute_reshape_1(self): # This tests multiple views
  180. new_st = []
  181. for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
  182. st = st.permute((1, 0))
  183. st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
  184. self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + \
  185. (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
  186. new_st.append(st)
  187. self.sts = new_st
  188. def test_permute_reshape_2(self):
  189. new_st = []
  190. for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
  191. st = st.permute((1, 0))
  192. st = st.reshape((1, base_shape[0]//5, base_shape[1]*5))
  193. self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + \
  194. (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
  195. new_st.append(st)
  196. self.sts = new_st
  197. def test_reshaping_splitting(self):
  198. self.st = CheckingShapeTracker((5,10,5,10))
  199. self.st.permute((1, 0, 3, 2))
  200. self.st.pad(((0,0), (0,5), (0,0), (0,5)))
  201. self.st.reshape((10,2,5,10,2,5))
  202. assert len(self.st.views) == 1
  203. self.st.assert_same()
  204. def test_reshape_splitting_1(self):
  205. self.st = CheckingShapeTracker((1,10,1))
  206. self.st.pad(((0,4),(0,0),(1,0)))
  207. self.st.reshape((5,5,2,2))
  208. assert len(self.st.views) == 1
  209. self.st.assert_same()
  210. def test_reshape_combining_1(self):
  211. self.st = CheckingShapeTracker((2,1,10))
  212. self.st.pad(((2,6), (0,0), (0,0)))
  213. self.st.reshape((100,))
  214. assert len(self.st.views) == 1
  215. self.st.assert_same()
  216. def test_reshape_combining_2(self):
  217. self.st = CheckingShapeTracker((1,1,5))
  218. self.st.pad(((3,6), (0,0), (0,5)))
  219. self.st.reshape((100,))
  220. assert len(self.st.views) == 1
  221. self.st.assert_same()
  222. def test_reshape_combining_3(self):
  223. self.st = CheckingShapeTracker((1,1,4))
  224. self.st.pad(((3,6), (0,0), (1,5)))
  225. self.st.reshape((100,))
  226. assert len(self.st.views) == 1
  227. assert self.st.views[0].mask[0] == (31, 35)
  228. self.st.assert_same()
  229. def test_reshape_combining_4(self):
  230. # interestingly this one is quite slow
  231. self.st = CheckingShapeTracker((1,1,5,5,1,1,5))
  232. self.st.pad(((3,6), (0,0), (0,5), (0,0), (3,6), (0,0), (0,5)))
  233. self.st.reshape((100,5,100))
  234. assert len(self.st.views) == 1
  235. self.st.assert_same()
  236. def test_reshape_splitting_combining(self):
  237. self.st = CheckingShapeTracker((1,5,5))
  238. self.st.pad(((0,4), (0,5), (0,0)))
  239. self.st.reshape((10,25))
  240. assert len(self.st.views) == 1
  241. self.st.assert_same()
  242. def test_reshape_only_1s(self):
  243. self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1))
  244. self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0)))
  245. self.st.reshape((5, 6, 3, 5))
  246. assert len(self.st.views) == 1
  247. self.st.assert_same()
  248. self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1))
  249. assert len(self.st.views) == 1
  250. self.st.assert_same()
  251. self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1))
  252. assert len(self.st.views) == 1
  253. self.st.assert_same()
  254. def test_zero_mask_1(self):
  255. self.st = CheckingShapeTracker((1, 3, 2))
  256. self.st.pad(((0,0), (0,3), (0,0)))
  257. self.st.shrink(((0,1), (3,6), (0,2)))
  258. self.st.reshape((3,2))
  259. assert len(self.st.views) == 1
  260. self.st.assert_same()
  261. self.st.reshape((1, 3, 1, 2, 1))
  262. assert len(self.st.views) == 1
  263. self.st.assert_same()
  264. def test_zero_mask_2(self):
  265. self.st = CheckingShapeTracker((1, 3, 2))
  266. self.st.pad(((0,2), (0,3), (0,0)))
  267. self.st.shrink(((2,3), (3,6), (0,2)))
  268. self.st.reshape((3,2))
  269. assert len(self.st.views) == 1
  270. self.st.assert_same()
  271. self.st.reshape((1, 3, 1, 2, 1))
  272. assert len(self.st.views) == 1
  273. self.st.assert_same()
  274. def test_expanded_reshaped(self):
  275. self.st = CheckingShapeTracker((1, 3, 2, 1))
  276. self.st.expand((5, 3, 2, 2))
  277. self.st.pad(((0,0), (0,3), (0,0), (0, 0)))
  278. self.st.reshape((5, 2, 3, 2, 2))
  279. assert len(self.st.views) == 1
  280. self.st.assert_same()
  281. def test_splitting_big(self):
  282. self.st = CheckingShapeTracker((1, 5, 1, 15, 1))
  283. self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0)))
  284. self.st.reshape((10, 1, 30))
  285. self.st.permute((2,1,0))
  286. self.st.reshape((2,3,5,2,5))
  287. assert len(self.st.views) == 1
  288. v = self.st.views[-1]
  289. assert v.strides == (0, 5, 1, 0, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
  290. self.st.assert_same()
  291. def test_combining_big(self):
  292. self.st = CheckingShapeTracker((1,3,1,5,3,1))
  293. self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0)))
  294. self.st.reshape((1,1,1,105,1,1))
  295. assert len(self.st.views) == 1
  296. v = self.st.views[-1]
  297. assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)) and v.offset == -30
  298. self.st.assert_same()
  299. def test_pad_reshape(self):
  300. self.st = CheckingShapeTracker((4,))
  301. self.st.pad(((2,2),))
  302. self.st.reshape((4,2))
  303. assert len(self.st.views) == 1
  304. self.st.assert_same()
  305. class TestSimplifyingShapeTracker(unittest.TestCase):
  306. def setUp(self):
  307. self.st = CheckingShapeTracker((1, 10))
  308. def tearDown(self):
  309. self.st.assert_same()
  310. # multiview simplify
  311. def test_expand_contract_simple(self):
  312. self.st = self.st.expand((10, 10))
  313. self.st = self.st.reshape((100,))
  314. print(self.st.views)
  315. assert(len(self.st.views) == 2)
  316. self.st = self.st.reshape((10, 10))
  317. print(self.st.views)
  318. self.st = self.st.simplify()
  319. print(self.st.views)
  320. assert(len(self.st.views) == 1)
  321. # multiview simplify
  322. def test_expand_contract_different_shape(self):
  323. self.st.expand((10, 10))
  324. self.st.reshape((100,))
  325. print(self.st.views)
  326. assert(len(self.st.views) == 2)
  327. self.st.reshape((2, 5, 2, 5))
  328. print(self.st.views)
  329. self.st = self.st.simplify()
  330. print(self.st.views)
  331. assert(len(self.st.views) == 1)
  332. # multiview simplify
  333. def test_expand_contract_still_complex(self):
  334. self.st.expand((10, 10))
  335. self.st.reshape((100,))
  336. print(self.st.views)
  337. assert(len(self.st.views) == 2)
  338. self.st.reshape((5, 20))
  339. self.st = self.st.simplify()
  340. print(self.st.views)
  341. assert(len(self.st.views) == 2)
  342. # Tensor.zeros(2, 4).permute(1,0).reshape(2, 4)
  343. # (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4
  344. class TestComplexShapeTracker(unittest.TestCase):
  345. def test_add_1s(self):
  346. self.st = CheckingShapeTracker((4, 4))
  347. self.st.permute((1,0))
  348. self.st.reshape((1,4,1,4,1))
  349. assert not self.st.contiguous
  350. self.st.permute((0,3,2,1,4))
  351. assert self.st.contiguous
  352. def test_permute_1s_simple(self):
  353. self.st = CheckingShapeTracker((1, 16, 9,9))
  354. self.st.permute((1,0,2,3))
  355. assert self.st.contiguous
  356. self.st = CheckingShapeTracker((2, 16, 9,9))
  357. self.st.permute((1,0,2,3))
  358. assert not self.st.contiguous
  359. def test_remove_1s_simple(self):
  360. self.st = CheckingShapeTracker((1, 16, 1, 1))
  361. self.st.reshape((16,))
  362. assert self.st.contiguous
  363. def test_remove_1s(self):
  364. self.st = CheckingShapeTracker((1, 4, 1, 4, 1))
  365. self.st.permute((0,3,2,1,4))
  366. self.st.reshape((4,4))
  367. assert not self.st.contiguous
  368. self.st.permute((1,0))
  369. assert self.st.contiguous
  370. def test_permute_reshape(self):
  371. self.st = CheckingShapeTracker((4, 4))
  372. self.st.permute((1,0))
  373. self.st.reshape((2, 2, 2, 2))
  374. # TODO: should also be tested by test_super_complex
  375. assert len(self.st.views) == 1
  376. def test_factorize_split(self):
  377. self.st = CheckingShapeTracker((4, 4))
  378. self.st.permute((1,0))
  379. self.st.reshape((2, 2, 2, 2))
  380. self.st.permute((2,3,0,1))
  381. assert self.st.contiguous
  382. def test_factorize_combine(self):
  383. self.st = CheckingShapeTracker((4, 4, 4))
  384. self.st.permute((2, 0, 1))
  385. self.st.reshape((4, 16))
  386. self.st.permute((1, 0))
  387. assert self.st.contiguous
  388. def test_factorize_combine_add_ones(self):
  389. self.st = CheckingShapeTracker((4, 4, 4))
  390. self.st.permute((2, 0, 1))
  391. self.st.reshape((4, 16, 1, 1))
  392. self.st.permute((1, 0, 2, 3))
  393. assert self.st.contiguous
  394. def test_fancy_factorize(self):
  395. self.st = CheckingShapeTracker((32, 3, 3, 1))
  396. self.st.reshape((8, 4, 3, 3))
  397. assert len(self.st.views) == 1
  398. def test_super_complex_2_fail(self):
  399. self.st = CheckingShapeTracker((4, 4, 4))
  400. self.st.permute((2, 0, 1))
  401. self.st.reshape((16, 4))
  402. assert len(self.st.views) != 1
  403. def test_work(self):
  404. self.st = CheckingShapeTracker((64, 1024, 4))
  405. self.st.reshape((1, 64, 128, 32))
  406. self.st.permute((0, 3, 1, 2))
  407. self.st.reshape((1, 32, 1, 64, 128))
  408. self.st.permute((0, 3, 4, 1, 2))
  409. assert self.st.contiguous
  410. def test_work2(self):
  411. self.st = CheckingShapeTracker((64, 1024, 4))
  412. self.st.reshape((1, 64, 128, 32))
  413. self.st.permute((0, 3, 1, 2))
  414. self.st.reshape((1, 1, 32, 64, 128))
  415. self.st.permute((0, 3, 4, 1, 2))
  416. self.st.reshape((64, 1024, 4))
  417. print(self.st.views)
  418. assert self.st.contiguous
  419. class TestSingleShapeTracker(unittest.TestCase):
  420. def setUp(self):
  421. self.st = CheckingShapeTracker((7,4))
  422. def tearDown(self):
  423. self.st.assert_same()
  424. def test_reshape(self):
  425. self.st.reshape((7,1,4))
  426. assert self.st.contiguous
  427. def test_permute(self):
  428. self.st.permute((1,0))
  429. assert not self.st.contiguous
  430. def test_shrink(self):
  431. self.st.shrink(((1,2), (0,4)))
  432. assert not self.st.contiguous
  433. def test_double_permute(self):
  434. self.st.permute((1,0))
  435. self.st.permute((1,0))
  436. assert self.st.contiguous
  437. def test_reshape_permute(self):
  438. self.st.reshape((7,1,4))
  439. self.st.permute((0,1,2))
  440. assert self.st.contiguous
  441. def test_reshape_permute_yes(self):
  442. self.st.reshape((7,1,4))
  443. self.st.permute((0,2,1))
  444. assert self.st.contiguous
  445. def test_reshape_permute_no(self):
  446. self.st.reshape((4,7))
  447. self.st.permute((1,0))
  448. assert not self.st.contiguous
  449. class TestShapeTrackerFuzzFailures(unittest.TestCase):
  450. def setUp(self):
  451. self.st = CheckingShapeTracker((3,3,3))
  452. def tearDown(self):
  453. self.st.assert_same()
  454. def test_case_1(self):
  455. self.st.shrink(((1, 2), (1, 3), (1, 3)))
  456. self.st.reshape((1, 4))
  457. self.st.shrink(((0, 1), (1, 3)))
  458. print(self.st.st)
  459. self.st = self.st.simplify()
  460. print(self.st.st)
  461. def test_case_2(self):
  462. self.st.stride( (1, 1, -2) )
  463. self.st.reshape( (3, 6) )
  464. self.st.shrink( ((1, 2), (1, 5)) )
  465. self.st.stride( (1, -1) )
  466. def test_case_3(self):
  467. self.st.shrink( ((0, 2), (0, 2), (0, 1)) )
  468. self.st.permute( (1, 0, 2) )
  469. self.st.reshape( (4,) )
  470. self.st.shrink( ((0, 3),) )
  471. self.st.stride( (-1,) )
  472. def test_case_4(self):
  473. self.st.reshape( (3, 3, 3, 1) )
  474. self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) )
  475. self.st.shrink( ((0, 2), (1, 2), (0, 2), (0, 1)) )
  476. self.st.expand( (2, 1, 2, 3) )
  477. class TestMaskedShapeTracker(unittest.TestCase):
  478. def test_pad_1x1(self):
  479. self.st = CheckingShapeTracker((1,1))
  480. self.st.pad(((1,1), (1,1)))
  481. self.st.assert_same()
  482. def test_pad_2x2(self):
  483. self.st = CheckingShapeTracker((2,2))
  484. self.st.pad(((1,1), (1,1)))
  485. self.st.assert_same()
  486. def test_pad_reshape(self):
  487. st1 = CheckingShapeTracker((1, 2))
  488. st1.pad(((1, 0), (0, 1)))
  489. st1.reshape((3, 2))
  490. st1.assert_same()
  491. st2 = CheckingShapeTracker((1, 2))
  492. st2.pad(((1, 1), (0, 2)))
  493. st2.reshape((4, 3))
  494. st2.assert_same()
  495. st3 = CheckingShapeTracker((1, 1, 1, 2))
  496. st3.pad(((0, 2), (1, 2), (2, 2), (0, 4)))
  497. st3.reshape((4, 3, 6, 5))
  498. st3.assert_same()
  499. class TestShapeTracker(unittest.TestCase):
  500. def setUp(self):
  501. self.st = CheckingShapeTracker((7,4))
  502. self.apply = lambda fxn: [fxn(x) for x in [self.st]]
  503. def tearDown(self):
  504. self.st.assert_same()
  505. def test_noop(self):
  506. pass
  507. def test_simple_split(self):
  508. self.test_permute()
  509. self.apply(lambda x: x.reshape((prod(self.st.shape), )))
  510. def test_simple_pad(self):
  511. self.st.pad(((1,1), (1,1)))
  512. def test_pad_shrink(self):
  513. self.st.pad(((1,1), (1,1)))
  514. self.st.shrink(((0,4), (0,4)))
  515. def test_pad_one_sided(self):
  516. self.st.pad(((0,1), (0,0)))
  517. def test_pad_reshape(self):
  518. self.st.pad(((0,1), (0,0)))
  519. self.st.reshape((8*4,))
  520. def test_pad_pad(self):
  521. self.st.pad(((1,1), (1,1)))
  522. self.st.pad(((1,1), (1,1)))
  523. def test_pad_permute(self):
  524. self.st.pad(((1,1), (2,2)))
  525. self.st.permute((1,0))
  526. def test_pad_expand(self):
  527. self.st.reshape((7,4,1))
  528. self.st.pad(((1,1), (1,1), (0,0)))
  529. self.st.expand((9,6,4))
  530. def test_pad_expand_alt(self):
  531. self.st.pad(((1,1), (1,1)))
  532. self.st.reshape((9,6,1))
  533. self.st.expand((9,6,4))
  534. def test_pad_stride(self):
  535. self.st.pad(((1,4), (1,3)))
  536. self.st.stride((2,2))
  537. def test_pad_stride_neg(self):
  538. self.st.pad(((1,2), (1,0)))
  539. self.st.stride((-1,-1))
  540. def test_pad_stride_both(self):
  541. self.st.pad(((1,2), (1,0)))
  542. self.st.stride((-2,-2))
  543. def test_shrink_pad(self):
  544. self.st.shrink(((0,4), (0,4)))
  545. self.st.pad(((1,1), (1,1)))
  546. def test_reshape(self):
  547. new_shape = self.st.shape[::-1]
  548. self.apply(lambda x: x.reshape(new_shape))
  549. def test_permute(self):
  550. if len(self.st.shape) == 2: self.apply(lambda x: x.permute((1,0)))
  551. elif len(self.st.shape) == 3: self.apply(lambda x: x.permute((2,0,1)))
  552. def test_reshape_with_1(self):
  553. new_shape = (self.st.shape[0], 1, self.st.shape[1])
  554. self.apply(lambda x: x.reshape(new_shape))
  555. def test_expand(self):
  556. self.test_reshape_with_1()
  557. new_shape = list(self.st.shape)
  558. new_shape[1] = 2
  559. self.apply(lambda x: x.expand(tuple(new_shape)))
  560. def test_flip_0(self):
  561. self.apply(lambda x: x.flip((0,)))
  562. def test_flip_1(self):
  563. self.apply(lambda x: x.flip((1,)))
  564. def test_flip_01(self):
  565. self.apply(lambda x: x.flip((0,1)))
  566. def test_slice_0(self):
  567. self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
  568. def test_slice_1(self):
  569. self.apply(lambda x: x.shrink(((0, x.shape[0]), (1, x.shape[1]))))
  570. def test_slice_1c1(self):
  571. self.apply(lambda x: x.shrink(((0, 1), (0, 1))))
  572. def test_slice_1c2(self):
  573. self.apply(lambda x: x.shrink(((1, 2), (1, 2))))
  574. def test_double_permute(self):
  575. self.apply(lambda x: x.permute((1, 0)))
  576. self.apply(lambda x: x.permute((1, 0)))
  577. def test_slice_permute(self):
  578. self.apply(lambda x: x.shrink(((0, 2), (2, 4))))
  579. self.apply(lambda x: x.permute((1, 0)))
  580. def test_slice_expand(self):
  581. self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
  582. self.apply(lambda x: x.expand((2, 10)))
  583. def test_double_stride(self):
  584. self.apply(lambda x: x.stride((1, 2)))
  585. self.apply(lambda x: x.stride((2, 1)))
  586. def test_stride(self): self.apply(lambda x: x.stride((2,1)))
  587. def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
  588. def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
  589. def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
  590. def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
  591. def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
  592. def test_reshape_then_permute(self):
  593. self.test_reshape()
  594. self.test_permute()
  595. def test_reshape_then_expand(self):
  596. self.test_reshape()
  597. self.test_expand()
  598. def test_permute_then_reshape(self):
  599. self.test_permute()
  600. self.test_reshape()
  601. def test_expand_then_reshape(self):
  602. self.test_expand()
  603. self.test_reshape()
  604. def test_combo(self):
  605. self.test_permute()
  606. self.test_reshape()
  607. self.test_slice_1()
  608. self.test_expand()
  609. self.test_permute()
  610. class TestShapeTrackerSize(unittest.TestCase):
  611. def test_simple_size(self):
  612. st = ShapeTracker.from_shape((100, 100))
  613. self.assertEqual(st.real_size(), 100*100)
  614. def test_expand_size(self):
  615. st = ShapeTracker.from_shape((100, 100))
  616. st = st.reshape((100, 100, 1))
  617. st = st.expand((100, 100, 100))
  618. self.assertEqual(st.real_size(), 100*100)
  619. def test_expand_size_flatten(self):
  620. st = ShapeTracker.from_shape((100, 100))
  621. st = st.reshape((100, 100, 1))
  622. st = st.expand((100, 100, 100))
  623. st = st.reshape((100*100*100,))
  624. self.assertEqual(st.real_size(), 100*100)
  625. def test_shrink_size_axis_0(self):
  626. st = ShapeTracker.from_shape((100, 100))
  627. st = st.shrink(((0, 50), (0, 100)))
  628. self.assertEqual(st.real_size(), 50*100)
  629. def test_shrink_size_axis_0_variable(self):
  630. st = ShapeTracker.from_shape((100, 100))
  631. st = st.shrink(((0, Variable("a", 0, 50)), (0, 100)))
  632. self.assertEqual(st.real_size(), 50*100)
  633. def test_shrink_size_axis_1(self):
  634. st = ShapeTracker.from_shape((100, 100))
  635. st = st.shrink(((0, 100), (0, 50)))
  636. self.assertEqual(st.real_size(), 9950) # careful here
  637. class TestIdxs(unittest.TestCase):
  638. def test_check_idx_range(self):
  639. # generated from: (Tensor.rand(4096,599*64) @ Tensor.rand(599*64,1024)).realize()
  640. # TODO: use int64
  641. st = ShapeTracker(views=(View(shape=(4096, 1024, 599, 1), strides=(613376, 599, 1, 0), offset=0, mask=None, contiguous=True),))
  642. with self.assertRaises(AssertionError):
  643. st.expr_idxs()
  644. class TestConsecutive(unittest.TestCase):
  645. @classmethod
  646. def setUpClass(self):
  647. from tinygrad.tensor import Tensor # easier test setup
  648. self.t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
  649. self.const = Tensor(2)
  650. self.ones = Tensor.ones(2, 4)
  651. def test_unmodified(self):
  652. assert self.t.lazydata.st.consecutive
  653. assert self.t.reshape(4, 2).lazydata.st.consecutive
  654. assert self.t.reshape(1, 8).lazydata.st.consecutive
  655. def test_sliced(self):
  656. assert self.t[0].lazydata.st.consecutive
  657. assert self.t[0, 1:2].lazydata.st.consecutive
  658. assert self.t[1].lazydata.st.consecutive
  659. assert not self.t[:, 0].lazydata.st.consecutive
  660. assert not self.t[:, 1].lazydata.st.consecutive
  661. def test_padded(self):
  662. assert not self.t.pad(((1, 1), None)).lazydata.st.consecutive
  663. assert not self.t.pad((None, (1, 1))).lazydata.st.consecutive
  664. def test_const(self):
  665. assert self.const.lazydata.st.consecutive
  666. def test_ones(self):
  667. assert not self.ones.lazydata.st.consecutive
  668. assert not self.ones[0, :].lazydata.st.consecutive
  669. # consecutive if sliced into size 1
  670. assert self.ones[0, 0].lazydata.st.consecutive
  671. if __name__ == '__main__':
  672. unittest.main()