Ver código fonte

fix model import path

Alex Cheema 1 ano atrás
pai
commit
a04974168e
1 arquivos alterados com 1 adições e 1 exclusões
  1. 1 1
      exo/inference/mlx/sharded_utils.py

+ 1 - 1
exo/inference/mlx/sharded_utils.py

@@ -42,7 +42,7 @@ def _get_classes(config: dict):
     model_type = config["model_type"]
     model_type = MODEL_REMAPPING.get(model_type, model_type)
     try:
-        arch = importlib.import_module(f"inference.mlx.models.{model_type}")
+        arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
     except ImportError:
         msg = f"Model type {model_type} not supported."
         logging.error(msg)