|
@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
|
|
|
from .losses import length_masked_ce_loss
|
|
|
from collections import OrderedDict
|
|
|
import asyncio
|
|
|
-
|
|
|
+from typing import Optional
|
|
|
Tensor.no_grad = True
|
|
|
# default settings
|
|
|
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
|
|
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
|
|
|
safe_save(state_dict, path)
|
|
|
|
|
|
- async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
+ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
def wrap_infer():
|
|
|
x = Tensor(input_data)
|
|
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.states[request_id].start += x.shape[1]
|
|
|
return out.realize()
|
|
|
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|
|
|
- return output_data.numpy()
|
|
|
+ return output_data.numpy(), inference_state
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
|
|
def step(x, y, l):
|