| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- import time, random, unittest, itertools
- from unittest.mock import patch
- from io import StringIO
- from collections import namedtuple
- from tqdm import tqdm
- from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
- import numpy as np
- class TestProgressBar(unittest.TestCase):
- def _compare_bars(self, bar1, bar2):
- prefix1, prog1, suffix1 = bar1.split("|")
- prefix2, prog2, suffix2 = bar2.split("|")
- self.assertEqual(len(bar1), len(bar2))
- self.assertEqual(prefix1, prefix2)
- def parse_timer(timer): return sum(int(x) * y for x, y in zip(timer.split(':')[::-1], (1, 60, 3600)))
- if "?" not in suffix1 and "?" not in suffix2:
- # allow for few sec diff in timers (removes flakiness)
- timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
- timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
- np.testing.assert_allclose(timer1, timer2, atol=5, rtol=1e-2)
- np.testing.assert_allclose(rm1, rm2, atol=5, rtol=1e-2)
- # get suffix without timers
- suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
- suffix2 = suffix2.split("[")[0] + suffix2.split(",")[1]
- self.assertEqual(suffix1, suffix2)
- else:
- self.assertEqual(suffix1, suffix2)
- diff = sum([c1 != c2 for c1, c2 in zip(prog1, prog2)]) # allow 1 char diff to be less flaky, but it should match
- assert diff <= 1, f"{diff=}\n{prog1=}\n{prog2=}"
- @patch('sys.stderr', new_callable=StringIO)
- @patch('shutil.get_terminal_size')
- def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
- for _ in range(10):
- total, ncols = random.randint(5, 30), random.randint(80, 240)
- mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
- mock_stderr.truncate(0)
- # compare bars at each iteration (only when tinytqdm bar has been updated)
- for n in (bar := tinytqdm(range(total), desc="Test")):
- time.sleep(0.01)
- if bar.i % bar.skip != 0: continue
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = n/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
- self._compare_bars(tinytqdm_output, tqdm_output)
- # compare final bars
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = total/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
- self._compare_bars(tinytqdm_output, tqdm_output)
- @patch('sys.stderr', new_callable=StringIO)
- @patch('shutil.get_terminal_size')
- def test_unit_scale(self, mock_terminal_size, mock_stderr):
- for unit_scale in [True, False]:
- # NOTE: numpy comparison raises TypeError if exponent > 22
- for exponent in range(1, 22, 3):
- low, high = 10 ** exponent, 10 ** (exponent+1)
- for _ in range(3):
- total, ncols = random.randint(low, high), random.randint(80, 240)
- mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
- mock_stderr.truncate(0)
- # compare bars at each iteration (only when tinytqdm bar has been updated)
- for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale):
- time.sleep(0.01)
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = n/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
- # print(f"tiny: {tinytqdm_output}")
- # print(f"tqdm: {tqdm_output}")
- self._compare_bars(tinytqdm_output, tqdm_output)
- if n > 3: break
- @patch('sys.stderr', new_callable=StringIO)
- @patch('shutil.get_terminal_size')
- def test_set_description(self, mock_terminal_size, mock_stderr):
- for _ in range(10):
- total, ncols = random.randint(5, 30), random.randint(80, 240)
- mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
- mock_stderr.truncate(0)
- expected_prefix = "Test"
- # compare bars at each iteration (only when tinytqdm bar has been updated)
- for i,n in enumerate(bar := tinytqdm(range(total), desc="Test")):
- time.sleep(0.01)
- if bar.i % bar.skip != 0: continue
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = n/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
- expected_prefix = desc = f"Test {i}" if i % 2 == 0 else ""
- bar.set_description(desc)
- self._compare_bars(tinytqdm_output, tqdm_output)
- # compare final bars
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = total/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
- self._compare_bars(tinytqdm_output, tqdm_output)
- @patch('sys.stderr', new_callable=StringIO)
- @patch('shutil.get_terminal_size')
- def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
- for _ in range(5):
- total, ncols = random.randint(5, 30), random.randint(80, 240)
- mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
- mock_stderr.truncate(0)
- # compare bars at each iteration (only when tinytqdm bar has been updated)
- for n in (bar := tinytrange(total, desc="Test")):
- time.sleep(0.01)
- if bar.i % bar.skip != 0: continue
- tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = n/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
- self._compare_bars(tiny_output, tqdm_output)
- # compare final bars
- tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = total/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
- self._compare_bars(tiny_output, tqdm_output)
- @patch('sys.stderr', new_callable=StringIO)
- @patch('shutil.get_terminal_size')
- def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):
- for _ in range(10):
- total, ncols = random.randint(10000, 100000), random.randint(80, 120)
- mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
- mock_stderr.truncate(0)
- # compare bars at each iteration (only when tinytqdm bar has been updated)
- bar = tinytqdm(total=total, desc="Test")
- n = 0
- while n < total:
- time.sleep(0.01)
- incr = (total // 10) + random.randint(0, 100)
- if n + incr > total: incr = total - n
- bar.update(incr, close=n+incr==total)
- n += incr
- if bar.i % bar.skip != 0: continue
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = n/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
- self._compare_bars(tinytqdm_output, tqdm_output)
- @patch('sys.stderr', new_callable=StringIO)
- @patch('shutil.get_terminal_size')
- def test_tqdm_output_custom_0_total(self, mock_terminal_size, mock_stderr):
- for _ in range(10):
- total, ncols = random.randint(10000, 100000), random.randint(80, 120)
- mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
- mock_stderr.truncate(0)
- # compare bars at each iteration (only when tinytqdm bar has been updated)
- bar = tinytqdm(total=0, desc="Test")
- n = 0
- while n < total:
- time.sleep(0.01)
- incr = (total // 10) + random.randint(0, 100)
- if n + incr > total: incr = total - n
- bar.update(incr, close=n+incr==total)
- n += incr
- if bar.i % bar.skip != 0: continue
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
- elapsed = n/iters_per_sec if n>0 else 0
- tqdm_output = tqdm.format_meter(n=n, total=0, elapsed=elapsed, ncols=ncols, prefix="Test")
- self.assertEqual(tinytqdm_output, tqdm_output)
- @patch('sys.stderr', new_callable=StringIO)
- @patch('shutil.get_terminal_size')
- def test_tqdm_output_custom_nolen_total(self, mock_terminal_size, mock_stderr):
- for unit_scale in [True, False]:
- for _ in range(3):
- gen = itertools.count(0)
- ncols = random.randint(80, 120)
- mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
- mock_stderr.truncate(0)
- # compare bars at each iteration (only when tinytqdm bar has been updated)
- for n,g in enumerate(tinytqdm(gen, desc="Test", unit_scale=unit_scale)):
- assert g == n
- time.sleep(0.01)
- tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
- if n:
- iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1])
- elapsed = n/iters_per_sec
- else:
- elapsed = 0
- tqdm_output = tqdm.format_meter(n=n, total=0, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
- self.assertEqual(tinytqdm_output, tqdm_output)
- if n > 5: break
- def test_tqdm_perf(self):
- st = time.perf_counter()
- for _ in tqdm(range(100)): time.sleep(0.01)
- tqdm_time = time.perf_counter() - st
- st = time.perf_counter()
- for _ in tinytqdm(range(100)): time.sleep(0.01)
- tinytqdm_time = time.perf_counter() - st
- assert tinytqdm_time < 2 * tqdm_time
- def test_tqdm_perf_high_iter(self):
- st = time.perf_counter()
- for _ in tqdm(range(10^7)): pass
- tqdm_time = time.perf_counter() - st
- st = time.perf_counter()
- for _ in tinytqdm(range(10^7)): pass
- tinytqdm_time = time.perf_counter() - st
- assert tinytqdm_time < 5 * tqdm_time
- if __name__ == '__main__':
- unittest.main()
|