test_helpers.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import unittest
  2. from PIL import Image
  3. from tinygrad.helpers import Context, ContextVar
  4. from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
  5. from tinygrad.shape.symbolic import Variable, NumNode
  6. VARIABLE = ContextVar("VARIABLE", 0)
  7. class TestContextVars(unittest.TestCase):
  8. # Ensuring that the test does not modify variables outside the tests.
  9. ctx = Context()
  10. def setUp(self): TestContextVars.ctx.__enter__()
  11. def tearDown(self): TestContextVars.ctx.__exit__()
  12. def test_initial_value_is_set(self):
  13. _TMP = ContextVar("_TMP", 5)
  14. self.assertEqual(_TMP.value, 5)
  15. def test_multiple_creation_ignored(self):
  16. _TMP2 = ContextVar("_TMP2", 1)
  17. _TMP2 = ContextVar("_TMP2", 2)
  18. self.assertEqual(_TMP2.value, 1)
  19. def test_new_var_inside_context(self):
  20. # Creating a _new_ variable inside a context should not have any effect on its scope (?)
  21. with Context(VARIABLE=1):
  22. _TMP3 = ContextVar("_TMP3", 1)
  23. _TMP3 = ContextVar("_TMP3", 2)
  24. self.assertEqual(_TMP3.value, 1)
  25. def test_value_accross_modules(self):
  26. # Mocking module import by invoking the code but not in our globals().
  27. exec('from tinygrad.helpers import ContextVar;C = ContextVar("C", 13)', {}) # pylint:disable=exec-used
  28. # It should not matter that the first creation was in another module.
  29. C = ContextVar("C", 0)
  30. self.assertEqual(C.value, 13)
  31. def test_assignment_across_modules(self):
  32. B = ContextVar("B", 1)
  33. # local assignment
  34. B.value = 2
  35. self.assertEqual(B.value, 2)
  36. # Assignment in another module.
  37. exec('from tinygrad.helpers import ContextVar;B = ContextVar("B", 0);B.value = 3;', {}) # pylint:disable=exec-used
  38. # Assignment in another module should affect this one as well.
  39. self.assertEqual(B.value, 3)
  40. def test_context_assignment(self):
  41. with Context(VARIABLE=1):
  42. self.assertEqual(VARIABLE.value, 1)
  43. self.assertEqual(VARIABLE.value, 0)
  44. def test_unknown_param_to_context(self):
  45. with self.assertRaises(KeyError):
  46. with Context(SOMETHING_ELSE=1):
  47. pass
  48. def test_inside_context_assignment(self):
  49. with Context(VARIABLE=4):
  50. # What you can and cannot do inside a context.
  51. # 1. This type of statement has no effect.
  52. VARIABLE = ContextVar("VARIABLE", 0)
  53. self.assertTrue(VARIABLE >= 4, "ContextVars inside contextmanager may not set a new value")
  54. # 2. The call syntax however has a local effect.
  55. VARIABLE.value = 13
  56. self.assertTrue(VARIABLE.value == 13, "Call syntax however works inside a contextmanager.")
  57. # Related to 2. above. Note that VARIABLE is back to 0 again as expected.
  58. self.assertEqual(VARIABLE.value, 0)
  59. def test_new_var_inside_context_other_module(self):
  60. with Context(VARIABLE=1):
  61. _NEW2 = ContextVar("_NEW2", 0)
  62. _NEW2 = ContextVar("_NEW2", 1)
  63. self.assertEqual(_NEW2.value, 0)
  64. code = """\
  65. from tinygrad.helpers import Context, ContextVar
  66. with Context(VARIABLE=1):
  67. _NEW3 = ContextVar("_NEW3", 0)"""
  68. exec(code, {}) # pylint:disable=exec-used
  69. # While _NEW3 was created in an outside scope it should still work the same as above.
  70. _NEW3 = ContextVar("_NEW3", 1)
  71. self.assertEqual(_NEW3.value, 0)
  72. def test_nested_context(self):
  73. with Context(VARIABLE=1):
  74. with Context(VARIABLE=2):
  75. with Context(VARIABLE=3):
  76. self.assertEqual(VARIABLE.value, 3)
  77. self.assertEqual(VARIABLE.value, 2)
  78. self.assertEqual(VARIABLE.value, 1)
  79. self.assertEqual(VARIABLE.value, 0)
  80. def test_decorator(self):
  81. @Context(VARIABLE=1, DEBUG=4)
  82. def test():
  83. self.assertEqual(VARIABLE.value, 1)
  84. self.assertEqual(VARIABLE.value, 0)
  85. test()
  86. self.assertEqual(VARIABLE.value, 0)
  87. def test_context_exit_reverts_updated_values(self):
  88. D = ContextVar("D", 1)
  89. D.value = 2
  90. with Context(D=3):
  91. ...
  92. assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
  93. class TestMergeDicts(unittest.TestCase):
  94. def test_merge_dicts(self):
  95. a = {"a": 1, "b": 2}
  96. b = {"a": 1, "c": 3}
  97. c = {}
  98. d = {"a": 2, "b": 2}
  99. assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
  100. assert merge_dicts([a, c]) == a
  101. assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
  102. with self.assertRaises(AssertionError):
  103. merge_dicts([a, d])
  104. class TestStripParens(unittest.TestCase):
  105. def test_simple(self): self.assertEqual("1+2", strip_parens("(1+2)"))
  106. def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))"))
  107. def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)"))
  108. class TestProd(unittest.TestCase):
  109. def test_empty(self): self.assertEqual(1, prod(tuple()))
  110. def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
  111. def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
  112. def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
  113. def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
  114. class TestRoundUp(unittest.TestCase):
  115. def test_round_up(self):
  116. self.assertEqual(round_up(-3,4), 0)
  117. self.assertEqual(round_up(-4,4), -4)
  118. self.assertEqual(round_up(6,4), 8)
  119. self.assertEqual(round_up(8,4), 8)
  120. self.assertEqual(round_up(232, 24984), 24984)
  121. self.assertEqual(round_up(24984, 232), 25056)
  122. @unittest.skip("no fetch tests because they need internet")
  123. class TestFetch(unittest.TestCase):
  124. def test_fetch_bad_http(self):
  125. self.assertRaises(Exception, fetch, 'http://www.google.com/404', allow_caching=False)
  126. def test_fetch_small(self):
  127. assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0)
  128. def test_fetch_img(self):
  129. img = fetch("https://avatars.githubusercontent.com/u/132956020", allow_caching=False)
  130. with Image.open(img) as pimg:
  131. assert pimg.size == (77, 77), pimg.size
  132. def test_fetch_subdir(self):
  133. img = fetch("https://avatars.githubusercontent.com/u/132956020", allow_caching=False, subdir="images")
  134. with Image.open(img) as pimg:
  135. assert pimg.size == (77, 77), pimg.size
  136. assert img.parent.name == "images"
  137. class TestFullyFlatten(unittest.TestCase):
  138. def test_fully_flatten(self):
  139. self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2])
  140. self.assertEqual(fully_flatten(((1, 3), (1, 2))), [1, 3, 1, 2])
  141. self.assertEqual(fully_flatten([[[1], [3]], [[1], [2]]]), [1, 3, 1, 2])
  142. self.assertEqual(fully_flatten([[[[1], 2], 3], 4]), [1, 2, 3, 4])
  143. self.assertEqual(fully_flatten([[1, 2, [3, 4]], [5, 6], 7]), [1, 2, 3, 4, 5, 6, 7])
  144. self.assertEqual(fully_flatten([[1, "ab"], [True, None], [3.14, [5, "b"]]]), [1, "ab", True, None, 3.14, 5, "b"])
  145. class TestMemoryview(unittest.TestCase):
  146. def test_from_mv_to_mv(self):
  147. base = memoryview(bytearray(b"\x11\x22\x33"*40))
  148. ct = from_mv(base)
  149. mv = to_mv(ct, len(base))
  150. mv[0] = 2
  151. assert base[0] == 2
  152. class TestGetContraction(unittest.TestCase):
  153. def test_contraction(self):
  154. r = get_contraction((1,2,3,4), (2,3,4))
  155. self.assertEqual(r, [[0, 1], [2], [3]])
  156. r = get_contraction((2,1,3,4), (2,3,4))
  157. self.assertEqual(r, [[0], [1, 2], [3]])
  158. r = get_contraction((1,2,3,1,4), (1,2,3,4))
  159. self.assertEqual(r, [[], [0, 1], [2], [3, 4]])
  160. r = get_contraction((1,2,3,1,4,1,1), (2,3,4))
  161. self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
  162. r = get_contraction((1,2,3,4), (1,2,3*4))
  163. self.assertEqual(r, [[], [0, 1], [2, 3]])
  164. r = get_contraction((1,2,3,4), (2,1,3,4))
  165. self.assertEqual(r, [[0, 1], [], [2], [3]])
  166. r = get_contraction((1,2,3,4), (1,1,2*3*4,1))
  167. self.assertEqual(r, [[], [], [0,1,2,3], []])
  168. r = get_contraction((2,1,3,4), (1,2,3,4))
  169. self.assertEqual(r, [[], [0], [1, 2], [3]])
  170. r = get_contraction((1,2,3,4), (2*3*4,1,1,1))
  171. self.assertEqual(r, [[0, 1, 2, 3], [], [], []])
  172. r = get_contraction((4,4,4,4), (16,1,16))
  173. self.assertEqual(r, [[0, 1], [], [2, 3]])
  174. r = get_contraction((1,2,3,4,1,1,1), (2,3,4))
  175. self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
  176. r = get_contraction((1,2,3,4), (1,2,3,4,1))
  177. self.assertEqual(r, [[], [0, 1], [2], [3], []])
  178. r = get_contraction((14,1,384,14,1,1,1,1), (1,14,384,14))
  179. self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7]])
  180. r = get_contraction((14,1,384,1,14,1,1,1,1), (1,14,384,14))
  181. self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7,8]])
  182. r = get_contraction((512, 512), (1, 1, 512, 1, 1, 1, 1, 512))
  183. self.assertEqual(r, [[], [], [0], [], [], [], [], [1]])
  184. r = get_contraction((1,2,3,4), (1,2,6,2))
  185. self.assertEqual(r, None)
  186. def test_contraction_ones(self):
  187. r = get_contraction((1,), (1,1,1))
  188. self.assertEqual(r, [[], [], [0]])
  189. r = get_contraction((1,1), (1,1,1))
  190. self.assertEqual(r, [[], [], [0, 1]])
  191. r = get_contraction((1,1,1,1), (1,))
  192. self.assertEqual(r, [[0,1,2,3]])
  193. r = get_contraction((1,1,1,1), (1,1))
  194. self.assertEqual(r, [[], [0,1,2,3]])
  195. r = get_contraction((1,1,1,1), (1,1,1))
  196. self.assertEqual(r, [[], [], [0,1,2,3]])
  197. r = get_contraction((1,1,1,1), (1,1,1,1))
  198. self.assertEqual(r, [[], [], [], [0,1,2,3]])
  199. class TestGetShape(unittest.TestCase):
  200. def test_get_shape(self):
  201. assert get_shape(2) == ()
  202. assert get_shape([]) == (0,)
  203. assert get_shape([[]]) == (1, 0)
  204. assert get_shape([[1, 2]]) == (1, 2)
  205. assert get_shape([[1, 2], (3, 4)]) == (2, 2)
  206. def test_inhomogeneous_shape(self):
  207. with self.assertRaises(ValueError): get_shape([[], [1]])
  208. with self.assertRaises(ValueError): get_shape([[1, [2]], [1]])
  209. if __name__ == '__main__':
  210. unittest.main()