vit.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import ast
  2. import numpy as np
  3. from PIL import Image
  4. from tinygrad.tensor import Tensor
  5. from tinygrad.helpers import getenv, fetch
  6. from extra.models.vit import ViT
  7. """
  8. fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
  9. import tensorflow as tf
  10. with tf.io.gfile.GFile(fn, "rb") as f:
  11. dat = f.read()
  12. with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g:
  13. g.write(dat)
  14. """
  15. Tensor.training = False
  16. if getenv("LARGE", 0) == 1:
  17. m = ViT(embed_dim=768, num_heads=12)
  18. else:
  19. # tiny
  20. m = ViT(embed_dim=192, num_heads=3)
  21. m.load_from_pretrained()
  22. # category labels
  23. lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
  24. #url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
  25. url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
  26. # junk
  27. img = Image.open(fetch(url))
  28. aspect_ratio = img.size[0] / img.size[1]
  29. img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
  30. img = np.array(img)
  31. y0,x0=(np.asarray(img.shape)[:2]-224)//2
  32. img = img[y0:y0+224, x0:x0+224]
  33. img = np.moveaxis(img, [2,0,1], [0,1,2])
  34. img = img.astype(np.float32)[:3].reshape(1,3,224,224)
  35. img /= 255.0
  36. img -= 0.5
  37. img /= 0.5
  38. out = m.forward(Tensor(img))
  39. outnp = out.numpy().ravel()
  40. choice = outnp.argmax()
  41. print(out.shape, choice, outnp[choice], lbls[choice])