浏览代码

Merge pull request #203 from exo-explore/non_blocking

Non blocking
Alex Cheema 8 月之前
父节点
当前提交
87e08f89f1
共有 3 个文件被更改,包括 57 次插入14 次删除
  1. 16 8
      exo/inference/mlx/sharded_inference_engine.py
  2. 15 6
      exo/inference/tinygrad/inference.py
  3. 26 0
      test/test_hf.py

+ 16 - 8
exo/inference/mlx/sharded_inference_engine.py

@@ -6,28 +6,32 @@ from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from typing import Optional
 from exo.download.shard_download import ShardDownloader
-
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
+    self.executor = ThreadPoolExecutor(max_workers=1)
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
+    loop = asyncio.get_running_loop()
     if image_str:
       image = await get_image_from_str(image_str)
-      inputs = self.tokenizer(prompt, image, return_tensors="np")
+      inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values))
+      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
     else:
-      output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
+      input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
+      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def ensure_shard(self, shard: Shard):
@@ -35,6 +39,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       return
 
     model_path = await self.shard_downloader.ensure_shard(shard)
-    model_shard, self.tokenizer = await load_shard(model_path, shard)
-    self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
-    self.shard = shard
+
+    if self.shard != shard:
+      loop = asyncio.get_running_loop()
+      def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
+      model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
+      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
+      self.shard = shard

+ 15 - 6
exo/inference/tinygrad/inference.py

@@ -12,6 +12,10 @@ from typing import Optional, Tuple
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
+from concurrent.futures import ThreadPoolExecutor
+import asyncio
+import threading
+from functools import partial
 
 Tensor.no_grad = True
 # default settings
@@ -52,14 +56,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
+    self.executor = ThreadPoolExecutor(max_workers=1)
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    toks = self.tokenizer.encode(prompt)
-    h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
+    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += len(toks)
@@ -75,7 +80,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += n_captured_toks
@@ -90,6 +95,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       return
 
     model_path = await self.shard_downloader.ensure_shard(shard)
-    self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
-    self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent)))
-    self.shard = shard
+
+    if self.shard != shard:
+      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
+
+      tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
+      self.tokenizer = await resolve_tokenizer(tokenizer_path)
+      self.shard = shard

+ 26 - 0
test/test_hf.py

@@ -0,0 +1,26 @@
+import os
+import sys
+
+# Add the project root to the Python path
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, project_root)
+
+import asyncio
+from exo.download.hf.hf_helpers import get_weight_map
+
+async def test_get_weight_map():
+  repo_ids = [
+    "mlx-community/quantized-gemma-2b",
+    "mlx-community/Meta-Llama-3.1-8B-4bit",
+    "mlx-community/Meta-Llama-3.1-70B-4bit",
+    "mlx-community/Meta-Llama-3.1-405B-4bit",
+  ]
+  for repo_id in repo_ids:
+    weight_map = await get_weight_map(repo_id)
+    assert weight_map is not None, "Weight map should not be None"
+    assert isinstance(weight_map, dict), "Weight map should be a dictionary"
+    assert len(weight_map) > 0, "Weight map should not be empty"
+    print(f"OK: {repo_id}")
+
+if __name__ == "__main__":
+  asyncio.run(test_get_weight_map())