topology_viz.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import math
  2. from typing import List
  3. from exo.helpers import exo_text
  4. from exo.topology.topology import Topology
  5. from exo.topology.partitioning_strategy import Partition
  6. from rich.console import Console
  7. from rich.panel import Panel
  8. from rich.text import Text
  9. from rich.live import Live
  10. from rich.style import Style
  11. from rich.color import Color
  12. from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
  13. class TopologyViz:
  14. def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
  15. self.chatgpt_api_endpoint = chatgpt_api_endpoint
  16. self.web_chat_url = web_chat_url
  17. self.topology = Topology()
  18. self.partitions: List[Partition] = []
  19. self.console = Console()
  20. self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
  21. self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
  22. self.live_panel.start()
  23. def update_visualization(self, topology: Topology, partitions: List[Partition]):
  24. self.topology = topology
  25. self.partitions = partitions
  26. self.refresh()
  27. def refresh(self):
  28. self.panel.renderable = self._generate_layout()
  29. # Update the panel title with the number of nodes and partitions
  30. node_count = len(self.topology.nodes)
  31. self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
  32. self.live_panel.update(self.panel, refresh=True)
  33. def _generate_layout(self) -> str:
  34. # Calculate visualization parameters
  35. num_partitions = len(self.partitions)
  36. radius_x = 30 # Increased horizontal radius
  37. radius_y = 12 # Decreased vertical radius
  38. center_x, center_y = 50, 28 # Centered horizontally and moved up slightly
  39. # Generate visualization
  40. visualization = [[' ' for _ in range(100)] for _ in range(55)] # Decreased height
  41. # Add exo_text at the top in bright yellow
  42. exo_lines = exo_text.split('\n')
  43. yellow_style = Style(color="bright_yellow")
  44. max_line_length = max(len(line) for line in exo_lines)
  45. for i, line in enumerate(exo_lines):
  46. centered_line = line.center(max_line_length)
  47. start_x = (100 - max_line_length) // 2 + 15 # Center the text plus empirical adjustment of 15
  48. colored_line = Text(centered_line, style=yellow_style)
  49. for j, char in enumerate(str(colored_line)):
  50. if 0 <= start_x + j < 100 and i < len(visualization):
  51. visualization[i][start_x + j] = char
  52. # Display chatgpt_api_endpoint and web_chat_url if set
  53. info_lines = []
  54. if self.web_chat_url:
  55. info_lines.append(f"Web Chat URL (tinychat): {self.web_chat_url}")
  56. if self.chatgpt_api_endpoint:
  57. info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
  58. info_start_y = len(exo_lines) + 1
  59. for i, line in enumerate(info_lines):
  60. start_x = (100 - len(line)) // 2 + 15 # Center the info lines plus empirical adjustment of 15
  61. for j, char in enumerate(line):
  62. if 0 <= start_x + j < 100 and info_start_y + i < 55:
  63. visualization[info_start_y + i][start_x + j] = char
  64. # Calculate total FLOPS and position on the bar
  65. total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
  66. bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
  67. print(f"{bar_pos=}")
  68. # Add GPU poor/rich bar
  69. bar_width = 30 # Increased bar width
  70. bar_start_x = (100 - bar_width) // 2 # Center the bar
  71. bar_y = info_start_y + len(info_lines) + 4 # Position the bar below the info section with two cells of space
  72. # Create a gradient bar using emojis
  73. gradient_bar = Text()
  74. emojis = ['🟥', '🟧', '🟨', '🟩'] # Red, Orange, Yellow, Green
  75. for i in range(bar_width):
  76. emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
  77. gradient_bar.append(emojis[emoji_index])
  78. # Add the gradient bar to the visualization
  79. visualization[bar_y][bar_start_x - 1] = '['
  80. visualization[bar_y][bar_start_x + bar_width] = ']'
  81. for i, segment in enumerate(str(gradient_bar)):
  82. visualization[bar_y][bar_start_x + i] = segment
  83. # Add labels
  84. visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = 'GPU poor'
  85. visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = 'GPU rich'
  86. # Add position indicator and FLOPS value
  87. pos_x = bar_start_x + int(bar_pos * bar_width)
  88. flops_str = f"{total_flops:.2f} TFLOPS"
  89. visualization[bar_y - 1][pos_x] = '▼'
  90. visualization[bar_y + 1][pos_x - len(flops_str)//2:pos_x + len(flops_str)//2 + len(flops_str)%2] = flops_str
  91. visualization[bar_y + 2][pos_x] = '▲'
  92. for i, partition in enumerate(self.partitions):
  93. device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
  94. angle = 2 * math.pi * i / num_partitions
  95. x = int(center_x + radius_x * math.cos(angle))
  96. y = int(center_y + radius_y * math.sin(angle))
  97. # Place node with different color for active node
  98. if partition.node_id == self.topology.active_node_id:
  99. visualization[y][x] = '🔴' # Red circle for active node
  100. else:
  101. visualization[y][x] = '🔵' # Blue circle for inactive nodes
  102. # Place node info (ID, start_layer, end_layer)
  103. node_info = [
  104. f"Model: {device_capabilities.model}",
  105. f"Mem: {device_capabilities.memory // 1024}GB",
  106. f"FLOPS: {device_capabilities.flops.fp16}T",
  107. f"Part: {partition.start:.2f}-{partition.end:.2f}"
  108. ]
  109. # Calculate info position based on angle
  110. info_distance_x = radius_x + 6 # Increased horizontal distance
  111. info_distance_y = radius_y + 3 # Decreased vertical distance
  112. info_x = int(center_x + info_distance_x * math.cos(angle))
  113. info_y = int(center_y + info_distance_y * math.sin(angle))
  114. # Adjust text position to avoid overwriting the node icon and prevent cutoff
  115. if info_x < x: # Text is to the left of the node
  116. info_x = max(0, x - len(max(node_info, key=len)) - 1)
  117. elif info_x > x: # Text is to the right of the node
  118. info_x = min(99 - len(max(node_info, key=len)), info_x)
  119. # Adjust for top and bottom nodes
  120. if 5*math.pi/4 < angle < 7*math.pi/4: # Node is near the top
  121. info_x += 4 # Shift text slightly to the right
  122. elif math.pi/4 < angle < 3*math.pi/4: # Node is near the bottom
  123. info_x += 3 # Shift text slightly to the right
  124. info_y -= 2 # Move text up by two cells
  125. for j, line in enumerate(node_info):
  126. for k, char in enumerate(line):
  127. if 0 <= info_y + j < 55 and 0 <= info_x + k < 100: # Updated height check
  128. # Ensure we're not overwriting the node icon
  129. if info_y + j != y or info_x + k != x:
  130. visualization[info_y + j][info_x + k] = char
  131. # Draw line to next node
  132. next_i = (i + 1) % num_partitions
  133. next_angle = 2 * math.pi * next_i / num_partitions
  134. next_x = int(center_x + radius_x * math.cos(next_angle))
  135. next_y = int(center_y + radius_y * math.sin(next_angle))
  136. # Simple line drawing
  137. steps = max(abs(next_x - x), abs(next_y - y))
  138. for step in range(1, steps):
  139. line_x = int(x + (next_x - x) * step / steps)
  140. line_y = int(y + (next_y - y) * step / steps)
  141. if 0 <= line_y < 55 and 0 <= line_x < 100: # Updated height check
  142. visualization[line_y][line_x] = '-'
  143. # Convert to string
  144. return '\n'.join(''.join(str(char) for char in row) for row in visualization)