test_optim.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numpy as np
  2. import torch
  3. import unittest
  4. from tinygrad import Tensor, Device, dtypes
  5. from tinygrad.nn.optim import Adam, SGD, AdamW
  6. from tinygrad.helpers import CI
  7. from test.helpers import is_dtype_supported
  8. np.random.seed(1337)
  9. x_init = np.random.randn(1,4).astype(np.float32)
  10. W_init = np.random.randn(4,4).astype(np.float32)
  11. m_init = np.random.randn(1,4).astype(np.float32)
  12. class TeenyNet:
  13. def __init__(self, tensor):
  14. self.x = tensor(x_init.copy(), requires_grad=True)
  15. self.W = tensor(W_init.copy(), requires_grad=True)
  16. def forward(self):
  17. return (self.x * self.W).sum()
  18. class TinyNet:
  19. def __init__(self, tensor):
  20. self.x = tensor(x_init.copy(), requires_grad=True)
  21. self.W = tensor(W_init.copy(), requires_grad=True)
  22. self.m = tensor(m_init.copy())
  23. def forward(self):
  24. out = self.x.matmul(self.W).relu()
  25. # print(out.detach().numpy())
  26. out = out.log_softmax(1)
  27. out = out.mul(self.m).add(self.m).sum()
  28. return out
  29. def step(tensor, optim, steps=1, teeny=False, **kwargs):
  30. net = TeenyNet(tensor) if teeny else TinyNet(tensor)
  31. optim = optim([net.x, net.W], **kwargs)
  32. for _ in range(steps):
  33. out = net.forward()
  34. optim.zero_grad()
  35. out.backward()
  36. optim.step()
  37. return net.x.detach().numpy(), net.W.detach().numpy()
  38. @unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
  39. class TestOptim(unittest.TestCase):
  40. def setUp(self):
  41. self.old_training = Tensor.training
  42. Tensor.training = True
  43. def tearDown(self):
  44. Tensor.training = self.old_training
  45. def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):
  46. for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts),
  47. step(torch.tensor, torch_optim, steps, **opts)):
  48. np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
  49. def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
  50. def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
  51. def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
  52. def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5)
  53. def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
  54. def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
  55. def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
  56. def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
  57. def test_sgd_high_lr_wd(self): self._test_sgd(1, {'lr': 10, 'weight_decay': 0.1}, 1e-6, 1e-5)
  58. def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0)
  59. def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4)
  60. def test_multistep_sgd_wd(self): self._test_sgd(10, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
  61. def test_multistep_sgd_high_lr_wd(self): self._test_sgd(10, {'lr': 9, 'weight_decay': 0.1}, 1e-6, 3e-4)
  62. def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0)
  63. def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4)
  64. def test_multistep_sgd_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-6, 0)
  65. def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-5, 3e-4)
  66. def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0)
  67. def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4)
  68. def test_multistep_sgd_nesterov_momentum_wd(self):
  69. self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0)
  70. def test_multistep_sgd_high_lr_nesterov_momentum_wd(self):
  71. self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
  72. def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
  73. def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4)
  74. def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0)
  75. def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-4, 1e-4)
  76. def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0)
  77. def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-3, 5e-4)
  78. def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0)
  79. def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3)
  80. def test_duped_weights(self):
  81. for Opt in [Adam, AdamW, SGD]:
  82. losses = []
  83. for i in range(2):
  84. w = Tensor(x_init.copy())
  85. opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1)
  86. loss = None
  87. for _ in range(3):
  88. loss = w.sum()
  89. opt.zero_grad()
  90. loss.backward()
  91. opt.step()
  92. losses.append(loss.numpy())
  93. np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
  94. @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
  95. def test_mixed_precision(self):
  96. old_default_float, dtypes.default_float = dtypes.default_float, dtypes.half
  97. # weight update would overflow without upcasting
  98. self._test_sgd(10, {'lr': 1e10}, 1e-6, 3e-4)
  99. self._test_adam(1, {'lr': 1e10}, 1e-4, 1e-4)
  100. self._test_adamw(1, {'lr': 1e10}, 1e-4, 1e-4)
  101. dtypes.default_float = old_default_float
  102. def test_assert_tensor_train(self):
  103. t = Tensor.ones((1,1), requires_grad=True)
  104. optimizer = Adam([t])
  105. optimizer.zero_grad()
  106. old_state = Tensor.training
  107. t.sum().backward()
  108. Tensor.training = False
  109. self.assertRaises(AssertionError, optimizer.step)
  110. Tensor.training = True
  111. optimizer.step()
  112. Tensor.training = old_state
  113. if __name__ == '__main__':
  114. unittest.main()