|
@@ -37,7 +37,7 @@ class StandardNode(Node):
|
|
|
await self.discovery.stop()
|
|
|
await self.server.stop()
|
|
|
|
|
|
- async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
if request_id not in self.buffered_token_output:
|
|
@@ -49,7 +49,7 @@ class StandardNode(Node):
|
|
|
await self.forward_to_next_shard(shard, prompt, request_id)
|
|
|
return
|
|
|
|
|
|
- result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
|
|
+ result, inference_state, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt, inference_state=inference_state)
|
|
|
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
if is_finished:
|
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
@@ -61,11 +61,11 @@ class StandardNode(Node):
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
|
|
|
|
|
if not is_finished:
|
|
|
- asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
|
|
|
+ asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
|
|
|
|
|
|
return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
|
|
|
|
|
|
- async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
if request_id not in self.buffered_token_output:
|
|
@@ -73,7 +73,7 @@ class StandardNode(Node):
|
|
|
|
|
|
try:
|
|
|
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
|
- result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
|
|
+ result, inference_state, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor, inference_state=inference_state)
|
|
|
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
if is_finished:
|
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
@@ -84,7 +84,7 @@ class StandardNode(Node):
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
|
|
|
|
|
if not is_finished:
|
|
|
- asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
|
|
|
+ asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
|
|
|
|
|
|
return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
except Exception as e:
|
|
@@ -93,7 +93,7 @@ class StandardNode(Node):
|
|
|
traceback.print_exc()
|
|
|
return None
|
|
|
|
|
|
- async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str) -> None:
|
|
|
+ async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
|
|
|
if not self.partitioning_strategy:
|
|
|
if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
|
return
|
|
@@ -109,9 +109,9 @@ class StandardNode(Node):
|
|
|
if next_partition:
|
|
|
if next_partition.node_id == self.id:
|
|
|
if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
- await self.process_tensor(shard, tensor_or_prompt, request_id)
|
|
|
+ await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
else:
|
|
|
- await self.process_prompt(shard, tensor_or_prompt, request_id)
|
|
|
+ await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
return
|
|
|
|
|
|
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|