Browse Source

use a queue for non-blocking mlx inference

Alex Cheema 8 months ago
parent
commit
9db16f8dca
1 changed files with 31 additions and 9 deletions
  1. 31 9
      exo/inference/mlx/sharded_inference_engine.py

+ 31 - 9
exo/inference/mlx/sharded_inference_engine.py

@@ -16,32 +16,54 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
     self.model_lock = threading.Lock()
     self.model_lock = threading.Lock()
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.executor = ThreadPoolExecutor(max_workers=1)
+    self.inference_queue = asyncio.Queue()
+    self._worker_task = None
+
+  async def _ensure_worker(self):
+    if self._worker_task is None:
+      self._worker_task = asyncio.create_task(self._inference_worker())
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
+    await self._ensure_worker()
+
     if image_str:
     if image_str:
       image = await get_image_from_str(image_str)
       image = await get_image_from_str(image_str)
       inputs = self.tokenizer(prompt, image, return_tensors="np")
       inputs = self.tokenizer(prompt, image, return_tensors="np")
       pixel_values = mx.array(inputs["pixel_values"])
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
       input_ids = mx.array(inputs["input_ids"])
-      output_data = await self._run_inference(request_id, input_ids, pixel_values)
+      output_data = await self._queue_inference(request_id, input_ids, pixel_values)
     else:
     else:
       input_ids = mx.array(self.tokenizer.encode(prompt))
       input_ids = mx.array(self.tokenizer.encode(prompt))
-      output_data = await self._run_inference(request_id, input_ids)
+      output_data = await self._queue_inference(request_id, input_ids)
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
+    await self._ensure_worker()
+
     input_tensor = mx.array(input_data)
     input_tensor = mx.array(input_data)
-    output_data = await self._run_inference(request_id, input_tensor)
+    output_data = await self._queue_inference(request_id, input_tensor)
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
-  async def _run_inference(self, request_id: str, *args):
-    with self.model_lock:
-      return await asyncio.get_event_loop().run_in_executor(
-        self.executor,
-        lambda: np.array(self.stateful_sharded_model.step(request_id, *args))
-      )
+  async def _queue_inference(self, request_id: str, *args):
+    future = asyncio.get_running_loop().create_future()
+    await self.inference_queue.put((future, request_id, args))
+    return await future
+
+  async def _inference_worker(self):
+    while True:
+      future, request_id, args = await self.inference_queue.get()
+      try:
+        result = await asyncio.get_running_loop().run_in_executor(
+          self.executor,
+          lambda: np.array(self.stateful_sharded_model.step(request_id, *args))
+        )
+        future.set_result(result)
+      except Exception as e:
+        future.set_exception(e)
+      finally:
+        self.inference_queue.task_done()
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard: