Browse Source

first try loading tokenizer from local path instead of always going to the internet first. significant speed ups

Alex Cheema 1 year ago
parent
commit
59c4393d95

+ 10 - 0
exo/download/hf/hf_helpers.py

@@ -17,6 +17,16 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
+async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
+  refs_dir = get_repo_root(repo_id)/"refs"
+  refs_file = refs_dir/revision
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
+      return snapshot_dir
+  return None
+
 
 def filter_repo_objects(
   items: Iterable[T],

+ 4 - 2
exo/download/hf/hf_shard_download.py

@@ -24,8 +24,10 @@ class HFShardDownloader(ShardDownloader):
       repo_root = get_repo_root(shard.model_id)
       snapshots_dir = repo_root/"snapshots"
       if snapshots_dir.exists():
-        most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
-        return most_recent_dir
+        visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
+        if visible_dirs:
+          most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
+          return most_recent_dir
 
     # If a download on this shard is already in progress, keep that one
     for active_shard in self.active_downloads:

+ 2 - 1
exo/inference/tinygrad/inference.py

@@ -3,6 +3,7 @@ import json
 import os
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.shard import Shard
+from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad import Tensor, dtypes, nn, Context
 from transformers import AutoTokenizer
@@ -90,5 +91,5 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
     model_path = await self.shard_downloader.ensure_shard(shard)
     self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
-    self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
+    self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent)))
     self.shard = shard

+ 17 - 8
exo/inference/tokenizers.py

@@ -1,12 +1,21 @@
 import traceback
+from aiofiles import os as aios
 from transformers import AutoTokenizer, AutoProcessor
+from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 
-
 async def resolve_tokenizer(model_id: str):
+  local_path = await get_local_snapshot_dir(model_id)
+  if await aios.path.exists(local_path):
+    if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
+    return await _resolve_tokenizer(local_path)
+  if DEBUG >= 2: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
+  return await _resolve_tokenizer(model_id)
+
+async def _resolve_tokenizer(model_id_or_local_path: str):
   try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
-    processor = AutoProcessor.from_pretrained(model_id, use_fast=True if "Mistral-Large" in model_id else False)
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
+    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in model_id_or_local_path else False)
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
     if not hasattr(processor, 'encode'):
@@ -15,14 +24,14 @@ async def resolve_tokenizer(model_id: str):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
 
   try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
-    return AutoTokenizer.from_pretrained(model_id)
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
+    return AutoTokenizer.from_pretrained(model_id_or_local_path)
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
 
-  raise ValueError(f"[TODO] Unsupported model: {model_id}")
+  raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")