|
@@ -57,12 +57,13 @@ class StandardNode(Node):
|
|
|
await self.server.stop()
|
|
|
|
|
|
async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
- asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
|
|
|
+ shard = self.get_current_shard(base_shard)
|
|
|
+ asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
|
|
|
start_time = time.perf_counter_ns()
|
|
|
resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
|
|
|
end_time = time.perf_counter_ns()
|
|
|
elapsed_time_ns = end_time - start_time
|
|
|
- asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
|
|
|
+ asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
|
|
|
return resp
|
|
|
|
|
|
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
@@ -95,13 +96,14 @@ class StandardNode(Node):
|
|
|
|
|
|
return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
|
|
|
- async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
- asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
|
|
|
+ async def process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ shard = self.get_current_shard(base_shard)
|
|
|
+ asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
|
|
|
start_time = time.perf_counter_ns()
|
|
|
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
|
|
|
end_time = time.perf_counter_ns()
|
|
|
elapsed_time_ns = end_time - start_time
|
|
|
- asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
|
|
|
+ asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
|
|
|
return resp
|
|
|
|
|
|
async def _process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
@@ -134,41 +136,39 @@ class StandardNode(Node):
|
|
|
traceback.print_exc()
|
|
|
return None
|
|
|
|
|
|
- async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
|
|
|
+ async def forward_to_next_shard(self, base_shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
|
|
|
if not self.partitioning_strategy:
|
|
|
if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
|
return
|
|
|
+ shard = self.get_current_shard(base_shard)
|
|
|
|
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
|
+ shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
|
|
|
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
|
|
if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
|
|
|
if current_partition_index is not None:
|
|
|
next_partition_index = (current_partition_index + 1) % len(partitions)
|
|
|
next_partition: Partition = partitions[next_partition_index]
|
|
|
+ next_shard = shards[next_partition_index]
|
|
|
if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
|
|
|
|
|
|
- if next_partition:
|
|
|
- if next_partition.node_id == self.id:
|
|
|
- if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
- await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
- else:
|
|
|
- await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
- return
|
|
|
-
|
|
|
- target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
|
- if not target_peer:
|
|
|
- raise ValueError(f"Peer for {next_partition} not found")
|
|
|
+ if next_partition.node_id == self.id:
|
|
|
+ if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
+ await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
+ else:
|
|
|
+ await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
+ return
|
|
|
|
|
|
- start_layer = int(next_partition.start * shard.n_layers)
|
|
|
- end_layer = int(next_partition.end * shard.n_layers) - 1
|
|
|
- next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
|
|
+ target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
|
+ if not target_peer:
|
|
|
+ raise ValueError(f"Peer for {next_partition} not found")
|
|
|
|
|
|
- if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
|
|
|
+ if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
|
|
|
|
|
|
- if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
- await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
|
|
|
- else:
|
|
|
- await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
|
|
|
+ if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
+ await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
|
|
|
+ else:
|
|
|
+ await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
|
|
|
|
|
|
def get_current_shard(self, base_shard: Shard) -> Shard:
|
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|