瀏覽代碼

propagate prompts to other nodes so they can display them, cleaner prompt/output output

Alex Cheema 3 月之前
父節點
當前提交
af171f06fa
共有 4 個文件被更改,包括 60 次插入65 次删除
  1. 2 2
      exo/inference/mlx/sharded_inference_engine.py
  2. 38 51
      exo/main.py
  3. 1 0
      exo/orchestration/node.py
  4. 19 12
      exo/viz/topology_viz.py

+ 2 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -1,10 +1,10 @@
 import numpy as np
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.sample_utils import top_p_sampling, make_sampler
+from mlx_lm.sample_utils import make_sampler
 import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
-from .sharded_utils import load_shard, load_model_shard, resolve_tokenizer
+from .sharded_utils import load_model_shard, resolve_tokenizer
 from .losses import loss_fns
 from ..shard import Shard
 from typing import Dict, Optional, Tuple

+ 38 - 51
exo/main.py

@@ -3,19 +3,15 @@ import asyncio
 import atexit
 import signal
 import json
-import logging
 import platform
 import os
-import sys
 import time
 import traceback
 import uuid
 import numpy as np
-from functools import partial
 from tqdm import tqdm
-from exo.train.dataset import load_dataset, iterate_batches, compose
+from exo.train.dataset import load_dataset, iterate_batches
 from exo.networking.manual.manual_discovery import ManualDiscovery
-from exo.networking.manual.network_topology_config import NetworkTopology
 from exo.orchestration.node import Node
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp.udp_discovery import UDPDiscovery
@@ -28,12 +24,11 @@ from exo.download.download_progress import RepoProgressEvent
 from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
-from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.inference_engine import get_inference_engine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 import uvloop
-from contextlib import asynccontextmanager
 import concurrent.futures
 import resource
 import psutil
@@ -45,31 +40,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
 
 # Configure uvloop for maximum performance
 def configure_uvloop():
-    # Install uvloop as event loop policy
     uvloop.install()
-
-    # Create new event loop
     loop = asyncio.new_event_loop()
     asyncio.set_event_loop(loop)
 
     # Increase file descriptor limits on Unix systems
     if not psutil.WINDOWS:
       soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-      try:
-          resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
+      try: resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
       except ValueError:
-        try:
-          resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
-        except ValueError:
-          pass
-
-    # Configure thread pool for blocking operations
-    loop.set_default_executor(
-      concurrent.futures.ThreadPoolExecutor(
-        max_workers=min(32, (os.cpu_count() or 1) * 4)
-      )
-    )
+        try: resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
+        except ValueError: pass
 
+    loop.set_default_executor(concurrent.futures.ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 4)))
     return loop
 
 # parse args
