compile.py 1.1 KB

1234567891011121314151617181920212223
  1. from pathlib import Path
  2. from examples.yolov8 import YOLOv8
  3. from tinygrad.tensor import Tensor
  4. from tinygrad.nn.state import safe_save
  5. from extra.export_model import export_model
  6. from tinygrad.helpers import fetch
  7. from tinygrad.helpers import getenv
  8. from tinygrad.device import Device
  9. from tinygrad.nn.state import safe_load, load_state_dict
  10. if __name__ == "__main__":
  11. Device.DEFAULT = "WEBGL"
  12. yolo_variant = 'n'
  13. yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
  14. weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors'
  15. fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location)
  16. state_dict = safe_load(weights_location)
  17. load_state_dict(yolo_infer, state_dict)
  18. prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,640,640))
  19. dirname = Path(__file__).parent
  20. safe_save(state, (dirname / "net.safetensors").as_posix())
  21. with open(dirname / f"net.js", "w") as text_file:
  22. text_file.write(prg)