Browse Source

Merge pull request #161 from exo-explore/cli

add support for running exo in cli with --run-model <model> --prompt <prompt>
Alex Cheema 11 months ago
parent
commit
dfa3fdcf08

+ 17 - 76
exo/api/chatgpt_api.py

@@ -3,55 +3,18 @@ import time
 import asyncio
 import asyncio
 import json
 import json
 from pathlib import Path
 from pathlib import Path
-from transformers import AutoTokenizer, AutoProcessor
+from transformers import AutoTokenizer
 from typing import List, Literal, Union, Dict
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
 import traceback
 import traceback
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
-from exo.helpers import terminal_link, PrefixDict
+from exo.helpers import PrefixDict
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
-
-shard_mappings = {
-  ### llama
-  "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "llama-3.1-405b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
-  },
-  "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
-  ### mistral
-  "mistral-nemo": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-  },
-  "mistral-large": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-  },
-  ### deepseek v2
-  "deepseek-coder-v2-lite": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
-  },
-  ### llava
-  "llava-1.5-7b-hf": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
-  },
-}
-
-
+from exo.models import model_base_shards
+from typing import Callable
 
 
 class Message:
 class Message:
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -64,7 +27,6 @@ class Message:
             "content": self.content
             "content": self.content
         }
         }
 
 
-
 class ChatCompletionRequest:
 class ChatCompletionRequest:
     def __init__(self, model: str, messages: List[Message], temperature: float):
     def __init__(self, model: str, messages: List[Message], temperature: float):
         self.model = model
         self.model = model
@@ -78,33 +40,6 @@ class ChatCompletionRequest:
             "temperature": self.temperature
             "temperature": self.temperature
         }
         }
 
 
