train_resnet.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. from PIL import Image
  4. from tinygrad.nn.state import get_parameters
  5. from tinygrad.nn import optim
  6. from tinygrad.helpers import getenv
  7. from extra.training import train, evaluate
  8. from extra.models.resnet import ResNet
  9. from extra.datasets import fetch_mnist
  10. class ComposeTransforms:
  11. def __init__(self, trans):
  12. self.trans = trans
  13. def __call__(self, x):
  14. for t in self.trans:
  15. x = t(x)
  16. return x
  17. if __name__ == "__main__":
  18. X_train, Y_train, X_test, Y_test = fetch_mnist()
  19. X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
  20. X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
  21. classes = 10
  22. TRANSFER = getenv('TRANSFER')
  23. model = ResNet(getenv('NUM', 18), num_classes=classes)
  24. if TRANSFER:
  25. model.load_from_pretrained()
  26. lr = 5e-3
  27. transform = ComposeTransforms([
  28. lambda x: [Image.fromarray(xx, mode='L').resize((64, 64)) for xx in x],
  29. lambda x: np.stack([np.asarray(xx) for xx in x], 0),
  30. lambda x: x / 255.0,
  31. lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
  32. ])
  33. for _ in range(5):
  34. optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
  35. train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
  36. evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
  37. lr /= 1.2
  38. print(f'reducing lr to {lr:.7f}')