Selaa lähdekoodia

exo topology visualisation that shows the topology of the network, device capabilities and the currently active node using opaque statuses. fixes #36. ready for #33

Alex Cheema 9 kuukautta sitten
vanhempi
commit
4b592f9d45

+ 8 - 0
exo/inference/shard.py

@@ -12,3 +12,11 @@ class Shard:
 
     def is_last_layer(self) -> bool:
         return self.end_layer == self.n_layers - 1
+
+    def to_dict(self) -> dict:
+        return {
+            "model_id": self.model_id,
+            "start_layer": self.start_layer,
+            "end_layer": self.end_layer,
+            "n_layers": self.n_layers
+        }

+ 9 - 6
exo/networking/grpc/grpc_discovery.py

@@ -30,7 +30,7 @@ class GRPCDiscovery(Discovery):
         self.listen_port = listen_port
         self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
         self.broadcast_interval = broadcast_interval
-        self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float]] = {}
+        self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
         self.broadcast_task = None
         self.listen_task = None
         self.cleanup_task = None
@@ -73,7 +73,7 @@ class GRPCDiscovery(Discovery):
                     if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
                     break  # No new peers found in the grace period, we are done
 
-        return [peer_handle for peer_handle, _ in self.known_peers.values()]
+        return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
 
     async def task_broadcast_presence(self):
         transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
@@ -109,9 +109,9 @@ class GRPCDiscovery(Discovery):
             peer_port = message['grpc_port']
             device_capabilities = DeviceCapabilities(**message['device_capabilities'])
             if peer_id not in self.known_peers:
-                self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time())
+                self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time())
                 if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
-            self.known_peers[peer_id] = (self.known_peers[peer_id][0], time.time())
+            self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
 
     async def task_listen_for_peers(self):
         await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
@@ -122,8 +122,11 @@ class GRPCDiscovery(Discovery):
             try:
                 current_time = time.time()
                 timeout = 15 * self.broadcast_interval
-                peers_to_remove = [peer_handle.id() for peer_handle, last_seen in self.known_peers.values() if not await peer_handle.is_connected() or current_time - last_seen > timeout]
-                if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, last_seen={last_seen}" for peer_handle, last_seen in self.known_peers.values()})
+                peers_to_remove = [
+                    peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if
+                    (not await peer_handle.is_connected() and current_time - connected_at > timeout) or current_time - last_seen > timeout
+                ]
+                if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
                 if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
                 for peer_id in peers_to_remove:
                     if peer_id in self.known_peers: del self.known_peers[peer_id]

+ 4 - 0
exo/networking/grpc/grpc_peer_handle.py

@@ -96,3 +96,7 @@ class GRPCPeerHandle(PeerHandle):
     async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
         request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
         await self.stub.SendResult(request)
+
+    async def send_opaque_status(self, request_id: str, status: str) -> None:
+        request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
+        await self.stub.SendOpaqueStatus(request)

+ 7 - 0
exo/networking/grpc/grpc_server.py

@@ -90,3 +90,10 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
         self.node.on_token.trigger_all(request_id, result, is_finished)
         return node_service_pb2.Empty()
+
+    async def SendOpaqueStatus(self, request, context):
+        request_id = request.request_id
+        status = request.status
+        if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
+        self.node.on_opaque_status.trigger_all(request_id, status)
+        return node_service_pb2.Empty()

+ 6 - 0
exo/networking/grpc/node_service.proto

