external_test_onnx_backend.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import unittest
  2. from typing import Any, Tuple
  3. from onnx.backend.base import Backend, BackendRep
  4. import onnx.backend.test
  5. import numpy as np
  6. from tinygrad import Tensor, Device, dtypes
  7. from tinygrad.helpers import getenv, OSX
  8. from test.helpers import is_dtype_supported
  9. # pip3 install tabulate
  10. pytest_plugins = 'onnx.backend.test.report',
  11. from extra.onnx import get_run_onnx
  12. class TinygradModel(BackendRep):
  13. def __init__(self, run_onnx, input_names):
  14. super().__init__()
  15. self.fxn = run_onnx
  16. self.input_names = input_names
  17. def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
  18. real_inputs = dict(zip(self.input_names, inputs))
  19. ret = self.fxn(real_inputs, debug=True)
  20. return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values())
  21. class TinygradBackend(Backend):
  22. @classmethod
  23. def prepare(cls, model, device):
  24. input_all = [x.name for x in model.graph.input]
  25. input_initializer = [x.name for x in model.graph.initializer]
  26. net_feed_input = [x for x in input_all if x not in input_initializer]
  27. print("prepare", cls, device, net_feed_input)
  28. run_onnx = get_run_onnx(model)
  29. return TinygradModel(run_onnx, net_feed_input)
  30. @classmethod
  31. def supports_device(cls, device: str) -> bool:
  32. # NOTE: this is onnx CPU
  33. return device == "CPU"
  34. backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
  35. # no support for reduce with multiply (needs llop)
  36. backend_test.exclude('test_reduce_prod_*')
  37. # TODO figure out why it's returning wrong values, geohotstan's uneducated guess is it's due to imprecision from float64 (double) -> float32
  38. # see Type Constraints: https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Adam.html#type-constraints
  39. backend_test.exclude('test_adam_multiple_cpu')
  40. backend_test.exclude('test_nesterov_momentum_cpu')
  41. # about different dtypes
  42. if not is_dtype_supported(dtypes.float64):
  43. backend_test.exclude('float64')
  44. backend_test.exclude('DOUBLE')
  45. # these have float64 inputs
  46. backend_test.exclude('test_eyelike_with_dtype_cpu')
  47. backend_test.exclude('test_reduce_log_sum_exp*')
  48. backend_test.exclude('test_operator_add*')
  49. backend_test.exclude('test_einsum_*')
  50. backend_test.exclude('test_cumsum_*')
  51. if not is_dtype_supported(dtypes.float16):
  52. backend_test.exclude('float16')
  53. backend_test.exclude('FLOAT16')
  54. # dtype cast
  55. backend_test.exclude('STRING')
  56. backend_test.exclude('FLOAT8')
  57. backend_test.exclude('INT4')
  58. backend_test.exclude('UINT4')
  59. backend_test.exclude('BFLOAT16') # not supported in numpy
  60. # TODO: fix these with true onnx float16
  61. backend_test.exclude('to_FLOAT16')
  62. backend_test.exclude('cast_no_saturate')
  63. backend_test.exclude('test_pow_types_int*')
  64. backend_test.exclude('test_convinteger_*')
  65. backend_test.exclude('test_matmulinteger_*')
  66. # we don't support indexes
  67. backend_test.exclude('test_nonzero_*')
  68. # no support for mod
  69. backend_test.exclude('test_mod_*')
  70. # no boolean ops (2d, 3d, 4d)
  71. backend_test.exclude('test_bitshift_*')
  72. # no string ops
  73. backend_test.exclude('string')
  74. backend_test.exclude('test_strnorm_*')
  75. backend_test.exclude('test_regex_*')
  76. # no scatternd gathernd
  77. backend_test.exclude('test_gathernd_*')
  78. backend_test.exclude('test_scatternd_*')
  79. # no quantize
  80. backend_test.exclude('test_dynamicquantizelinear_*')
  81. backend_test.exclude('test_qlinearmatmul_*')
  82. backend_test.exclude('test_qlinearconv_*')
  83. backend_test.exclude('test_quantizelinear_*')
  84. # no rnn
  85. backend_test.exclude('test_gru_*')
  86. backend_test.exclude('test_rnn_*')
  87. backend_test.exclude('test_lstm_*')
  88. backend_test.exclude('test_simple_rnn_*')
  89. # no control flow
  90. # control flow uses AttributeProto.GRAPH
  91. backend_test.exclude('test_if_*')
  92. backend_test.exclude('test_loop*')
  93. backend_test.exclude('test_range_float_type_positive_delta_expanded_cpu') # requires loop
  94. backend_test.exclude('test_affine_grid_2d_align_corners_expanded_cpu')
  95. backend_test.exclude('test_affine_grid_2d_expanded_cpu')
  96. backend_test.exclude('test_affine_grid_3d_align_corners_expanded_cpu')
  97. backend_test.exclude('test_affine_grid_3d_expanded_cpu')
  98. backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu')
  99. # unsupported (strange) ops
  100. backend_test.exclude('test_bitwise_*')
  101. backend_test.exclude('test_blackmanwindow_*')
  102. backend_test.exclude('test_bernoulli_*')
  103. backend_test.exclude('test_det_*')
  104. backend_test.exclude('test_col2im_*')
  105. backend_test.exclude('test_hammingwindow_*')
  106. backend_test.exclude('test_hannwindow_*')
  107. backend_test.exclude('test_hardmax_*')
  108. backend_test.exclude('test_gridsample_*')
  109. backend_test.exclude('test_dft_*')
  110. backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i'
  111. backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i'
  112. backend_test.exclude('test_unique_*')
  113. backend_test.exclude('test_sequence_*')
  114. backend_test.exclude('test_nonmaxsuppression_*')
  115. backend_test.exclude('test_reversesequence_*')
  116. backend_test.exclude('test_roialign_*')
  117. backend_test.exclude('test_top_k_*')
  118. backend_test.exclude('test_tfidfvectorizer_*')
  119. backend_test.exclude('test_stft_*')
  120. backend_test.exclude('test_melweightmatrix_*')
  121. # more strange ops
  122. backend_test.exclude('test_basic_deform_conv_*')
  123. backend_test.exclude('test_deform_conv_*')
  124. backend_test.exclude('test_lppool_*')
  125. backend_test.exclude('test_scan*')
  126. backend_test.exclude('test_split_to_sequence_*')
  127. backend_test.exclude('test_resize_downsample_scales_cubic_*') # unsure how to implement cubic
  128. backend_test.exclude('test_resize_downsample_sizes_cubic_*') # unsure how to implement cubic
  129. backend_test.exclude('test_resize_upsample_scales_cubic_*') # unsure how to implement cubic
  130. backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to implement cubic
  131. backend_test.exclude('test_ai_onnx_ml_tree_ensemble_*') # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/aionnxml/op_tree_ensemble.py#L121
  132. # rest of the failing tests
  133. backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
  134. backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
  135. backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip
  136. backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
  137. backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
  138. backend_test.exclude('test_group_normalization_*') # numerical inaccuracy problem. Current Group Normalization OP fails test
  139. if Device.DEFAULT in ['GPU', 'METAL']:
  140. backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
  141. backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')
  142. backend_test.exclude('test_resize_upsample_sizes_nearest_cpu')
  143. if Device.DEFAULT == "METAL" or (OSX and Device.DEFAULT == "GPU"):
  144. # numerical inaccuracy
  145. backend_test.exclude('test_mish_cpu')
  146. backend_test.exclude('test_mish_expanded_cpu')
  147. # disable model tests for now since they are slow
  148. if not getenv("MODELTESTS"):
  149. for x in backend_test.test_suite:
  150. if 'OnnxBackendRealModelTest' in str(type(x)):
  151. backend_test.exclude(str(x).split(" ")[0])
  152. else:
  153. # model tests all pass!
  154. backend_test.include('test_resnet50')
  155. backend_test.include('test_inception_v1')
  156. backend_test.include('test_inception_v2')
  157. backend_test.include('test_densenet121')
  158. backend_test.include('test_shufflenet')
  159. backend_test.include('test_squeezenet')
  160. backend_test.include('test_bvlc_alexnet')
  161. backend_test.include('test_zfnet512')
  162. backend_test.include('test_vgg19')
  163. globals().update(backend_test.enable_report().test_cases)
  164. if __name__ == '__main__':
  165. unittest.main()