Pārlūkot izejas kodu

fix inference engine

Pranav Veldurthi 7 mēneši atpakaļ
vecāks
revīzija
b13e368368

+ 1 - 2
exo/inference/inference_engine.py

@@ -43,10 +43,9 @@ class InferenceEngine(ABC):
     tokens = await self.encode(shard, prompt)
     if shard.model_id != 'stable-diffusion-2-1-base':
       x = tokens.reshape(1, -1)
-      output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
     else:
       x = tokens
-      output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
+    output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
 
     return output_data, inference_state
 

+ 3 - 3
exo/inference/tinygrad/inference.py

@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
 from .losses import length_masked_ce_loss
 from collections import OrderedDict
 import asyncio
-
+from typing import Optional
 Tensor.no_grad = True 
 # default settings
 TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
     safe_save(state_dict, path) 
   
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     def wrap_infer():
       x = Tensor(input_data)
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       self.states[request_id].start += x.shape[1]
       return out.realize()
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
-    return output_data.numpy()
+    return output_data.numpy(), inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
     def step(x, y, l):