beautiful_mnist_torch.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from tinygrad import dtypes
  2. from tinygrad.helpers import trange
  3. from tinygrad.nn.datasets import mnist
  4. import torch
  5. from torch import nn, optim
  6. class Model(nn.Module):
  7. def __init__(self):
  8. super().__init__()
  9. self.c1 = nn.Conv2d(1, 32, 5)
  10. self.c2 = nn.Conv2d(32, 32, 5)
  11. self.bn1 = nn.BatchNorm2d(32)
  12. self.m1 = nn.MaxPool2d(2)
  13. self.c3 = nn.Conv2d(32, 64, 3)
  14. self.c4 = nn.Conv2d(64, 64, 3)
  15. self.bn2 = nn.BatchNorm2d(64)
  16. self.m2 = nn.MaxPool2d(2)
  17. self.lin = nn.Linear(576, 10)
  18. def forward(self, x):
  19. x = nn.functional.relu(self.c1(x))
  20. x = nn.functional.relu(self.c2(x), 0)
  21. x = self.m1(self.bn1(x))
  22. x = nn.functional.relu(self.c3(x), 0)
  23. x = nn.functional.relu(self.c4(x), 0)
  24. x = self.m2(self.bn2(x))
  25. return self.lin(torch.flatten(x, 1))
  26. if __name__ == "__main__":
  27. mps_device = torch.device("mps")
  28. X_train, Y_train, X_test, Y_test = mnist()
  29. X_train = torch.tensor(X_train.float().numpy(), device=mps_device)
  30. Y_train = torch.tensor(Y_train.cast(dtypes.int64).numpy(), device=mps_device)
  31. X_test = torch.tensor(X_test.float().numpy(), device=mps_device)
  32. Y_test = torch.tensor(Y_test.cast(dtypes.int64).numpy(), device=mps_device)
  33. model = Model().to(mps_device)
  34. optimizer = optim.Adam(model.parameters(), 1e-3)
  35. loss_fn = nn.CrossEntropyLoss()
  36. #@torch.compile
  37. def step(samples):
  38. X,Y = X_train[samples], Y_train[samples]
  39. out = model(X)
  40. loss = loss_fn(out, Y)
  41. optimizer.zero_grad()
  42. loss.backward()
  43. optimizer.step()
  44. return loss
  45. test_acc = float('nan')
  46. for i in (t:=trange(70)):
  47. samples = torch.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well
  48. loss = step(samples)
  49. if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()
  50. t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")