Browse Source

add gpu poor/rich bar in panel. fixes #33

Alex Cheema 1 year ago
parent
commit
9fa0cb1ab1
3 changed files with 70 additions and 27 deletions
  1. 0 0
      exo/viz/__init__.py
  2. 3 4
      exo/viz/test_topology_viz.py
  3. 67 23
      exo/viz/topology_viz.py

+ 0 - 0
exo/viz/__init__.py


+ 3 - 4
exo/viz/test_topology_viz.py

@@ -4,7 +4,6 @@ from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
-from exo.helpers import AsyncCallbackSystem
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     async def asyncSetUp(self):
@@ -26,15 +25,15 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
             Partition("node1", 0, 0.2),
             Partition("node4", 0.2, 0.4),
             Partition("node2", 0.4, 0.8),
-            Partition("node3", 0.8, 1),
+            Partition("node3", 0.8, 0.9),
         ])
         time.sleep(2)
         self.topology.active_node_id = "node3"
         self.top_viz.update_visualization(self.topology, [
             Partition("node1", 0, 0.3),
-            Partition("node2", 0.3, 0.7),
+            Partition("node5", 0.3, 0.5),
+            Partition("node2", 0.5, 0.7),
             Partition("node4", 0.7, 0.9),
-            Partition("node3", 0.9, 1),
         ])
         time.sleep(2)
 

+ 67 - 23
exo/viz/topology_viz.py

@@ -1,7 +1,6 @@
 import math
-from typing import Dict, List
+from typing import List
 from exo.helpers import exo_text
-from exo.orchestration.node import Node
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from rich.console import Console
@@ -9,7 +8,8 @@ from rich.panel import Panel
 from rich.text import Text
 from rich.live import Live
 from rich.style import Style
-from exo.topology.device_capabilities import DeviceCapabilities, UNKNOWN_DEVICE_CAPABILITIES
+from rich.color import Color
+from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
 class TopologyViz:
     def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
@@ -38,11 +38,12 @@ class TopologyViz:
     def _generate_layout(self) -> str:
         # Calculate visualization parameters
         num_partitions = len(self.partitions)
-        radius = 12  # Reduced radius
-        center_x, center_y = 45, 25  # Adjusted center_x to center the visualization
+        radius_x = 30  # Increased horizontal radius
+        radius_y = 12  # Decreased vertical radius
+        center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
 
         # Generate visualization
-        visualization = [[' ' for _ in range(90)] for _ in range(45)]  # Increased width to 90
+        visualization = [[' ' for _ in range(100)] for _ in range(55)]  # Decreased height
 
         # Add exo_text at the top in bright yellow
         exo_lines = exo_text.split('\n')
@@ -50,10 +51,10 @@ class TopologyViz:
         max_line_length = max(len(line) for line in exo_lines)
         for i, line in enumerate(exo_lines):
             centered_line = line.center(max_line_length)
-            start_x = (90 - max_line_length) // 2  # Calculate starting x position to center the text
+            start_x = (100 - max_line_length) // 2 + 15 # Center the text plus empirical adjustment of 15
             colored_line = Text(centered_line, style=yellow_style)
             for j, char in enumerate(str(colored_line)):
-                if 0 <= start_x + j < 90 and i < len(visualization):  # Ensure we don't exceed the width and height
+                if 0 <= start_x + j < 100 and i < len(visualization):
                     visualization[i][start_x + j] = char
 
         # Display chatgpt_api_endpoint and web_chat_url if set
@@ -63,18 +64,53 @@ class TopologyViz:
         if self.chatgpt_api_endpoint:
             info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
 
+        info_start_y = len(exo_lines) + 1
         for i, line in enumerate(info_lines):
-            start_x = 0
+            start_x = (100 - len(line)) // 2 + 15 # Center the info lines plus empirical adjustment of 15
             for j, char in enumerate(line):
