Просмотр исходного кода

closely match prev impl mlx non blocking

Alex Cheema 8 месяцев назад
Родитель
Сommit
9345684b38
2 измененных файлов с 7 добавлено и 29 удалено
  1. 6 26
      exo/inference/mlx/sharded_inference_engine.py
  2. 1 3
      exo/inference/tinygrad/inference.py

+ 6 - 26
exo/inference/mlx/sharded_inference_engine.py

@@ -23,21 +23,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
-      output_data = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
+      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
     else:
       input_ids = await loop.run_in_executor(self.executor, lambda: mx.array(self.tokenizer.encode(prompt)))
-      output_data = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
+      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    input_tensor = mx.array(input_data)
-    output_data = np.array(await asyncio.get_running_loop().run_in_executor(
-      self.executor,
-      self.stateful_sharded_model.step,
-      request_id,
-      input_tensor
-    ))
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def ensure_shard(self, shard: Shard):
@@ -48,21 +42,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()
-
-      # Run load_shard in a separate thread
-      def load_shard_wrapper():
-        return asyncio.run(load_shard(model_path, shard))
-
-      model_shard, self.tokenizer = await loop.run_in_executor(
-        self.executor,
-        load_shard_wrapper
-      )
-
-      # Create StatefulShardedModel in the executor
-      self.stateful_sharded_model = await loop.run_in_executor(
-        self.executor,
-        StatefulShardedModel,
-        shard,
-        model_shard
-      )
+      def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
+      model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
+      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard

+ 1 - 3
exo/inference/tinygrad/inference.py

@@ -64,9 +64,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
     toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    input_tensor = Tensor([toks])
-
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor([toks]), start_pos, TEMPERATURE)
 
     if h.shape == (1,):
       start_pos += len(toks)