external_model_benchmark.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import csv, pathlib, time, numpy as np
  2. from os import getenv
  3. from tinygrad.device import CompileError
  4. import torch
  5. torch.set_num_threads(1)
  6. import onnx
  7. from onnx.helper import tensor_dtype_to_np_dtype
  8. import onnxruntime as ort
  9. from onnx2torch import convert
  10. from extra.onnx import get_run_onnx
  11. from tinygrad.helpers import OSX, DEBUG, fetch
  12. from tinygrad import Tensor, Device
  13. MODELS = {
  14. "resnet50": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx",
  15. "openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx",
  16. "efficientnet": "https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx",
  17. "shufflenet": "https://github.com/onnx/models/raw/main/validated/vision/classification/shufflenet/model/shufflenet-9.onnx",
  18. "commavq": "https://huggingface.co/commaai/commavq-gpt2m/resolve/main/gpt2m.onnx",
  19. "dm": "https://github.com/commaai/openpilot/raw/ba7f840a06dbc8ae3c45b3b4976c88a21895aed0/selfdrive/modeld/models/dmonitoring_model.onnx",
  20. # broken in torch MPS
  21. # "zfnet": "https://github.com/onnx/models/raw/main/archive/vision/classification/zfnet-512/model/zfnet512-9.onnx",
  22. # TypeError: BatchNormalization() got an unexpected keyword argument 'is_test'
  23. # "densenet": "https://github.com/onnx/models/raw/main/archive/vision/classification/densenet-121/model/densenet-3.onnx",
  24. # AssertionError: only onnx version >= 10 supported for slice
  25. # "bert": "https://github.com/onnx/models/raw/main/archive/text/machine_comprehension/bert-squad/model/bertsquad-8.onnx",
  26. # really slow
  27. # "resnet18": "https://github.com/onnx/models/raw/main/archive/vision/classification/resnet/model/resnet18-v2-7.onnx",
  28. }
  29. CSV = {}
  30. open_csv = None
  31. def benchmark(mnm, nm, fxn):
  32. tms = []
  33. for _ in range(3):
  34. st = time.perf_counter_ns()
  35. ret = fxn()
  36. tms.append(time.perf_counter_ns() - st)
  37. print(f"{mnm:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms")
  38. CSV[nm] = min(tms)*1e-6
  39. return min(tms), ret
  40. #BASE = pathlib.Path(__file__).parents[2] / "weights" / "onnx"
  41. BASE = pathlib.Path("/tmp/onnx")
  42. def benchmark_model(m, devices, validate_outs=False):
  43. torch.manual_seed(1)
  44. global open_csv, CSV
  45. CSV = {"model": m}
  46. fn = fetch(MODELS[m])
  47. onnx_model = onnx.load(fn)
  48. output_names = [out.name for out in onnx_model.graph.output]
  49. excluded = {inp.name for inp in onnx_model.graph.initializer}
  50. input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501
  51. input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded}
  52. #input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast
  53. np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()}
  54. assert len(input_shapes) < 30, f"too many input shapes {len(input_shapes)}"
  55. # print input names
  56. if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
  57. for device in devices:
  58. try:
  59. Device.DEFAULT = device
  60. inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
  61. tinygrad_model = get_run_onnx(onnx_model)
  62. benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()})
  63. from tinygrad.engine.jit import TinyJit
  64. tinygrad_jitted_model = TinyJit(lambda **kwargs: {k:v.realize() for k,v in tinygrad_model(kwargs).items()})
  65. for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
  66. benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
  67. del inputs, tinygrad_model, tinygrad_jitted_model
  68. except CompileError as e:
  69. # METAL fails with buffer count limit
  70. if m == "dm" and device == "METAL": return
  71. raise e
  72. # convert model to torch
  73. try:
  74. torch_model = convert(onnx_model)
  75. except Exception as e:
  76. # model conversion failed
  77. print(f"{m:16s}onnx2torch {type(e).__name__:>25}")
  78. else:
  79. torch_inputs = [torch.tensor(x) for x in np_inputs.values()]
  80. try: benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs))
  81. except Exception as e: print(f"{m:16s}torch_cpu {type(e).__name__:>25}")
  82. torch_device = "mps" if OSX else "cuda"
  83. torch_mps_model = torch_model.to(torch_device)
  84. torch_mps_inputs = [x.to(torch_device) for x in torch_inputs]
  85. try: benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs))
  86. except Exception as e: print(f"{m:16s}torch_{torch_device} {type(e).__name__:>25}")
  87. # bench onnxruntime
  88. ort_options = ort.SessionOptions()
  89. ort_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
  90. ort_options.log_severity_level = 3 # no warnings
  91. for backend in ["CPU", "CUDA" if not OSX else "CoreML"]: # https://onnxruntime.ai/docs/execution-providers/
  92. provider = backend+"ExecutionProvider"
  93. if provider not in ort.get_available_providers(): continue
  94. ort_sess = ort.InferenceSession(str(fn), ort_options, [provider])
  95. try:
  96. benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs))
  97. except Exception as e: print(f"{m:16s}onnxruntime_{backend.lower()} {type(e).__name__:>25}")
  98. del ort_sess
  99. if validate_outs:
  100. for device in devices:
  101. rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models
  102. Device.DEFAULT = device
  103. inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
  104. tinygrad_model = get_run_onnx(onnx_model)
  105. tinygrad_out = tinygrad_model(inputs)
  106. ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"])
  107. onnx_out = ort_sess.run(output_names, np_inputs)
  108. onnx_out = dict([*list(zip(output_names, onnx_out))])
  109. assert_allclose(tinygrad_out, onnx_out, rtol=rtol, atol=atol)
  110. print(f"{m:16s}outputs validated on {device=} with rtol={rtol:.1e}, atol={atol:.1e}")
  111. if open_csv is None:
  112. open_csv = csv.DictWriter(open('onnx_inference_speed.csv', 'w', newline=''), fieldnames=list(CSV.keys()))
  113. open_csv.writeheader()
  114. open_csv.writerow(CSV)
  115. def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5):
  116. assert len(tiny_out) == len(onnx_out) and tiny_out.keys() == onnx_out.keys()
  117. for k in tiny_out.keys():
  118. tiny_v, onnx_v = tiny_out[k], onnx_out[k]
  119. if tiny_v is None: assert tiny_v == onnx_v
  120. else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}")
  121. if __name__ == "__main__":
  122. devices = [Device.DEFAULT] if getenv("NOCLANG") else [Device.DEFAULT, "CLANG"]
  123. if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), devices, True)
  124. else:
  125. for m in MODELS: benchmark_model(m, devices, True)