Răsfoiți Sursa

processor load

Varshith 1 an în urmă
părinte
comite
2849128d6a
1 a modificat fișierele cu 6 adăugiri și 6 ștergeri
  1. 6 6
      exo/inference/mlx/sharded_utils.py

+ 6 - 6
exo/inference/mlx/sharded_utils.py

@@ -235,10 +235,10 @@ async def load_shard(
         model = apply_lora_layers(model, adapter_path)
         model.eval()
 
-    # TODO: figure out a better way
-    if "llama" in str(model_path):
-        tokenizer = load_tokenizer(model_path, tokenizer_config)
-        return model, tokenizer
-    elif "llava" in str(model_path):
+    # TODO: figure out a generic solution
+    if model.model_type == "llava":
         processor = AutoProcessor.from_pretrained(model_path)
-        return model, processor
+        return model, processor
+    else:
+        tokenizer = load_tokenizer(model_path, tokenizer_config)
+        return model, tokenizer