Bläddra i källkod

Merge pull request #161 from exo-explore/cli

add support for running exo in cli with --run-model <model> --prompt <prompt>
Alex Cheema 11 månader sedan
förälder
incheckning
dfa3fdcf08

+ 17 - 76
exo/api/chatgpt_api.py

@@ -3,55 +3,18 @@ import time
 import asyncio
 import json
 from pathlib import Path
-from transformers import AutoTokenizer, AutoProcessor
+from transformers import AutoTokenizer
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
 from exo import DEBUG, VERSION
-from exo.helpers import terminal_link, PrefixDict
+from exo.helpers import PrefixDict
 from exo.inference.shard import Shard
+from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-
-shard_mappings = {
-  ### llama
-  "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "llama-3.1-405b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
-  },
-  "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
-  ### mistral
-  "mistral-nemo": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-  },
-  "mistral-large": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-  },
-  ### deepseek v2
-  "deepseek-coder-v2-lite": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
-  },
-  ### llava
-  "llava-1.5-7b-hf": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
-  },
-}
-
-
+from exo.models import model_base_shards
+from typing import Callable
 
 class Message:
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -64,7 +27,6 @@ class Message:
             "content": self.content
         }
 
-
 class ChatCompletionRequest:
     def __init__(self, model: str, messages: List[Message], temperature: float):
         self.model = model
@@ -78,33 +40,6 @@ class ChatCompletionRequest:
             "temperature": self.temperature
         }
 
-
-
-async def resolve_tokenizer(model_id: str):
-  try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
-    processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
-    if not hasattr(processor, 'eos_token_id'):
-      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
-    if not hasattr(processor, 'encode'):
-      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
-    if not hasattr(processor, 'decode'):
-      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
-    return processor
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
-    return AutoTokenizer.from_pretrained(model_id)
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  raise ValueError(f"[TODO] Unsupported model: {model_id}")
-
-
 def generate_completion(
   chat_request: ChatCompletionRequest,
   tokenizer,
@@ -221,10 +156,11 @@ class PromptSession:
     self.prompt = prompt
 
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
+    self.on_chat_completion_request = on_chat_completion_request
     self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
@@ -257,7 +193,7 @@ class ChatGPTAPI:
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(shard.model_id)
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
@@ -269,12 +205,12 @@ class ChatGPTAPI:
     chat_request = parse_chat_request(data)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
       chat_request.model = "llama-3.1-8b"
-    if not chat_request.model or chat_request.model not in shard_mappings:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+    if not chat_request.model or chat_request.model not in model_base_shards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
       chat_request.model = "llama-3.1-8b"
-    shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
+    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
     if not shard:
-      supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
@@ -285,6 +221,11 @@ class ChatGPTAPI:
 
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
+    if self.on_chat_completion_request:
+      try:
+        self.on_chat_completion_request(request_id, chat_request, prompt)
+      except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
     # request_id = None
     # match = self.prompts.find_longest_prefix(prompt)
     # if match and len(prompt) > len(match[1].prompt):

+ 0 - 15
exo/helpers.py

@@ -32,21 +32,6 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
 
 
-def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
-  if inference_engine_name == "mlx":
-    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-
-    return MLXDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "tinygrad":
-    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-    import tinygrad.helpers
-    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-
-    return TinygradDynamicShardInferenceEngine(shard_downloader)
-  else:
-    raise ValueError(f"Inference engine {inference_engine_name} not supported")
-
-
 def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
   used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
 

+ 16 - 0
exo/inference/inference_engine.py

@@ -1,4 +1,5 @@
 import numpy as np
+import os
 
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
@@ -12,3 +13,18 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
+
+
+def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  if inference_engine_name == "mlx":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+
+    return MLXDynamicShardInferenceEngine(shard_downloader)
+  elif inference_engine_name == "tinygrad":
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+    import tinygrad.helpers
+    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
+  else:
+    raise ValueError(f"Inference engine {inference_engine_name} not supported")

+ 27 - 0
exo/inference/tokenizers.py

@@ -0,0 +1,27 @@
+import traceback
+from transformers import AutoTokenizer, AutoProcessor
+from exo.helpers import DEBUG
+
+async def resolve_tokenizer(model_id: str):
+  try:
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
+    processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
+    if not hasattr(processor, 'eos_token_id'):
+      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
+    if not hasattr(processor, 'encode'):
+      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
+    if not hasattr(processor, 'decode'):
+      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
+    return processor
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  try:
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
+    return AutoTokenizer.from_pretrained(model_id)
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  raise ValueError(f"[TODO] Unsupported model: {model_id}")

+ 40 - 0
exo/models.py

@@ -0,0 +1,40 @@
+from exo.inference.shard import Shard
+
+model_base_shards = {
+  ### llama
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-405b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
+  },
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+  },
+  ### mistral
+  "mistral-nemo": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
+  },
+  "mistral-large": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
+  },
+  ### deepseek v2
+  "deepseek-coder-v2-lite": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
+  },
+  ### llava
+  "llava-1.5-7b-hf": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
+  },
+}
+

