소스 검색

processor load

Varshith 1 년 전
부모
커밋
2849128d6a
1개의 변경된 파일6개의 추가작업 그리고 6개의 파일을 삭제
  1. 6 6
      exo/inference/mlx/sharded_utils.py

+ 6 - 6
exo/inference/mlx/sharded_utils.py

@@ -235,10 +235,10 @@ async def load_shard(
         model = apply_lora_layers(model, adapter_path)
         model.eval()
 
-    # TODO: figure out a better way
-    if "llama" in str(model_path):
-        tokenizer = load_tokenizer(model_path, tokenizer_config)
-        return model, tokenizer
-    elif "llava" in str(model_path):
+    # TODO: figure out a generic solution
+    if model.model_type == "llava":
         processor = AutoProcessor.from_pretrained(model_path)
-        return model, processor
+        return model, processor
+    else:
+        tokenizer = load_tokenizer(model_path, tokenizer_config)
+        return model, tokenizer