瀏覽代碼

make sure mlx stuff is on separate thread non blocking

Alex Cheema 3 月之前
父節點
當前提交
4a5b80a958
共有 1 個文件被更改,包括 5 次插入1 次删除
  1. 5 1
      exo/inference/mlx/sharded_inference_engine.py

+ 5 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -93,7 +93,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       )
       output_data, inference_state = result
 
-    output_data = np.array(output_data, copy=False)
+    await self._eval_mlx(output_data)
+    output_data = await asyncio.get_running_loop().run_in_executor(
+      self._mlx_thread,
+      lambda: np.array(output_data, copy=False)
+    )
     return output_data, inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):