external_test_optim.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #!/usr/bin/env python
  2. import unittest
  3. import numpy as np
  4. import tensorflow as tf
  5. import tensorflow_addons as tfa
  6. from tensorflow.python.ops import math_ops
  7. from extra.lr_scheduler import LRSchedulerGroup
  8. from tinygrad.tensor import Tensor
  9. from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
  10. from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
  11. from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
  12. from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
  13. np.random.seed(1337)
  14. x_init = np.random.randn(1,4).astype(np.float32)
  15. W_init = np.random.randn(4,4).astype(np.float32)
  16. m_init = np.random.randn(1,4).astype(np.float32)
  17. class TinyNet:
  18. def __init__(self):
  19. self.x = Tensor(x_init.copy(), requires_grad=True)
  20. self.W = Tensor(W_init.copy(), requires_grad=True)
  21. self.m = Tensor(m_init.copy())
  22. def forward(self):
  23. out = self.x.matmul(self.W).relu()
  24. out = out.log_softmax(1)
  25. out = out.mul(self.m).add(self.m).sum()
  26. return out
  27. class TinyNetTF:
  28. def __init__(self):
  29. self.x = tf.Variable(x_init.copy(), trainable=True, name="x")
  30. self.W = tf.Variable(W_init.copy(), trainable=True, name="W")
  31. self.m = tf.constant(m_init.copy())
  32. def forward(self):
  33. out = tf.matmul(self.x, self.W)
  34. out = tf.nn.relu(out)
  35. out = tf.nn.log_softmax(out, axis=1)
  36. out = tf.multiply(out, self.m) + self.m
  37. out = tf.reduce_sum(out)
  38. return out
  39. def step(optim, steps=1, kwargs={}, scheduler=None, schedopts=None, do_optim=True):
  40. net = TinyNet()
  41. optim = optim([net.x, net.W], **kwargs)
  42. if scheduler is not None: scheduler = scheduler(optim, **schedopts)
  43. lrs = []
  44. for _ in range(steps):
  45. if do_optim:
  46. out = net.forward()
  47. optim.zero_grad()
  48. out.backward()
  49. lrs.append(optim.lr.item() if not isinstance(optim, OptimizerGroup) else optim.optimizers[0].lr.item())
  50. if do_optim: optim.step()
  51. if scheduler is not None: scheduler.step()
  52. return lrs, net.x.detach().numpy(), net.W.detach().numpy()
  53. def step_tf(optim, steps=1, kwargs={}, scheduler=None, schedopts=None, do_optim=True):
  54. net = TinyNetTF()
  55. if scheduler is not None: kwargs['lr'] = scheduler(**schedopts)
  56. optim = optim(**kwargs)
  57. lrs = []
  58. for _ in range(steps):
  59. if do_optim:
  60. with tf.GradientTape() as tape:
  61. out = net.forward()
  62. lr_t = optim.learning_rate
  63. # refer to test/external/mlperf_resnet/lars_optimizer.py:_prepare_local
  64. if callable(lr_t): lr_t = lr_t(math_ops.cast(optim.iterations, tf.float32))
  65. lrs.append(lr_t)
  66. if do_optim:
  67. grads = tape.gradient(out, [net.x, net.W])
  68. optim.apply_gradients(zip(grads, [net.x, net.W]))
  69. # optim calls scheduler in tf
  70. else:
  71. optim._iterations.assign_add(1)
  72. return lrs, net.x.numpy(), net.W.numpy()
  73. # skip list is skipping W
  74. def create_tiny_lars(params, lr, skip_list=False):
  75. if skip_list: return OptimizerGroup(LARS([params[0]], lr), SGD([params[1]], lr, classic=True, weight_decay=0., momentum=.9))
  76. return LARS(params, lr)
  77. def create_tf_lars(lr, skip_list=False): return LARSOptimizer(lr, skip_list=["W"] if skip_list else None)
  78. def create_tiny_polylr(optim, initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
  79. assert power == 2
  80. if skip_list: return LRSchedulerGroup(
  81. PolynomialDecayWithWarmup(optim[0], initial_lr, end_lr, train_steps, warmup, power),
  82. PolynomialDecayWithWarmup(optim[1], initial_lr, end_lr, train_steps, warmup, power))
  83. return PolynomialDecayWithWarmup(optim, initial_lr, end_lr, train_steps, warmup, power)
  84. def create_tf_polylr(initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
  85. assert power == 2
  86. return PolynomialDecayWithWarmup_tf(1, 1, train_steps,
  87. initial_learning_rate=initial_lr, end_learning_rate=end_lr, warmup_epochs=warmup)
  88. class ExternalTestOptim(unittest.TestCase):
  89. def setUp(self):
  90. self.old_training = Tensor.training
  91. Tensor.training = True
  92. def tearDown(self):
  93. Tensor.training = self.old_training
  94. def _test_optim(self, tinygrad_optim, tensorflow_optim, steps, opts, atol, rtol, tiny_sched=None, tf_sched=None, schedopts=None, do_optim=True):
  95. for x,y in zip(step(tinygrad_optim, steps=steps, kwargs=opts, scheduler=tiny_sched, schedopts=schedopts, do_optim=do_optim),
  96. step_tf(tensorflow_optim, steps=steps, kwargs=opts, scheduler=tf_sched, schedopts=schedopts, do_optim=do_optim)):
  97. np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
  98. def _test_lamb(self, steps, opts, atol, rtol): self._test_optim(LAMB, tfa.optimizers.LAMB, steps, opts, atol, rtol)
  99. def _test_lars(self, steps, opts, atol, rtol): self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol)
  100. def _test_lars_polylr(self, steps, opts, schedopts, atol, rtol, do_optim=True):
  101. self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol,
  102. tiny_sched=create_tiny_polylr, tf_sched=create_tf_polylr, schedopts=schedopts, do_optim=do_optim)
  103. def test_lamb(self): self._test_lamb(1, {'lr': 0.001}, 1e-5, 0)
  104. def test_lamb_high_lr(self): self._test_lamb(1, {'lr': 10}, 1e-5, 1e-5)
  105. def test_multistep_lamb(self): self._test_lamb(10, {'lr': 0.001}, 1e-5, 0)
  106. def test_multistep_lamb_high_lr(self): self._test_lamb(10, {'lr': 10}, 1e-5, 3e-4)
  107. def test_lars(self): self._test_lars(1, {'lr': 0.01}, 1e-5, 0)
  108. def test_lars_high_lr(self): self._test_lars(1, {'lr': 10}, 1e-5, 1e-5)
  109. def test_multistep_lars(self): self._test_lars(10, {'lr': 0.001}, 1e-5, 0)
  110. def test_multistep_lars_high_lr(self): self._test_lars(10, {'lr': 10}, 1e-5, 3e-4)
  111. def test_lars_skip(self): self._test_lars(10, {'lr': 10, 'skip_list': True}, 1e-5, 3e-4)
  112. def test_lars_skip_high_lr(self): self._test_lars(1, {'lr': 10, 'skip_list': True}, 1e-5, 1e-5)
  113. def test_lars_skip_multistep(self): self._test_lars(10, {'lr': 0.001, 'skip_list': True}, 1e-5, 0)
  114. def test_lars_skip_multistep_high_lr(self): self._test_lars(10, {'lr': 10, 'skip_list': True}, 1e-5, 3e-4)
  115. def test_lars_polylr(self):
  116. self._test_lars_polylr(10, {'lr': 1.0}, {
  117. 'initial_lr': 1.0,
  118. 'end_lr': 1e-4,
  119. 'train_steps': 10,
  120. 'warmup': 3
  121. }, 1e-5, 1e-5)
  122. def test_lars_polylr_large(self):
  123. self._test_lars_polylr(100, {'lr': 10.0}, {
  124. 'initial_lr': 10.0,
  125. 'end_lr': 1e-5,
  126. 'train_steps': 100,
  127. 'warmup': 43
  128. }, 1e-5, 1e-5, do_optim=False)
  129. def test_lars_polylr_skip(self):
  130. self._test_lars_polylr(10, {'lr': 1.0, 'skip_list': True}, {
  131. 'initial_lr': 1.0,
  132. 'end_lr': 1e-4,
  133. 'train_steps': 10,
  134. 'warmup': 3,
  135. 'skip_list': True
  136. }, 1e-5, 1e-5)
  137. @unittest.skip("slow, but you can run this locally to check")
  138. def test_lars_polylr_resnet(self):
  139. train_files = 1_281_167
  140. BS = 624
  141. steps_per_epoch = train_files // BS
  142. epochs = 45
  143. warmup_epochs = 5
  144. self._test_lars_polylr(steps_per_epoch * epochs, {'lr': 10.4}, {
  145. 'initial_lr': 10.4,
  146. 'end_lr': 1e-4,
  147. # step counts for BS=624 EPOCHS=45 resnet
  148. 'train_steps': steps_per_epoch * epochs,
  149. 'warmup': steps_per_epoch * warmup_epochs,
  150. }, 1e-5, 1e-5, do_optim=False)
  151. if __name__ == '__main__':
  152. unittest.main()