+ 14 - 13
exo/orchestration/standard_node.py

@@ -29,6 +29,7 @@ class StandardNode(Node):
     chatgpt_api_endpoints: List[str] = [],
     web_chat_urls: List[str] = [],
     disable_tui: Optional[bool] = False,
+    topology_viz: Optional[TopologyViz] = None,
   ):
     self.id = _id
     self.inference_engine = inference_engine
@@ -39,13 +40,25 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not disable_tui else None
     self.max_generate_tokens = max_generate_tokens
+    self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(5))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
   def on_node_status(self, request_id, opaque_status):
     try:
       status_data = json.loads(opaque_status)
@@ -66,18 +79,6 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
 
-  async def start(self, wait_for_peers: int = 0) -> None:
-    await self.server.start()
-    await self.discovery.start()
-    await self.update_peers(wait_for_peers)
-    await self.collect_topology()
-    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(5))
-
-  async def stop(self) -> None:
-    await self.discovery.stop()
-    await self.server.stop()
-
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(

+ 71 - 3
exo/viz/topology_viz.py

@@ -1,17 +1,20 @@
 import math
+from collections import OrderedDict
 from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.download.hf.hf_helpers import RepoProgressEvent
-from rich.console import Console
-from rich.panel import Panel
+from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.console import Console, Group
 from rich.text import Text
 from rich.live import Live
 from rich.style import Style
 from rich.table import Table
 from rich.layout import Layout
-from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.syntax import Syntax
+from rich.panel import Panel
+from rich.markdown import Markdown
 
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
@@ -21,17 +24,24 @@ class TopologyViz:
     self.partitions: List[Partition] = []
     self.node_id = None
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.requests: OrderedDict[str, Tuple[str, str]] = {}
 
     self.console = Console()
     self.layout = Layout()
     self.layout.split(
       Layout(name="main"),
+      Layout(name="prompt_output", size=15),
       Layout(name="download", size=25)
     )
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.layout["main"].update(self.main_panel)
+    self.layout["prompt_output"].update(self.prompt_output_panel)
     self.layout["download"].update(self.download_panel)
+
+    # Initially hide the prompt_output panel
+    self.layout["prompt_output"].visible = False
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
 
@@ -43,12 +53,34 @@ class TopologyViz:
       self.node_download_progress = node_download_progress
     self.refresh()
 
+  def update_prompt(self, request_id: str, prompt: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [prompt, self.requests[request_id][1]]
+    else:
+      self.requests[request_id] = [prompt, ""]
+    self.refresh()
+
+  def update_prompt_output(self, request_id: str, output: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [self.requests[request_id][0], output]
+    else:
+      self.requests[request_id] = ["", output]
+    self.refresh()
+
   def refresh(self):
     self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
 
+    # Update and show/hide prompt and output panel
+    if any(r[0] or r[1] for r in self.requests.values()):
+        self.prompt_output_panel = self._generate_prompt_output_layout()
+        self.layout["prompt_output"].update(self.prompt_output_panel)
+        self.layout["prompt_output"].visible = True
+    else:
+        self.layout["prompt_output"].visible = False
+
     # Only show download_panel if there are in-progress downloads
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
       self.download_panel.renderable = self._generate_download_layout()
@@ -58,6 +90,42 @@ class TopologyViz:
 
     self.live_panel.update(self.layout, refresh=True)
 
+  def _generate_prompt_output_layout(self) -> Panel:
+    content = []
+    requests = list(self.requests.values())[-3:]  # Get the 3 most recent requests
+    max_width = self.console.width - 6  # Full width minus padding and icon
+    max_lines = 13  # Maximum number of lines for the entire panel content
+
+    for (prompt, output) in reversed(requests):
+        prompt_icon, output_icon = "💬️", "🤖"
+
+        # Process prompt
+        prompt_lines = prompt.split('\n')
+        if len(prompt_lines) > max_lines // 2:
+            prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
+        prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
+        prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+
+        # Process output
+        output_lines = output.split('\n')
+        remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
+        if len(output_lines) > remaining_lines:
+            output_lines = output_lines[:remaining_lines - 1] + ['...']
+        output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
+        output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+
+        content.append(prompt_text)
+        content.append(output_text)
+        content.append(Text())  # Empty line between entries
+
+    return Panel(
+        Group(*content),
+        title="",
+        border_style="cyan",
+        height=15,  # Increased height to accommodate multiple lines
+        expand=True  # Allow the panel to expand to full width
+    )
+
   def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     num_partitions = len(self.partitions)

+ 52 - 6
main.py

@@ -4,6 +4,8 @@ import signal
 import json
 import time
 import traceback
+import uuid
+from asyncio import CancelledError
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -11,8 +13,13 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.tokenizers import resolve_tokenizer
+from exo.orchestration.node import Node
+from exo.models import model_base_shards
+from exo.viz.topology_viz import TopologyViz
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -31,6 +38,8 @@ parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90,
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
+parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
+parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 args = parser.parse_args()
 
 print_yellow_exo()
@@ -58,6 +67,7 @@ if DEBUG >= 0:
     print("ChatGPT API endpoint served at:")
     for chatgpt_api_endpoint in chatgpt_api_endpoints:
         print(f" - {terminal_link(chatgpt_api_endpoint)}")
+topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
     args.node_id,
     None,
@@ -68,11 +78,14 @@ node = StandardNode(
     partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
     disable_tui=args.disable_tui,
     max_generate_tokens=args.max_generate_tokens,
+    topology_viz=topology_viz
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None)
+node.on_token.register("update_topology_viz").on_next(
+    lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
+)
 def preemptively_start_download(request_id: str, opaque_status: str):
     try:
         status = json.loads(opaque_status)
@@ -110,6 +123,36 @@ async def shutdown(signal, loop):
     await server.stop()
     loop.stop()
 
+async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
+    shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+    if not shard:
+        print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+        return
+    tokenizer = await resolve_tokenizer(shard.model_id)
+    request_id = str(uuid.uuid4())
+    callback_id = f"cli-wait-response-{request_id}"
+    callback = node.on_token.register(callback_id)
+    if topology_viz:
+        topology_viz.update_prompt(request_id, prompt)
+    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+    try:
+        print(f"Processing prompt: {prompt}")
+        await node.process_prompt(shard, prompt, None, request_id=request_id)
+
+        _, tokens, _ = await callback.wait(
+            lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
+            timeout=90
+        )
+
+        print("\nGenerated response:")
+        print(tokenizer.decode(tokens))
+    except Exception as e:
+        print(f"Error processing prompt: {str(e)}")
+        traceback.print_exc()
+    finally:
+        node.on_token.deregister(callback_id)
+
 async def main():
     loop = asyncio.get_running_loop()
 
@@ -121,9 +164,12 @@ async def main():
         loop.add_signal_handler(s, handle_exit)
 
     await node.start(wait_for_peers=args.wait_for_peers)
-    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
 
-    await asyncio.Event().wait()
+    if args.run_model:
+        await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+    else:
+        asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+        await asyncio.Event().wait()
 
 if __name__ == "__main__":
     loop = asyncio.new_event_loop()
@@ -134,4 +180,4 @@ if __name__ == "__main__":
         print("Received keyboard interrupt. Shutting down...")
     finally:
         loop.run_until_complete(shutdown(signal.SIGTERM, loop))
-        loop.close()
+        loop.close()