optim.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # sorted in order of increasing complexity
  2. from typing import List
  3. from tinygrad.helpers import dedup, flatten, getenv
  4. from tinygrad.tensor import Tensor
  5. from tinygrad.dtype import dtypes, least_upper_dtype
  6. class Optimizer:
  7. """
  8. Base class for all optimizers.
  9. """
  10. def __init__(self, params: List[Tensor], lr: float):
  11. # if it's None, but being put into an optimizer, set it to True
  12. for x in params:
  13. if x.requires_grad is None: x.requires_grad = True
  14. self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
  15. assert len(self.params) != 0, "optimizer must have at least one param"
  16. self.device = self.params[0].device
  17. self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
  18. # store lr in at least float32 precision
  19. self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
  20. dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
  21. def zero_grad(self):
  22. """
  23. Zeroes the gradients of all the parameters.
  24. """
  25. for param in self.params: param.grad = None
  26. def step(self):
  27. """
  28. Performs a single optimization step.
  29. """
  30. Tensor.realize(*self.schedule_step())
  31. def schedule_step(self) -> List[Tensor]:
  32. """
  33. Returns the tensors that need to be realized to perform a single optimization step.
  34. """
  35. assert Tensor.training, (
  36. f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
  37. - help: Consider setting Tensor.training=True before calling Optimizer.step().""")
  38. return self._step()+self.params+self.buffers
  39. def _step(self) -> List[Tensor]: raise NotImplementedError
  40. class OptimizerGroup(Optimizer):
  41. """
  42. Combines multiple optimizers into one.
  43. """
  44. def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
  45. self.optimizers = optimizers
  46. self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
  47. def __getitem__(self, i): return self.optimizers[i]
  48. def zero_grad(self): [o.zero_grad() for o in self.optimizers]
  49. def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
  50. # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
  51. def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
  52. """
  53. Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
  54. `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
  55. - Described: https://paperswithcode.com/method/sgd
  56. """
  57. return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
  58. class LARS(Optimizer):
  59. """
  60. Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
  61. - Described: https://paperswithcode.com/method/lars
  62. - Paper: https://arxiv.org/abs/1708.03888v3
  63. """
  64. def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
  65. super().__init__(params, lr)
  66. self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
  67. self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
  68. def _step(self) -> List[Tensor]:
  69. for i, t in enumerate(self.params):
  70. assert t.grad is not None
  71. # contiguous is needed since the grads can allegedly form a "diamond"
  72. # TODO: fix this in lazy.py
  73. g = t.grad.contiguous()
  74. if self.tcoef != 0:
  75. r1 = t.detach().square().sum().sqrt()
  76. r2 = g.square().sum().sqrt()
  77. r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
  78. else: r = 1.0
  79. g = g + self.wd * t.detach()
  80. # classic momentum does post learning rate update
  81. if self.classic: g = g * r * self.lr
  82. if self.momentum:
  83. self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
  84. g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
  85. # popular momentum does pre learning rate update
  86. if not self.classic: g = g * r * self.lr
  87. t.assign((t.detach() - g).cast(t.dtype))
  88. return self.b
  89. # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
  90. def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
  91. """
  92. AdamW optimizer with optional weight decay.
  93. - Described: https://paperswithcode.com/method/adamw
  94. - Paper: https://arxiv.org/abs/1711.05101v3
  95. """
  96. return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)
  97. def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
  98. """
  99. Adam optimizer.
  100. - Described: https://paperswithcode.com/method/adam
  101. - Paper: https://arxiv.org/abs/1412.6980
  102. """
  103. return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
  104. class LAMB(Optimizer):
  105. """
  106. LAMB optimizer with optional weight decay.
  107. - Described: https://paperswithcode.com/method/lamb
  108. - Paper: https://arxiv.org/abs/1904.00962
  109. """
  110. def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
  111. super().__init__(params, lr)
  112. self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
  113. self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
  114. self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
  115. self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
  116. def _step(self) -> List[Tensor]:
  117. self.b1_t *= self.b1
  118. self.b2_t *= self.b2
  119. for i, t in enumerate(self.params):
  120. assert t.grad is not None
  121. self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
  122. self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
  123. m_hat = self.m[i] / (1.0 - self.b1_t)
  124. v_hat = self.v[i] / (1.0 - self.b2_t)
  125. up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
  126. if not self.adam:
  127. r1 = t.detach().square().sum().sqrt()
  128. r2 = up.square().sum().sqrt()
  129. r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
  130. else:
  131. r = 1.0
  132. t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
  133. return [self.b1_t, self.b2_t] + self.m + self.v