Browse Source

standardise tinygrad models/tokenizers so it can handle mlx hf

Alex Cheema 9 months ago
parent
commit
55bcad98e3
2 changed files with 31 additions and 59 deletions
  1. 23 56
      exo/inference/tinygrad/inference.py
  2. 8 3
      exo/inference/tinygrad/models/llama.py

+ 23 - 56
exo/inference/tinygrad/inference.py

@@ -44,42 +44,6 @@ MODEL_PARAMS = {
 }
 }
 
 
 
 
-class Tokenizer:
-  pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
-
-  def __init__(self, model_path: str):
-    mergeable_ranks = load_tiktoken_bpe(model_path)
-    self.num_base_tokens = len(mergeable_ranks)
-    special_tokens = [
-      "<|begin_of_text|>",
-      "<|end_of_text|>",
-      "<|reserved_special_token_0|>",
-      "<|reserved_special_token_1|>",
-      "<|reserved_special_token_2|>",
-      "<|reserved_special_token_3|>",
-      "<|start_header_id|>",
-      "<|end_header_id|>",
-      "<|reserved_special_token_4|>",
-      "<|eot_id|>",
-    ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
-    self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
-
-    self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
-
-  @property
-  def bos_id(self):
-    return self.special_tokens["<|begin_of_text|>"]
-
-  @property
-  def stop_tokens(self):
-    return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
-
-  def decode(self, toks):
-    return self.model.decode([t for t in toks if t < self.num_base_tokens])
-
-  def encode(self, text, allow_special=False):
-    return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
-
 
 
 # **** helper functions ****
 # **** helper functions ****
 async def fetch_async(
 async def fetch_async(
@@ -214,7 +178,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     return (
     return (
       output_data,
       output_data,
       json.dumps({"start_pos": start_pos}),
       json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens,
+      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
     )
     )
 
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
@@ -228,7 +192,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     return (
     return (
       output_data,
       output_data,
       json.dumps({"start_pos": start_pos}),
       json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens,
+      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
     )
     )
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
@@ -239,40 +203,42 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
     models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
     model_path = models_dir / shard.model_id
     model_path = models_dir / shard.model_id
     size = "8B"
     size = "8B"
-    if Path(model_path / "model.safetensors.index.json").exists():
+    if Path(model_path / "tokenizer_config.json").exists():
       model = model_path
       model = model_path
     else:
     else:
 
 
       if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
       if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
       if shard.model_id.lower().find("llama3-8b-sfr") != -1:
       if shard.model_id.lower().find("llama3-8b-sfr") != -1:
+        num_files = 4
+        for i in range(num_files):
+          await fetch_async(
+            f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors",
+            f"model-{(i+1):05d}-of-{num_files:05d}.safetensors",
+            subdir=shard.model_id,
+          )
         await fetch_async(
         await fetch_async(
-          "https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model",
-          "tokenizer.model",
+          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json",
+          "config.json",
           subdir=shard.model_id,
           subdir=shard.model_id,
         )
         )
-        await fetch_async(
-          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors",
-          "model-00001-of-00004.safetensors",
+        model = await fetch_async(
+          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json",
+          "model.safetensors.index.json",
           subdir=shard.model_id,
           subdir=shard.model_id,
         )
         )
         await fetch_async(
         await fetch_async(
-          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors",
-          "model-00002-of-00004.safetensors",
+          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json",
+          "special_tokens_map.json",
           subdir=shard.model_id,
           subdir=shard.model_id,
         )
         )
         await fetch_async(
         await fetch_async(
-          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors",
-          "model-00003-of-00004.safetensors",
+          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json",
+          "tokenizer.json",
           subdir=shard.model_id,
           subdir=shard.model_id,
         )
         )
         await fetch_async(
         await fetch_async(
-          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors",
-          "model-00004-of-00004.safetensors",
-          subdir=shard.model_id,
-        )
-        model = await fetch_async(
-          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json",
-          "model.safetensors.index.json",
+          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json",
+          "tokenizer_config.json",
           subdir=shard.model_id,
           subdir=shard.model_id,
         )
         )
         size = "8B"
         size = "8B"
@@ -289,7 +255,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
         raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
 
 
     model = build_transformer(model_path, shard=shard, model_size=size)
     model = build_transformer(model_path, shard=shard, model_size=size)
-    tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
+    from transformers import AutoTokenizer
+    tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
 
 
     self.shard = shard
     self.shard = shard
     self.model = model
     self.model = model

+ 8 - 3
exo/inference/tinygrad/models/llama.py

@@ -214,10 +214,8 @@ class Transformer:
       h = self.tok_embeddings(h)
       h = self.tok_embeddings(h)
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
-    for i, layer in enumerate(self.layers):
+    for layer in self.layers:
       h = layer(h, start_pos, freqs_cis, mask)
       h = layer(h, start_pos, freqs_cis, mask)
-      # if i == 0 or i == len(self.layers) - 1:
-      #   print(f"layer {i}: {str(h.numpy())[:60]}")
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
       logits = self.output(self.norm(h)).float()[:, -1, :]
       logits = self.output(self.norm(h)).float()[:, -1, :]
@@ -257,10 +255,17 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     "model.embed_tokens.weight": "tok_embeddings.weight",
     "model.embed_tokens.weight": "tok_embeddings.weight",
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.biases": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.scales": f"layers.{l}.attention.w{x}.scale" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
+    **{f"model.layers.{l}.post_attention_layernorm.biases": f"layers.{l}.ffn_norm.bias" for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.biases": f"layers.{l}.feed_forward.w{y}.bias" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.scales": f"layers.{l}.feed_forward.w{y}.scale" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
     "lm_head.weight": "output.weight",
+    "lm_head.biases": "output.bias",
+    "lm_head.scales": "output.scale",
   }
   }
   sd = {}
   sd = {}
   for k, v in weights.items():
   for k, v in weights.items():