|
@@ -53,11 +53,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
|
|
return tokens
|
|
|
|
|
|
- async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
|
|
|
- tokens = await self.encode(shard, prompt)
|
|
|
- output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
|
|
|
- return output_data
|
|
|
-
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
|
|
|
await self.ensure_shard(shard)
|
|
|
output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
|