lr_schedulers.py 1.1 KB

12345678910111213141516171819202122
  1. from tinygrad import Tensor, dtypes
  2. from tinygrad.nn.optim import Optimizer
  3. from extra.lr_scheduler import LR_Scheduler
  4. # https://github.com/mlcommons/training/blob/e237206991d10449d9675d95606459a3cb6c21ad/image_classification/tensorflow2/lars_util.py
  5. class PolynomialDecayWithWarmup(LR_Scheduler):
  6. def __init__(self, optimizer: Optimizer, initial_lr, end_lr, train_steps, warmup, power=2):
  7. super().__init__(optimizer)
  8. self.epoch_counter = self.epoch_counter.cast(dtypes.float32)
  9. assert train_steps > 0 and warmup > 0
  10. self.warmup = min(warmup, train_steps)
  11. self.initial_lr, self.end_lr, self.epochs, self.power = initial_lr, end_lr, train_steps, power
  12. # set lr for first warmup step
  13. self.optimizer.lr.assign(self.get_lr()).realize()
  14. def get_lr(self):
  15. # LR is 0 on the first step, matching the reference.
  16. warmup_lr = (self.epoch_counter * (1.0 / self.warmup)) * self.initial_lr
  17. x = (1 - (self.epoch_counter - self.warmup) / (self.epochs - self.warmup + 1))
  18. return (self.epoch_counter <= self.warmup).where(warmup_lr, (self.initial_lr - self.end_lr) * x ** self.power + self.end_lr).cast(self.optimizer.lr.dtype)