Browse Source

show prompts and outputs in tui

Alex Cheema 8 tháng trước cách đây
mục cha
commit
b95916e0b5
4 tập tin đã thay đổi với 102 bổ sung23 xóa
  1. 8 1
      exo/api/chatgpt_api.py
  2. 14 13
      exo/orchestration/standard_node.py
  3. 71 3
      exo/viz/topology_viz.py
  4. 9 6
      main.py

+ 8 - 1
exo/api/chatgpt_api.py

@@ -14,6 +14,7 @@ from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import model_base_shards
+from typing import Callable
 
 class Message:
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -155,10 +156,11 @@ class PromptSession:
     self.prompt = prompt
 
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
+    self.on_chat_completion_request = on_chat_completion_request
     self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
@@ -219,6 +221,11 @@ class ChatGPTAPI:
 
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
+    if self.on_chat_completion_request:
+      try:
+        self.on_chat_completion_request(request_id, chat_request, prompt)
+      except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
     # request_id = None
     # match = self.prompts.find_longest_prefix(prompt)
     # if match and len(prompt) > len(match[1].prompt):

+ 14 - 13
exo/orchestration/standard_node.py

@@ -29,6 +29,7 @@ class StandardNode(Node):
     chatgpt_api_endpoints: List[str] = [],
     web_chat_urls: List[str] = [],
     disable_tui: Optional[bool] = False,
+    topology_viz: Optional[TopologyViz] = None,
   ):
     self.id = _id
     self.inference_engine = inference_engine
@@ -39,13 +40,25 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not disable_tui else None
     self.max_generate_tokens = max_generate_tokens
+    self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(5))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
   def on_node_status(self, request_id, opaque_status):
     try:
       status_data = json.loads(opaque_status)
@@ -66,18 +79,6 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
 
-  async def start(self, wait_for_peers: int = 0) -> None:
-    await self.server.start()
-    await self.discovery.start()
-    await self.update_peers(wait_for_peers)
-    await self.collect_topology()
-    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(5))
-
-  async def stop(self) -> None:
-    await self.discovery.stop()
-    await self.server.stop()
-
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(

+ 71 - 3
exo/viz/topology_viz.py

@@ -1,17 +1,20 @@
 import math
+from collections import OrderedDict
 from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.download.hf.hf_helpers import RepoProgressEvent
-from rich.console import Console
-from rich.panel import Panel
+from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.console import Console, Group
 from rich.text import Text
 from rich.live import Live
 from rich.style import Style
 from rich.table import Table
 from rich.layout import Layout
-from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.syntax import Syntax
+from rich.panel import Panel
+from rich.markdown import Markdown
 
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
@@ -21,17 +24,24 @@ class TopologyViz:
     self.partitions: List[Partition] = []
     self.node_id = None
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.requests: OrderedDict[str, Tuple[str, str]] = {}
 
     self.console = Console()
     self.layout = Layout()
     self.layout.split(
       Layout(name="main"),
+      Layout(name="prompt_output", size=15),
       Layout(name="download", size=25)
     )
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.layout["main"].update(self.main_panel)
+    self.layout["prompt_output"].update(self.prompt_output_panel)
     self.layout["download"].update(self.download_panel)
+
+    # Initially hide the prompt_output panel
+    self.layout["prompt_output"].visible = False
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
 
@@ -43,12 +53,34 @@ class TopologyViz:
       self.node_download_progress = node_download_progress
     self.refresh()
 
+  def update_prompt(self, request_id: str, prompt: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [prompt, self.requests[request_id][1]]
+    else:
+      self.requests[request_id] = [prompt, ""]
+    self.refresh()
+
+  def update_prompt_output(self, request_id: str, output: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [self.requests[request_id][0], output]
+    else:
+      self.requests[request_id] = ["", output]
+    self.refresh()
+
   def refresh(self):
     self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
 
+    # Update and show/hide prompt and output panel
+    if any(r[0] or r[1] for r in self.requests.values()):
+        self.prompt_output_panel = self._generate_prompt_output_layout()
+        self.layout["prompt_output"].update(self.prompt_output_panel)
+        self.layout["prompt_output"].visible = True
+    else:
+        self.layout["prompt_output"].visible = False
+
     # Only show download_panel if there are in-progress downloads
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
       self.download_panel.renderable = self._generate_download_layout()
@@ -58,6 +90,42 @@ class TopologyViz:
 
     self.live_panel.update(self.layout, refresh=True)
 
+  def _generate_prompt_output_layout(self) -> Panel:
+    content = []
+    requests = list(self.requests.values())[-3:]  # Get the 3 most recent requests
+    max_width = self.console.width - 6  # Full width minus padding and icon
+    max_lines = 13  # Maximum number of lines for the entire panel content
+
+    for (prompt, output) in reversed(requests):
+        prompt_icon, output_icon = "💬️", "🤖"
+
+        # Process prompt
+        prompt_lines = prompt.split('\n')
+        if len(prompt_lines) > max_lines // 2:
+            prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
+        prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
+        prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+
+        # Process output
+        output_lines = output.split('\n')
+        remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
+        if len(output_lines) > remaining_lines:
+            output_lines = output_lines[:remaining_lines - 1] + ['...']
+        output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
+        output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+
+        content.append(prompt_text)
+        content.append(output_text)
+        content.append(Text())  # Empty line between entries
+
+    return Panel(
+        Group(*content),
+        title="",
+        border_style="cyan",
+        height=15,  # Increased height to accommodate multiple lines
+        expand=True  # Allow the panel to expand to full width
+    )
+
   def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     num_partitions = len(self.partitions)

+ 9 - 6
main.py

@@ -17,6 +17,7 @@ from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.models import model_base_shards
+from exo.viz.topology_viz import TopologyViz
 import uuid
 
 # parse args
@@ -37,7 +38,7 @@ parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max t
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
-parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model")
+parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 args = parser.parse_args()
 
 print_yellow_exo()
@@ -65,6 +66,7 @@ if DEBUG >= 0:
     print("ChatGPT API endpoint served at:")
     for chatgpt_api_endpoint in chatgpt_api_endpoints:
         print(f" - {terminal_link(chatgpt_api_endpoint)}")
+topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
     args.node_id,
     None,
@@ -75,11 +77,14 @@ node = StandardNode(
     partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
     disable_tui=args.disable_tui,
     max_generate_tokens=args.max_generate_tokens,
+    topology_viz=topology_viz
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt))
+node.on_token.register("update_topology_viz").on_next(
+    lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens)
+)
 def preemptively_start_download(request_id: str, opaque_status: str):
     try:
         status = json.loads(opaque_status)
@@ -126,6 +131,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     request_id = str(uuid.uuid4())
     callback_id = f"cli-wait-response-{request_id}"
     callback = node.on_token.register(callback_id)
+    topology_viz.update_prompt(request_id, prompt)
     prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
 
     try:
@@ -158,9 +164,6 @@ async def main():
     await node.start(wait_for_peers=args.wait_for_peers)
 
     if args.run_model:
-        if not args.prompt:
-            print("Error: --prompt is required when using --run-model")
-            return
         await run_model_cli(node, inference_engine, args.run_model, args.prompt)
     else:
         asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task