|
@@ -235,10 +235,10 @@ async def load_shard(
|
|
|
model = apply_lora_layers(model, adapter_path)
|
|
|
model.eval()
|
|
|
|
|
|
- # TODO: figure out a better way
|
|
|
- if "llama" in str(model_path):
|
|
|
- tokenizer = load_tokenizer(model_path, tokenizer_config)
|
|
|
- return model, tokenizer
|
|
|
- elif "llava" in str(model_path):
|
|
|
+ # TODO: figure out a generic solution
|
|
|
+ if model.model_type == "llava":
|
|
|
processor = AutoProcessor.from_pretrained(model_path)
|
|
|
- return model, processor
|
|
|
+ return model, processor
|
|
|
+ else:
|
|
|
+ tokenizer = load_tokenizer(model_path, tokenizer_config)
|
|
|
+ return model, tokenizer
|