|
@@ -47,7 +47,7 @@ class Node:
|
|
|
self.max_generate_tokens = max_generate_tokens
|
|
|
self.topology_viz = topology_viz
|
|
|
self.default_sample_temperature = default_sample_temperature
|
|
|
- self._on_token = AsyncCallbackSystem[str, Tuple[str, int, bool]]()
|
|
|
+ self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
|
|
|
self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
|
|
|
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
|
|
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
|
@@ -130,9 +130,8 @@ class Node:
|
|
|
self.buffered_token_output[request_id][0].append(token.item())
|
|
|
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
- asyncio.create_task(self.broadcast_result(request_id, [self.buffered_token_output[request_id][0][-1]], is_finished))
|
|
|
forward = token.reshape(1, -1)
|
|
|
- intermediate_result = self.buffered_token_output[request_id][0][-1]
|
|
|
+ intermediate_result = [self.buffered_token_output[request_id][0][-1]]
|
|
|
else:
|
|
|
forward = result
|
|
|
else:
|
|
@@ -575,16 +574,16 @@ class Node:
|
|
|
return self.topology
|
|
|
|
|
|
@property
|
|
|
- def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, int, bool]]:
|
|
|
+ def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
|
|
|
return self._on_token
|
|
|
|
|
|
@property
|
|
|
def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
|
|
|
return self._on_opaque_status
|
|
|
|
|
|
- def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
|
|
|
- if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}")
|
|
|
- self.on_token.trigger_all(request_id, token, is_finished)
|
|
|
+ def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
|
|
|
+ if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
|
|
|
+ self.on_token.trigger_all(request_id, tokens, is_finished)
|
|
|
|
|
|
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
|
|
async def send_result_to_peer(peer):
|