|  | @@ -1,5 +1,9 @@
 | 
											
												
													
														|  | -from typing import List, Dict, Optional, Callable, Tuple, Union
 |  | 
 | 
											
												
													
														|  |  import numpy as np
 |  |  import numpy as np
 | 
											
												
													
														|  | 
 |  | +import json
 | 
											
												
													
														|  | 
 |  | +import asyncio
 | 
											
												
													
														|  | 
 |  | +import uuid
 | 
											
												
													
														|  | 
 |  | +import time
 | 
											
												
													
														|  | 
 |  | +from typing import List, Dict, Optional, Callable, Tuple, Union
 | 
											
												
													
														|  |  from exo.networking import Discovery, PeerHandle, Server
 |  |  from exo.networking import Discovery, PeerHandle, Server
 | 
											
												
													
														|  |  from exo.inference.inference_engine import InferenceEngine, Shard
 |  |  from exo.inference.inference_engine import InferenceEngine, Shard
 | 
											
												
													
														|  |  from .node import Node
 |  |  from .node import Node
 | 
											
										
											
												
													
														|  | @@ -9,8 +13,7 @@ from exo.topology.partitioning_strategy import PartitioningStrategy
 | 
											
												
													
														|  |  from exo.topology.partitioning_strategy import Partition
 |  |  from exo.topology.partitioning_strategy import Partition
 | 
											
												
													
														|  |  from exo import DEBUG
 |  |  from exo import DEBUG
 | 
											
												
													
														|  |  from exo.helpers import AsyncCallback, AsyncCallbackSystem
 |  |  from exo.helpers import AsyncCallback, AsyncCallbackSystem
 | 
											
												
													
														|  | -import asyncio
 |  | 
 | 
											
												
													
														|  | -import uuid
 |  | 
 | 
											
												
													
														|  | 
 |  | +from exo.viz.topology_viz import TopologyViz
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  class StandardNode(Node):
 |  |  class StandardNode(Node):
 | 
											
												
													
														|  |      def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256):
 |  |      def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256):
 | 
											
										
											
												
													
														|  | @@ -23,8 +26,24 @@ class StandardNode(Node):
 | 
											
												
													
														|  |          self.topology: Topology = Topology()
 |  |          self.topology: Topology = Topology()
 | 
											
												
													
														|  |          self.device_capabilities = device_capabilities()
 |  |          self.device_capabilities = device_capabilities()
 | 
											
												
													
														|  |          self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
 |  |          self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
 | 
											
												
													
														|  | -        self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
 |  | 
 | 
											
												
													
														|  | 
 |  | +        self.topology_viz = TopologyViz()
 | 
											
												
													
														|  |          self.max_generate_tokens = max_generate_tokens
 |  |          self.max_generate_tokens = max_generate_tokens
 | 
											
												
													
														|  | 
 |  | +        self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
 | 
											
												
													
														|  | 
 |  | +        self._on_opaque_status = AsyncCallbackSystem[str, str]()
 | 
											
												
													
														|  | 
 |  | +        self._on_opaque_status.register("node_status").on_next(self.on_node_status)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def on_node_status(self, request_id, opaque_status):
 | 
											
												
													
														|  | 
 |  | +        try:
 | 
											
												
													
														|  | 
 |  | +            status_data = json.loads(opaque_status)
 | 
											
												
													
														|  | 
 |  | +            if status_data.get("type", "") == "node_status":
 | 
											
												
													
														|  | 
 |  | +                if status_data.get("status", "").startswith("start_"):
 | 
											
												
													
														|  | 
 |  | +                    self.current_topology.active_node_id = status_data.get("node_id")
 | 
											
												
													
														|  | 
 |  | +                elif status_data.get("status", "").startswith("end_"):
 | 
											
												
													
														|  | 
 |  | +                    if status_data.get("node_id") == self.current_topology.active_node_id:
 | 
											
												
													
														|  | 
 |  | +                        self.current_topology.active_node_id = None
 | 
											
												
													
														|  | 
 |  | +            self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
 | 
											
												
													
														|  | 
 |  | +        except json.JSONDecodeError:
 | 
											
												
													
														|  | 
 |  | +            pass
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      async def start(self, wait_for_peers: int = 0) -> None:
 |  |      async def start(self, wait_for_peers: int = 0) -> None:
 | 
											
												
													
														|  |          await self.server.start()
 |  |          await self.server.start()
 | 
											
										
											
												
													
														|  | @@ -39,6 +58,15 @@ class StandardNode(Node):
 | 
											
												
													
														|  |          await self.server.stop()
 |  |          await self.server.stop()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
 |  |      async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
 | 
											
												
													
														|  | 
 |  | +        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
 | 
											
												
													
														|  | 
 |  | +        start_time = time.perf_counter_ns()
 | 
											
												
													
														|  | 
 |  | +        resp = await self._process_prompt(shard, prompt, request_id, inference_state)
 | 
											
												
													
														|  | 
 |  | +        end_time = time.perf_counter_ns()
 | 
											
												
													
														|  | 
 |  | +        elapsed_time_ns = end_time - start_time
 | 
											
												
													
														|  | 
 |  | +        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
 | 
											
												
													
														|  | 
 |  | +        return resp
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    async def _process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
 | 
											
												
													
														|  |          if request_id is None:
 |  |          if request_id is None:
 | 
											
												
													
														|  |              request_id = str(uuid.uuid4())
 |  |              request_id = str(uuid.uuid4())
 | 
											
												
													
														|  |          if request_id not in self.buffered_token_output:
 |  |          if request_id not in self.buffered_token_output:
 | 
											
										
											
												
													
														|  | @@ -68,6 +96,15 @@ class StandardNode(Node):
 | 
											
												
													
														|  |          return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 |  |          return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
 |  |      async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
 | 
											
												
													
														|  | 
 |  | +        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
 | 
											
												
													
														|  | 
 |  | +        start_time = time.perf_counter_ns()
 | 
											
												
													
														|  | 
 |  | +        resp = await self._process_tensor(shard, tensor, request_id, inference_state)
 | 
											
												
													
														|  | 
 |  | +        end_time = time.perf_counter_ns()
 | 
											
												
													
														|  | 
 |  | +        elapsed_time_ns = end_time - start_time
 | 
											
												
													
														|  | 
 |  | +        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
 | 
											
												
													
														|  | 
 |  | +        return resp
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    async def _process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
 | 
											
												
													
														|  |          if request_id is None:
 |  |          if request_id is None:
 | 
											
												
													
														|  |              request_id = str(uuid.uuid4())
 |  |              request_id = str(uuid.uuid4())
 | 
											
												
													
														|  |          if request_id not in self.buffered_token_output:
 |  |          if request_id not in self.buffered_token_output:
 | 
											
										
											
												
													
														|  | @@ -206,7 +243,9 @@ class StandardNode(Node):
 | 
											
												
													
														|  |              except Exception as e:
 |  |              except Exception as e:
 | 
											
												
													
														|  |                  print(f"Error collecting topology from {peer.id()}: {e}")
 |  |                  print(f"Error collecting topology from {peer.id()}: {e}")
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +        next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
 | 
											
												
													
														|  |          self.topology = next_topology
 |  |          self.topology = next_topology
 | 
											
												
													
														|  | 
 |  | +        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
 | 
											
												
													
														|  |          return next_topology
 |  |          return next_topology
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      # TODO: unify this and collect_topology as global actions
 |  |      # TODO: unify this and collect_topology as global actions
 | 
											
										
											
												
													
														|  | @@ -237,6 +276,10 @@ class StandardNode(Node):
 | 
											
												
													
														|  |      def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
 |  |      def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
 | 
											
												
													
														|  |          return self._on_token
 |  |          return self._on_token
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +    @property
 | 
											
												
													
														|  | 
 |  | +    def on_opaque_status(self) -> AsyncCallbackSystem[str, str]:
 | 
											
												
													
														|  | 
 |  | +        return self._on_opaque_status
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
 |  |      def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
 | 
											
												
													
														|  |          if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
 |  |          if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
 | 
											
												
													
														|  |          self.on_token.trigger_all(request_id, tokens, is_finished)
 |  |          self.on_token.trigger_all(request_id, tokens, is_finished)
 | 
											
										
											
												
													
														|  | @@ -253,4 +296,12 @@ class StandardNode(Node):
 | 
											
												
													
														|  |                  traceback.print_exc()
 |  |                  traceback.print_exc()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          print(f"Broadcast result: {request_id=} {result=} {is_finished=}")
 |  |          print(f"Broadcast result: {request_id=} {result=} {is_finished=}")
 | 
											
												
													
														|  | -        await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
 |  | 
 | 
											
												
													
														|  | 
 |  | +        await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
 | 
											
												
													
														|  | 
 |  | +        for peer in self.peers:
 | 
											
												
													
														|  | 
 |  | +            await peer.send_opaque_status(request_id, status)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    @property
 | 
											
												
													
														|  | 
 |  | +    def current_topology(self) -> Topology:
 | 
											
												
													
														|  | 
 |  | +        return self.topology
 |