|
@@ -1,17 +1,20 @@
|
|
|
import math
|
|
|
+from collections import OrderedDict
|
|
|
from typing import List, Optional, Tuple, Dict
|
|
|
from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
|
|
|
from exo.topology.topology import Topology
|
|
|
from exo.topology.partitioning_strategy import Partition
|
|
|
from exo.download.hf.hf_helpers import RepoProgressEvent
|
|
|
-from rich.console import Console
|
|
|
-from rich.panel import Panel
|
|
|
+from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
|
|
|
+from rich.console import Console, Group
|
|
|
from rich.text import Text
|
|
|
from rich.live import Live
|
|
|
from rich.style import Style
|
|
|
from rich.table import Table
|
|
|
from rich.layout import Layout
|
|
|
-from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
|
|
|
+from rich.syntax import Syntax
|
|
|
+from rich.panel import Panel
|
|
|
+from rich.markdown import Markdown
|
|
|
|
|
|
class TopologyViz:
|
|
|
def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
|
|
@@ -21,17 +24,24 @@ class TopologyViz:
|
|
|
self.partitions: List[Partition] = []
|
|
|
self.node_id = None
|
|
|
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
|
|
+ self.requests: OrderedDict[str, Tuple[str, str]] = {}
|
|
|
|
|
|
self.console = Console()
|
|
|
self.layout = Layout()
|
|
|
self.layout.split(
|
|
|
Layout(name="main"),
|
|
|
+ Layout(name="prompt_output", size=15),
|
|
|
Layout(name="download", size=25)
|
|
|
)
|
|
|
self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
|
|
|
+ self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
|
|
|
self.download_panel = Panel("", title="Download Progress", border_style="cyan")
|
|
|
self.layout["main"].update(self.main_panel)
|
|
|
+ self.layout["prompt_output"].update(self.prompt_output_panel)
|
|
|
self.layout["download"].update(self.download_panel)
|
|
|
+
|
|
|
+ # Initially hide the prompt_output panel
|
|
|
+ self.layout["prompt_output"].visible = False
|
|
|
self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
|
|
|
self.live_panel.start()
|
|
|
|
|
@@ -43,12 +53,34 @@ class TopologyViz:
|
|
|
self.node_download_progress = node_download_progress
|
|
|
self.refresh()
|
|
|
|
|
|
+ def update_prompt(self, request_id: str, prompt: Optional[str] = None):
|
|
|
+ if request_id in self.requests:
|
|
|
+ self.requests[request_id] = [prompt, self.requests[request_id][1]]
|
|
|
+ else:
|
|
|
+ self.requests[request_id] = [prompt, ""]
|
|
|
+ self.refresh()
|
|
|
+
|
|
|
+ def update_prompt_output(self, request_id: str, output: Optional[str] = None):
|
|
|
+ if request_id in self.requests:
|
|
|
+ self.requests[request_id] = [self.requests[request_id][0], output]
|
|
|
+ else:
|
|
|
+ self.requests[request_id] = ["", output]
|
|
|
+ self.refresh()
|
|
|
+
|
|
|
def refresh(self):
|
|
|
self.main_panel.renderable = self._generate_main_layout()
|
|
|
# Update the panel title with the number of nodes and partitions
|
|
|
node_count = len(self.topology.nodes)
|
|
|
self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
|
|
|
|
|
|
+ # Update and show/hide prompt and output panel
|
|
|
+ if any(r[0] or r[1] for r in self.requests.values()):
|
|
|
+ self.prompt_output_panel = self._generate_prompt_output_layout()
|
|
|
+ self.layout["prompt_output"].update(self.prompt_output_panel)
|
|
|
+ self.layout["prompt_output"].visible = True
|
|
|
+ else:
|
|
|
+ self.layout["prompt_output"].visible = False
|
|
|
+
|
|
|
# Only show download_panel if there are in-progress downloads
|
|
|
if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
|
|
|
self.download_panel.renderable = self._generate_download_layout()
|
|
@@ -58,6 +90,42 @@ class TopologyViz:
|
|
|
|
|
|
self.live_panel.update(self.layout, refresh=True)
|
|
|
|
|
|
+ def _generate_prompt_output_layout(self) -> Panel:
|
|
|
+ content = []
|
|
|
+ requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
|
|
|
+ max_width = self.console.width - 6 # Full width minus padding and icon
|
|
|
+ max_lines = 13 # Maximum number of lines for the entire panel content
|
|
|
+
|
|
|
+ for (prompt, output) in reversed(requests):
|
|
|
+ prompt_icon, output_icon = "💬️", "🤖"
|
|
|
+
|
|
|
+ # Process prompt
|
|
|
+ prompt_lines = prompt.split('\n')
|
|
|
+ if len(prompt_lines) > max_lines // 2:
|
|
|
+ prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
|
|
|
+ prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
|
|
|
+ prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
|
|
|
+
|
|
|
+ # Process output
|
|
|
+ output_lines = output.split('\n')
|
|
|
+ remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
|
|
|
+ if len(output_lines) > remaining_lines:
|
|
|
+ output_lines = output_lines[:remaining_lines - 1] + ['...']
|
|
|
+ output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
|
|
|
+ output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
|
|
|
+
|
|
|
+ content.append(prompt_text)
|
|
|
+ content.append(output_text)
|
|
|
+ content.append(Text()) # Empty line between entries
|
|
|
+
|
|
|
+ return Panel(
|
|
|
+ Group(*content),
|
|
|
+ title="",
|
|
|
+ border_style="cyan",
|
|
|
+ height=15, # Increased height to accommodate multiple lines
|
|
|
+ expand=True # Allow the panel to expand to full width
|
|
|
+ )
|
|
|
+
|
|
|
def _generate_main_layout(self) -> str:
|
|
|
# Calculate visualization parameters
|
|
|
num_partitions = len(self.partitions)
|