Browse Source

fix opaque broadcast

Alex Cheema 1 year ago
parent
commit
5b8f127bf4
1 changed files with 25 additions and 25 deletions
  1. 25 25
      exo/orchestration/standard_node.py

+ 25 - 25
exo/orchestration/standard_node.py

@@ -57,12 +57,13 @@ class StandardNode(Node):
         await self.server.stop()
 
     async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
+        shard = self.get_current_shard(base_shard)
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
         start_time = time.perf_counter_ns()
         resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
         end_time = time.perf_counter_ns()
         elapsed_time_ns = end_time - start_time
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
         return resp
 
     async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
@@ -95,13 +96,14 @@ class StandardNode(Node):
 
         return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
+    async def process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+        shard = self.get_current_shard(base_shard)
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
         start_time = time.perf_counter_ns()
         resp = await self._process_tensor(shard, tensor, request_id, inference_state)
         end_time = time.perf_counter_ns()
         elapsed_time_ns = end_time - start_time
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
         return resp
 
     async def _process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
@@ -134,41 +136,39 @@ class StandardNode(Node):
             traceback.print_exc()
             return None
 
-    async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
+    async def forward_to_next_shard(self, base_shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
         if not self.partitioning_strategy:
             if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
             return
+        shard = self.get_current_shard(base_shard)
 
         partitions = self.partitioning_strategy.partition(self.topology)
+        shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
         current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
         if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
         if current_partition_index is not None:
             next_partition_index = (current_partition_index + 1) % len(partitions)
             next_partition: Partition = partitions[next_partition_index]
+            next_shard = shards[next_partition_index]
             if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
 
-            if next_partition:
-                if next_partition.node_id == self.id:
-                    if isinstance(tensor_or_prompt, np.ndarray):
-                        await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-                    else:
-                        await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-                    return
-
-                target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-                if not target_peer:
-                    raise ValueError(f"Peer for {next_partition} not found")
+            if next_partition.node_id == self.id:
+                if isinstance(tensor_or_prompt, np.ndarray):
+                    await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+                else:
+                    await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+                return
 
-                start_layer = int(next_partition.start * shard.n_layers)
-                end_layer = int(next_partition.end * shard.n_layers) - 1
-                next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
+            target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
+            if not target_peer:
+                raise ValueError(f"Peer for {next_partition} not found")
 
-                if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
+            if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
 
-                if isinstance(tensor_or_prompt, np.ndarray):
-                    await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
-                else:
-                    await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
+            if isinstance(tensor_or_prompt, np.ndarray):
+                await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
+            else:
+                await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
 
     def get_current_shard(self, base_shard: Shard) -> Shard:
         partitions = self.partitioning_strategy.partition(self.topology)