Sfoglia il codice sorgente

show prompts and outputs in tui

Alex Cheema 8 mesi fa
parent
commit
b95916e0b5
4 ha cambiato i file con 102 aggiunte e 23 eliminazioni
  1. 8 1
      exo/api/chatgpt_api.py
  2. 14 13
      exo/orchestration/standard_node.py
  3. 71 3
      exo/viz/topology_viz.py
  4. 9 6
      main.py

+ 8 - 1
exo/api/chatgpt_api.py

@@ -14,6 +14,7 @@ from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
 from exo.models import model_base_shards
 from exo.models import model_base_shards
+from typing import Callable
 
 
 class Message:
 class Message:
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -155,10 +156,11 @@ class PromptSession:
     self.prompt = prompt
     self.prompt = prompt
 
 
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
     self.response_timeout_secs = response_timeout_secs
+    self.on_chat_completion_request = on_chat_completion_request
     self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
@@ -219,6 +221,11 @@ class ChatGPTAPI:
 
 
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
+    if self.on_chat_completion_request:
+      try:
+        self.on_chat_completion_request(request_id, chat_request, prompt)
+      except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
     # request_id = None
     # request_id = None
     # match = self.prompts.find_longest_prefix(prompt)
     # match = self.prompts.find_longest_prefix(prompt)
     # if match and len(prompt) > len(match[1].prompt):
     # if match and len(prompt) > len(match[1].prompt):

+ 14 - 13
exo/orchestration/standard_node.py

@@ -29,6 +29,7 @@ class StandardNode(Node):
     chatgpt_api_endpoints: List[str] = [],
     chatgpt_api_endpoints: List[str] = [],
     web_chat_urls: List[str] = [],
     web_chat_urls: List[str] = [],
     disable_tui: Optional[bool] = False,
     disable_tui: Optional[bool] = False,
+    topology_viz: Optional[TopologyViz] = None,
   ):
   ):
     self.id = _id
     self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
@@ -39,13 +40,25 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not disable_tui else None
     self.max_generate_tokens = max_generate_tokens
     self.max_generate_tokens = max_generate_tokens
+    self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
 
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(5))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
       status_data = json.loads(opaque_status)
       status_data = json.loads(opaque_status)
@@ -66,18 +79,6 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
       if DEBUG >= 1: traceback.print_exc()
 
 
