|
@@ -67,8 +67,12 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.session = {}
|
|
|
|
|
|
async def poll_state(self, request_id: str, max_caches=2):
|
|
|
- if len(self.states) >= self.max_states:
|
|
|
- self.states.popitem(last=False)
|
|
|
+ if request_id not in self.states:
|
|
|
+ if len(self.states) >= self.max_states:
|
|
|
+ self.states.popitem(last=False)
|
|
|
+ make_prompt_state(
|
|
|
+ else:
|
|
|
+ self.states.move_to_end(request_id)
|
|
|
state = self.state[request_id]
|
|
|
return {"start_pos": state.start, "cache": state.cache}
|
|
|
|
|
@@ -91,8 +95,12 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
state = self.poll_state(request_id)
|
|
|
- output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), **state).realize())
|
|
|
- self.state[request_id].start += input_data.shape[1]
|
|
|
+ def wrap_infer(data):
|
|
|
+ x = Tensor(data)
|
|
|
+ state = poll_state(request_id, x)
|
|
|
+ self.model(x, **state)
|
|
|
+ self.state[request_id].start += x.shape[1]
|
|
|
+ output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: wrap_infer(input_data).realize())
|
|
|
return output_data.numpy()
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|