model_spec.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # load each model here, quick benchmark
  2. from tinygrad import Tensor, GlobalCounters
  3. from tinygrad.helpers import getenv
  4. import numpy as np
  5. def test_model(model, *inputs):
  6. GlobalCounters.reset()
  7. out = model(*inputs)
  8. if isinstance(out, Tensor): out = out.numpy()
  9. # TODO: return event future to still get the time_sum_s without DEBUG=2
  10. print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
  11. def spec_resnet():
  12. # Resnet50-v1.5
  13. from extra.models.resnet import ResNet50
  14. mdl = ResNet50()
  15. img = Tensor.randn(1, 3, 224, 224)
  16. test_model(mdl, img)
  17. def spec_retinanet():
  18. # Retinanet with ResNet backbone
  19. from extra.models.resnet import ResNet50
  20. from extra.models.retinanet import RetinaNet
  21. mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
  22. img = Tensor.randn(1, 3, 224, 224)
  23. test_model(mdl, img)
  24. def spec_unet3d():
  25. # 3D UNET
  26. from extra.models.unet3d import UNet3D
  27. mdl = UNet3D()
  28. #mdl.load_from_pretrained()
  29. img = Tensor.randn(1, 1, 128, 128, 128)
  30. test_model(mdl, img)
  31. def spec_rnnt():
  32. from extra.models.rnnt import RNNT
  33. mdl = RNNT()
  34. #mdl.load_from_pretrained()
  35. x = Tensor.randn(220, 1, 240)
  36. y = Tensor.randn(1, 220)
  37. test_model(mdl, x, y)
  38. def spec_bert():
  39. from extra.models.bert import BertForQuestionAnswering
  40. mdl = BertForQuestionAnswering()
  41. #mdl.load_from_pretrained()
  42. x = Tensor.randn(1, 384)
  43. am = Tensor.randn(1, 384)
  44. tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
  45. test_model(mdl, x, am, tt)
  46. def spec_mrcnn():
  47. from extra.models.mask_rcnn import MaskRCNN, ResNet
  48. mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
  49. #mdl.load_from_pretrained()
  50. x = Tensor.randn(3, 224, 224)
  51. test_model(mdl, [x])
  52. if __name__ == "__main__":
  53. # inference only for now
  54. Tensor.training = False
  55. Tensor.no_grad = True
  56. for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
  57. nm = f"spec_{m}"
  58. if nm in globals():
  59. print(f"testing {m}")
  60. globals()[nm]()