瀏覽代碼

handle inference_state properly

Alex Cheema 3 月之前
父節點
當前提交
2aed3f3518
共有 2 個文件被更改,包括 2 次插入2 次删除
  1. 1 1
      exo/inference/inference_engine.py
  2. 1 1
      exo/inference/mlx/sharded_inference_engine.py

+ 1 - 1
exo/inference/inference_engine.py

@@ -23,7 +23,7 @@ class InferenceEngine(ABC):
     pass
 
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
     pass
 
   @abstractmethod

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -77,7 +77,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
     
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> np.ndarray:
     await self.ensure_shard(shard)
     loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}