efficientnet.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # load weights from
  2. # https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
  3. # a rough copy of
  4. # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
  5. import sys
  6. import ast
  7. import time
  8. import numpy as np
  9. from PIL import Image
  10. from tinygrad.tensor import Tensor
  11. from tinygrad.helpers import getenv, fetch, Timing
  12. from tinygrad.engine.jit import TinyJit
  13. from extra.models.efficientnet import EfficientNet
  14. np.set_printoptions(suppress=True)
  15. # TODO: you should be able to put these in the jitted function
  16. bias = Tensor([0.485, 0.456, 0.406])
  17. scale = Tensor([0.229, 0.224, 0.225])
  18. @TinyJit
  19. def _infer(model, img):
  20. img = img.permute((2,0,1))
  21. img = img / 255.0
  22. img = img - bias.reshape((1,-1,1,1))
  23. img = img / scale.reshape((1,-1,1,1))
  24. return model.forward(img).realize()
  25. def infer(model, img):
  26. # preprocess image
  27. aspect_ratio = img.size[0] / img.size[1]
  28. img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
  29. img = np.array(img)
  30. y0,x0=(np.asarray(img.shape)[:2]-224)//2
  31. retimg = img = img[y0:y0+224, x0:x0+224]
  32. # if you want to look at the image
  33. """
  34. import matplotlib.pyplot as plt
  35. plt.imshow(img)
  36. plt.show()
  37. """
  38. # run the net
  39. out = _infer(model, Tensor(img.astype("float32"))).numpy()
  40. # if you want to look at the outputs
  41. """
  42. import matplotlib.pyplot as plt
  43. plt.plot(out[0])
  44. plt.show()
  45. """
  46. return out, retimg
  47. if __name__ == "__main__":
  48. # instantiate my net
  49. model = EfficientNet(getenv("NUM", 0))
  50. model.load_from_pretrained()
  51. # category labels
  52. lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
  53. # load image and preprocess
  54. url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
  55. if url == 'webcam':
  56. import cv2
  57. cap = cv2.VideoCapture(0)
  58. cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  59. while 1:
  60. _ = cap.grab() # discard one frame to circumvent capture buffering
  61. ret, frame = cap.read()
  62. img = Image.fromarray(frame[:, :, [2,1,0]])
  63. lt = time.monotonic_ns()
  64. out, retimg = infer(model, img)
  65. print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
  66. SCALE = 3
  67. simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
  68. retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
  69. cv2.imshow('capture', retimg)
  70. if cv2.waitKey(1) & 0xFF == ord('q'):
  71. break
  72. cap.release()
  73. cv2.destroyAllWindows()
  74. else:
  75. img = Image.open(fetch(url))
  76. for i in range(getenv("CNT", 1)):
  77. with Timing("did inference in "):
  78. out, _ = infer(model, img)
  79. print(np.argmax(out), np.max(out), lbls[np.argmax(out)])