Browse Source

use resolve_tokenizer consistently

Alex Cheema 8 months ago
parent
commit
914ffa8d35
2 changed files with 5 additions and 12 deletions
  1. 2 2
      exo/download/hf/hf_helpers.py
  2. 3 10
      exo/inference/mlx/sharded_utils.py

+ 2 - 2
exo/download/hf/hf_helpers.py

@@ -97,9 +97,9 @@ async def get_auth_headers():
   return {}
   return {}
 
 
 
 
-def get_repo_root(repo_id: str) -> Path:
+def get_repo_root(repo_id: str | Path) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 
 
 

+ 3 - 10
exo/inference/mlx/sharded_utils.py

@@ -15,12 +15,12 @@ import base64
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from transformers import AutoProcessor
 
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 from mlx_lm.tuner.utils import apply_lora_layers
 
 
 from exo import DEBUG
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -171,15 +171,8 @@ async def load_shard(
     model = apply_lora_layers(model, adapter_path)
     model = apply_lora_layers(model, adapter_path)
     model.eval()
     model.eval()
 
 
-  # TODO: figure out a generic solution
-  if model.model_type == "llava":
-    processor = AutoProcessor.from_pretrained(model_path)
-    processor.eos_token_id = processor.tokenizer.eos_token_id
-    processor.encode = processor.tokenizer.encode
-    return model, processor
-  else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
-    return model, tokenizer
+  tokenizer = await resolve_tokenizer(model_path)
+  return model, tokenizer
 
 
 
 
 async def get_image_from_str(_image_str: str):
 async def get_image_from_str(_image_str: str):