|
@@ -77,15 +77,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
await self.ensure_shard(shard)
|
|
|
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
|
|
|
|
|
|
- async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
|
|
|
+ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
|
|
await self.ensure_shard(shard)
|
|
|
loop = asyncio.get_running_loop()
|
|
|
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
|
|
x = mx.array(input_data)
|
|
|
if self.model.model_type != 'StableDiffusionPipeline':
|
|
|
- output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
|
|
|
+ output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
|
|
|
else:
|
|
|
- output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
|
|
|
+ output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
|
|
|
output_data = np.array(output_data)
|
|
|
return output_data, inference_state
|
|
|
|