|
@@ -77,7 +77,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
await self.ensure_shard(shard)
|
|
|
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
|
|
|
|
|
|
- async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
|
|
|
+ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
loop = asyncio.get_running_loop()
|
|
|
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|