tokenizers.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import traceback
  2. from aiofiles import os as aios
  3. from os import PathLike
  4. from pathlib import Path
  5. from typing import Union
  6. from transformers import AutoTokenizer, AutoProcessor
  7. import numpy as np
  8. from exo.download.hf.hf_helpers import get_local_snapshot_dir
  9. from exo.helpers import DEBUG
  10. class DummyTokenizer:
  11. def __init__(self):
  12. self.eos_token_id = 69
  13. self.vocab_size = 1000
  14. def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
  15. return "dummy_tokenized_prompt"
  16. def encode(self, text):
  17. return np.array([1])
  18. def decode(self, tokens):
  19. return "dummy" * len(tokens)
  20. async def resolve_tokenizer(model_id: str):
  21. if model_id == "dummy":
  22. return DummyTokenizer()
  23. local_path = await get_local_snapshot_dir(model_id)
  24. if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
  25. try:
  26. if local_path and await aios.path.exists(local_path):
  27. if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
  28. return await _resolve_tokenizer(local_path)
  29. except:
  30. if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
  31. if DEBUG >= 5: traceback.print_exc()
  32. return await _resolve_tokenizer(model_id)
  33. async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
  34. try:
  35. if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
  36. processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
  37. if not hasattr(processor, 'eos_token_id'):
  38. processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
  39. if not hasattr(processor, 'encode'):
  40. processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
  41. if not hasattr(processor, 'decode'):
  42. processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
  43. return processor
  44. except Exception as e:
  45. if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
  46. if DEBUG >= 4: print(traceback.format_exc())
  47. try:
  48. if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
  49. return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
  50. except Exception as e:
  51. if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
  52. if DEBUG >= 4: print(traceback.format_exc())
  53. raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")