abstractions3.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # abstractions2 goes from back to front, here we will go from front to back
  2. from typing import List
  3. from tqdm import tqdm
  4. from tinygrad.helpers import DEBUG
  5. # *****
  6. # 0. Load mnist on the device
  7. from tinygrad.nn.datasets import mnist
  8. X_train, Y_train, _, _ = mnist()
  9. X_train = X_train.float()
  10. X_train -= X_train.mean()
  11. # *****
  12. # 1. Define an MNIST model.
  13. from tinygrad import Tensor
  14. l1 = Tensor.kaiming_uniform(128, 784)
  15. l2 = Tensor.kaiming_uniform(10, 128)
  16. def model(x): return x.flatten(1).dot(l1.T).relu().dot(l2.T)
  17. l1n, l2n = l1.numpy(), l2.numpy()
  18. # *****
  19. # 2. Choose a batch for training and do the backward pass.
  20. from tinygrad.nn.optim import SGD
  21. optim = SGD([l1, l2])
  22. X, Y = X_train[(samples:=Tensor.randint(128, high=X_train.shape[0]))], Y_train[samples]
  23. optim.zero_grad()
  24. model(X).sparse_categorical_crossentropy(Y).backward()
  25. optim._step() # this will step the optimizer without running realize
  26. # *****
  27. # 3. Create a schedule.
  28. # The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
  29. # l1.lazydata and l2.lazydata define a computation graph
  30. from tinygrad.engine.schedule import ScheduleItem
  31. schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)
  32. print(f"The schedule contains {len(schedule)} items.")
  33. for si in schedule: print(str(si)[:80])
  34. # *****
  35. # 4. Lower a schedule.
  36. from tinygrad.engine.realize import lower_schedule_item, ExecItem
  37. lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si).prg, list(si.bufs)) for si in tqdm(schedule)]
  38. # *****
  39. # 5. Run the schedule
  40. for ei in tqdm(lowered): ei.run()
  41. # *****
  42. # 6. Print the weight change
  43. print("first weight change\n", l1.numpy()-l1n)
  44. print("second weight change\n", l2.numpy()-l2n)