beautiful_mnist_multigpu.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
  2. from typing import List, Callable
  3. from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
  4. from tinygrad.helpers import getenv, colored, trange
  5. from tinygrad.nn.datasets import mnist
  6. GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
  7. class Model:
  8. def __init__(self):
  9. self.layers: List[Callable[[Tensor], Tensor]] = [
  10. nn.Conv2d(1, 32, 5), Tensor.relu,
  11. nn.Conv2d(32, 32, 5), Tensor.relu,
  12. nn.BatchNorm2d(32), Tensor.max_pool2d,
  13. nn.Conv2d(32, 64, 3), Tensor.relu,
  14. nn.Conv2d(64, 64, 3), Tensor.relu,
  15. nn.BatchNorm2d(64), Tensor.max_pool2d,
  16. lambda x: x.flatten(1), nn.Linear(576, 10)]
  17. def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
  18. if __name__ == "__main__":
  19. X_train, Y_train, X_test, Y_test = mnist()
  20. # we shard the test data on axis 0
  21. X_test.shard_(GPUS, axis=0)
  22. Y_test.shard_(GPUS, axis=0)
  23. model = Model()
  24. for k, x in nn.state.get_state_dict(model).items(): x.to_(GPUS) # we put a copy of the model on every GPU
  25. opt = nn.optim.Adam(nn.state.get_parameters(model))
  26. @TinyJit
  27. def train_step() -> Tensor:
  28. with Tensor.train():
  29. opt.zero_grad()
  30. samples = Tensor.randint(512, high=X_train.shape[0])
  31. Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0
  32. # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
  33. loss = model(Xt).sparse_categorical_crossentropy(Yt).backward()
  34. opt.step()
  35. return loss
  36. @TinyJit
  37. def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
  38. test_acc = float('nan')
  39. for i in (t:=trange(70)):
  40. GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
  41. loss = train_step()
  42. if i%10 == 9: test_acc = get_test_acc().item()
  43. t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
  44. # verify eval acc
  45. if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
  46. if test_acc >= target: print(colored(f"{test_acc=} >= {target}", "green"))
  47. else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))