| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import ast
- import pathlib
- import unittest
- import numpy as np
- from PIL import Image
- from tinygrad.helpers import getenv
- from tinygrad.tensor import Tensor
- from extra.models.efficientnet import EfficientNet
- from extra.models.vit import ViT
- from extra.models.resnet import ResNet50
- def _load_labels():
- labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
- return ast.literal_eval(labels_filename.read_text())
- _LABELS = _load_labels()
- def preprocess(img, new=False):
- # preprocess image
- aspect_ratio = img.size[0] / img.size[1]
- img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
- img = np.array(img)
- y0, x0 =(np.asarray(img.shape)[:2] - 224) // 2
- img = img[y0: y0 + 224, x0: x0 + 224]
- # low level preprocess
- if new:
- img = img.astype(np.float32)
- img -= [127.0, 127.0, 127.0]
- img /= [128.0, 128.0, 128.0]
- img = img[None]
- else:
- img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
- img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
- img /= 255.0
- img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
- img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
- return img
- def _infer(model: EfficientNet, img, bs=1):
- old_training = Tensor.training
- Tensor.training = False
- img = preprocess(img)
- # run the net
- if bs > 1: img = img.repeat(bs, axis=0)
- out = model.forward(Tensor(img))
- Tensor.training = old_training
- return _LABELS[np.argmax(out.numpy()[0])]
- chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')
- car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')
- class TestEfficientNet(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.model = EfficientNet(number=getenv("NUM"))
- cls.model.load_from_pretrained()
- @classmethod
- def tearDownClass(cls):
- del cls.model
- def test_chicken(self):
- label = _infer(self.model, chicken_img)
- self.assertEqual(label, "hen")
- def test_chicken_bigbatch(self):
- label = _infer(self.model, chicken_img, 2)
- self.assertEqual(label, "hen")
- def test_car(self):
- label = _infer(self.model, car_img)
- self.assertEqual(label, "sports car, sport car")
- class TestViT(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.model = ViT()
- cls.model.load_from_pretrained()
- @classmethod
- def tearDownClass(cls):
- del cls.model
- def test_chicken(self):
- label = _infer(self.model, chicken_img)
- self.assertEqual(label, "cock")
- def test_car(self):
- label = _infer(self.model, car_img)
- self.assertEqual(label, "racer, race car, racing car")
- class TestResNet(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.model = ResNet50()
- cls.model.load_from_pretrained()
- @classmethod
- def tearDownClass(cls):
- del cls.model
- def test_chicken(self):
- label = _infer(self.model, chicken_img)
- self.assertEqual(label, "hen")
- def test_car(self):
- label = _infer(self.model, car_img)
- self.assertEqual(label, "sports car, sport car")
- if __name__ == '__main__':
- unittest.main()
|