Browse Source

Since infer_prompt is a thin wrapper that works the same for all inference engines, we can de-abstract it

Nel Nibcord 8 months ago
parent
commit
b9d0fb6825

+ 6 - 4
exo/inference/inference_engine.py

@@ -20,13 +20,15 @@ class InferenceEngine(ABC):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     pass
 
-  @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
-    pass
-
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     pass
+  
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
+    tokens = await self.encode(shard, prompt)
+    output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
+    return output_data 
+
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:

+ 0 - 5
exo/inference/mlx/sharded_inference_engine.py

@@ -53,11 +53,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
     
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
-    tokens = await self.encode(shard, prompt)
-    output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
-    return output_data 
-
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
     await self.ensure_shard(shard)
     output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))

+ 0 - 5
exo/inference/tinygrad/inference.py

@@ -82,11 +82,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
-    tokens = await self.encode(shard, prompt)
-    output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
-    return output_data 
-
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)