Browse Source

fix model import path

Alex Cheema 1 year ago
parent
commit
a04974168e
1 changed files with 1 additions and 1 deletions
  1. 1 1
      exo/inference/mlx/sharded_utils.py

+ 1 - 1
exo/inference/mlx/sharded_utils.py

@@ -42,7 +42,7 @@ def _get_classes(config: dict):
     model_type = config["model_type"]
     model_type = MODEL_REMAPPING.get(model_type, model_type)
     try:
-        arch = importlib.import_module(f"inference.mlx.models.{model_type}")
+        arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
     except ImportError:
         msg = f"Model type {model_type} not supported."
         logging.error(msg)