|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
-from .stateful_model import StatefulModel
|
|
|
+from .stateful_model import StatefulModel, make_prompt_state
|
|
|
from .losses import length_masked_ce_loss
|
|
|
import asyncio
|
|
|
|
|
@@ -63,8 +63,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.shard = None
|
|
|
self.shard_downloader = shard_downloader
|
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
+ self.states = OrderedDict()
|
|
|
self.session = {}
|
|
|
|
|
|
+ async def poll_state(self, request_id: str, max_caches=2):
|
|
|
+ if len(self.states) >= self.max_states:
|
|
|
+ self.states.popitem(last=False)
|
|
|
+ state = self.state[request_id]
|
|
|
+ return {"start_pos": state.start, "cache": state.cache}
|
|
|
+
|
|
|
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
|
|
logits = x[:, -1, :]
|
|
|
def sample_wrapper():
|
|
@@ -83,7 +90,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
- output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
|
|
|
+ state = self.poll_state(request_id)
|
|
|
+ output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), **state).realize())
|
|
|
+ self.state[request_id].start += input_data.shape[1]
|
|
|
return output_data.numpy()
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|