train_efficientnet.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import traceback
  2. import time
  3. from multiprocessing import Process, Queue
  4. import numpy as np
  5. from tinygrad.nn.state import get_parameters
  6. from tinygrad.nn import optim
  7. from tinygrad.helpers import getenv, trange
  8. from tinygrad.tensor import Tensor
  9. from extra.datasets import fetch_cifar
  10. from extra.models.efficientnet import EfficientNet
  11. class TinyConvNet:
  12. def __init__(self, classes=10):
  13. conv = 3
  14. inter_chan, out_chan = 8, 16 # for speed
  15. self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
  16. self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
  17. self.l1 = Tensor.uniform(out_chan*6*6, classes)
  18. def forward(self, x):
  19. x = x.conv2d(self.c1).relu().max_pool2d()
  20. x = x.conv2d(self.c2).relu().max_pool2d()
  21. x = x.reshape(shape=[x.shape[0], -1])
  22. return x.dot(self.l1)
  23. if __name__ == "__main__":
  24. IMAGENET = getenv("IMAGENET")
  25. classes = 1000 if IMAGENET else 10
  26. TINY = getenv("TINY")
  27. TRANSFER = getenv("TRANSFER")
  28. if TINY:
  29. model = TinyConvNet(classes)
  30. elif TRANSFER:
  31. model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
  32. model.load_from_pretrained()
  33. else:
  34. model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
  35. parameters = get_parameters(model)
  36. print("parameter count", len(parameters))
  37. optimizer = optim.Adam(parameters, lr=0.001)
  38. BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048)
  39. print(f"training with batch size {BS} for {steps} steps")
  40. if IMAGENET:
  41. from extra.datasets.imagenet import fetch_batch
  42. def loader(q):
  43. while 1:
  44. try:
  45. q.put(fetch_batch(BS))
  46. except Exception:
  47. traceback.print_exc()
  48. q = Queue(16)
  49. for i in range(2):
  50. p = Process(target=loader, args=(q,))
  51. p.daemon = True
  52. p.start()
  53. else:
  54. X_train, Y_train, _, _ = fetch_cifar()
  55. X_train = X_train.reshape((-1, 3, 32, 32))
  56. Y_train = Y_train.reshape((-1,))
  57. with Tensor.train():
  58. for i in (t := trange(steps)):
  59. if IMAGENET:
  60. X, Y = q.get(True)
  61. else:
  62. samp = np.random.randint(0, X_train.shape[0], size=(BS))
  63. X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
  64. st = time.time()
  65. out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
  66. fp_time = (time.time()-st)*1000.0
  67. y = np.zeros((BS,classes), np.float32)
  68. y[range(y.shape[0]),Y] = -classes
  69. y = Tensor(y, requires_grad=False)
  70. loss = out.log_softmax().mul(y).mean()
  71. optimizer.zero_grad()
  72. st = time.time()
  73. loss.backward()
  74. bp_time = (time.time()-st)*1000.0
  75. st = time.time()
  76. optimizer.step()
  77. opt_time = (time.time()-st)*1000.0
  78. st = time.time()
  79. loss = loss.numpy()
  80. cat = out.argmax(axis=1).numpy()
  81. accuracy = (cat == Y).mean()
  82. finish_time = (time.time()-st)*1000.0
  83. # printing
  84. t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
  85. (loss, accuracy,
  86. fp_time, bp_time, opt_time, finish_time,
  87. fp_time + bp_time + opt_time + finish_time))
  88. del out, y, loss