test_tqdm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import time, random, unittest, itertools
  2. from unittest.mock import patch
  3. from io import StringIO
  4. from collections import namedtuple
  5. from tqdm import tqdm
  6. from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
  7. import numpy as np
  8. class TestProgressBar(unittest.TestCase):
  9. def _compare_bars(self, bar1, bar2):
  10. prefix1, prog1, suffix1 = bar1.split("|")
  11. prefix2, prog2, suffix2 = bar2.split("|")
  12. self.assertEqual(len(bar1), len(bar2))
  13. self.assertEqual(prefix1, prefix2)
  14. def parse_timer(timer): return sum(int(x) * y for x, y in zip(timer.split(':')[::-1], (1, 60, 3600)))
  15. if "?" not in suffix1 and "?" not in suffix2:
  16. # allow for few sec diff in timers (removes flakiness)
  17. timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
  18. timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
  19. np.testing.assert_allclose(timer1, timer2, atol=5, rtol=1e-2)
  20. np.testing.assert_allclose(rm1, rm2, atol=5, rtol=1e-2)
  21. # get suffix without timers
  22. suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
  23. suffix2 = suffix2.split("[")[0] + suffix2.split(",")[1]
  24. self.assertEqual(suffix1, suffix2)
  25. else:
  26. self.assertEqual(suffix1, suffix2)
  27. diff = sum([c1 != c2 for c1, c2 in zip(prog1, prog2)]) # allow 1 char diff to be less flaky, but it should match
  28. assert diff <= 1, f"{diff=}\n{prog1=}\n{prog2=}"
  29. @patch('sys.stderr', new_callable=StringIO)
  30. @patch('shutil.get_terminal_size')
  31. def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
  32. for _ in range(10):
  33. total, ncols = random.randint(5, 30), random.randint(80, 240)
  34. mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
  35. mock_stderr.truncate(0)
  36. # compare bars at each iteration (only when tinytqdm bar has been updated)
  37. for n in (bar := tinytqdm(range(total), desc="Test")):
  38. time.sleep(0.01)
  39. if bar.i % bar.skip != 0: continue
  40. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  41. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  42. elapsed = n/iters_per_sec if n>0 else 0
  43. tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
  44. self._compare_bars(tinytqdm_output, tqdm_output)
  45. # compare final bars
  46. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  47. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  48. elapsed = total/iters_per_sec if n>0 else 0
  49. tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
  50. self._compare_bars(tinytqdm_output, tqdm_output)
  51. @patch('sys.stderr', new_callable=StringIO)
  52. @patch('shutil.get_terminal_size')
  53. def test_unit_scale(self, mock_terminal_size, mock_stderr):
  54. for unit_scale in [True, False]:
  55. # NOTE: numpy comparison raises TypeError if exponent > 22
  56. for exponent in range(1, 22, 3):
  57. low, high = 10 ** exponent, 10 ** (exponent+1)
  58. for _ in range(3):
  59. total, ncols = random.randint(low, high), random.randint(80, 240)
  60. mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
  61. mock_stderr.truncate(0)
  62. # compare bars at each iteration (only when tinytqdm bar has been updated)
  63. for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale):
  64. time.sleep(0.01)
  65. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  66. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  67. elapsed = n/iters_per_sec if n>0 else 0
  68. tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
  69. # print(f"tiny: {tinytqdm_output}")
  70. # print(f"tqdm: {tqdm_output}")
  71. self._compare_bars(tinytqdm_output, tqdm_output)
  72. if n > 3: break
  73. @patch('sys.stderr', new_callable=StringIO)
  74. @patch('shutil.get_terminal_size')
  75. def test_set_description(self, mock_terminal_size, mock_stderr):
  76. for _ in range(10):
  77. total, ncols = random.randint(5, 30), random.randint(80, 240)
  78. mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
  79. mock_stderr.truncate(0)
  80. expected_prefix = "Test"
  81. # compare bars at each iteration (only when tinytqdm bar has been updated)
  82. for i,n in enumerate(bar := tinytqdm(range(total), desc="Test")):
  83. time.sleep(0.01)
  84. if bar.i % bar.skip != 0: continue
  85. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  86. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  87. elapsed = n/iters_per_sec if n>0 else 0
  88. tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
  89. expected_prefix = desc = f"Test {i}" if i % 2 == 0 else ""
  90. bar.set_description(desc)
  91. self._compare_bars(tinytqdm_output, tqdm_output)
  92. # compare final bars
  93. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  94. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  95. elapsed = total/iters_per_sec if n>0 else 0
  96. tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
  97. self._compare_bars(tinytqdm_output, tqdm_output)
  98. @patch('sys.stderr', new_callable=StringIO)
  99. @patch('shutil.get_terminal_size')
  100. def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
  101. for _ in range(5):
  102. total, ncols = random.randint(5, 30), random.randint(80, 240)
  103. mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
  104. mock_stderr.truncate(0)
  105. # compare bars at each iteration (only when tinytqdm bar has been updated)
  106. for n in (bar := tinytrange(total, desc="Test")):
  107. time.sleep(0.01)
  108. if bar.i % bar.skip != 0: continue
  109. tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  110. iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  111. elapsed = n/iters_per_sec if n>0 else 0
  112. tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
  113. self._compare_bars(tiny_output, tqdm_output)
  114. # compare final bars
  115. tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  116. iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  117. elapsed = total/iters_per_sec if n>0 else 0
  118. tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
  119. self._compare_bars(tiny_output, tqdm_output)
  120. @patch('sys.stderr', new_callable=StringIO)
  121. @patch('shutil.get_terminal_size')
  122. def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):
  123. for _ in range(10):
  124. total, ncols = random.randint(10000, 100000), random.randint(80, 120)
  125. mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
  126. mock_stderr.truncate(0)
  127. # compare bars at each iteration (only when tinytqdm bar has been updated)
  128. bar = tinytqdm(total=total, desc="Test")
  129. n = 0
  130. while n < total:
  131. time.sleep(0.01)
  132. incr = (total // 10) + random.randint(0, 100)
  133. if n + incr > total: incr = total - n
  134. bar.update(incr, close=n+incr==total)
  135. n += incr
  136. if bar.i % bar.skip != 0: continue
  137. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  138. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  139. elapsed = n/iters_per_sec if n>0 else 0
  140. tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
  141. self._compare_bars(tinytqdm_output, tqdm_output)
  142. @patch('sys.stderr', new_callable=StringIO)
  143. @patch('shutil.get_terminal_size')
  144. def test_tqdm_output_custom_0_total(self, mock_terminal_size, mock_stderr):
  145. for _ in range(10):
  146. total, ncols = random.randint(10000, 100000), random.randint(80, 120)
  147. mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
  148. mock_stderr.truncate(0)
  149. # compare bars at each iteration (only when tinytqdm bar has been updated)
  150. bar = tinytqdm(total=0, desc="Test")
  151. n = 0
  152. while n < total:
  153. time.sleep(0.01)
  154. incr = (total // 10) + random.randint(0, 100)
  155. if n + incr > total: incr = total - n
  156. bar.update(incr, close=n+incr==total)
  157. n += incr
  158. if bar.i % bar.skip != 0: continue
  159. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  160. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
  161. elapsed = n/iters_per_sec if n>0 else 0
  162. tqdm_output = tqdm.format_meter(n=n, total=0, elapsed=elapsed, ncols=ncols, prefix="Test")
  163. self.assertEqual(tinytqdm_output, tqdm_output)
  164. @patch('sys.stderr', new_callable=StringIO)
  165. @patch('shutil.get_terminal_size')
  166. def test_tqdm_output_custom_nolen_total(self, mock_terminal_size, mock_stderr):
  167. for unit_scale in [True, False]:
  168. for _ in range(3):
  169. gen = itertools.count(0)
  170. ncols = random.randint(80, 120)
  171. mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
  172. mock_stderr.truncate(0)
  173. # compare bars at each iteration (only when tinytqdm bar has been updated)
  174. for n,g in enumerate(tinytqdm(gen, desc="Test", unit_scale=unit_scale)):
  175. assert g == n
  176. time.sleep(0.01)
  177. tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
  178. if n:
  179. iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1])
  180. elapsed = n/iters_per_sec
  181. else:
  182. elapsed = 0
  183. tqdm_output = tqdm.format_meter(n=n, total=0, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
  184. self.assertEqual(tinytqdm_output, tqdm_output)
  185. if n > 5: break
  186. def test_tqdm_perf(self):
  187. st = time.perf_counter()
  188. for _ in tqdm(range(100)): time.sleep(0.01)
  189. tqdm_time = time.perf_counter() - st
  190. st = time.perf_counter()
  191. for _ in tinytqdm(range(100)): time.sleep(0.01)
  192. tinytqdm_time = time.perf_counter() - st
  193. assert tinytqdm_time < 2 * tqdm_time
  194. def test_tqdm_perf_high_iter(self):
  195. st = time.perf_counter()
  196. for _ in tqdm(range(10^7)): pass
  197. tqdm_time = time.perf_counter() - st
  198. st = time.perf_counter()
  199. for _ in tinytqdm(range(10^7)): pass
  200. tinytqdm_time = time.perf_counter() - st
  201. assert tinytqdm_time < 5 * tqdm_time
  202. if __name__ == '__main__':
  203. unittest.main()