test_end2end.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import torch
  2. from torch import nn
  3. import unittest
  4. import numpy as np
  5. from tinygrad.nn.state import get_parameters, get_state_dict
  6. from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
  7. from tinygrad.tensor import Tensor
  8. from extra.datasets import fetch_mnist
  9. from tinygrad.helpers import CI
  10. def compare_tiny_torch(model, model_torch, X, Y):
  11. with Tensor.train():
  12. model_torch.train()
  13. model_state_dict = get_state_dict(model)
  14. for k,v in model_torch.named_parameters():
  15. if not CI: print(f"initting {k} from torch")
  16. model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
  17. optimizer = optim.SGD(get_parameters(model), lr=0.001)
  18. optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.001)
  19. Xt = torch.Tensor(X.numpy())
  20. np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
  21. out = model(X)
  22. loss = (out * Y).mean()
  23. if not CI: print(loss.realize().numpy())
  24. out_torch = model_torch(torch.Tensor(X.numpy()))
  25. loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
  26. if not CI: print(loss_torch.detach().numpy())
  27. # assert losses match
  28. np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
  29. # zero and backward
  30. optimizer.zero_grad()
  31. loss.backward()
  32. optimizer_torch.zero_grad()
  33. loss_torch.backward()
  34. for k,v in list(model_torch.named_parameters())[::-1]:
  35. g = model_state_dict[k].grad.numpy()
  36. gt = v.grad.detach().numpy()
  37. if not CI: print("testing grads", k, model_state_dict[k].grad.dtype)
  38. np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
  39. # take the steps
  40. optimizer.step()
  41. optimizer_torch.step()
  42. # assert weights match
  43. for k,v in model_torch.named_parameters():
  44. if not CI: print("testing weight", k, model_state_dict[k].dtype)
  45. np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
  46. def get_mnist_data():
  47. _X_train, _Y_train, X_test, Y_test = fetch_mnist()
  48. BS = 32
  49. num_classes = 10
  50. X = Tensor(X_test[0:BS].astype(np.float32))
  51. Y = np.zeros((BS, num_classes), np.float32)
  52. Y[range(BS),Y_test[0:BS]] = -1.0*num_classes
  53. return X, Tensor(Y)
  54. class TestEnd2End(unittest.TestCase):
  55. @classmethod
  56. def setUpClass(cls):
  57. cls.X, cls.Y = get_mnist_data()
  58. def setUp(self):
  59. torch.manual_seed(123)
  60. def test_linear_mnist(self):
  61. class LinTiny:
  62. def __init__(self, bias=False):
  63. self.l1 = Linear(784, 128, bias=bias)
  64. self.l2 = Linear(128, 10, bias=bias)
  65. def __call__(self, x):
  66. return self.l2(self.l1(x).relu()).log_softmax(-1)
  67. class LinTorch(nn.Module):
  68. def __init__(self, bias=False):
  69. super().__init__()
  70. self.l1 = nn.Linear(784, 128, bias=bias)
  71. self.l2 = nn.Linear(128, 10, bias=bias)
  72. def forward(self, x):
  73. return self.l2(self.l1(x).relu()).log_softmax(-1)
  74. compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
  75. def test_bn_mnist(self):
  76. class LinTiny:
  77. def __init__(self):
  78. self.l1 = Linear(784, 128)
  79. self.l2 = Linear(128, 10)
  80. self.bn1 = BatchNorm2d(128)
  81. def __call__(self, x):
  82. return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
  83. class LinTorch(nn.Module):
  84. def __init__(self):
  85. super().__init__()
  86. self.l1 = nn.Linear(784, 128)
  87. self.l2 = nn.Linear(128, 10)
  88. self.bn1 = nn.BatchNorm2d(128)
  89. def forward(self, x):
  90. return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
  91. compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
  92. def test_bn_alone(self):
  93. np.random.seed(1337)
  94. X = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
  95. Y = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
  96. compare_tiny_torch(BatchNorm2d(10), nn.BatchNorm2d(10), X, Y)
  97. def test_bn_linear(self):
  98. BS, K = 2, 1
  99. eps = 0
  100. X = Tensor([1,0]).reshape(BS, K, 1, 1)
  101. Y = Tensor([-1,0]).reshape(BS, K, 1, 1)
  102. class LinTiny:
  103. def __init__(self):
  104. self.l1 = Conv2d(K, K, 1, bias=False)
  105. self.bn1 = BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
  106. def __call__(self, x): return self.bn1(self.l1(x))
  107. class LinTorch(nn.Module):
  108. def __init__(self):
  109. super().__init__()
  110. self.l1 = nn.Conv2d(K, K, 1, bias=False)
  111. self.bn1 = nn.BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
  112. def forward(self, x): return self.bn1(self.l1(x))
  113. model_torch = LinTorch()
  114. with torch.no_grad():
  115. model_torch.l1.weight[:] = 1.
  116. compare_tiny_torch(LinTiny(), model_torch, X, Y)
  117. def test_conv_mnist(self):
  118. class LinTiny:
  119. def __init__(self, has_batchnorm=False):
  120. self.c1 = Conv2d(1, 8, 3, stride=2)
  121. self.c2 = Conv2d(8, 16, 3, stride=2)
  122. self.l1 = Linear(16*6*6, 10)
  123. if has_batchnorm:
  124. self.bn1, self.bn2 = BatchNorm2d(8), BatchNorm2d(16)
  125. else:
  126. self.bn1, self.bn2 = lambda x: x, lambda x: x
  127. def __call__(self, x):
  128. return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
  129. class LinTorch(nn.Module):
  130. def __init__(self, has_batchnorm=False):
  131. super().__init__()
  132. self.c1 = nn.Conv2d(1, 8, 3, stride=2)
  133. self.c2 = nn.Conv2d(8, 16, 3, stride=2)
  134. self.l1 = nn.Linear(16*6*6, 10)
  135. if has_batchnorm:
  136. self.bn1, self.bn2 = nn.BatchNorm2d(8), nn.BatchNorm2d(16)
  137. else:
  138. self.bn1, self.bn2 = lambda x: x, lambda x: x
  139. def forward(self, x):
  140. return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
  141. for has_batchnorm in [False, True]:
  142. with self.subTest(has_batchnorm=has_batchnorm):
  143. compare_tiny_torch(LinTiny(has_batchnorm), LinTorch(has_batchnorm), self.X.reshape((-1, 1, 28, 28)), self.Y)
  144. if __name__ == "__main__":
  145. unittest.main()