lr_scheduler.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import math
  2. from typing import List
  3. from tinygrad.nn.optim import Optimizer
  4. from tinygrad.tensor import Tensor
  5. class LR_Scheduler:
  6. def __init__(self, optimizer: Optimizer):
  7. self.optimizer = optimizer
  8. self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
  9. def get_lr(self): pass
  10. def step(self) -> None:
  11. self.epoch_counter.assign(self.epoch_counter + 1).realize()
  12. self.optimizer.lr.assign(self.get_lr()).realize()
  13. class LRSchedulerGroup:
  14. def __init__(self, *schedulers: LR_Scheduler): self.schedulers = schedulers
  15. def step(self) -> None:
  16. for s in self.schedulers: s.step()
  17. class MultiStepLR(LR_Scheduler):
  18. def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
  19. super().__init__(optimizer)
  20. self.milestones = milestones
  21. self.gamma = gamma
  22. def get_lr(self) -> Tensor:
  23. if self.epoch_counter.numpy()[0] not in self.milestones:
  24. return self.optimizer.lr
  25. return self.optimizer.lr * self.gamma
  26. class ReduceLROnPlateau(LR_Scheduler):
  27. def __init__(self, optimizer: Optimizer, mode="min", factor=0.1, patience=10, threshold=1e-4, threshold_mode="rel"):
  28. assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"]
  29. super().__init__(optimizer)
  30. self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = mode, factor, patience, threshold, threshold_mode
  31. self.best = float('inf') if mode == "min" else float('-inf')
  32. self.bad_epoch = 0
  33. if mode == "min": self.threshold *= -1
  34. def is_better(self, current: float) -> bool:
  35. dynamic_threshold = self.best*(1+self.threshold) if self.threshold_mode == "rel" else self.best+self.threshold
  36. if self.mode == "min":
  37. return current < dynamic_threshold
  38. return current > dynamic_threshold
  39. def step(self, current: float) -> None:
  40. self.epoch_counter.assign(self.epoch_counter + 1).realize()
  41. if self.is_better(current):
  42. self.bad_epoch = 0
  43. self.best = current
  44. else:
  45. self.bad_epoch += 1
  46. if self.bad_epoch > self.patience:
  47. self.optimizer.lr *= self.factor
  48. self.bad_epoch = 0
  49. class CosineAnnealingLR(LR_Scheduler):
  50. def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0):
  51. super().__init__(optimizer)
  52. self.T_max = T_max
  53. self.eta_min = eta_min
  54. self.eta_max = optimizer.lr.numpy()[0]
  55. def get_lr(self) -> Tensor:
  56. lr = self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))
  57. return Tensor([lr], device=self.optimizer.device, dtype=self.optimizer.lr.dtype)
  58. class OneCycleLR(LR_Scheduler):
  59. def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
  60. anneal_strategy: str = 'linear', cycle_momentum: bool = False):
  61. super().__init__(optimizer)
  62. self.initial_lr = max_lr / div_factor
  63. self.max_lr = max_lr
  64. self.min_lr = self.initial_lr / final_div_factor
  65. self.total_steps = total_steps
  66. self.pct_start = pct_start
  67. assert anneal_strategy == 'linear', 'only linear annealing supported'
  68. assert not cycle_momentum, 'cycle momentum not supported'
  69. self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR
  70. @staticmethod
  71. def _annealing_linear(start: float, end: float, pct: Tensor) -> Tensor: return (pct*(end-start)+start)
  72. def get_lr(self) -> Tensor:
  73. return (self.epoch_counter < self.total_steps*self.pct_start).where(
  74. self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)),
  75. self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start)))
  76. ).cast(self.optimizer.lr.dtype)