Răsfoiți Sursa

simplify non-blocking mlx inference

Alex Cheema 11 luni în urmă
părinte
comite
6881722b72
1 a modificat fișierele cu 34 adăugiri și 44 ștergeri
  1. 34 44
      exo/inference/mlx/sharded_inference_engine.py

+ 34 - 44
exo/inference/mlx/sharded_inference_engine.py

@@ -6,64 +6,39 @@ from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from typing import Optional
 from exo.download.shard_download import ShardDownloader
-import threading
-from concurrent.futures import ThreadPoolExecutor
 import asyncio
+from concurrent.futures import ThreadPoolExecutor
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
-    self.model_lock = threading.Lock()
     self.executor = ThreadPoolExecutor(max_workers=1)
-    self.inference_queue = asyncio.Queue()
-    self._worker_task = None
-
-  async def _ensure_worker(self):
-    if self._worker_task is None:
-      self._worker_task = asyncio.create_task(self._inference_worker())
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    await self._ensure_worker()
-
+    loop = asyncio.get_running_loop()
     if image_str:
       image = await get_image_from_str(image_str)
-      inputs = self.tokenizer(prompt, image, return_tensors="np")
+      inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
-      output_data = await self._queue_inference(request_id, input_ids, pixel_values)
+      output_data = await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values)
     else:
-      input_ids = mx.array(self.tokenizer.encode(prompt))
-      output_data = await self._queue_inference(request_id, input_ids)
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+      input_ids = await loop.run_in_executor(self.executor, lambda: mx.array(self.tokenizer.encode(prompt)))
+      output_data = await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids)
+    return np.array(output_data), "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    await self._ensure_worker()
-
     input_tensor = mx.array(input_data)
-    output_data = await self._queue_inference(request_id, input_tensor)
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
-
-  async def _queue_inference(self, request_id: str, *args):
-    future = asyncio.get_running_loop().create_future()
-    await self.inference_queue.put((future, request_id, args))
-    return await future
-
-  async def _inference_worker(self):
-    while True:
-      future, request_id, args = await self.inference_queue.get()
-      try:
-        result = await asyncio.get_running_loop().run_in_executor(
-          self.executor,
-          lambda: np.array(self.stateful_sharded_model.step(request_id, *args))
-        )
-        future.set_result(result)
-      except Exception as e:
-        future.set_exception(e)
-      finally:
-        self.inference_queue.task_done()
+    output_data = await asyncio.get_running_loop().run_in_executor(
+      self.executor,
+      self.stateful_sharded_model.step,
+      request_id,
+      input_tensor
+    )
+    return np.array(output_data), "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -71,8 +46,23 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
     model_path = await self.shard_downloader.ensure_shard(shard)
 
-    with self.model_lock:
-      if self.shard != shard:
-        model_shard, self.tokenizer = await load_shard(model_path, shard)
-        self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
-        self.shard = shard
+    if self.shard != shard:
+      loop = asyncio.get_running_loop()
+
+      # Run load_shard in a separate thread
+      def load_shard_wrapper():
+        return asyncio.run(load_shard(model_path, shard))
+
+      model_shard, self.tokenizer = await loop.run_in_executor(
+        self.executor,
+        load_shard_wrapper
+      )
+
+      # Create StatefulShardedModel in the executor
+      self.stateful_sharded_model = await loop.run_in_executor(
+        self.executor,
+        StatefulShardedModel,
+        shard,
+        model_shard
+      )
+      self.shard = shard