12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # load each model here, quick benchmark
- from tinygrad import Tensor, GlobalCounters
- from tinygrad.helpers import getenv
- import numpy as np
- def test_model(model, *inputs):
- GlobalCounters.reset()
- out = model(*inputs)
- if isinstance(out, Tensor): out = out.numpy()
- # TODO: return event future to still get the time_sum_s without DEBUG=2
- print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
- def spec_resnet():
- # Resnet50-v1.5
- from extra.models.resnet import ResNet50
- mdl = ResNet50()
- img = Tensor.randn(1, 3, 224, 224)
- test_model(mdl, img)
- def spec_retinanet():
- # Retinanet with ResNet backbone
- from extra.models.resnet import ResNet50
- from extra.models.retinanet import RetinaNet
- mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
- img = Tensor.randn(1, 3, 224, 224)
- test_model(mdl, img)
- def spec_unet3d():
- # 3D UNET
- from extra.models.unet3d import UNet3D
- mdl = UNet3D()
- #mdl.load_from_pretrained()
- img = Tensor.randn(1, 1, 128, 128, 128)
- test_model(mdl, img)
- def spec_rnnt():
- from extra.models.rnnt import RNNT
- mdl = RNNT()
- #mdl.load_from_pretrained()
- x = Tensor.randn(220, 1, 240)
- y = Tensor.randn(1, 220)
- test_model(mdl, x, y)
- def spec_bert():
- from extra.models.bert import BertForQuestionAnswering
- mdl = BertForQuestionAnswering()
- #mdl.load_from_pretrained()
- x = Tensor.randn(1, 384)
- am = Tensor.randn(1, 384)
- tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
- test_model(mdl, x, am, tt)
- def spec_mrcnn():
- from extra.models.mask_rcnn import MaskRCNN, ResNet
- mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
- #mdl.load_from_pretrained()
- x = Tensor.randn(3, 224, 224)
- test_model(mdl, [x])
- if __name__ == "__main__":
- # inference only for now
- Tensor.training = False
- Tensor.no_grad = True
- for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
- nm = f"spec_{m}"
- if nm in globals():
- print(f"testing {m}")
- globals()[nm]()
|