瀏覽代碼

fix end of request behaviour and add back broadcasting tokens to other nodes

Alex Cheema 5 月之前
父節點
當前提交
72c3fdab46
共有 1 個文件被更改,包括 2 次插入1 次删除
  1. 2 1
      exo/orchestration/standard_node.py

+ 2 - 1
exo/orchestration/standard_node.py

@@ -115,10 +115,11 @@ class StandardNode(Node):
       token = await self.inference_engine.sample(result)
       token = await self.inference_engine.sample(result)
       await self.inference_engine.ensure_shard(shard)
       await self.inference_engine.ensure_shard(shard)
       self.buffered_token_output[request_id][0].append(token.item())
       self.buffered_token_output[request_id][0].append(token.item())
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
       is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
       is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
       forward = token.reshape(1, -1)
       forward = token.reshape(1, -1)
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
     else:
     else:
       forward = result
       forward = result