Просмотр исходного кода

fix regression introduced by image_str for tinygrad

Alex Cheema 1 год назад
Родитель
Сommit
76766253cd
2 измененных файлов с 3 добавлено и 3 удалено
  1. 2 2
      exo/inference/inference_engine.py
  2. 1 1
      exo/inference/tinygrad/inference.py

+ 2 - 2
exo/inference/inference_engine.py

@@ -7,11 +7,11 @@ from .shard import Shard
 
 class InferenceEngine(ABC):
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     pass
 
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
 
   @abstractmethod

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -198,7 +198,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0