-  async def start(self, wait_for_peers: int = 0) -> None:
-    await self.server.start()
-    await self.discovery.start()
-    await self.update_peers(wait_for_peers)
-    await self.collect_topology()
-    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(5))
-
-  async def stop(self) -> None:
-    await self.discovery.stop()
-    await self.server.stop()
-
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(

+ 71 - 3
exo/viz/topology_viz.py

@@ -1,17 +1,20 @@
 import math
 import math
+from collections import OrderedDict
 from typing import List, Optional, Tuple, Dict
 from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
-from rich.console import Console
-from rich.panel import Panel
+from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.console import Console, Group
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
 from rich.table import Table
 from rich.table import Table
 from rich.layout import Layout
 from rich.layout import Layout
-from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.syntax import Syntax
+from rich.panel import Panel
+from rich.markdown import Markdown
 
 
 class TopologyViz:
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
@@ -21,17 +24,24 @@ class TopologyViz:
     self.partitions: List[Partition] = []
     self.partitions: List[Partition] = []
     self.node_id = None
     self.node_id = None
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.requests: OrderedDict[str, Tuple[str, str]] = {}
 
 
     self.console = Console()
     self.console = Console()
     self.layout = Layout()
     self.layout = Layout()
     self.layout.split(
     self.layout.split(
       Layout(name="main"),
       Layout(name="main"),
+      Layout(name="prompt_output", size=15),
       Layout(name="download", size=25)
       Layout(name="download", size=25)
     )
     )
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.layout["main"].update(self.main_panel)
     self.layout["main"].update(self.main_panel)
+    self.layout["prompt_output"].update(self.prompt_output_panel)
     self.layout["download"].update(self.download_panel)
     self.layout["download"].update(self.download_panel)
+
+    # Initially hide the prompt_output panel
+    self.layout["prompt_output"].visible = False
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 
@@ -43,12 +53,34 @@ class TopologyViz:
       self.node_download_progress = node_download_progress
       self.node_download_progress = node_download_progress
     self.refresh()
     self.refresh()
 
 
+  def update_prompt(self, request_id: str, prompt: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [prompt, self.requests[request_id][1]]
+    else:
+      self.requests[request_id] = [prompt, ""]
+    self.refresh()
+
+  def update_prompt_output(self, request_id: str, output: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [self.requests[request_id][0], output]
+    else:
+      self.requests[request_id] = ["", output]
+    self.refresh()
+
   def refresh(self):
   def refresh(self):
     self.main_panel.renderable = self._generate_main_layout()
     self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     node_count = len(self.topology.nodes)
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
 
 
+    # Update and show/hide prompt and output panel
+    if any(r[0] or r[1] for r in self.requests.values()):
+        self.prompt_output_panel = self._generate_prompt_output_layout()
+        self.layout["prompt_output"].update(self.prompt_output_panel)
+        self.layout["prompt_output"].visible = True
+    else:
+        self.layout["prompt_output"].visible = False
+
     # Only show download_panel if there are in-progress downloads
     # Only show download_panel if there are in-progress downloads
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
       self.download_panel.renderable = self._generate_download_layout()
       self.download_panel.renderable = self._generate_download_layout()
@@ -58,6 +90,42 @@ class TopologyViz:
 
 
     self.live_panel.update(self.layout, refresh=True)
     self.live_panel.update(self.layout, refresh=True)
 
 
+  def _generate_prompt_output_layout(self) -> Panel:
+    content = []
+    requests = list(self.requests.values())[-3:]  # Get the 3 most recent requests
+    max_width = self.console.width - 6  # Full width minus padding and icon
+    max_lines = 13  # Maximum number of lines for the entire panel content
+
+    for (prompt, output) in reversed(requests):
+        prompt_icon, output_icon = "💬️", "🤖"
+
+        # Process prompt
+        prompt_lines = prompt.split('\n')
+        if len(prompt_lines) > max_lines // 2:
+            prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
+        prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
+        prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+
+        # Process output
+        output_lines = output.split('\n')
+        remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
+        if len(output_lines) > remaining_lines:
+            output_lines = output_lines[:remaining_lines - 1] + ['...']
+        output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
+        output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+
+        content.append(prompt_text)
+        content.append(output_text)
+        content.append(Text())  # Empty line between entries
+
+    return Panel(
+        Group(*content),
+        title="",
+        border_style="cyan",
+        height=15,  # Increased height to accommodate multiple lines
+        expand=True  # Allow the panel to expand to full width
+    )
+
   def _generate_main_layout(self) -> str:
   def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     # Calculate visualization parameters
     num_partitions = len(self.partitions)
     num_partitions = len(self.partitions)

+ 9 - 6
main.py

@@ -17,6 +17,7 @@ from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.orchestration.node import Node
 from exo.models import model_base_shards
 from exo.models import model_base_shards
+from exo.viz.topology_viz import TopologyViz
 import uuid
 import uuid
 
 
 # parse args
 # parse args
@@ -37,7 +38,7 @@ parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max t
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
-parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model")
+parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 args = parser.parse_args()
 args = parser.parse_args()
 
 
 print_yellow_exo()
 print_yellow_exo()
@@ -65,6 +66,7 @@ if DEBUG >= 0:
     print("ChatGPT API endpoint served at:")
     print("ChatGPT API endpoint served at:")
     for chatgpt_api_endpoint in chatgpt_api_endpoints:
     for chatgpt_api_endpoint in chatgpt_api_endpoints:
         print(f" - {terminal_link(chatgpt_api_endpoint)}")
         print(f" - {terminal_link(chatgpt_api_endpoint)}")
+topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
 node = StandardNode(
     args.node_id,
     args.node_id,
     None,
     None,
@@ -75,11 +77,14 @@ node = StandardNode(
     partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
     partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
     disable_tui=args.disable_tui,
     disable_tui=args.disable_tui,
     max_generate_tokens=args.max_generate_tokens,
     max_generate_tokens=args.max_generate_tokens,
+    topology_viz=topology_viz
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt))
+node.on_token.register("update_topology_viz").on_next(
+    lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens)
+)
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):
     try:
     try:
         status = json.loads(opaque_status)
         status = json.loads(opaque_status)
@@ -126,6 +131,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
     callback_id = f"cli-wait-response-{request_id}"
     callback_id = f"cli-wait-response-{request_id}"
     callback = node.on_token.register(callback_id)
     callback = node.on_token.register(callback_id)
+    topology_viz.update_prompt(request_id, prompt)
     prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
     prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
 
 
     try:
     try:
@@ -158,9 +164,6 @@ async def main():
     await node.start(wait_for_peers=args.wait_for_peers)
     await node.start(wait_for_peers=args.wait_for_peers)
 
 
     if args.run_model:
     if args.run_model:
-        if not args.prompt:
-            print("Error: --prompt is required when using --run-model")
-            return
         await run_model_cli(node, inference_engine, args.run_model, args.prompt)
         await run_model_cli(node, inference_engine, args.run_model, args.prompt)
     else:
     else:
         asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
         asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task