external_test_yolov8.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import numpy as np
  2. from examples.yolov8 import YOLOv8, get_variant_multiples, preprocess, postprocess, label_predictions
  3. import unittest
  4. import io, cv2
  5. import onnxruntime as ort
  6. import ultralytics
  7. from tinygrad.nn.state import safe_load, load_state_dict
  8. from tinygrad.helpers import fetch
  9. class TestYOLOv8(unittest.TestCase):
  10. def test_all_load_weights(self):
  11. for variant in ['n', 's', 'm', 'l', 'x']:
  12. depth, width, ratio = get_variant_multiples(variant)
  13. TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
  14. state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors'))
  15. load_state_dict(TinyYolov8, state_dict)
  16. print(f'successfully loaded weights for yolov{variant}')
  17. def test_predictions(self):
  18. test_image_urls = ['https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg', 'https://www.aljazeera.com/wp-content/uploads/2022/10/2022-04-28T192650Z_1186456067_UP1EI4S1I0P14_RTRMADP_3_SOCCER-ENGLAND-MUN-CHE-REPORT.jpg']
  19. variant = 'n'
  20. depth, width, ratio = get_variant_multiples(variant)
  21. TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
  22. state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors'))
  23. load_state_dict(TinyYolov8, state_dict)
  24. for i in range(len(test_image_urls)):
  25. img = cv2.imdecode(np.frombuffer(fetch(test_image_urls[i]).read_bytes(), np.uint8), 1)
  26. test_image = preprocess([img])
  27. predictions = TinyYolov8(test_image)
  28. post_predictions = postprocess(preds=predictions, img=test_image, orig_imgs=[img])
  29. labels = label_predictions(post_predictions)
  30. assert labels == {5: 1, 0: 4, 11: 1} if i == 0 else labels == {0: 13, 29: 1, 32: 1}
  31. def test_forward_pass_torch_onnx(self):
  32. variant = 'n'
  33. weights_location = fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors')
  34. weights_location_pt = fetch(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', name=f"yolov8{variant}.pt") # it needs the pt extension # noqa: E501
  35. weights_location_onnx = weights_location_pt.parent / f"yolov8{variant}.onnx"
  36. # the ultralytics export prints a lot of unneccesary things
  37. if not weights_location_onnx.is_file():
  38. model = ultralytics.YOLO(model=weights_location_pt, task='Detect')
  39. model.export(format="onnx",imgsz=[640, 480])
  40. depth, width, ratio = get_variant_multiples(variant)
  41. TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
  42. state_dict = safe_load(weights_location)
  43. load_state_dict(TinyYolov8, state_dict)
  44. image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg').read_bytes()).read(), np.uint8)] # noqa: E501
  45. orig_image = [cv2.imdecode(image_location[0], 1)]
  46. input_image = preprocess(orig_image)
  47. onnx_session = ort.InferenceSession(weights_location_onnx)
  48. onnx_input_name = onnx_session.get_inputs()[0].name
  49. onnx_output_name = onnx_session.get_outputs()[0].name
  50. onnx_output = onnx_session.run([onnx_output_name], {onnx_input_name: input_image.numpy()})
  51. tiny_output = TinyYolov8(input_image)
  52. # currently rtol is 0.025 because there is a 1-2% difference in our predictions
  53. # because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch.
  54. # This difference does not make a difference "visually".
  55. np.testing.assert_allclose(onnx_output[0], tiny_output.numpy(), atol=5e-4, rtol=0.025)
  56. if __name__ == '__main__':
  57. unittest.main()