test_lr_scheduler.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import numpy as np
  2. import torch
  3. import unittest
  4. from tinygrad.tensor import Tensor
  5. from tinygrad.nn.state import get_parameters
  6. from tinygrad.nn.optim import Adam, SGD
  7. from tinygrad.helpers import DEBUG
  8. from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR
  9. from extra.training import train, evaluate
  10. from extra.datasets import fetch_mnist
  11. np.random.seed(1337)
  12. Tensor.manual_seed(1337)
  13. X_train, Y_train, X_test, Y_test = fetch_mnist()
  14. class TinyBobNet:
  15. def __init__(self):
  16. self.l1 = Tensor.scaled_uniform(784, 128)
  17. self.l2 = Tensor.scaled_uniform(128, 10)
  18. def parameters(self):
  19. return get_parameters(self)
  20. def forward(self, x):
  21. return x.dot(self.l1).relu().dot(self.l2).log_softmax()
  22. def lr_scheduler_training(sched_fn=None, args=None):
  23. model = TinyBobNet()
  24. optim = Adam(model.parameters(), lr=0.01)
  25. if sched_fn is not None: sched = sched_fn(optim, **args)
  26. for _ in range(25):
  27. train(model, X_train, Y_train, optim, 100)
  28. if sched_fn is not None:
  29. if isinstance(sched, ReduceLROnPlateau):
  30. sched.step(evaluate(model, X_test, Y_test))
  31. else:
  32. sched.step()
  33. return evaluate(model, X_test, Y_test)
  34. def current_lr(optim): return optim.param_groups[0]['lr'] if hasattr(optim, 'param_groups') else optim.lr
  35. def get_lrs(optim, sched, epochs, steps=1, accs=None):
  36. lr = current_lr(optim)
  37. if not isinstance(lr, float): lr = lr.numpy()[0]
  38. lrs = [lr]
  39. for e in range(epochs):
  40. for _ in range(steps):
  41. optim.step()
  42. sched.step() if accs is None else sched.step(accs[e])
  43. lr = current_lr(optim)
  44. if not isinstance(lr, float): lr = lr.numpy()[0]
  45. lrs.append(lr)
  46. return lrs
  47. class TestLrScheduler(unittest.TestCase):
  48. def setUp(self):
  49. self.old_training = Tensor.training
  50. Tensor.training = True
  51. def tearDown(self):
  52. Tensor.training = self.old_training
  53. def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True):
  54. accs = opts.pop('accs', None)
  55. test_tensor = Tensor([0.], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
  56. test_tensor.mean().backward()
  57. if adam:
  58. tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01)
  59. else:
  60. tinygrad_optim, torch_optim = SGD([test_tensor], lr=0.01), torch.optim.SGD([torch.tensor([0.], requires_grad=True)], lr=0.01)
  61. tinygrad_sched, torch_sched = tinygrad_sched(tinygrad_optim, **opts), torch_sched(torch_optim, **opts)
  62. tinygrad_lrs = get_lrs(tinygrad_optim, tinygrad_sched, epochs, accs=accs)
  63. torch_lrs = get_lrs(torch_optim, torch_sched, epochs, accs=accs)
  64. np.testing.assert_allclose(tinygrad_lrs, torch_lrs, atol=atol, rtol=rtol)
  65. def _test_multisteplr(self, epochs, opts, atol, rtol, adam=True):
  66. self._test_lr_scheduler(MultiStepLR, torch.optim.lr_scheduler.MultiStepLR, epochs, opts, atol, rtol, adam=adam)
  67. def _test_reducelronplateau(self, epochs, opts, atol, rtol):
  68. opts['accs'] = np.random.randn(epochs)
  69. self._test_lr_scheduler(ReduceLROnPlateau, torch.optim.lr_scheduler.ReduceLROnPlateau, epochs, opts, atol, rtol)
  70. def _test_cosineannealinglr(self, epochs, opts, atol, rtol):
  71. opts['T_max'] = epochs
  72. self._test_lr_scheduler(CosineAnnealingLR, torch.optim.lr_scheduler.CosineAnnealingLR, epochs, opts, atol, rtol)
  73. def _test_onecyclelr(self, epochs, opts, atol, rtol):
  74. opts['total_steps'] = epochs
  75. self._test_lr_scheduler(OneCycleLR, torch.optim.lr_scheduler.OneCycleLR, epochs, opts, atol, rtol)
  76. def test_multisteplr(self): self._test_multisteplr(10, {'milestones': [1, 2, 7]}, 1e-6, 1e-6)
  77. def test_multisteplr_gamma(self): self._test_multisteplr(10, {'milestones': [1, 2, 7], 'gamma': 0.1337}, 1e-6, 1e-6)
  78. def test_reducelronplateau(self): self._test_reducelronplateau(100, {}, 1e-6, 1e-6)
  79. def test_reducelronplateau_max(self): self._test_reducelronplateau(100, {'mode': 'max'}, 1e-6, 1e-6)
  80. def test_reducelronplateau_factor(self): self._test_reducelronplateau(100, {'factor': 0.1337}, 1e-6, 1e-6)
  81. def test_reducelronplateau_patience(self): self._test_reducelronplateau(100, {'patience': 3}, 1e-6, 1e-6)
  82. def test_reducelronplateau_threshold(self): self._test_reducelronplateau(100, {'threshold': 1e-6}, 1e-6, 1e-6)
  83. def test_reducelronplateau_threshold_mode(self): self._test_reducelronplateau(100, {'threshold_mode': 'abs'}, 1e-6, 1e-6)
  84. def test_cosineannealinglr(self): self._test_cosineannealinglr(100, {}, 1e-6, 1e-6)
  85. def test_cosineannealinglr_eta_min(self): self._test_cosineannealinglr(100, {'eta_min': 0.001}, 1e-6, 1e-6)
  86. def test_multistep_2step(self):
  87. # was making this fail with LRU=1, some issue with epoch_counter
  88. if DEBUG>=2: print("first")
  89. self._test_multisteplr(1, {'milestones': [1]}, 1e-6, 1e-6, adam=False)
  90. if DEBUG>=2: print("second")
  91. self._test_multisteplr(1, {'milestones': [1], 'gamma': 0.133}, 1e-6, 1e-6, adam=False)
  92. if DEBUG>=2: print("third")
  93. def test_onecyclelr(self): self._test_onecyclelr(1000, {'pct_start': 0.3, 'anneal_strategy': 'linear',
  94. 'cycle_momentum': False, 'div_factor': 25.0,
  95. 'final_div_factor': 10000.0, 'max_lr':1e-5}, 1e-6, 1e-6)
  96. @unittest.skip("slow")
  97. def test_training(self):
  98. without = lr_scheduler_training()
  99. sched_fns = [MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR]
  100. argss = [{'milestones': [5, 7, 10, 15], 'gamma': 0.5}, {'factor': 0.5, 'patience': 2}, {'T_max': 25, 'eta_min': 0.001},
  101. {'pct_start': 0.3, 'anneal_strategy': 'linear', 'cycle_momentum': False, 'div_factor': 25.0, 'final_div_factor': 10000.0,
  102. 'max_lr':1e-5, 'total_steps': 25}]
  103. for sched_fn, args in zip(sched_fns, argss):
  104. with_sched = lr_scheduler_training(sched_fn, args)
  105. assert with_sched > without
  106. if __name__ == '__main__':
  107. unittest.main()