-                if j < 90 and i + len(exo_lines) < 45:  # Ensure we don't exceed the width and height
-                    visualization[i + len(exo_lines)][j] = char
+                if 0 <= start_x + j < 100 and info_start_y + i < 55:
+                    visualization[info_start_y + i][start_x + j] = char
+
+        # Calculate total FLOPS and position on the bar
+        total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
+        bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
+        print(f"{bar_pos=}")
+
+        # Add GPU poor/rich bar
+        bar_width = 30  # Increased bar width
+        bar_start_x = (100 - bar_width) // 2  # Center the bar
+        bar_y = info_start_y + len(info_lines) + 4  # Position the bar below the info section with two cells of space
+        
+        # Create a gradient bar using emojis
+        gradient_bar = Text()
+        emojis = ['🟥', '🟧', '🟨', '🟩']  # Red, Orange, Yellow, Green
+        for i in range(bar_width):
+            emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
+            gradient_bar.append(emojis[emoji_index])
+
+        # Add the gradient bar to the visualization
+        visualization[bar_y][bar_start_x - 1] = '['
+        visualization[bar_y][bar_start_x + bar_width] = ']'
+        for i, segment in enumerate(str(gradient_bar)):
+            visualization[bar_y][bar_start_x + i] = segment
+        
+        # Add labels
+        visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = 'GPU poor'
+        visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = 'GPU rich'
+        
+        # Add position indicator and FLOPS value
+        pos_x = bar_start_x + int(bar_pos * bar_width)
+        flops_str = f"{total_flops:.2f} TFLOPS"
+        visualization[bar_y - 1][pos_x] = '▼'
+        visualization[bar_y + 1][pos_x - len(flops_str)//2:pos_x + len(flops_str)//2 + len(flops_str)%2] = flops_str
+        visualization[bar_y + 2][pos_x] = '▲'
 
         for i, partition in enumerate(self.partitions):
             device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
             angle = 2 * math.pi * i / num_partitions
-            x = int(center_x + radius * math.cos(angle))
-            y = int(center_y + radius * math.sin(angle))
+            x = int(center_x + radius_x * math.cos(angle))
+            y = int(center_y + radius_y * math.sin(angle))
 
             # Place node with different color for active node
             if partition.node_id == self.topology.active_node_id:
@@ -91,19 +127,27 @@ class TopologyViz:
             ]
 
             # Calculate info position based on angle
-            info_distance = radius + 3  # Reduced distance
-            info_x = int(center_x + info_distance * math.cos(angle))
-            info_y = int(center_y + info_distance * math.sin(angle))
+            info_distance_x = radius_x + 6  # Increased horizontal distance
+            info_distance_y = radius_y + 3  # Decreased vertical distance
+            info_x = int(center_x + info_distance_x * math.cos(angle))
+            info_y = int(center_y + info_distance_y * math.sin(angle))
 
-            # Adjust text position to avoid overwriting the node icon
+            # Adjust text position to avoid overwriting the node icon and prevent cutoff
             if info_x < x:  # Text is to the left of the node
                 info_x = max(0, x - len(max(node_info, key=len)) - 1)
             elif info_x > x:  # Text is to the right of the node
-                info_x = min(89 - len(max(node_info, key=len)), info_x)
+                info_x = min(99 - len(max(node_info, key=len)), info_x)
+            
+            # Adjust for top and bottom nodes
+            if 5*math.pi/4 < angle < 7*math.pi/4:  # Node is near the top
+                info_x += 4  # Shift text slightly to the right
+            elif math.pi/4 < angle < 3*math.pi/4:  # Node is near the bottom
+                info_x += 3  # Shift text slightly to the right
+                info_y -= 2  # Move text up by two cells
 
             for j, line in enumerate(node_info):
                 for k, char in enumerate(line):
-                    if 0 <= info_y + j < 45 and 0 <= info_x + k < 90:  # Updated width check
+                    if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
                         # Ensure we're not overwriting the node icon
                         if info_y + j != y or info_x + k != x:
                             visualization[info_y + j][info_x + k] = char
@@ -111,16 +155,16 @@ class TopologyViz:
             # Draw line to next node
             next_i = (i + 1) % num_partitions
             next_angle = 2 * math.pi * next_i / num_partitions
-            next_x = int(center_x + radius * math.cos(next_angle))
-            next_y = int(center_y + radius * math.sin(next_angle))
+            next_x = int(center_x + radius_x * math.cos(next_angle))
+            next_y = int(center_y + radius_y * math.sin(next_angle))
 
             # Simple line drawing
             steps = max(abs(next_x - x), abs(next_y - y))
             for step in range(1, steps):
                 line_x = int(x + (next_x - x) * step / steps)
                 line_y = int(y + (next_y - y) * step / steps)
-                if 0 <= line_y < 45 and 0 <= line_x < 90:  # Updated width check
+                if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
                     visualization[line_y][line_x] = '-'
 
         # Convert to string
-        return '\n'.join(''.join(str(char) for char in row) for row in visualization)
+        return '\n'.join(''.join(str(char) for char in row) for row in visualization)