|
@@ -13,6 +13,7 @@ from exo.download.shard_download import ShardDownloader
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from .stateful_model import StatefulModel, make_prompt_state
|
|
|
from .losses import length_masked_ce_loss
|
|
|
+from collections import OrderedDict
|
|
|
import asyncio
|
|
|
|
|
|
Tensor.no_grad = False
|
|
@@ -66,14 +67,14 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.states = OrderedDict()
|
|
|
self.session = {}
|
|
|
|
|
|
- async def poll_state(self, request_id: str, max_caches=2):
|
|
|
+ def poll_state(self, x, request_id: str, max_states=2):
|
|
|
if request_id not in self.states:
|
|
|
- if len(self.states) >= self.max_states:
|
|
|
+ if len(self.states) >= max_states:
|
|
|
self.states.popitem(last=False)
|
|
|
- make_prompt_state(
|
|
|
+ self.states[request_id] = make_prompt_state(x, self.model, self.shard)
|
|
|
else:
|
|
|
self.states.move_to_end(request_id)
|
|
|
- state = self.state[request_id]
|
|
|
+ state = self.states[request_id]
|
|
|
return {"start_pos": state.start, "cache": state.cache}
|
|
|
|
|
|
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
|
@@ -94,13 +95,13 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
- state = self.poll_state(request_id)
|
|
|
- def wrap_infer(data):
|
|
|
- x = Tensor(data)
|
|
|
- state = poll_state(request_id, x)
|
|
|
- self.model(x, **state)
|
|
|
- self.state[request_id].start += x.shape[1]
|
|
|
- output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: wrap_infer(input_data).realize())
|
|
|
+ def wrap_infer():
|
|
|
+ x = Tensor(input_data)
|
|
|
+ state = self.poll_state(x, request_id)
|
|
|
+ out = self.model(x, **state)
|
|
|
+ self.states[request_id].start += x.shape[1]
|
|
|
+ return out.realize()
|
|
|
+ output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|
|
|
return output_data.numpy()
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|