|
@@ -115,10 +115,11 @@ class StandardNode(Node):
|
|
|
token = await self.inference_engine.sample(result)
|
|
|
await self.inference_engine.ensure_shard(shard)
|
|
|
self.buffered_token_output[request_id][0].append(token.item())
|
|
|
- self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
|
|
|
forward = token.reshape(1, -1)
|
|
|
+ self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
+ asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
|
|
|
else:
|
|
|
forward = result
|
|
|
|