浏览代码

async model downloading fixes #30

Alex Cheema 1 年之前
父节点
当前提交
a4cc667754

+ 1 - 27
exo/inference/mlx/sharded_inference_engine.py

@@ -6,32 +6,6 @@ from .sharded_utils import load_shard
 from ..shard import Shard
 from ..shard import Shard
 from typing import Optional
 from typing import Optional
 
 
-class MLXFixedShardInferenceEngine(InferenceEngine):
-    def __init__(self, model_path: str, shard: Shard):
-        self.shard = shard
-        model_shard, self.tokenizer = load_shard(model_path, shard)
-        self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
-
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        if shard != self.shard:
-            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
-
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
-
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
-        if shard != self.shard:
-            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
-
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
-
-    async def reset_shard(self, shard: Shard):
-        if shard != self.shard:
-            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
-
-        self.stateful_sharded_model.reset()
-
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
     def __init__(self):
         self.shard = None
         self.shard = None
@@ -54,6 +28,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
         if self.shard == shard:
         if self.shard == shard:
             return
             return
 
 
-        model_shard, self.tokenizer = load_shard(shard.model_id, shard)
+        model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
         self.shard = shard
         self.shard = shard

+ 10 - 4
exo/inference/mlx/sharded_utils.py

@@ -4,6 +4,8 @@ import glob
 import importlib
 import importlib
 import json
 import json
 import logging
 import logging
+import asyncio
+from functools import partial
 from pathlib import Path
 from pathlib import Path
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
@@ -151,7 +153,11 @@ def load_model_shard(
     model.eval()
     model.eval()
     return model
     return model
 
 
-def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
+async def snapshot_download_async(*args, **kwargs):
+    func = partial(snapshot_download, *args, **kwargs)
+    return await asyncio.get_event_loop().run_in_executor(None, func)
+
+async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
     """
     """
     Ensures the model is available locally. If the path does not exist locally,
     Ensures the model is available locally. If the path does not exist locally,
     it is downloaded from the Hugging Face Hub.
     it is downloaded from the Hugging Face Hub.
@@ -167,7 +173,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
     if not model_path.exists():
     if not model_path.exists():
         try:
         try:
             model_path = Path(
             model_path = Path(
-                snapshot_download(
+                await snapshot_download_async(
                     repo_id=path_or_hf_repo,
                     repo_id=path_or_hf_repo,
                     revision=revision,
                     revision=revision,
                     allow_patterns=[
                     allow_patterns=[
@@ -191,7 +197,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
     return model_path
     return model_path
 
 
 
 
-def load_shard(
+async def load_shard(
     path_or_hf_repo: str,
     path_or_hf_repo: str,
     shard: Shard,
     shard: Shard,
     tokenizer_config={},
     tokenizer_config={},
@@ -220,7 +226,7 @@ def load_shard(
         FileNotFoundError: If config file or safetensors are not found.
         FileNotFoundError: If config file or safetensors are not found.
         ValueError: If model class or args class are not found.
         ValueError: If model class or args class are not found.
     """
     """
-    model_path = get_model_path(path_or_hf_repo)
+    model_path = await get_model_path(path_or_hf_repo)
 
 
     model = load_model_shard(model_path, shard, lazy, model_config)
     model = load_model_shard(model_path, shard, lazy, model_config)
     if adapter_path is not None:
     if adapter_path is not None:

+ 10 - 10
exo/inference/test_inference_engine.py

@@ -24,15 +24,15 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     assert np.array_equal(resp_full, resp2)
     assert np.array_equal(resp_full, resp2)
     assert np.array_equal(next_resp_full, resp4)
     assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    MLXDynamicShardInferenceEngine(),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-))
-
-# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549
 # asyncio.run(test_inference_engine(
 # asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "llama3-8b-sfr",
+#     MLXDynamicShardInferenceEngine(),
+#     MLXDynamicShardInferenceEngine(),
+#     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
 # ))
 # ))
+
+# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549
+asyncio.run(test_inference_engine(
+    TinygradDynamicShardInferenceEngine(),
+    TinygradDynamicShardInferenceEngine(),
+    "llama3-8b-sfr",
+))

+ 16 - 10
exo/inference/tinygrad/inference.py

@@ -1,16 +1,18 @@
-
+import asyncio
+from functools import partial
 from pathlib import Path
 from pathlib import Path
-from typing import List, Optional
+from typing import List, Optional, Union
 import json, argparse, random, time
 import json, argparse, random, time
 import tiktoken
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
 from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
 from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
-from tinygrad.helpers import DEBUG, tqdm, _cache_dir
+from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
 import numpy as np
+import os
 
 
 MODEL_PARAMS = {
 MODEL_PARAMS = {
   "8B": {
   "8B": {
@@ -58,6 +60,11 @@ class Tokenizer:
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
 
 
 # **** helper functions ****
 # **** helper functions ****
+async def fetch_async(url: str, name: Optional[Union[Path, str]] = None, subdir: Optional[str] = None,
+                      allow_caching=not os.getenv("DISABLE_HTTP_CACHE")) -> Path:
+    func = partial(fetch, url, name, subdir, allow_caching)
+    return await asyncio.get_event_loop().run_in_executor(None, func)
+
 def concat_weights(models, device=None):
 def concat_weights(models, device=None):
   def convert(name) -> Tensor:
   def convert(name) -> Tensor:
     disk_tensors: List[Tensor] = [model[name] for model in models]
     disk_tensors: List[Tensor] = [model[name] for model in models]
@@ -176,16 +183,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         if Path(model_path / "model.safetensors.index.json").exists():
         if Path(model_path / "model.safetensors.index.json").exists():
             model = model_path
             model = model_path
         else:
         else:
-            from tinygrad.helpers import fetch
 
 
             if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
             if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
             if shard.model_id.lower().find("llama3-8b-sfr") != -1:
             if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-                fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-                model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
+                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
+                model = await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
                 size = "8B"
                 size = "8B"
             elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
             elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
                 raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
                 raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")