-
-
-async def resolve_tokenizer(model_id: str):
-  try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
-    processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
-    if not hasattr(processor, 'eos_token_id'):
-      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
-    if not hasattr(processor, 'encode'):
-      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
-    if not hasattr(processor, 'decode'):
-      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
-    return processor
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
-    return AutoTokenizer.from_pretrained(model_id)
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  raise ValueError(f"[TODO] Unsupported model: {model_id}")
-
-
 def generate_completion(
 def generate_completion(
   chat_request: ChatCompletionRequest,
   chat_request: ChatCompletionRequest,
   tokenizer,
   tokenizer,
@@ -221,10 +156,11 @@ class PromptSession:
     self.prompt = prompt
     self.prompt = prompt
 
 
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
     self.response_timeout_secs = response_timeout_secs
+    self.on_chat_completion_request = on_chat_completion_request
     self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
@@ -257,7 +193,7 @@ class ChatGPTAPI:
 
 
   async def handle_post_chat_token_encode(self, request):
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
     data = await request.json()
-    shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(shard.model_id)
     tokenizer = await resolve_tokenizer(shard.model_id)
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
@@ -269,12 +205,12 @@ class ChatGPTAPI:
     chat_request = parse_chat_request(data)
     chat_request = parse_chat_request(data)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
       chat_request.model = "llama-3.1-8b"
       chat_request.model = "llama-3.1-8b"
-    if not chat_request.model or chat_request.model not in shard_mappings:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+    if not chat_request.model or chat_request.model not in model_base_shards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
       chat_request.model = "llama-3.1-8b"
       chat_request.model = "llama-3.1-8b"
-    shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
+    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
     if not shard:
     if not shard:
-      supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
       return web.json_response(
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
         status=400,
@@ -285,6 +221,11 @@ class ChatGPTAPI:
 
 
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
+    if self.on_chat_completion_request:
+      try:
+        self.on_chat_completion_request(request_id, chat_request, prompt)
+      except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
     # request_id = None
     # request_id = None
     # match = self.prompts.find_longest_prefix(prompt)
     # match = self.prompts.find_longest_prefix(prompt)
     # if match and len(prompt) > len(match[1].prompt):
     # if match and len(prompt) > len(match[1].prompt):

+ 0 - 15
exo/helpers.py

@@ -32,21 +32,6 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
   return "Non-Mac, non-Linux system"
 
 
 
 
-def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
-  if inference_engine_name == "mlx":
-    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-
-    return MLXDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "tinygrad":
-    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-    import tinygrad.helpers
-    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-
-    return TinygradDynamicShardInferenceEngine(shard_downloader)
-  else:
-    raise ValueError(f"Inference engine {inference_engine_name} not supported")
-
-
 def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
 def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
   used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
   used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
 
 

+ 16 - 0
exo/inference/inference_engine.py

@@ -1,4 +1,5 @@
 import numpy as np
 import numpy as np
+import os
 
 
 from typing import Tuple, Optional
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
@@ -12,3 +13,18 @@ class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
     pass
+
+
+def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  if inference_engine_name == "mlx":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+
+    return MLXDynamicShardInferenceEngine(shard_downloader)
+  elif inference_engine_name == "tinygrad":
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+    import tinygrad.helpers
+    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
+  else:
+    raise ValueError(f"Inference engine {inference_engine_name} not supported")

+ 27 - 0
exo/inference/tokenizers.py

@@ -0,0 +1,27 @@
+import traceback
+from transformers import AutoTokenizer, AutoProcessor
+from exo.helpers import DEBUG
+
+async def resolve_tokenizer(model_id: str):
+  try:
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
+    processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
+    if not hasattr(processor, 'eos_token_id'):
+      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
+    if not hasattr(processor, 'encode'):
+      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
+    if not hasattr(processor, 'decode'):
+      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
+    return processor
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  try:
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
+    return AutoTokenizer.from_pretrained(model_id)
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  raise ValueError(f"[TODO] Unsupported model: {model_id}")

+ 40 - 0
exo/models.py

@@ -0,0 +1,40 @@
+from exo.inference.shard import Shard
+
+model_base_shards = {
+  ### llama
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-405b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
+  },
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+  },
+  ### mistral
+  "mistral-nemo": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
+  },
+  "mistral-large": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
+  },
+  ### deepseek v2
+  "deepseek-coder-v2-lite": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
+  },
+  ### llava
+  "llava-1.5-7b-hf": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
+  },
+}
+

+ 14 - 13
exo/orchestration/standard_node.py

@@ -29,6 +29,7 @@ class StandardNode(Node):
     chatgpt_api_endpoints: List[str] = [],
     chatgpt_api_endpoints: List[str] = [],
     web_chat_urls: List[str] = [],
     web_chat_urls: List[str] = [],
     disable_tui: Optional[bool] = False,
     disable_tui: Optional[bool] = False,
+    topology_viz: Optional[TopologyViz] = None,
   ):
   ):
     self.id = _id
     self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
@@ -39,13 +40,25 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not disable_tui else None
     self.max_generate_tokens = max_generate_tokens
     self.max_generate_tokens = max_generate_tokens
+    self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
 
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(5))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
       status_data = json.loads(opaque_status)
       status_data = json.loads(opaque_status)
@@ -66,18 +79,6 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
       if DEBUG >= 1: traceback.print_exc()
 
 
