Browse Source

Corrected type annotations

Nel Nibcord 8 months ago
parent
commit
c06b5f3b56

+ 1 - 5
exo/inference/dummy_inference_engine.py

@@ -28,11 +28,7 @@ class DummyInferenceEngine(InferenceEngine):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None):
-    output_data = await self.infer_tensor(request_id, shard, await self.encode(shard, prompt), inference_state)
-    return output_data 
-
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
     output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))

+ 5 - 5
exo/inference/mlx/sharded_inference_engine.py

@@ -37,23 +37,23 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def sample(self, x):
+  async def sample(self, x) -> np.ndarray:
     y = mx.array(x)
     logits = y[:, -1, :]
     out = np.array(sample_logits(logits))
     return out
 
-  async def encode(self, shard: Shard, prompt: str):
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return tokens
+    return np.array(tokens)
 
-  async def decode(self, shard: Shard, tokens):
+  async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
     
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
     return output_data

+ 4 - 4
exo/inference/tinygrad/inference.py

@@ -65,24 +65,24 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def sample(self, x: np.ndarray):
+  async def sample(self, x: np.ndarray) -> np.ndarray:
     logits = x[:, -1, :]
     def sample_wrapper():
       return sample_logits(Tensor(x).flatten(), TEMPERATURE, 0, 0.8, 0.0, 0.0).realize()
     out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
     return out.numpy()
 
-  async def encode(self, shard: Shard, prompt: str):
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
     return np.array(tokens)
   
-  async def decode(self, shard: Shard, tokens):
+  async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos)