lars_util.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # https://github.com/mlcommons/training/blob/e237206991d10449d9675d95606459a3cb6c21ad/image_classification/tensorflow2/lars_util.py
  2. # changes: commented out logging
  3. # changes: convert_to_tensor_v2 -> convert_to_tensor
  4. # changes: extend from tf.python.keras.optimizer_v2.learning_rate_schedule.LearningRateScheduler
  5. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. # ==============================================================================
  19. """Enable Layer-wise Adaptive Rate Scaling optimizer in ResNet."""
  20. from __future__ import absolute_import
  21. from __future__ import division
  22. from __future__ import print_function
  23. from absl import flags
  24. import tensorflow as tf
  25. #from tf2_common.utils.mlp_log import mlp_log
  26. from tensorflow.python.eager import context
  27. from tensorflow.python.framework import ops
  28. from tensorflow.python.ops import math_ops
  29. from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
  30. FLAGS = flags.FLAGS
  31. def define_lars_flags():
  32. """Defines flags needed by LARS optimizer."""
  33. flags.DEFINE_float(
  34. 'end_learning_rate', default=None,
  35. help=('Polynomial decay end learning rate.'))
  36. flags.DEFINE_float(
  37. 'lars_epsilon', default=0.0,
  38. help=('Override autoselected LARS epsilon.'))
  39. flags.DEFINE_float(
  40. 'warmup_epochs', default=None,
  41. help=('Override autoselected polynomial decay warmup epochs.'))
  42. flags.DEFINE_float(
  43. 'momentum',
  44. default=0.9,
  45. help=('Momentum parameter used in the MomentumOptimizer.'))
  46. class PolynomialDecayWithWarmup(learning_rate_schedule.LearningRateSchedule):
  47. """A LearningRateSchedule that uses a polynomial decay with warmup."""
  48. def __init__(
  49. self,
  50. batch_size,
  51. steps_per_epoch,
  52. train_steps,
  53. initial_learning_rate=None,
  54. end_learning_rate=None,
  55. warmup_epochs=None,
  56. compute_lr_on_cpu=False,
  57. name=None):
  58. """Applies a polynomial decay to the learning rate with warmup."""
  59. super(PolynomialDecayWithWarmup, self).__init__()
  60. self.batch_size = batch_size
  61. self.steps_per_epoch = steps_per_epoch
  62. self.train_steps = train_steps
  63. self.name = name
  64. self.learning_rate_ops_cache = {}
  65. self.compute_lr_on_cpu = compute_lr_on_cpu
  66. if batch_size < 16384:
  67. self.initial_learning_rate = 10.0
  68. warmup_epochs_ = 5
  69. elif batch_size < 32768:
  70. self.initial_learning_rate = 25.0
  71. warmup_epochs_ = 5
  72. else:
  73. self.initial_learning_rate = 31.2
  74. warmup_epochs_ = 25
  75. # Override default poly learning rate and warmup epochs
  76. if initial_learning_rate:
  77. self.initial_learning_rate = initial_learning_rate
  78. if end_learning_rate:
  79. self.end_learning_rate = end_learning_rate
  80. else:
  81. self.end_learning_rate = 0.0001
  82. if warmup_epochs is not None:
  83. warmup_epochs_ = warmup_epochs
  84. self.warmup_epochs = warmup_epochs_
  85. """
  86. opt_name = FLAGS.optimizer.lower()
  87. mlp_log.mlperf_print('opt_name', opt_name)
  88. if opt_name == 'lars':
  89. mlp_log.mlperf_print('{}_epsilon'.format(opt_name), FLAGS.lars_epsilon)
  90. mlp_log.mlperf_print('{}_opt_weight_decay'.format(opt_name),
  91. FLAGS.weight_decay)
  92. mlp_log.mlperf_print('{}_opt_base_learning_rate'.format(opt_name),
  93. self.initial_learning_rate)
  94. mlp_log.mlperf_print('{}_opt_learning_rate_warmup_epochs'.format(opt_name),
  95. warmup_epochs_)
  96. mlp_log.mlperf_print('{}_opt_end_learning_rate'.format(opt_name),
  97. self.end_learning_rate)
  98. """
  99. warmup_steps = warmup_epochs_ * steps_per_epoch
  100. self.warmup_steps = tf.cast(warmup_steps, tf.float32)
  101. self.decay_steps = train_steps - warmup_steps + 1
  102. """
  103. mlp_log.mlperf_print('{}_opt_learning_rate_decay_steps'.format(opt_name),
  104. int(self.decay_steps))
  105. mlp_log.mlperf_print(
  106. '{}_opt_learning_rate_decay_poly_power'.format(opt_name), 2.0)
  107. mlp_log.mlperf_print('{}_opt_momentum'.format(opt_name), FLAGS.momentum)
  108. """
  109. self.poly_rate_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
  110. initial_learning_rate=self.initial_learning_rate,
  111. decay_steps=self.decay_steps,
  112. end_learning_rate=self.end_learning_rate,
  113. power=2.0)
  114. def __call__(self, step):
  115. if tf.executing_eagerly():
  116. return self._get_learning_rate(step)
  117. # In an eager function or graph, the current implementation of optimizer
  118. # repeatedly call and thus create ops for the learning rate schedule. To
  119. # avoid this, we cache the ops if not executing eagerly.
  120. graph = tf.compat.v1.get_default_graph()
  121. if graph not in self.learning_rate_ops_cache:
  122. if self.compute_lr_on_cpu:
  123. with tf.device('/device:CPU:0'):
  124. self.learning_rate_ops_cache[graph] = self._get_learning_rate(step)
  125. else:
  126. self.learning_rate_ops_cache[graph] = self._get_learning_rate(step)
  127. return self.learning_rate_ops_cache[graph]
  128. def _get_learning_rate(self, step):
  129. with ops.name_scope_v2(self.name or 'PolynomialDecayWithWarmup') as name:
  130. initial_learning_rate = ops.convert_to_tensor(
  131. self.initial_learning_rate, name='initial_learning_rate')
  132. warmup_steps = ops.convert_to_tensor(
  133. self.warmup_steps, name='warmup_steps')
  134. warmup_rate = (
  135. initial_learning_rate * step / warmup_steps)
  136. poly_steps = math_ops.subtract(step, warmup_steps)
  137. poly_rate = self.poly_rate_scheduler(poly_steps)
  138. decay_rate = tf.where(step <= warmup_steps,
  139. warmup_rate, poly_rate, name=name)
  140. return decay_rate
  141. def get_config(self):
  142. return {
  143. 'batch_size': self.batch_size,
  144. 'steps_per_epoch': self.steps_per_epoch,
  145. 'train_steps': self.train_steps,
  146. 'initial_learning_rate': self.initial_learning_rate,
  147. 'end_learning_rate': self.end_learning_rate,
  148. 'warmup_epochs': self.warmup_epochs,
  149. 'name': self.name,
  150. }