test_efficientnet.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import ast
  2. import pathlib
  3. import unittest
  4. import numpy as np
  5. from PIL import Image
  6. from tinygrad.helpers import getenv
  7. from tinygrad.tensor import Tensor
  8. from extra.models.efficientnet import EfficientNet
  9. from extra.models.vit import ViT
  10. from extra.models.resnet import ResNet50
  11. def _load_labels():
  12. labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
  13. return ast.literal_eval(labels_filename.read_text())
  14. _LABELS = _load_labels()
  15. def preprocess(img, new=False):
  16. # preprocess image
  17. aspect_ratio = img.size[0] / img.size[1]
  18. img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
  19. img = np.array(img)
  20. y0, x0 =(np.asarray(img.shape)[:2] - 224) // 2
  21. img = img[y0: y0 + 224, x0: x0 + 224]
  22. # low level preprocess
  23. if new:
  24. img = img.astype(np.float32)
  25. img -= [127.0, 127.0, 127.0]
  26. img /= [128.0, 128.0, 128.0]
  27. img = img[None]
  28. else:
  29. img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
  30. img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
  31. img /= 255.0
  32. img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
  33. img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
  34. return img
  35. def _infer(model: EfficientNet, img, bs=1):
  36. old_training = Tensor.training
  37. Tensor.training = False
  38. img = preprocess(img)
  39. # run the net
  40. if bs > 1: img = img.repeat(bs, axis=0)
  41. out = model.forward(Tensor(img))
  42. Tensor.training = old_training
  43. return _LABELS[np.argmax(out.numpy()[0])]
  44. chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')
  45. car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')
  46. class TestEfficientNet(unittest.TestCase):
  47. @classmethod
  48. def setUpClass(cls):
  49. cls.model = EfficientNet(number=getenv("NUM"))
  50. cls.model.load_from_pretrained()
  51. @classmethod
  52. def tearDownClass(cls):
  53. del cls.model
  54. def test_chicken(self):
  55. label = _infer(self.model, chicken_img)
  56. self.assertEqual(label, "hen")
  57. def test_chicken_bigbatch(self):
  58. label = _infer(self.model, chicken_img, 2)
  59. self.assertEqual(label, "hen")
  60. def test_car(self):
  61. label = _infer(self.model, car_img)
  62. self.assertEqual(label, "sports car, sport car")
  63. class TestViT(unittest.TestCase):
  64. @classmethod
  65. def setUpClass(cls):
  66. cls.model = ViT()
  67. cls.model.load_from_pretrained()
  68. @classmethod
  69. def tearDownClass(cls):
  70. del cls.model
  71. def test_chicken(self):
  72. label = _infer(self.model, chicken_img)
  73. self.assertEqual(label, "cock")
  74. def test_car(self):
  75. label = _infer(self.model, car_img)
  76. self.assertEqual(label, "racer, race car, racing car")
  77. class TestResNet(unittest.TestCase):
  78. @classmethod
  79. def setUpClass(cls):
  80. cls.model = ResNet50()
  81. cls.model.load_from_pretrained()
  82. @classmethod
  83. def tearDownClass(cls):
  84. del cls.model
  85. def test_chicken(self):
  86. label = _infer(self.model, chicken_img)
  87. self.assertEqual(label, "hen")
  88. def test_car(self):
  89. label = _infer(self.model, car_img)
  90. self.assertEqual(label, "sports car, sport car")
  91. if __name__ == '__main__':
  92. unittest.main()