training.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. from tinygrad.tensor import Tensor
  3. from tinygrad.helpers import CI, trange
  4. from tinygrad.engine.jit import TinyJit
  5. def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
  6. transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True):
  7. def train_step(x, y):
  8. # network
  9. out = model.forward(x) if hasattr(model, 'forward') else model(x)
  10. loss = lossfn(out, y)
  11. optim.zero_grad()
  12. loss.backward()
  13. if noloss: del loss
  14. optim.step()
  15. if noloss: return (None, None)
  16. cat = out.argmax(axis=-1)
  17. accuracy = (cat == y).mean()
  18. return loss.realize(), accuracy.realize()
  19. if allow_jit: train_step = TinyJit(train_step)
  20. with Tensor.train():
  21. losses, accuracies = [], []
  22. for i in (t := trange(steps, disable=CI)):
  23. samp = np.random.randint(0, X_train.shape[0], size=(BS))
  24. x = Tensor(transform(X_train[samp]), requires_grad=False)
  25. y = Tensor(target_transform(Y_train[samp]))
  26. loss, accuracy = train_step(x, y)
  27. # printing
  28. if not noloss:
  29. loss, accuracy = loss.numpy(), accuracy.numpy()
  30. losses.append(loss)
  31. accuracies.append(accuracy)
  32. t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
  33. return [losses, accuracies]
  34. def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x,
  35. target_transform=lambda y: y):
  36. Tensor.training = False
  37. def numpy_eval(Y_test, num_classes):
  38. Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
  39. for i in trange((len(Y_test)-1)//BS+1, disable=CI):
  40. x = Tensor(transform(X_test[i*BS:(i+1)*BS]))
  41. out = model.forward(x) if hasattr(model, 'forward') else model(x)
  42. Y_test_preds_out[i*BS:(i+1)*BS] = out.numpy()
  43. Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
  44. Y_test = target_transform(Y_test)
  45. return (Y_test == Y_test_preds).mean(), Y_test_preds
  46. if num_classes is None: num_classes = Y_test.max().astype(int)+1
  47. acc, Y_test_pred = numpy_eval(Y_test, num_classes)
  48. print("test set accuracy is %f" % acc)
  49. return (acc, Y_test_pred) if return_predict else acc