Alex Cheema 5 months ago
parent
commit
ce5041ee1b

+ 5 - 5
exo/inference/debug_inference_engine.py

@@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
 
-  next_resp_full = await inference_engine_1.infer_tensor(
+  next_resp_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     input_data=token_full,
   )
 
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
+  resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
   )
   token2 = await inference_engine_2.sample(resp2)
-  resp3 = await inference_engine_1.infer_tensor(
+  resp3, _ = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
     input_data=token2,
   )
-  resp4 = await inference_engine_2.infer_tensor(
+  resp4, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -25,7 +25,7 @@ class DummyInferenceEngine(InferenceEngine):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return self.tokenizer.decode(tokens)
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> np.ndarray:
     await self.ensure_shard(shard)
     return input_data + 1 if self.shard.is_last_layer() else input_data
 

+ 2 - 2
exo/inference/inference_engine.py

@@ -23,7 +23,7 @@ class InferenceEngine(ABC):
     pass
 
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     pass
 
   @abstractmethod
@@ -39,7 +39,7 @@ class InferenceEngine(ABC):
   async def clear_session(self):
     self.session.empty()
   
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     tokens = await self.encode(shard, prompt)
     if shard.model_id != 'stable-diffusion-2-1-base':
       x = tokens.reshape(1, -1)

+ 2 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -77,13 +77,14 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
     
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
     if self.model.model_type != 'StableDiffusionPipeline':
       output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
+      inference_state = {}
     else:
       output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
     output_data = np.array(output_data)

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
     safe_save(state_dict, path) 
   
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     def wrap_infer():
       x = Tensor(input_data)

+ 2 - 2
exo/orchestration/node.py

@@ -206,7 +206,7 @@ class Node:
       return None
     else:
       self.outstanding_requests[request_id] = "processing"
-      result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
+      result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
       ret = await self.process_inference_result(shard, result, request_id, inference_state)
       return result
 
@@ -336,7 +336,7 @@ class Node:
           loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
         else:
           self.outstanding_requests[request_id] = "preprocessing"
-          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
           self.outstanding_requests[request_id] = "waiting"
           loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
         self.outstanding_requests.pop(request_id)