@@ -183,7 +166,7 @@ server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 api = ChatGPTAPI(
   node,
-  inference_engine.__class__.__name__,
+  node.inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
   default_model=args.default_model,
@@ -192,21 +175,31 @@ api = ChatGPTAPI(
 buffered_token_output = {}
 def update_topology_viz(req_id, tokens, __):
   if not topology_viz: return
-  if not inference_engine.shard: return
-  if inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
-
+  if not node.inference_engine.shard: return
+  if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
   if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
   else: buffered_token_output[req_id] = tokens
-  topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
+  topology_viz.update_prompt_output(req_id, node.inference_engine.tokenizer.decode(buffered_token_output[req_id]))
 node.on_token.register("update_topology_viz").on_next(update_topology_viz)
+def update_prompt_viz(request_id, opaque_status: str):
+  if not topology_viz: return
+  try:
+    status = json.loads(opaque_status)
+    if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
+    topology_viz.update_prompt(request_id, status.get("prompt", "corrupted prompt (this should never happen)"))
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Failed to update prompt viz: {e}")
+      traceback.print_exc()
+node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz)
 
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
-    if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
-      current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
-      if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
+    if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
+    current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+    if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+    asyncio.create_task(shard_downloader.ensure_shard(current_shard, node.inference_engine.__class__.__name__))
   except Exception as e:
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
@@ -229,11 +222,11 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
-async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-  inference_class = inference_engine.__class__.__name__
+async def run_model_cli(node: Node, model_name: str, prompt: str):
+  inference_class = node.inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   if not shard:
-    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   request_id = str(uuid.uuid4())
@@ -284,11 +277,11 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
 
   return total_loss, total_tokens
 
-async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
-  inference_class = inference_engine.__class__.__name__
+async def eval_model_cli(node: Node, model_name, dataloader, batch_size, num_batches=-1):
+  inference_class = node.inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   if not shard:
-    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   train, val, test = dataloader(tokenizer.encode)
@@ -298,11 +291,11 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
   print("Waiting for outstanding tasks")
   await hold_outstanding(node)
 
-async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
-  inference_class = inference_engine.__class__.__name__
+async def train_model_cli(node: Node, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
+  inference_class = node.inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   if not shard:
-    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   train, val, test = dataloader(tokenizer.encode)
@@ -362,7 +355,7 @@ async def main():
     if not model_name:
       print("Error: Model name is required when using 'run' command or --run-model")
       return
-    await run_model_cli(node, inference_engine, model_name, args.prompt)
+    await run_model_cli(node, model_name, args.prompt)
   elif args.command == "eval" or args.command == 'train':
     model_name = args.model_name
     dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item)
@@ -371,12 +364,12 @@ async def main():
       if not model_name:
         print("Error: Much like a human, I can't evaluate anything without a model")
         return
-      await eval_model_cli(node, inference_engine, model_name, dataloader, args.batch_size)
+      await eval_model_cli(node, model_name, dataloader, args.batch_size)
     else:
       if not model_name:
         print("Error: This train ain't leaving the station without a model")
         return
-      await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
+      await train_model_cli(node, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
 
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
@@ -387,11 +380,6 @@ async def main():
     for i in tqdm(range(50)):
       await asyncio.sleep(.1)
 
-@asynccontextmanager
-async def setup_node(args):
-    # Rest of setup_node implementation...
-    pass
-
 def run():
     loop = None
     try:
@@ -400,8 +388,7 @@ def run():
     except KeyboardInterrupt:
         print("\nShutdown requested... exiting")
     finally:
-        if loop:
-            loop.close()
+        if loop: loop.close()
 
 if __name__ == "__main__":
   run()

+ 1 - 0
exo/orchestration/node.py

@@ -586,6 +586,7 @@ class Node:
     self.on_token.trigger_all(request_id, tokens, is_finished)
   
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    if DEBUG >= 2: print(f"Broadcasting result: {request_id=} {result=} {is_finished=}")
     async def send_result_to_peer(peer):
       try:
         await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)

+ 19 - 12
exo/viz/topology_viz.py

@@ -51,17 +51,11 @@ class TopologyViz:
     self.refresh()
 
   def update_prompt(self, request_id: str, prompt: Optional[str] = None):
-    if request_id in self.requests:
-      self.requests[request_id] = [prompt, self.requests[request_id][1]]
-    else:
-      self.requests[request_id] = [prompt, ""]
+    self.requests[request_id] = [prompt, self.requests.get(request_id, ["", ""])[1]]
     self.refresh()
 
   def update_prompt_output(self, request_id: str, output: Optional[str] = None):
-    if request_id in self.requests:
-      self.requests[request_id] = [self.requests[request_id][0], output]
-    else:
-      self.requests[request_id] = ["", output]
+    self.requests[request_id] = [self.requests.get(request_id, ["", ""])[0], output]
     self.refresh()
 
   def refresh(self):
@@ -101,10 +95,10 @@ class TopologyViz:
       prompt_icon, output_icon = "💬️", "🤖"
 
       # Calculate max lines for prompt and output
-      max_prompt_lines = lines_per_entry // 3  # Allocate 1/3 for prompt
+      max_prompt_lines = max(3, lines_per_entry // 2)  # Ensure at least 3 lines for prompt
       max_output_lines = lines_per_entry - max_prompt_lines - 1  # Remaining space minus spacing
 
-      # Process prompt
+      # Process prompt with more generous line allocation
       prompt_lines = []
       for line in prompt.split('\n'):
         words = line.split()
@@ -124,8 +118,15 @@ class TopologyViz:
         if current_line:
           prompt_lines.append(' '.join(current_line))
 
+      # Show more prompt content and append ellipses to last line if needed
       if len(prompt_lines) > max_prompt_lines:
-        prompt_lines = prompt_lines[:max_prompt_lines - 1] + ['...']
+        prompt_lines = prompt_lines[:max_prompt_lines]
+        # Append ellipses to last line if there's room, otherwise truncate last line
+        last_line = prompt_lines[-1]
+        if len(last_line) + 4 <= max_width:  # +4 for " ..."
+          prompt_lines[-1] = last_line + " ..."
+        else:
+          prompt_lines[-1] = last_line[:max_width-4] + " ..."
 
       prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
       prompt_text.append('\n'.join(prompt_lines), style="white")
@@ -151,7 +152,13 @@ class TopologyViz:
           output_lines.append(' '.join(current_line))
 
       if len(output_lines) > max_output_lines:
-        output_lines = output_lines[:max_output_lines - 1] + ['...']
+        output_lines = output_lines[:max_output_lines]
+        last_line = output_lines[-1] if output_lines else None
+        if last_line:
+          if len(last_line) + 4 <= max_width:
+            output_lines[-1] = last_line + " ..."
+          else:
+            output_lines[-1] = last_line[:max_width-4] + " ..."
 
       output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
       output_text.append('\n'.join(output_lines), style="white")