@@ -10,6 +10,7 @@ service NodeService {
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc GlobalReset (GlobalResetRequest) returns (Empty) {}
   rpc SendResult (SendResultRequest) returns (Empty) {}
+  rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
 }
 
 message Shard {
@@ -91,4 +92,9 @@ message SendResultRequest {
   bool is_finished = 3;
 }
 
+message SendOpaqueStatusRequest {
+  string request_id = 1;
+  string status = 2;
+}
+
 message Empty {}

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 43 - 0
exo/networking/grpc/node_service_pb2_grpc.py

@@ -74,6 +74,11 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
+        self.SendOpaqueStatus = channel.unary_unary(
+                '/node_service.NodeService/SendOpaqueStatus',
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
 
 
 class NodeServiceServicer(object):
@@ -121,6 +126,12 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
+    def SendOpaqueStatus(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
 
 def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
@@ -159,6 +170,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.SendResultRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
+            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendOpaqueStatus,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
             'node_service.NodeService', rpc_method_handlers)
@@ -358,3 +374,30 @@ class NodeService(object):
             timeout,
             metadata,
             _registered_method=True)
+
+    @staticmethod
+    def SendOpaqueStatus(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendOpaqueStatus',
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 10 - 0
exo/orchestration/node.py

@@ -38,7 +38,17 @@ class Node(ABC):
     async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
         pass
 
+    @property
+    @abstractmethod
+    def current_topology(self) -> Topology:
+        pass
+
     @property
     @abstractmethod
     def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
         pass
+
+    @property
+    @abstractmethod
+    def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+        pass

+ 56 - 5
exo/orchestration/standard_node.py

@@ -1,5 +1,9 @@
-from typing import List, Dict, Optional, Callable, Tuple, Union
 import numpy as np
+import json
+import asyncio
+import uuid
+import time
+from typing import List, Dict, Optional, Callable, Tuple, Union
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from .node import Node
@@ -9,8 +13,7 @@ from exo.topology.partitioning_strategy import PartitioningStrategy
 from exo.topology.partitioning_strategy import Partition
 from exo import DEBUG
 from exo.helpers import AsyncCallback, AsyncCallbackSystem
-import asyncio
-import uuid
+from exo.viz.topology_viz import TopologyViz
 
 class StandardNode(Node):
     def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256):
@@ -23,8 +26,24 @@ class StandardNode(Node):
         self.topology: Topology = Topology()
         self.device_capabilities = device_capabilities()
         self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-        self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
+        self.topology_viz = TopologyViz()
         self.max_generate_tokens = max_generate_tokens
+        self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
+        self._on_opaque_status = AsyncCallbackSystem[str, str]()
+        self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+
+    def on_node_status(self, request_id, opaque_status):
+        try:
+            status_data = json.loads(opaque_status)
+            if status_data.get("type", "") == "node_status":
+                if status_data.get("status", "").startswith("start_"):
+                    self.current_topology.active_node_id = status_data.get("node_id")
+                elif status_data.get("status", "").startswith("end_"):
+                    if status_data.get("node_id") == self.current_topology.active_node_id:
+                        self.current_topology.active_node_id = None
+            self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+        except json.JSONDecodeError:
+            pass
 
     async def start(self, wait_for_peers: int = 0) -> None:
         await self.server.start()
@@ -39,6 +58,15 @@ class StandardNode(Node):
         await self.server.stop()
 
     async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
+        start_time = time.perf_counter_ns()
+        resp = await self._process_prompt(shard, prompt, request_id, inference_state)
+        end_time = time.perf_counter_ns()
+        elapsed_time_ns = end_time - start_time
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
+        return resp
+
+    async def _process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         if request_id is None:
             request_id = str(uuid.uuid4())
         if request_id not in self.buffered_token_output:
@@ -68,6 +96,15 @@ class StandardNode(Node):
         return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 
     async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
+        start_time = time.perf_counter_ns()
+        resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+        end_time = time.perf_counter_ns()
+        elapsed_time_ns = end_time - start_time
+        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
+        return resp
+
+    async def _process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         if request_id is None:
             request_id = str(uuid.uuid4())
         if request_id not in self.buffered_token_output:
@@ -206,7 +243,9 @@ class StandardNode(Node):
             except Exception as e:
                 print(f"Error collecting topology from {peer.id()}: {e}")
 
+        next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
         self.topology = next_topology
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
         return next_topology
 
     # TODO: unify this and collect_topology as global actions
@@ -237,6 +276,10 @@ class StandardNode(Node):
     def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
         return self._on_token
 
+    @property
+    def on_opaque_status(self) -> AsyncCallbackSystem[str, str]:
+        return self._on_opaque_status
+
     def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
         if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
         self.on_token.trigger_all(request_id, tokens, is_finished)
@@ -253,4 +296,12 @@ class StandardNode(Node):
                 traceback.print_exc()
 
         print(f"Broadcast result: {request_id=} {result=} {is_finished=}")
-        await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
+        await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
+
+    async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
+        for peer in self.peers:
+            await peer.send_opaque_status(request_id, status)
+
+    @property
+    def current_topology(self) -> Topology:
+        return self.topology

+ 2 - 1
exo/topology/topology.py

@@ -1,10 +1,11 @@
 from .device_capabilities import DeviceCapabilities
-from typing import Dict, Set
+from typing import Dict, Set, Optional
 
 class Topology:
     def __init__(self):
         self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
         self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
+        self.active_node_id: Optional[str] = None
 
     def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
         self.nodes[node_id] = device_capabilities

+ 43 - 0
exo/viz/test_topology_viz.py

@@ -0,0 +1,43 @@
+import asyncio
+import unittest
+from exo.viz.topology_viz import TopologyViz
+from exo.topology.topology import Topology
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+from exo.topology.partitioning_strategy import Partition
+from exo.helpers import AsyncCallbackSystem
+
+class TestNodeViz(unittest.IsolatedAsyncioTestCase):
+    async def asyncSetUp(self):
+        self.topology = Topology()
+        self.topology.update_node("node1", DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0,fp16=2.0,int8=4.0)))
+        self.topology.update_node("node2", DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0,fp16=4.0,int8=8.0)))
+        self.topology.update_node("node3", DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0,fp16=8.0,int8=16.0)))
+        self.topology.update_node("node4", DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0,fp16=16.0,int8=32.0)))
+
+        self.top_viz = TopologyViz()
+        await asyncio.sleep(2)  # Simulate running for a short time
+
+    async def test_layout_generation(self):
+        self.top_viz._generate_layout()
+        self.top_viz.refresh()
+        import time
+        time.sleep(2)
+        self.top_viz.update_visualization(self.topology, [
+            Partition("node1", 0, 0.2),
+            Partition("node4", 0.2, 0.4),
+            Partition("node2", 0.4, 0.8),
+            Partition("node3", 0.8, 1),
+        ])
+        time.sleep(2)
+        self.topology.active_node_id = "node3"
+        self.top_viz.update_visualization(self.topology, [
+            Partition("node1", 0, 0.3),
+            Partition("node2", 0.3, 0.7),
+            Partition("node4", 0.7, 0.9),
+            Partition("node3", 0.9, 1),
+        ])
+        time.sleep(2)
+
+
+if __name__ == "__main__":
+    unittest.main()

