|
@@ -93,7 +93,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
)
|
|
|
output_data, inference_state = result
|
|
|
|
|
|
- output_data = np.array(output_data, copy=False)
|
|
|
+ await self._eval_mlx(output_data)
|
|
|
+ output_data = await asyncio.get_running_loop().run_in_executor(
|
|
|
+ self._mlx_thread,
|
|
|
+ lambda: np.array(output_data, copy=False)
|
|
|
+ )
|
|
|
return output_data, inference_state
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|