|
@@ -39,8 +39,8 @@ class StandardNode(Node):
|
|
|
self.topology: Topology = Topology()
|
|
|
self.device_capabilities = device_capabilities()
|
|
|
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
|
|
|
- self.buffered_inputs: Dict[str, Tuple[List[np.ndarray], bool]] = {}
|
|
|
- self.buffered_logits: Dict[str, Tuple[List[np.ndarray], bool]] = {}
|
|
|
+ self.buffered_logits: Dict[str, List[np.ndarray]] = {}
|
|
|
+ self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
|
|
|
self.max_generate_tokens = max_generate_tokens
|
|
|
self.topology_viz = topology_viz
|
|
|
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
|
|
@@ -117,13 +117,13 @@ class StandardNode(Node):
|
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
|
|
|
|
if request_id not in self.buffered_logits:
|
|
|
- self.buffered_logits[request_id] = ([], False)
|
|
|
+ self.buffered_logits[request_id] = []
|
|
|
|
|
|
- self.buffered_logits[request_id][0] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
|
|
|
+ self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
|
|
|
|
|
|
if shard.is_last_layer():
|
|
|
result = await self.inference_engine.sample(result)
|
|
|
- inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id][0]) + 1})
|
|
|
+ inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id]) + 1})
|
|
|
|
|
|
await self.inference_engine.ensure_shard(shard)
|
|
|
is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
@@ -138,7 +138,6 @@ class StandardNode(Node):
|
|
|
|
|
|
if is_finished:
|
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
|
- self.buffered_logits[request_id] = (self.buffered_logits[request_id][0], True)
|
|
|
else:
|
|
|
asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
|
|
|
|