+ 110 - 0
exo/viz/topology_viz.py

@@ -0,0 +1,110 @@
+import math
+from typing import Dict, List
+from exo.helpers import exo_text
+from exo.orchestration.node import Node
+from exo.topology.topology import Topology
+from exo.topology.partitioning_strategy import Partition
+from rich.console import Console
+from rich.panel import Panel
+from rich.text import Text
+from rich.live import Live
+from rich.style import Style
+from exo.topology.device_capabilities import DeviceCapabilities, UNKNOWN_DEVICE_CAPABILITIES
+
+class TopologyViz:
+    def __init__(self):
+        self.console = Console()
+        self.topology = Topology()
+        self.partitions: List[Partition] = []
+        self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
+        self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
+        self.live_panel.start()
+
+    def update_visualization(self, topology: Topology, partitions: List[Partition]):
+        self.topology = topology
+        self.partitions = partitions
+        self.refresh()
+
+    def refresh(self):
+        self.panel.renderable = self._generate_layout()
+        # Update the panel title with the number of nodes and partitions
+        node_count = len(self.topology.nodes)
+        self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
+        self.live_panel.update(self.panel, refresh=True)
+
+    def _generate_layout(self) -> str:
+        # Calculate visualization parameters
+        num_partitions = len(self.partitions)
+        radius = 12  # Reduced radius
+        center_x, center_y = 45, 25  # Adjusted center_x to center the visualization
+
+        # Generate visualization
+        visualization = [[' ' for _ in range(90)] for _ in range(45)]  # Increased width to 90
+
+        # Add exo_text at the top in bright yellow
+        exo_lines = exo_text.split('\n')
+        yellow_style = Style(color="bright_yellow")
+        max_line_length = max(len(line) for line in exo_lines)
+        for i, line in enumerate(exo_lines):
+            centered_line = line.center(max_line_length)
+            start_x = (90 - max_line_length) // 2  # Calculate starting x position to center the text
+            colored_line = Text(centered_line, style=yellow_style)
+            for j, char in enumerate(str(colored_line)):
+                if 0 <= start_x + j < 90 and i < len(visualization):  # Ensure we don't exceed the width and height
+                    visualization[i][start_x + j] = char
+
+        for i, partition in enumerate(self.partitions):
+            device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
+
+            angle = 2 * math.pi * i / num_partitions
+            x = int(center_x + radius * math.cos(angle))
+            y = int(center_y + radius * math.sin(angle))
+
+            # Place node with different color for active node
+            if partition.node_id == self.topology.active_node_id:
+                visualization[y][x] = '🔴'  # Red circle for active node
+            else:
+                visualization[y][x] = '🔵'  # Blue circle for inactive nodes
+
+            # Place node info (ID, start_layer, end_layer)
+            node_info = [
+                f"Model: {device_capabilities.model}",
+                f"Mem: {device_capabilities.memory // 1024}GB",
+                f"FLOPS: {device_capabilities.flops.fp16}T",
+                f"Part: {partition.start:.2f}-{partition.end:.2f}"
+            ]
+
+            # Calculate info position based on angle
+            info_distance = radius + 3  # Reduced distance
+            info_x = int(center_x + info_distance * math.cos(angle))
+            info_y = int(center_y + info_distance * math.sin(angle))
+
+            # Adjust text position to avoid overwriting the node icon
+            if info_x < x:  # Text is to the left of the node
+                info_x = max(0, x - len(max(node_info, key=len)) - 1)
+            elif info_x > x:  # Text is to the right of the node
+                info_x = min(89 - len(max(node_info, key=len)), info_x)
+
+            for j, line in enumerate(node_info):
+                for k, char in enumerate(line):
+                    if 0 <= info_y + j < 45 and 0 <= info_x + k < 90:  # Updated width check
+                        # Ensure we're not overwriting the node icon
+                        if info_y + j != y or info_x + k != x:
+                            visualization[info_y + j][info_x + k] = char
+
+            # Draw line to next node
+            next_i = (i + 1) % num_partitions
+            next_angle = 2 * math.pi * next_i / num_partitions
+            next_x = int(center_x + radius * math.cos(next_angle))
+            next_y = int(center_y + radius * math.sin(next_angle))
+
+            # Simple line drawing
+            steps = max(abs(next_x - x), abs(next_y - y))
+            for step in range(1, steps):
+                line_x = int(x + (next_x - x) * step / steps)
+                line_y = int(y + (next_y - y) * step / steps)
+                if 0 <= line_y < 45 and 0 <= line_x < 90:  # Updated width check
+                    visualization[line_y][line_x] = '-'
+
+        # Convert to string
+        return '\n'.join(''.join(str(char) for char in row) for row in visualization)

+ 4 - 0
main.py

@@ -5,8 +5,10 @@ import uuid
 import platform
 import psutil
 import os
+import json
 from typing import List
 from exo.orchestration.standard_node import StandardNode
+from exo.viz.topology_viz import TopologyViz
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
@@ -59,7 +61,9 @@ server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__)
 
+topology_viz = TopologyViz()
 node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+node.on_opaque_status.register("main_log").on_next(lambda request_id, status: print(f"!!! [{request_id}] Opaque Status: {status}"))
 
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""

+ 1 - 0
setup.py

@@ -15,6 +15,7 @@ install_requires = [
     "psutil==6.0.0",
     "pynvml==11.5.3",
     "requests==2.32.3",
+    "rich==13.7.1",
     "safetensors==0.4.3",
     "tiktoken==0.7.0",
     "tokenizers==0.19.1",

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä