test_const_folding.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import unittest, math
  2. from tinygrad import Tensor, Device, dtypes
  3. from tinygrad.engine.schedule import create_schedule
  4. from tinygrad.helpers import CI
  5. from tinygrad.ops import MetaOps
  6. import numpy as np
  7. from test.helpers import is_dtype_supported
  8. def _check_ast_count(desired_count:int, t:Tensor):
  9. # NOTE: this has side effect because everything can be scheduled only once
  10. schedule = create_schedule(t.lazydata.lbs)
  11. asts = [s for s in schedule if s.ast.op is MetaOps.KERNEL]
  12. assert len(asts) == desired_count
  13. class TestUnaryOpsConstFolding(unittest.TestCase):
  14. def test_all_consts_ops(self):
  15. _check_ast_count(0, Tensor.ones(4).exp())
  16. _check_ast_count(0, Tensor.ones(4).sqrt())
  17. _check_ast_count(0, Tensor.ones(4) + Tensor.ones(4))
  18. _check_ast_count(0, Tensor.ones(4) / Tensor.ones(4))
  19. def test_cast(self):
  20. _check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
  21. _check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
  22. def test_neg_folding(self):
  23. _check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
  24. _check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
  25. _check_ast_count(0, Tensor([1, 2, 3]).neg().neg())
  26. def test_neg_realized_no_fold(self):
  27. x = Tensor.randn(32, 32)
  28. x = x.clip(0, 1).realize()
  29. _check_ast_count(1, x.neg())
  30. class TestBinaryOpsConstFolding(unittest.TestCase):
  31. def test_add_literal_zero(self):
  32. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
  33. def test_add_tensor_zero(self):
  34. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4))
  35. def test_literal_zero_add(self):
  36. _check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4]))
  37. def test_tensor_zero_add(self):
  38. _check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4]))
  39. def test_sub_literal_zero(self):
  40. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0)
  41. def test_sub_tensor_zero(self):
  42. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) - Tensor.zeros(4))
  43. def test_mul_literal_zero(self):
  44. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0)
  45. def test_mul_tensor_zero(self):
  46. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4))
  47. def test_literal_zero_mul(self):
  48. _check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0)
  49. def test_tensor_zero_mul(self):
  50. _check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4]))
  51. def test_mul_literal_one(self):
  52. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1)
  53. def test_mul_tensor_one(self):
  54. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4))
  55. def test_literal_one_mul(self):
  56. _check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4]))
  57. def test_tensor_one_mul(self):
  58. _check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
  59. def test_bool_tensor_mul_bool(self):
  60. _check_ast_count(0, Tensor([True, False]) * True)
  61. _check_ast_count(0, Tensor([True, False]) * False)
  62. def test_bool_mul_bool_tensor(self):
  63. _check_ast_count(0, True * Tensor([True, False]))
  64. _check_ast_count(0, False * Tensor([True, False]))
  65. def test_div_literal_one(self):
  66. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
  67. def test_div_tensor_one(self):
  68. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
  69. def test_pow_literal_zero(self):
  70. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
  71. def test_pow_tensor_zero(self):
  72. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4))
  73. def test_pow_literal_one(self):
  74. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
  75. def test_pow_tensor_one(self):
  76. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
  77. def test_literal_one_pow(self):
  78. _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
  79. def test_tensor_one_pow(self):
  80. _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
  81. # folds advance indexing into basic indexing
  82. class TestIndexingConstFolding(unittest.TestCase):
  83. def test_scalar_index(self):
  84. t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
  85. _check_ast_count(0, t[:,:,Tensor(1),:])
  86. _check_ast_count(0, t[:,:,Tensor(1)+2,:])
  87. _check_ast_count(0, t[:,:,Tensor(1),Tensor(0)])
  88. @unittest.expectedFailure
  89. def test_const_tensor_index(self):
  90. # TODO: implement const tensor folded indexing
  91. t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
  92. _check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
  93. _check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
  94. _check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
  95. class TestMovedConstFolding(unittest.TestCase):
  96. def test_add_shrunk_zero(self):
  97. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
  98. def test_add_padded_zero(self):
  99. # TODO: it's 1 now, this might be possible to fold
  100. _check_ast_count(1, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
  101. def test_mul_shrunk_one(self):
  102. _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
  103. def test_add_padded_one(self):
  104. _check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
  105. def test_cast_padded(self):
  106. # NOTE: this is folded due to CAST_BEFORE_VIEW
  107. _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
  108. np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
  109. _check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
  110. np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
  111. # not folded
  112. _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
  113. np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
  114. class TestReduceOpsConstFolding(unittest.TestCase):
  115. def test_const_sum(self):
  116. _check_ast_count(0, Tensor.ones(4, 5, 6).sum())
  117. np.testing.assert_equal(Tensor.ones(4, 5, 6).sum().numpy(), 4 * 5 * 6)
  118. _check_ast_count(0, Tensor.ones(4, 5, 6).sum(axis=0))
  119. np.testing.assert_equal(Tensor.ones(4, 5, 6).sum(axis=0).numpy(), np.full((5, 6), 4))
  120. _check_ast_count(0, Tensor(4).sum())
  121. np.testing.assert_equal(Tensor(4).sum().numpy(), 4)
  122. def test_padded_const_sum(self):
  123. _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
  124. np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)
  125. # NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0.
  126. _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
  127. np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
  128. def test_const_max(self):
  129. _check_ast_count(0, Tensor.ones(4, 5, 6).max())
  130. np.testing.assert_equal(Tensor.ones(4, 5, 6).max().numpy(), 1)
  131. _check_ast_count(0, Tensor(4).max())
  132. np.testing.assert_equal(Tensor(4).max().numpy(), 4)
  133. def test_sum_output_dtype(self):
  134. # sum output dtype can be different from input
  135. for dt in dtypes.fields().values():
  136. if is_dtype_supported(dt):
  137. t = Tensor.ones(16, dtype=dt).reshape(4, 4)
  138. assert t.sum().dtype == t.contiguous().sum().dtype
  139. @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
  140. class TestMultiConstFolding(unittest.TestCase):
  141. def test_multi_const_folding_literal(self):
  142. ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
  143. t = Tensor.arange(16).float().realize().to(ds)
  144. # non const folding case creates one ast on each shard
  145. _check_ast_count(4, t + 1)
  146. _check_ast_count(4, 1 + t)
  147. _check_ast_count(4, t * 2)
  148. _check_ast_count(4, 2 * t)
  149. # const folded
  150. _check_ast_count(0, t + 0)
  151. _check_ast_count(0, 0 + t)
  152. _check_ast_count(0, t * 0)
  153. _check_ast_count(0, 0 * t)
  154. _check_ast_count(0, t * 1)
  155. _check_ast_count(0, 1 * t)
  156. np.testing.assert_equal((t + 0).numpy(), np.arange(16))
  157. np.testing.assert_equal((t * 0).numpy(), [0] * 16)
  158. np.testing.assert_equal((t * 1).numpy(), np.arange(16))
  159. _check_ast_count(0, t ** 0)
  160. _check_ast_count(0, t ** 1)
  161. _check_ast_count(0, 1 ** t)
  162. def test_multi_const_folding_tensor(self):
  163. ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
  164. t = Tensor.arange(16).float().realize().to(ds)
  165. zero = Tensor.zeros(16).realize().to(ds)
  166. one = Tensor.ones(16).realize().to(ds)
  167. # const folded
  168. _check_ast_count(0, t + zero)
  169. _check_ast_count(0, zero + t)
  170. _check_ast_count(0, t * zero)
  171. _check_ast_count(0, zero * t)
  172. _check_ast_count(0, t * one)
  173. _check_ast_count(0, one * t)
  174. np.testing.assert_equal((t + zero).numpy(), np.arange(16))
  175. np.testing.assert_equal((t * zero).numpy(), [0] * 16)
  176. np.testing.assert_equal((t * one).numpy(), np.arange(16))
  177. @unittest.expectedFailure
  178. def test_multi_todo_pow(self):
  179. ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
  180. t = Tensor.arange(16).float().realize().to(ds)
  181. zero = Tensor.zeros(16).realize().to(ds)
  182. one = Tensor.ones(16).realize().to(ds)
  183. # TODO: fix pow folding
  184. _check_ast_count(0, t ** zero)
  185. _check_ast_count(0, t ** one)
  186. _check_ast_count(0, one ** t)
  187. class TestTautologicalCompare(unittest.TestCase):
  188. # without const folding, these would have triggered -Wtautological-compare in clang
  189. def test_lt_false(self):
  190. # bool < False is always false
  191. np.testing.assert_equal((Tensor([True, False]) < False).numpy(), [False, False])
  192. def test_true_lt(self):
  193. # True < bool is always false
  194. np.testing.assert_equal((True < Tensor([True, False])).numpy(), [False, False])
  195. def test_truth_table(self):
  196. np.testing.assert_equal((Tensor(False) < Tensor(False)).numpy(), False)
  197. np.testing.assert_equal((Tensor(False) < Tensor(True)).numpy(), True)
  198. np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False)
  199. np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False)
  200. @unittest.skip("not implemented yet")
  201. def test_a_eq_a(self):
  202. # self eq is always true for int or bool
  203. a = Tensor([1, 2, 3])
  204. np.testing.assert_equal((a == a).numpy(), [True, True, True])
  205. # not true for nan
  206. a = Tensor([math.nan, 1.0, 2.0])
  207. np.testing.assert_equal((a == a).numpy(), [False, True, True])
  208. @unittest.skip("not implemented yet")
  209. def test_a_ne_a(self):
  210. # self not eq is always false for int or bool
  211. a = Tensor([1, 2, 3])
  212. np.testing.assert_equal((a != a).numpy(), [False, False, False])
  213. # not true for nan
  214. a = Tensor([math.nan, 1.0, 2.0])
  215. np.testing.assert_equal((a != a).numpy(), [True, False, False])
  216. if __name__ == '__main__':
  217. unittest.main()