Browse Source

match previous impl with np.array in mlx

Alex Cheema 9 months ago
parent
commit
d6e661fd69
1 changed files with 6 additions and 6 deletions
  1. 6 6
      exo/inference/mlx/sharded_inference_engine.py

+ 6 - 6
exo/inference/mlx/sharded_inference_engine.py

@@ -23,22 +23,22 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
-      output_data = await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values)
+      output_data = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
     else:
       input_ids = await loop.run_in_executor(self.executor, lambda: mx.array(self.tokenizer.encode(prompt)))
-      output_data = await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids)
-    return np.array(output_data), "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+      output_data = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
+    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     input_tensor = mx.array(input_data)
-    output_data = await asyncio.get_running_loop().run_in_executor(
+    output_data = np.array(await asyncio.get_running_loop().run_in_executor(
       self.executor,
       self.stateful_sharded_model.step,
       request_id,
       input_tensor
-    )
-    return np.array(output_data), "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    ))
+    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard: