Переглянути джерело

fix Mistral-Large special case when we pass in a path

Alex Cheema 8 місяців тому
батько
коміт
01cc6a4c9d
1 змінених файлів з 5 додано та 2 видалено
  1. 5 2
      exo/inference/tokenizers.py

+ 5 - 2
exo/inference/tokenizers.py

@@ -1,5 +1,8 @@
 import traceback
 from aiofiles import os as aios
+from os import PathLike
+from pathlib import Path
+from typing import Union
 from transformers import AutoTokenizer, AutoProcessor
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
@@ -16,10 +19,10 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 5: traceback.print_exc()
   return await _resolve_tokenizer(model_id)
 
-async def _resolve_tokenizer(model_id_or_local_path: str):
+async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
-    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in model_id_or_local_path else False)
+    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False)
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
     if not hasattr(processor, 'encode'):