|
@@ -1,12 +1,21 @@
|
|
|
import traceback
|
|
|
+from aiofiles import os as aios
|
|
|
from transformers import AutoTokenizer, AutoProcessor
|
|
|
+from exo.download.hf.hf_helpers import get_local_snapshot_dir
|
|
|
from exo.helpers import DEBUG
|
|
|
|
|
|
-
|
|
|
async def resolve_tokenizer(model_id: str):
|
|
|
+ local_path = await get_local_snapshot_dir(model_id)
|
|
|
+ if await aios.path.exists(local_path):
|
|
|
+ if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
|
|
|
+ return await _resolve_tokenizer(local_path)
|
|
|
+ if DEBUG >= 2: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
|
|
|
+ return await _resolve_tokenizer(model_id)
|
|
|
+
|
|
|
+async def _resolve_tokenizer(model_id_or_local_path: str):
|
|
|
try:
|
|
|
- if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
|
|
|
- processor = AutoProcessor.from_pretrained(model_id, use_fast=True if "Mistral-Large" in model_id else False)
|
|
|
+ if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
|
|
|
+ processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in model_id_or_local_path else False)
|
|
|
if not hasattr(processor, 'eos_token_id'):
|
|
|
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
|
|
|
if not hasattr(processor, 'encode'):
|
|
@@ -15,14 +24,14 @@ async def resolve_tokenizer(model_id: str):
|
|
|
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
|
|
|
return processor
|
|
|
except Exception as e:
|
|
|
- if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
|
|
|
+ if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
|
|
|
if DEBUG >= 4: print(traceback.format_exc())
|
|
|
|
|
|
try:
|
|
|
- if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
|
|
|
- return AutoTokenizer.from_pretrained(model_id)
|
|
|
+ if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
|
|
|
+ return AutoTokenizer.from_pretrained(model_id_or_local_path)
|
|
|
except Exception as e:
|
|
|
- if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
|
|
|
+ if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
|
|
|
if DEBUG >= 4: print(traceback.format_exc())
|
|
|
|
|
|
- raise ValueError(f"[TODO] Unsupported model: {model_id}")
|
|
|
+ raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
|