소스 검색

async model downloading fixes #30

Alex Cheema 1 년 전
부모
커밋
a4cc667754
4개의 변경된 파일37개의 추가작업 그리고 51개의 파일을 삭제
  1. 1 27
      exo/inference/mlx/sharded_inference_engine.py
  2. 10 4
      exo/inference/mlx/sharded_utils.py
  3. 10 10
      exo/inference/test_inference_engine.py
  4. 16 10
      exo/inference/tinygrad/inference.py

+ 1 - 27
exo/inference/mlx/sharded_inference_engine.py

@@ -6,32 +6,6 @@ from .sharded_utils import load_shard
 from ..shard import Shard
 from typing import Optional
 
-class MLXFixedShardInferenceEngine(InferenceEngine):
-    def __init__(self, model_path: str, shard: Shard):
-        self.shard = shard
-        model_shard, self.tokenizer = load_shard(model_path, shard)
-        self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
-
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        if shard != self.shard:
-            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
-
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
-
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
-        if shard != self.shard:
-            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
-
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
-
-    async def reset_shard(self, shard: Shard):
-        if shard != self.shard:
-            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
-
-        self.stateful_sharded_model.reset()
-
 class MLXDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
         self.shard = None
@@ -54,6 +28,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
         if self.shard == shard:
             return
 
-        model_shard, self.tokenizer = load_shard(shard.model_id, shard)
+        model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
         self.shard = shard

+ 10 - 4
exo/inference/mlx/sharded_utils.py

@@ -4,6 +4,8 @@ import glob
 import importlib
 import json
 import logging
+import asyncio
+from functools import partial
 from pathlib import Path
 from typing import Optional, Tuple
 
@@ -151,7 +153,11 @@ def load_model_shard(
     model.eval()
     return model
 
-def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
+async def snapshot_download_async(*args, **kwargs):
+    func = partial(snapshot_download, *args, **kwargs)
+    return await asyncio.get_event_loop().run_in_executor(None, func)
+
+async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
     """
     Ensures the model is available locally. If the path does not exist locally,
     it is downloaded from the Hugging Face Hub.
@@ -167,7 +173,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
     if not model_path.exists():
         try:
             model_path = Path(
-                snapshot_download(
+                await snapshot_download_async(
                     repo_id=path_or_hf_repo,
                     revision=revision,
                     allow_patterns=[
@@ -191,7 +197,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
     return model_path
 
 
-def load_shard(
+async def load_shard(
     path_or_hf_repo: str,
     shard: Shard,
     tokenizer_config={},
@@ -220,7 +226,7 @@ def load_shard(
         FileNotFoundError: If config file or safetensors are not found.
         ValueError: If model class or args class are not found.
     """
-    model_path = get_model_path(path_or_hf_repo)
+    model_path = await get_model_path(path_or_hf_repo)
 
     model = load_model_shard(model_path, shard, lazy, model_config)
     if adapter_path is not None:

+ 10 - 10
exo/inference/test_inference_engine.py

@@ -24,15 +24,15 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     assert np.array_equal(resp_full, resp2)
     assert np.array_equal(next_resp_full, resp4)
 
-asyncio.run(test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    MLXDynamicShardInferenceEngine(),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-))
-
-# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549
 # asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "llama3-8b-sfr",
+#     MLXDynamicShardInferenceEngine(),
+#     MLXDynamicShardInferenceEngine(),
+#     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
 # ))
+
+# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549
+asyncio.run(test_inference_engine(
+    TinygradDynamicShardInferenceEngine(),
+    TinygradDynamicShardInferenceEngine(),
+    "llama3-8b-sfr",
+))

+ 16 - 10
exo/inference/tinygrad/inference.py

@@ -1,16 +1,18 @@
-
+import asyncio
+from functools import partial
 from pathlib import Path
-from typing import List, Optional
+from typing import List, Optional, Union
 import json, argparse, random, time
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
 from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
-from tinygrad.helpers import DEBUG, tqdm, _cache_dir
+from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
+import os
 
 MODEL_PARAMS = {
   "8B": {
@@ -58,6 +60,11 @@ class Tokenizer:
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
 
 # **** helper functions ****
+async def fetch_async(url: str, name: Optional[Union[Path, str]] = None, subdir: Optional[str] = None,
+                      allow_caching=not os.getenv("DISABLE_HTTP_CACHE")) -> Path:
+    func = partial(fetch, url, name, subdir, allow_caching)
+    return await asyncio.get_event_loop().run_in_executor(None, func)
+
 def concat_weights(models, device=None):
   def convert(name) -> Tensor:
     disk_tensors: List[Tensor] = [model[name] for model in models]
@@ -176,16 +183,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         if Path(model_path / "model.safetensors.index.json").exists():
             model = model_path
         else:
-            from tinygrad.helpers import fetch
 
             if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
             if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-                fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-                model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
+                model = await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
                 size = "8B"
             elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
                 raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")