|
@@ -55,7 +55,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.shard_downloader = shard_downloader
|
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
|
|
- async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
|
|
+ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
|
|
|
await self.ensure_shard(shard)
|
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
@@ -72,7 +72,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
n_captured_toks = len(toks)
|
|
|
return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
|
|
|
|
|
|
- async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
|
|
+ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
|
|
|
await self.ensure_shard(shard)
|
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|