lars_optimizer.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # https://github.com/mlcommons/training/blob/e3769c8dcf88cd21e1001dd2f894b40a1513ec5d/image_classification/tensorflow2/lars_optimizer.py
  2. # changes: don't call lr_t if it's not a schedule
  3. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. """Layer-wise Adaptive Rate Scaling optimizer for large-batch training."""
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import tensorflow as tf
  22. # from tf2_common.training import optimizer_v2modified
  23. from tensorflow.python.framework import ops
  24. from tensorflow.python.keras import backend_config
  25. from tensorflow.python.keras.optimizer_v2 import optimizer_v2
  26. from tensorflow.python.ops import array_ops
  27. from tensorflow.python.ops import linalg_ops
  28. from tensorflow.python.ops import math_ops
  29. from tensorflow.python.training import training_ops
  30. from tensorflow.python.ops import state_ops
  31. # class LARSOptimizer(optimizer_v2modified.OptimizerV2Modified):
  32. class LARSOptimizer(optimizer_v2.OptimizerV2):
  33. """Layer-wise Adaptive Rate Scaling for large batch training.
  34. Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
  35. I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
  36. Implements the LARS learning rate scheme presented in the paper above. This
  37. optimizer is useful when scaling the batch size to up to 32K without
  38. significant performance degradation. It is recommended to use the optimizer
  39. in conjunction with:
  40. - Gradual learning rate warm-up
  41. - Linear learning rate scaling
  42. - Poly rule learning rate decay
  43. Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors
  44. use the default momentum optimizer.
  45. """
  46. def __init__(
  47. self,
  48. learning_rate,
  49. momentum=0.9,
  50. weight_decay=0.0001,
  51. # The LARS coefficient is a hyperparameter
  52. eeta=0.001,
  53. epsilon=0.0,
  54. name="LARSOptimizer",
  55. # Enable skipping variables from LARS scaling.
  56. # TODO(sameerkm): Enable a direct mechanism to pass a
  57. # subset of variables to the optimizer.
  58. skip_list=None,
  59. use_nesterov=False,
  60. **kwargs):
  61. """Construct a new LARS Optimizer.
  62. Args:
  63. learning_rate: A `Tensor`, floating point value, or a schedule that is a
  64. `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
  65. that takes no arguments and returns the actual value to use. The
  66. learning rate.
  67. momentum: A floating point value. Momentum hyperparameter.
  68. weight_decay: A floating point value. Weight decay hyperparameter.
  69. eeta: LARS coefficient as used in the paper. Dfault set to LARS
  70. coefficient from the paper. (eeta / weight_decay) determines the highest
  71. scaling factor in LARS.
  72. epsilon: Optional epsilon parameter to be set in models that have very
  73. small gradients. Default set to 0.0.
  74. name: Optional name prefix for variables and ops created by LARSOptimizer.
  75. skip_list: List of strings to enable skipping variables from LARS scaling.
  76. If any of the strings in skip_list is a subset of var.name, variable
  77. 'var' is skipped from LARS scaling. For a typical classification model
  78. with batch normalization, the skip_list is ['batch_normalization',
  79. 'bias']
  80. use_nesterov: when set to True, nesterov momentum will be enabled
  81. **kwargs: keyword arguments.
  82. Raises:
  83. ValueError: If a hyperparameter is set to a non-sensical value.
  84. """
  85. if momentum < 0.0:
  86. raise ValueError("momentum should be positive: %s" % momentum)
  87. if weight_decay < 0.0:
  88. raise ValueError("weight_decay should be positive: %s" % weight_decay)
  89. super(LARSOptimizer, self).__init__(name=name, **kwargs)
  90. self._set_hyper("learning_rate", learning_rate)
  91. # When directly using class members, instead of
  92. # _set_hyper and _get_hyper (such as learning_rate above),
  93. # the values are fixed after __init(), and not being
  94. # updated during the training process.
  95. # This provides better performance but less flexibility.
  96. self.momentum = momentum
  97. self.weight_decay = weight_decay
  98. self.eeta = eeta
  99. self.epsilon = epsilon or backend_config.epsilon()
  100. self._skip_list = skip_list
  101. self.use_nesterov = use_nesterov
  102. def _prepare_local(self, var_device, var_dtype, apply_state):
  103. lr_t = self._get_hyper("learning_rate", var_dtype)
  104. local_step = math_ops.cast(self.iterations, var_dtype)
  105. if callable(lr_t): lr_t = math_ops.cast(lr_t(local_step), var_dtype)
  106. learning_rate_t = array_ops.identity(lr_t)
  107. apply_state[(var_device, var_dtype)].update(
  108. dict(
  109. learning_rate=learning_rate_t,
  110. ))
  111. def _create_slots(self, var_list):
  112. for v in var_list:
  113. self.add_slot(v, "momentum")
  114. def compute_lr(self, grad, var, coefficients):
  115. scaled_lr = coefficients["learning_rate"]
  116. if self._skip_list is None or not any(v in var.name
  117. for v in self._skip_list):
  118. w_norm = linalg_ops.norm(var, ord=2)
  119. g_norm = linalg_ops.norm(grad, ord=2)
  120. trust_ratio = array_ops.where(
  121. math_ops.greater(w_norm, 0),
  122. array_ops.where(
  123. math_ops.greater(g_norm, 0),
  124. (self.eeta * w_norm /
  125. (g_norm + self.weight_decay * w_norm + self.epsilon)), 1.0), 1.0)
  126. scaled_lr = coefficients["learning_rate"] * trust_ratio
  127. # Add the weight regularization gradient
  128. grad = grad + self.weight_decay * var
  129. return scaled_lr, grad
  130. def _apply_dense(self, grad, var, apply_state=None):
  131. var_device, var_dtype = var.device, var.dtype.base_dtype
  132. coefficients = ((apply_state or {}).get((var_device, var_dtype))
  133. or self._fallback_apply_state(var_device, var_dtype))
  134. scaled_lr, grad = self.compute_lr(grad, var, coefficients)
  135. mom = self.get_slot(var, "momentum")
  136. return training_ops.apply_momentum(
  137. var,
  138. mom,
  139. math_ops.cast(1.0, var.dtype.base_dtype),
  140. grad * scaled_lr,
  141. self.momentum,
  142. use_locking=False,
  143. use_nesterov=self.use_nesterov)
  144. def _resource_apply_dense(self, grad, var, apply_state=None):
  145. var_device, var_dtype = var.device, var.dtype.base_dtype
  146. coefficients = ((apply_state or {}).get((var_device, var_dtype))
  147. or self._fallback_apply_state(var_device, var_dtype))
  148. scaled_lr, grad = self.compute_lr(grad, var, coefficients)
  149. mom = self.get_slot(var, "momentum")
  150. # Use ApplyKerasMomentum instead of ApplyMomentum
  151. # training_ops.resource_apply_keras_momentum(
  152. # var.handle,
  153. # mom.handle,
  154. # scaled_lr,
  155. # grad,
  156. # coefficients["momentum"],
  157. # use_locking=False,
  158. # use_nesterov=self.use_nesterov)
  159. mom_t = mom * self.momentum - grad * scaled_lr
  160. mom_t = state_ops.assign(mom, mom_t, use_locking=False)
  161. if self.use_nesterov:
  162. var_t = var + mom_t * self.momentum - grad * scaled_lr
  163. else:
  164. var_t = var + mom_t
  165. return state_ops.assign(var, var_t, use_locking=False).op
  166. # Fallback to momentum optimizer for sparse tensors
  167. def _apply_sparse(self, grad, var, apply_state=None):
  168. var_device, var_dtype = var.device, var.dtype.base_dtype
  169. coefficients = ((apply_state or {}).get((var_device, var_dtype))
  170. or self._fallback_apply_state(var_device, var_dtype))
  171. mom = self.get_slot(var, "momentum")
  172. return training_ops.sparse_apply_momentum(
  173. var,
  174. mom,
  175. coefficients["learning_rate"],
  176. grad.values,
  177. grad.indices,
  178. self.momentum,
  179. use_locking=False,
  180. use_nesterov=self.use_nesterov)
  181. def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
  182. var_device, var_dtype = var.device, var.dtype.base_dtype
  183. coefficients = ((apply_state or {}).get((var_device, var_dtype))
  184. or self._fallback_apply_state(var_device, var_dtype))
  185. mom = self.get_slot(var, "momentum")
  186. return training_ops.resource_sparse_apply_keras_momentum(
  187. var.handle,
  188. mom.handle,
  189. coefficients["learning_rate"],
  190. grad,
  191. indices,
  192. self.momentum,
  193. use_locking=False,
  194. use_nesterov=self.use_nesterov)
  195. def get_config(self):
  196. config = super(LARSOptimizer, self).get_config()
  197. config.update({
  198. "learning_rate": self._serialize_hyperparameter("learning_rate"),
  199. "momentum": self.momentum,
  200. "weight_decay": self.weight_decay,
  201. "eeta": self.eeta,
  202. "epsilon": self.epsilon,
  203. "use_nesterov": self.use_nesterov,
  204. })
  205. return config