-  async def start(self, wait_for_peers: int = 0) -> None:
-    await self.server.start()
-    await self.discovery.start()
-    await self.update_peers(wait_for_peers)
-    await self.collect_topology()
-    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(5))
-
-  async def stop(self) -> None:
-    await self.discovery.stop()
-    await self.server.stop()
-
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(

+ 71 - 3
exo/viz/topology_viz.py

@@ -1,17 +1,20 @@
 import math
 import math
+from collections import OrderedDict
 from typing import List, Optional, Tuple, Dict
 from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
-from rich.console import Console
-from rich.panel import Panel
+from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.console import Console, Group
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
 from rich.table import Table
 from rich.table import Table
 from rich.layout import Layout
 from rich.layout import Layout
-from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.syntax import Syntax
+from rich.panel import Panel
+from rich.markdown import Markdown
 
 
 class TopologyViz:
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
@@ -21,17 +24,24 @@ class TopologyViz:
     self.partitions: List[Partition] = []
     self.partitions: List[Partition] = []
     self.node_id = None
     self.node_id = None
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.requests: OrderedDict[str, Tuple[str, str]] = {}
 
 
     self.console = Console()
     self.console = Console()
     self.layout = Layout()
     self.layout = Layout()
     self.layout.split(
     self.layout.split(
       Layout(name="main"),
       Layout(name="main"),
+      Layout(name="prompt_output", size=15),
       Layout(name="download", size=25)
       Layout(name="download", size=25)
     )
     )
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.layout["main"].update(self.main_panel)
     self.layout["main"].update(self.main_panel)
+    self.layout["prompt_output"].update(self.prompt_output_panel)
     self.layout["download"].update(self.download_panel)
     self.layout["download"].update(self.download_panel)
+
+    # Initially hide the prompt_output panel
+    self.layout["prompt_output"].visible = False
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 
@@ -43,12 +53,34 @@ class TopologyViz:
       self.node_download_progress = node_download_progress
       self.node_download_progress = node_download_progress
     self.refresh()
     self.refresh()
 
 
+  def update_prompt(self, request_id: str, prompt: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [prompt, self.requests[request_id][1]]
+    else:
+      self.requests[request_id] = [prompt, ""]
+    self.refresh()
+
+  def update_prompt_output(self, request_id: str, output: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [self.requests[request_id][0], output]
+    else:
+      self.requests[request_id] = ["", output]
+    self.refresh()
+
   def refresh(self):
   def refresh(self):
     self.main_panel.renderable = self._generate_main_layout()
     self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     node_count = len(self.topology.nodes)
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
 
 
+    # Update and show/hide prompt and output panel
+    if any(r[0] or r[1] for r in self.requests.values()):
+        self.prompt_output_panel = self._generate_prompt_output_layout()
+        self.layout["prompt_output"].update(self.prompt_output_panel)
+        self.layout["prompt_output"].visible = True
+    else:
+        self.layout["prompt_output"].visible = False
+
     # Only show download_panel if there are in-progress downloads
     # Only show download_panel if there are in-progress downloads
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
       self.download_panel.renderable = self._generate_download_layout()
       self.download_panel.renderable = self._generate_download_layout()
@@ -58,6 +90,42 @@ class TopologyViz:
 
 
     self.live_panel.update(self.layout, refresh=True)
     self.live_panel.update(self.layout, refresh=True)
 
 
+  def _generate_prompt_output_layout(self) -> Panel:
+    content = []
+    requests = list(self.requests.values())[-3:]  # Get the 3 most recent requests
+    max_width = self.console.width - 6  # Full width minus padding and icon
+    max_lines = 13  # Maximum number of lines for the entire panel content
+
+    for (prompt, output) in reversed(requests):
+        prompt_icon, output_icon = "💬️", "🤖"
+
+        # Process prompt
+        prompt_lines = prompt.split('\n')
+        if len(prompt_lines) > max_lines // 2:
+            prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
+        prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
+        prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+
+        # Process output
+        output_lines = output.split('\n')
+        remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
+        if len(output_lines) > remaining_lines:
+            output_lines = output_lines[:remaining_lines - 1] + ['...']
+        output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
+        output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+
+        content.append(prompt_text)
+        content.append(output_text)
+        content.append(Text())  # Empty line between entries
+
+    return Panel(
+        Group(*content),
+        title="",
+        border_style="cyan",
+        height=15,  # Increased height to accommodate multiple lines
+        expand=True  # Allow the panel to expand to full width
+    )
+
   def _generate_main_layout(self) -> str:
   def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     # Calculate visualization parameters
     num_partitions = len(self.partitions)
     num_partitions = len(self.partitions)

+ 52 - 6
main.py

@@ -4,6 +4,8 @@ import signal
 import json
 import json
 import time
 import time
 import traceback
 import traceback
+import uuid
+from asyncio import CancelledError
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -11,8 +13,13 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.tokenizers import resolve_tokenizer
+from exo.orchestration.node import Node
+from exo.models import model_base_shards
+from exo.viz.topology_viz import TopologyViz
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -31,6 +38,8 @@ parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90,
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
+parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
+parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 args = parser.parse_args()
 args = parser.parse_args()
 
 
 print_yellow_exo()
 print_yellow_exo()
@@ -58,6 +67,7 @@ if DEBUG >= 0:
     print("ChatGPT API endpoint served at:")
     print("ChatGPT API endpoint served at:")
     for chatgpt_api_endpoint in chatgpt_api_endpoints:
     for chatgpt_api_endpoint in chatgpt_api_endpoints:
         print(f" - {terminal_link(chatgpt_api_endpoint)}")
         print(f" - {terminal_link(chatgpt_api_endpoint)}")
+topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
 node = StandardNode(
     args.node_id,
     args.node_id,
     None,
     None,
@@ -68,11 +78,14 @@ node = StandardNode(
     partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
     partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
     disable_tui=args.disable_tui,
     disable_tui=args.disable_tui,
     max_generate_tokens=args.max_generate_tokens,
     max_generate_tokens=args.max_generate_tokens,
+    topology_viz=topology_viz
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None)
+node.on_token.register("update_topology_viz").on_next(
+    lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
+)
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):
     try:
     try:
         status = json.loads(opaque_status)
         status = json.loads(opaque_status)
@@ -110,6 +123,36 @@ async def shutdown(signal, loop):
     await server.stop()
     await server.stop()
     loop.stop()
     loop.stop()
 
 
+async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
+    shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+    if not shard:
+        print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+        return
+    tokenizer = await resolve_tokenizer(shard.model_id)
+    request_id = str(uuid.uuid4())
+    callback_id = f"cli-wait-response-{request_id}"
+    callback = node.on_token.register(callback_id)
+    if topology_viz:
+        topology_viz.update_prompt(request_id, prompt)
+    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+    try:
+        print(f"Processing prompt: {prompt}")
+        await node.process_prompt(shard, prompt, None, request_id=request_id)
+
+        _, tokens, _ = await callback.wait(
+            lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
+            timeout=90
+        )
+
+        print("\nGenerated response:")
+        print(tokenizer.decode(tokens))
+    except Exception as e:
+        print(f"Error processing prompt: {str(e)}")
+        traceback.print_exc()
+    finally:
+        node.on_token.deregister(callback_id)
+
 async def main():
 async def main():
     loop = asyncio.get_running_loop()
     loop = asyncio.get_running_loop()
 
 
@@ -121,9 +164,12 @@ async def main():
         loop.add_signal_handler(s, handle_exit)
         loop.add_signal_handler(s, handle_exit)
 
 
     await node.start(wait_for_peers=args.wait_for_peers)
     await node.start(wait_for_peers=args.wait_for_peers)
-    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
 
 
-    await asyncio.Event().wait()
+    if args.run_model:
+        await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+    else:
+        asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+        await asyncio.Event().wait()
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     loop = asyncio.new_event_loop()
     loop = asyncio.new_event_loop()
@@ -134,4 +180,4 @@ if __name__ == "__main__":
         print("Received keyboard interrupt. Shutting down...")
         print("Received keyboard interrupt. Shutting down...")
     finally:
     finally:
         loop.run_until_complete(shutdown(signal.SIGTERM, loop))
         loop.run_until_complete(shutdown(signal.SIGTERM, loop))
-        loop.close()
+        loop.close()