|
@@ -42,7 +42,7 @@ def _get_classes(config: dict):
|
|
|
model_type = config["model_type"]
|
|
|
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
|
try:
|
|
|
- arch = importlib.import_module(f"inference.mlx.models.{model_type}")
|
|
|
+ arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
|
|
|
except ImportError:
|
|
|
msg = f"Model type {model_type} not supported."
|
|
|
logging.error(msg)
|