Kaynağa Gözat

Merge branch 'main' into circleci

thenatlog 1 yıl önce
ebeveyn
işleme
2a2b0f4f37

+ 172 - 11
exo/api/chatgpt_api.py

@@ -2,6 +2,7 @@ import uuid
 import time
 import asyncio
 import json
+import os
 from pathlib import Path
 from transformers import AutoTokenizer
 from typing import List, Literal, Union, Dict
@@ -14,10 +15,12 @@ from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
-from exo.apputil import create_animation_mp4
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable, Optional
-import tempfile
+from exo.download.hf.hf_shard_download import HFShardDownloader
+import shutil
+from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
+from exo.apputil import create_animation_mp4
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -175,7 +178,11 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
+    cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
     cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
+    cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
+    cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
@@ -215,22 +222,79 @@ class ChatGPTAPI:
     return web.json_response({"status": "ok"})
 
   async def handle_model_support(self, request):
-    return web.json_response({
-      "model pool": {
-        model_name: pretty_name.get(model_name, model_name)
-        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
-      }
-    })
+    try:
+        response = web.StreamResponse(
+            status=200,
+            reason='OK',
+            headers={
+                'Content-Type': 'text/event-stream',
+                'Cache-Control': 'no-cache',
+                'Connection': 'keep-alive',
+            }
+        )
+        await response.prepare(request)
+
+        for model_name, pretty in pretty_name.items():
+            if model_name in model_cards:
+                model_info = model_cards[model_name]
+
+                if self.inference_engine_classname in model_info.get("repo", {}):
+                    shard = build_base_shard(model_name, self.inference_engine_classname)
+                    if shard:
+                        downloader = HFShardDownloader(quick_check=True)
+                        downloader.current_shard = shard
+                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+                        status = await downloader.get_shard_download_status()
+
+                        download_percentage = status.get("overall") if status else None
+                        total_size = status.get("total_size") if status else None
+                        total_downloaded = status.get("total_downloaded") if status else False
+
+                        model_data = {
+                            model_name: {
+                                "name": pretty,
+                                "downloaded": download_percentage == 100 if download_percentage is not None else False,
+                                "download_percentage": download_percentage,
+                                "total_size": total_size,
+                                "total_downloaded": total_downloaded
+                            }
+                        }
+
+                        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
+
+        await response.write(b"data: [DONE]\n\n")
+        return response
+
+    except Exception as e:
+        print(f"Error in handle_model_support: {str(e)}")
+        traceback.print_exc()
+        return web.json_response(
+            {"detail": f"Server error: {str(e)}"},
+            status=500
+        )
 
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = build_base_shard(self.default_model, self.inference_engine_classname)
+    model = data.get("model", self.default_model)
+    if model and model.startswith("gpt-"):  # Handle gpt- model requests
+      model = self.default_model
+    if not model or model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      model = self.default_model
+    shard = build_base_shard(model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
+    prompt = build_prompt(tokenizer, messages)
+    tokens = tokenizer.encode(prompt)
+    return web.json_response({
+      "length": len(prompt),
+      "num_tokens": len(tokens),
+      "encoded_tokens": tokens,
+      "encoded_prompt": prompt,
+    })
 
   async def handle_get_download_progress(self, request):
     progress_data = {}
@@ -371,6 +435,71 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
+  async def handle_delete_model(self, request):
+    try:
+      model_name = request.match_info.get('model_name')
+      if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
+
+      if not model_name or model_name not in model_cards:
+        return web.json_response(
+          {"detail": f"Invalid model name: {model_name}"},
+          status=400
+          )
+
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard:
+        return web.json_response(
+          {"detail": "Could not build shard for model"},
+          status=400
+        )
+
+      repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+      if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
+
+      # Get the HF cache directory using the helper function
+      hf_home = get_hf_home()
+      cache_dir = get_repo_root(repo_id)
+
+      if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
+
+      if os.path.exists(cache_dir):
+        if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
+        try:
+          shutil.rmtree(cache_dir)
+          return web.json_response({
+            "status": "success",
+            "message": f"Model {model_name} deleted successfully",
+            "path": str(cache_dir)
+          })
+        except Exception as e:
+          return web.json_response({
+            "detail": f"Failed to delete model files: {str(e)}"
+          }, status=500)
+      else:
+        return web.json_response({
+          "detail": f"Model files not found at {cache_dir}"
+        }, status=404)
+
+    except Exception as e:
+        print(f"Error in handle_delete_model: {str(e)}")
+        traceback.print_exc()
+        return web.json_response({
+            "detail": f"Server error: {str(e)}"
+        }, status=500)
+
+  async def handle_get_initial_models(self, request):
+    model_data = {}
+    for model_name, pretty in pretty_name.items():
+        model_data[model_name] = {
+            "name": pretty,
+            "downloaded": None,  # Initially unknown
+            "download_percentage": None,  # Change from 0 to null
+            "total_size": None,
+            "total_downloaded": None,
+            "loading": True  # Add loading state
+        }
+    return web.json_response(model_data)
+
   async def handle_create_animation(self, request):
     try:
       data = await request.json()
@@ -410,6 +539,38 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"error": str(e)}, status=500)
 
+  async def handle_post_download(self, request):
+    try:
+      data = await request.json()
+      model_name = data.get("model")
+      if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
+      if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
+      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+
+      return web.json_response({
+        "status": "success",
+        "message": f"Download started for model: {model_name}"
+      })
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
+  async def handle_get_topology(self, request):
+    try:
+      topology = self.node.current_topology
+      if topology:
+        return web.json_response(topology.to_json())
+      else:
+        return web.json_response({})
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response(
+        {"detail": f"Error getting topology: {str(e)}"},
+        status=500
+      )
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()

+ 62 - 2
exo/download/hf/hf_helpers.py

@@ -166,10 +166,18 @@ async def download_file(
     downloaded_size = local_file_size
     downloaded_this_session = 0
     mode = 'ab' if use_range_request else 'wb'
-    if downloaded_size == total_size:
+    percentage = await get_file_download_percentage(
+      session,
+      repo_id,
+      revision,
+      file_path,
+      Path(save_directory)
+    )
+    
+    if percentage == 100:
       if DEBUG >= 2: print(f"File already downloaded: {file_path}")
       if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
       return
 
     if response.status == 200:
@@ -432,6 +440,57 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
 
+async def get_file_download_percentage(
+    session: aiohttp.ClientSession,
+    repo_id: str,
+    revision: str,
+    file_path: str,
+    snapshot_dir: Path,
+) -> float:
+  """
+    Calculate the download percentage for a file by comparing local and remote sizes.
+    """
+  try:
+    local_path = snapshot_dir / file_path
+    if not await aios.path.exists(local_path):
+      return 0
+
+    # Get local file size first
+    local_size = await aios.path.getsize(local_path)
+    if local_size == 0:
+      return 0
+
+    # Check remote size
+    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+    url = urljoin(base_url, file_path)
+    headers = await get_auth_headers()
+
+    # Use HEAD request with redirect following for all files
+    async with session.head(url, headers=headers, allow_redirects=True) as response:
+      if response.status != 200:
+        if DEBUG >= 2:
+          print(f"Failed to get remote file info for {file_path}: {response.status}")
+        return 0
+
+      remote_size = int(response.headers.get('Content-Length', 0))
+
+      if remote_size == 0:
+        if DEBUG >= 2:
+          print(f"Remote size is 0 for {file_path}")
+        return 0
+
+      # Only return 100% if sizes match exactly
+      if local_size == remote_size:
+        return 100.0
+
+      # Calculate percentage based on sizes
+      return (local_size / remote_size) * 100 if remote_size > 0 else 0
+
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Error checking file download status for {file_path}: {e}")
+    return 0
+
 async def has_hf_home_read_access() -> bool:
   hf_home = get_hf_home()
   try: return await aios.access(hf_home, os.R_OK)
@@ -441,3 +500,4 @@ async def has_hf_home_write_access() -> bool:
   hf_home = get_hf_home()
   try: return await aios.access(hf_home, os.W_OK)
   except OSError: return False
+

+ 90 - 2
exo/download/hf/hf_shard_download.py

@@ -1,13 +1,20 @@
 import asyncio
 import traceback
 from pathlib import Path
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, Optional, Union
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
+from exo.download.hf.hf_helpers import (
+    download_repo_files, RepoProgressEvent, get_weight_map, 
+    get_allow_patterns, get_repo_root, fetch_file_list, 
+    get_local_snapshot_dir, get_file_download_percentage,
+    filter_repo_objects
+)
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.models import model_cards, get_repo
+import aiohttp
+from aiofiles import os as aios
 
 
 class HFShardDownloader(ShardDownloader):
@@ -17,8 +24,13 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads: Dict[Shard, asyncio.Task] = {}
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+    self.current_shard: Optional[Shard] = None
+    self.current_repo_id: Optional[str] = None
+    self.revision: str = "main"
 
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    self.current_shard = shard
+    self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
     repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
@@ -77,3 +89,79 @@ class HFShardDownloader(ShardDownloader):
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return self._on_progress
+
+  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
+    if not self.current_shard or not self.current_repo_id:
+      if DEBUG >= 2:
+        print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
+      return None
+
+    try:
+      # If no snapshot directory exists, return None - no need to check remote files
+      snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
+      if not snapshot_dir:
+        if DEBUG >= 2:
+          print(f"No snapshot directory found for {self.current_repo_id}")
+        return None
+
+      # Get the weight map to know what files we need
+      weight_map = await get_weight_map(self.current_repo_id, self.revision)
+      if not weight_map:
+        if DEBUG >= 2:
+          print(f"No weight map found for {self.current_repo_id}")
+        return None
+
+      # Get all files needed for this shard
+      patterns = get_allow_patterns(weight_map, self.current_shard)
+
+      # Check download status for all relevant files
+      status = {}
+      total_bytes = 0
+      downloaded_bytes = 0
+
+      async with aiohttp.ClientSession() as session:
+        file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+        relevant_files = list(
+            filter_repo_objects(
+                file_list, allow_patterns=patterns, key=lambda x: x["path"]))
+
+        for file in relevant_files:
+          file_size = file["size"]
+          total_bytes += file_size
+
+          percentage = await get_file_download_percentage(
+              session,
+              self.current_repo_id,
+              self.revision,
+              file["path"],
+              snapshot_dir,
+          )
+          status[file["path"]] = percentage
+          downloaded_bytes += (file_size * (percentage / 100))
+
+        # Add overall progress weighted by file size
+        if total_bytes > 0:
+          status["overall"] = (downloaded_bytes / total_bytes) * 100
+        else:
+          status["overall"] = 0
+          
+        # Add total size in bytes
+        status["total_size"] = total_bytes
+        if status["overall"] != 100:
+          status["total_downloaded"] = downloaded_bytes
+        
+
+        if DEBUG >= 2:
+          print(f"Download calculation for {self.current_repo_id}:")
+          print(f"Total bytes: {total_bytes}")
+          print(f"Downloaded bytes: {downloaded_bytes}")
+          for file in relevant_files:
+            print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
+
+      return status
+
+    except Exception as e:
+      if DEBUG >= 2:
+        print(f"Error getting shard download status: {e}")
+        traceback.print_exc()
+      return None

+ 14 - 1
exo/download/shard_download.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Dict
 from pathlib import Path
 from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
@@ -26,6 +26,16 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
 
+  @abstractmethod
+  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+    """Get the download status of shards.
+    
+    Returns:
+        Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
+        or None if status cannot be determined
+    """
+    pass
+
 
 class NoopShardDownloader(ShardDownloader):
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
@@ -34,3 +44,6 @@ class NoopShardDownloader(ShardDownloader):
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return AsyncCallbackSystem()
+
+  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+    return None

+ 76 - 3
exo/helpers.py

@@ -8,8 +8,11 @@ import platform
 import psutil
 import uuid
 import netifaces
+import subprocess
 from pathlib import Path
 import tempfile
+import json
+from concurrent.futures import ThreadPoolExecutor
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -22,6 +25,9 @@ exo_text = r"""
  \___/_/\_\___/ 
     """
 
+# Single shared thread pool for subprocess operations
+subprocess_pool = ThreadPoolExecutor(max_workers=4, thread_name_prefix="subprocess_worker")
+
 
 def get_system_info():
   if psutil.MACOS:
@@ -222,7 +228,7 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
     return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
 
 
-def get_all_ip_addresses():
+def get_all_ip_addresses_and_interfaces():
   try:
     ip_addresses = []
     for interface in netifaces.interfaces():
@@ -230,12 +236,79 @@ def get_all_ip_addresses():
       if netifaces.AF_INET in ifaddresses:
         for link in ifaddresses[netifaces.AF_INET]:
           ip = link['addr']
-          ip_addresses.append(ip)
+          ip_addresses.append((ip, interface))
     return list(set(ip_addresses))
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-    return ["localhost"]
+    return [("localhost", "lo")]
+
+async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
+  try:
+    # Use the shared subprocess_pool
+    output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
+      ['system_profiler', 'SPNetworkDataType', '-json'],
+      capture_output=True,
+      text=True,
+      close_fds=True
+    ).stdout)
+
+    data = json.loads(output)
+
+    for interface in data.get('SPNetworkDataType', []):
+      if interface.get('interface') == ifname:
+        hardware = interface.get('hardware', '').lower()
+        type_name = interface.get('type', '').lower()
+        name = interface.get('_name', '').lower()
+
+        if 'thunderbolt' in name:
+          return (5, "Thunderbolt")
+        if hardware == 'ethernet' or type_name == 'ethernet':
+          if 'usb' in name:
+            return (4, "Ethernet [USB]")
+          return (4, "Ethernet")
+        if hardware == 'airport' or type_name == 'airport' or 'wi-fi' in name:
+          return (3, "WiFi")
+        if type_name == 'vpn':
+          return (1, "External Virtual")
+
+  except Exception as e:
+    if DEBUG >= 2: print(f"Error detecting macOS interface type: {e}")
+
+  return None
+
+async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
+  # On macOS, try to get interface type using networksetup
+  if psutil.MACOS:
+    macos_type = await get_macos_interface_type(ifname)
+    if macos_type is not None: return macos_type
+
+  # Local container/virtual interfaces
+  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
+    'bridge' in ifname):
+    return (7, "Container Virtual")
+
+  # Loopback interface
+  if ifname.startswith('lo'):
+    return (6, "Loopback")
+
+  # Traditional detection for non-macOS systems or fallback
+  if ifname.startswith(('tb', 'nx', 'ten')):
+    return (5, "Thunderbolt")
+
+  # Regular ethernet detection
+  if ifname.startswith(('eth', 'en')) and not ifname.startswith(('en1', 'en0')):
+    return (4, "Ethernet")
+
+  # WiFi detection
+  if ifname.startswith(('wlan', 'wifi', 'wl')) or ifname in ['en0', 'en1']:
+    return (3, "WiFi")
+
+  # Non-local virtual interfaces (VPNs, tunnels)
+  if ifname.startswith(('tun', 'tap', 'vtun', 'utun', 'gif', 'stf', 'awdl', 'llw')):
+    return (1, "External Virtual")
 
+  # Other physical interfaces
+  return (2, "Other")
 
 async def shutdown(signal, loop, server):
   """Gracefully shutdown the server and close the asyncio loop."""

+ 6 - 6
exo/main.py

@@ -21,7 +21,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
@@ -81,8 +81,8 @@ if args.node_port is None:
   if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 args.node_id = args.node_id or get_or_create_node_id()
-chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip, _ in get_all_ip_addresses_and_interfaces()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip, _ in get_all_ip_addresses_and_interfaces()]
 if DEBUG >= 0:
   print("Chat interface started:")
   for web_chat_url in web_chat_urls:
@@ -100,7 +100,7 @@ if args.discovery_module == "udp":
     args.node_port,
     args.listen_port,
     args.broadcast_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     allowed_node_ids=allowed_node_ids
   )
@@ -108,7 +108,7 @@ elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
     args.node_id,
     args.node_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     tailscale_api_key=args.tailscale_api_key,
     tailnet=args.tailnet_name,
@@ -117,7 +117,7 @@ elif args.discovery_module == "tailscale":
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
-  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,

+ 8 - 0
exo/models.py

@@ -3,6 +3,13 @@ from typing import Optional, List
 
 model_cards = {
   ### llama
+  "llama-3.3-70b": {
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.3-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.3-70B-Instruct",
+    },
+  },
   "llama-3.2-1b": {
     "layers": 16,
     "repo": {
@@ -85,6 +92,7 @@ model_cards = {
 }
 
 pretty_name = {
+  "llama-3.3-70b": "Llama 3.3 70B",
   "llama-3.2-1b": "Llama 3.2 1B",
   "llama-3.2-3b": "Llama 3.2 3B",
   "llama-3.1-8b": "Llama 3.1 8B",

+ 12 - 5
exo/networking/grpc/grpc_peer_handle.py

@@ -14,9 +14,10 @@ from exo.helpers import DEBUG
 
 
 class GRPCPeerHandle(PeerHandle):
-  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+  def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
     self._id = _id
     self.address = address
+    self.desc = desc
     self._device_capabilities = device_capabilities
     self.channel = None
     self.stub = None
@@ -27,6 +28,9 @@ class GRPCPeerHandle(PeerHandle):
   def addr(self) -> str:
     return self.address
 
+  def description(self) -> str:
+    return self.desc
+
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
 
@@ -119,12 +123,15 @@ class GRPCPeerHandle(PeerHandle):
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
       device_capabilities = DeviceCapabilities(
-        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+        model=capabilities.model,
+        chip=capabilities.chip,
+        memory=capabilities.memory,
+        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
       )
       topology.update_node(node_id, device_capabilities)
-    for node_id, peers in response.peer_graph.items():
-      for peer_id in peers.peer_ids:
-        topology.add_edge(node_id, peer_id)
+    for node_id, peer_connections in response.peer_graph.items():
+      for conn in peer_connections.connections:
+        topology.add_edge(node_id, conn.to_id, conn.description)
     return topology
 
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:

+ 10 - 2
exo/networking/grpc/grpc_server.py

@@ -85,7 +85,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
   async def CollectTopology(self, request, context):
     max_depth = request.max_depth
     visited = set(request.visited)
-    topology = await self.node.collect_topology(visited, max_depth)
+    topology = self.node.current_topology
     nodes = {
       node_id:
         node_service_pb2.DeviceCapabilities(
@@ -96,7 +96,15 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         )
       for node_id, cap in topology.nodes.items()
     }
-    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+    peer_graph = {
+      node_id: node_service_pb2.PeerConnections(
+        connections=[
+          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
+          for conn in connections
+        ]
+      )
+      for node_id, connections in topology.peer_graph.items()
+    }
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 

+ 11 - 6
exo/networking/grpc/node_service.proto

@@ -53,17 +53,22 @@ message CollectTopologyRequest {
 
 message Topology {
   map<string, DeviceCapabilities> nodes = 1;
-  map<string, Peers> peer_graph = 2;
+  map<string, PeerConnections> peer_graph = 2;
 }
 
-message Peers {
-    repeated string peer_ids = 1;
+message PeerConnection {
+  string to_id = 1;
+  optional string description = 2;
+}
+
+message PeerConnections {
+  repeated PeerConnection connections = 1;
 }
 
 message DeviceFlops {
-  float fp32 = 1;
-  float fp16 = 2;
-  float int8 = 3;
+  double fp32 = 1;
+  double fp16 = 2;
+  double int8 = 3;
 }
 
 message DeviceCapabilities {

Dosya farkı çok büyük olduğundan ihmal edildi
+ 11 - 1
exo/networking/grpc/node_service_pb2.py


+ 2 - 7
exo/networking/grpc/node_service_pb2_grpc.py

@@ -5,10 +5,8 @@ import warnings
 
 from . import node_service_pb2 as node__service__pb2
 
-GRPC_GENERATED_VERSION = '1.64.1'
+GRPC_GENERATED_VERSION = '1.68.0'
 GRPC_VERSION = grpc.__version__
-EXPECTED_ERROR_RELEASE = '1.65.0'
-SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
@@ -18,15 +16,12 @@ except ImportError:
     _version_not_supported = True
 
 if _version_not_supported:
-    warnings.warn(
+    raise RuntimeError(
         f'The grpc package installed is at version {GRPC_VERSION},'
         + f' but the generated code in node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
     )
 
 

+ 1 - 1
exo/networking/manual/manual_discovery.py

@@ -53,7 +53,7 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")

+ 5 - 5
exo/networking/manual/test_manual_discovery.py

@@ -14,7 +14,7 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
     _ = self.discovery1.start()
 
   async def asyncTearDown(self):
@@ -33,8 +33,8 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -63,8 +63,8 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery2.start()
 

+ 4 - 0
exo/networking/peer_handle.py

@@ -15,6 +15,10 @@ class PeerHandle(ABC):
   def addr(self) -> str:
     pass
 
+  @abstractmethod
+  def description(self) -> str:
+    pass
+
   @abstractmethod
   def device_capabilities(self) -> DeviceCapabilities:
     pass

+ 2 - 2
exo/networking/tailscale/tailscale_discovery.py

@@ -14,7 +14,7 @@ class TailscaleDiscovery(Discovery):
     self,
     node_id: str,
     node_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
     discovery_interval: int = 5,
     discovery_timeout: int = 30,
     update_interval: int = 15,
@@ -91,7 +91,7 @@ class TailscaleDiscovery(Discovery):
             continue
 
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", "TS", device_capabilities)
             if not await new_peer_handle.health_check():
               if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
               continue

+ 1 - 1
exo/networking/tailscale/test_tailscale_discovery.py

@@ -13,7 +13,7 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     self.discovery = TailscaleDiscovery(
       node_id="test_node",
       node_port=50051,
-      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
       tailscale_api_key=self.tailscale_api_key,
       tailnet=self.tailnet
     )

+ 4 - 4
exo/networking/udp/test_udp_discovery.py

@@ -13,8 +13,8 @@ class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -41,8 +41,8 @@ class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, "localhost", 50054)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery2.start()
 

+ 15 - 12
exo/networking/udp/udp_discovery.py

@@ -7,7 +7,7 @@ from typing import List, Dict, Callable, Tuple, Coroutine
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses_and_interfaces, get_interface_priority_and_type
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
@@ -41,8 +41,8 @@ class UDPDiscovery(Discovery):
     node_port: int,
     listen_port: int,
     broadcast_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
-    broadcast_interval: int = 1,
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
+    broadcast_interval: int = 2.5,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     allowed_node_ids: List[str] = None,
@@ -87,27 +87,28 @@ class UDPDiscovery(Discovery):
     while True:
       # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
       # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
-      for addr in get_all_ip_addresses():
+      for addr, interface_name in get_all_ip_addresses_and_interfaces():
+        interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
         message = json.dumps({
           "type": "discovery",
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "interface_name": interface_name,
+          "interface_type": interface_type,
         })
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
 
         transport = None
         try:
           transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
-          if DEBUG_DISCOVERY >= 3:
-            print(f"Broadcasting presence at ({addr})")
+          if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
         except Exception as e:
-          print(f"Error in broadcast presence ({addr}): {e}")
+          print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
         finally:
           if transport:
-            try:
-              transport.close()
+            try: transport.close()
             except Exception as e:
               if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
               if DEBUG_DISCOVERY >= 2: traceback.print_exc()
@@ -144,6 +145,8 @@ class UDPDiscovery(Discovery):
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_prio = message["priority"]
+      peer_interface_name = message["interface_name"]
+      peer_interface_type = message["interface_type"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
 
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
@@ -153,7 +156,7 @@ class UDPDiscovery(Discovery):
             if DEBUG >= 1:
               print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
-        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", f"{peer_interface_type} ({peer_interface_name})", device_capabilities)
         if not await new_peer_handle.health_check():
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
           return

+ 10 - 10
exo/orchestration/standard_node.py

@@ -56,9 +56,9 @@ class StandardNode(Node):
     await self.server.start()
     await self.discovery.start()
     await self.update_peers(wait_for_peers)
-    await self.collect_topology()
+    await self.collect_topology(set())
     if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(1.0))
+    asyncio.create_task(self.periodic_topology_collection(2.0))
 
   async def stop(self) -> None:
     await self.discovery.stop()
@@ -83,7 +83,7 @@ class StandardNode(Node):
         download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
         self.node_download_progress[status_data.get('node_id')] = download_progress
       if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
+        self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
     except Exception as e:
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
@@ -374,8 +374,8 @@ class StandardNode(Node):
       try:
         did_peers_change = await self.update_peers()
         if DEBUG >= 2: print(f"{did_peers_change=}")
+        await self.collect_topology(set())
         if did_peers_change:
-          await self.collect_topology()
           await self.select_best_inference_engine()
       except Exception as e:
         print(f"Error collecting topology: {e}")
@@ -386,7 +386,7 @@ class StandardNode(Node):
       return None, False
     return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
 
-  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
+  async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
     next_topology = Topology()
     next_topology.update_node(self.id, self.device_capabilities)
 
@@ -398,7 +398,7 @@ class StandardNode(Node):
 
     for peer in self.peers:
       next_topology.update_node(peer.id(), peer.device_capabilities())
-      next_topology.add_edge(self.id, peer.id())
+      next_topology.add_edge(self.id, peer.id(), peer.description())
 
       if peer.id() in prev_visited:
         continue
@@ -410,16 +410,16 @@ class StandardNode(Node):
       try:
         other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
         if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
-        self.topology.merge(other_topology)
+        next_topology.merge(peer.id(), other_topology)
       except Exception as e:
         print(f"Error collecting topology from {peer.id()}: {e}")
         traceback.print_exc()
 
-    next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
+    next_topology.active_node_id = self.topology.active_node_id
     self.topology = next_topology
     if self.topology_viz:
-      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
-    return next_topology
+      self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id)
+    return self.topology
 
   @property
   def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:

+ 202 - 29
exo/tinychat/index.css

@@ -21,8 +21,8 @@ main {
 .home {
   width: 100%;
   height: 90%;
-
   margin-bottom: 10rem;
+  padding-top: 2rem;
 }
 
 .title {
@@ -129,8 +129,9 @@ main {
   flex-direction: column;
   gap: 1rem;
   align-items: center;
-  padding-top: 1rem;
+  padding-top: 4rem;
   padding-bottom: 11rem;
+  margin: 0 auto;
 }
 
 .message {
@@ -149,10 +150,17 @@ main {
   color: #000;
 }
 .download-progress {
-  margin-bottom: 12em;
+  position: fixed;
+  bottom: 11rem;
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
+  width: 100%;
+  max-width: 1200px;
   overflow-y: auto;
   min-height: 350px;
   padding: 2rem;
+  z-index: 998;
 }
 .message > pre {
   white-space: pre-wrap;
@@ -271,23 +279,24 @@ main {
 }
 
 .input-container {
-  position: absolute;
+  position: fixed;
   bottom: 0;
-
-  /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(
-    0deg,
-    var(--primary-bg-color) 55%,
-    transparent 100%
-  );
-
-  width: 100%;
+  left: 250px;
+  width: calc(100% - 250px);
   max-width: 1200px;
   display: flex;
   flex-direction: column;
   justify-content: center;
   align-items: center;
   z-index: 999;
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
 }
 
 .input-performance {
@@ -372,22 +381,7 @@ p {
 }
 
 .model-selector {
-  display: flex;
-  justify-content: center;
-  padding: 20px 0;
-}
-.model-selector select {
-  padding: 10px 20px;
-  font-size: 16px;
-  border: 1px solid #ccc;
-  border-radius: 5px;
-  background-color: #f8f8f8;
-  cursor: pointer;
-}
-.model-selector select:focus {
-  outline: none;
-  border-color: #007bff;
-  box-shadow: 0 0 0 2px rgba(0,123,255,.25);
+  display: none;
 }
 
 /* Image upload button styles */
@@ -481,4 +475,183 @@ p {
 
 .clear-history-button i {
   font-size: 14px;
+}
+
+/* Add new sidebar styles */
+.sidebar {
+  position: fixed;
+  left: 0;
+  top: 0;
+  bottom: 0;
+  width: 250px;
+  background-color: var(--secondary-color);
+  padding: 20px;
+  overflow-y: auto;
+  z-index: 1000;
+}
+
+.model-option {
+  padding: 12px;
+  margin: 8px 0;
+  border-radius: 8px;
+  background-color: var(--primary-bg-color);
+  cursor: pointer;
+  transition: all 0.2s ease;
+}
+
+.model-option:hover {
+  transform: translateX(5px);
+}
+
+.model-option.selected {
+  border-left: 3px solid var(--primary-color);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-name {
+  font-weight: bold;
+  margin-bottom: 4px;
+}
+
+.model-progress {
+  font-size: 0.9em;
+  color: var(--secondary-color-transparent);
+  display: flex;
+  flex-direction: column;
+  gap: 0.5rem;
+}
+
+.model-progress-info {
+  display: flex;
+  flex-direction: column;
+  gap: 0.5rem;
+}
+
+.model-progress i {
+  font-size: 0.9em;
+  color: var(--primary-color);
+}
+
+/* Adjust main content to accommodate sidebar */
+main {
+  margin-left: 250px;
+  width: calc(100% - 250px);
+}
+
+/* Add styles for the back button */
+.back-button {
+  position: fixed;
+  top: 1rem;
+  left: calc(250px + 1rem); /* Sidebar width + padding */
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  padding: 0.5rem 1rem;
+  border-radius: 8px;
+  border: none;
+  cursor: pointer;
+  display: flex;
+  align-items: center;
+  gap: 0.5rem;
+  z-index: 1000;
+  transition: all 0.2s ease;
+}
+
+.back-button:hover {
+  transform: translateX(-5px);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-info {
+  display: flex;
+  flex-direction: column;
+  gap: 4px;
+}
+
+.model-size {
+  font-size: 0.8em;
+  color: var(--secondary-color-transparent);
+  opacity: 0.8;
+}
+
+.model-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    margin-bottom: 4px;
+}
+
+.model-delete-button {
+    background: none;
+    border: none;
+    color: var(--red-color);
+    padding: 4px 8px;
+    cursor: pointer;
+    transition: all 0.2s ease;
+    opacity: 0.7;
+}
+
+.model-delete-button:hover {
+    opacity: 1;
+    transform: scale(1.1);
+}
+
+.model-option:hover .model-delete-button {
+    opacity: 1;
+}
+
+.loading-container {
+    display: flex;
+    flex-direction: column;
+    align-items: center;
+    gap: 10px;
+    padding: 20px;
+    color: var(--secondary-color-transparent);
+}
+
+.loading-container i {
+    font-size: 24px;
+}
+
+.loading-container span {
+    font-size: 14px;
+}
+
+/* Add this to your CSS */
+.fa-spin {
+    animation: fa-spin 2s infinite linear;
+}
+
+@keyframes fa-spin {
+    0% {
+        transform: rotate(0deg);
+    }
+    100% {
+        transform: rotate(360deg);
+    }
+}
+
+.model-download-button {
+  background: none;
+  border: none;
+  color: var(--primary-color);
+  padding: 4px 8px;
+  border-radius: 4px;
+  cursor: pointer;
+  transition: all 0.2s ease;
+  display: inline-flex;
+  align-items: center;
+  gap: 6px;
+  background-color: var(--primary-bg-color);
+  font-size: 0.9em;
+  width: fit-content;
+  align-self: flex-start;
+}
+
+.model-download-button:hover {
+  transform: scale(1.05);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-download-button i {
+  font-size: 0.9em;
 }

+ 119 - 33
exo/tinychat/index.html

@@ -25,14 +25,73 @@
 </head>
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
-     <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity class="toast">
+  <div class="sidebar">
+    <h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
+    
+    <!-- Loading indicator -->
+    <div class="loading-container" x-show="Object.keys(models).length === 0">
+        <i class="fas fa-spinner fa-spin"></i>
+        <span>Loading models...</span>
+    </div>
+    
+    <template x-for="(model, key) in models" :key="key">
+        <div class="model-option" 
+             :class="{ 'selected': cstate.selectedModel === key }"
+             @click="cstate.selectedModel = key">
+            <div class="model-header">
+                <div class="model-name" x-text="model.name"></div>
+                <button 
+                    @click.stop="deleteModel(key, model)"
+                    class="model-delete-button"
+                    x-show="model.download_percentage > 0">
+                    <i class="fas fa-trash"></i>
+                </button>
+            </div>
+            <div class="model-info">
+                <div class="model-progress">
+                    <template x-if="model.loading">
+                        <span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
+                    </template>
+                    <div class="model-progress-info">
+                        <template x-if="!model.loading && model.download_percentage != null">
+                            <span>
+                                <!-- Check if there's an active download for this model -->
+                                <template x-if="downloadProgress?.some(p => 
+                                    p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
+                                )">
+                                    <i class="fas fa-circle-notch fa-spin"></i>
+                                </template>
+                                <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
+                            </span>
+                        </template>
+                        <template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
+                            <button 
+                                @click.stop="handleDownload(key)"
+                                class="model-download-button">
+                                <i class="fas fa-download"></i>
+                                <span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
+                            </button>
+                        </template>
+                    </div>
+                </div>
+                <template x-if="model.total_size">
+                    <div class="model-size" x-text="model.total_downloaded ? 
+                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` : 
+                        formatBytes(model.total_size)">
+                    </div>
+                </template>
+            </div>
+        </div>
+    </template>
+  </div> 
+    <!-- Error Toast -->
+    <div x-show="errorMessage !== null" x-transition.opacity class="toast">
         <div class="toast-header">
-            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
             <div class="toast-header-buttons">
                 <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
                         class="toast-expand-button" 
-                        x-show="errorMessage.stack">
+                        x-show="errorMessage?.stack">
                     <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
                 </button>
                 <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
@@ -41,13 +100,10 @@
             </div>
         </div>
         <div class="toast-content" x-show="errorExpanded" x-transition>
-            <span x-text="errorMessage.stack"></span>
+            <span x-text="errorMessage?.stack || ''"></span>
         </div>
     </div>
-<div class="model-selector">
-  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
-  </select>
-</div>
+
 <div @popstate.window="
       if (home === 2) {
         home = -1;
@@ -79,10 +135,8 @@
 <template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
 <div @click="
             cstate = _state;
-            if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
-            // updateTotalTokens(cstate.messages);
-            home = 1;
-            // ensure that going back in history will go back to home
+            if (!cstate.selectedModel) cstate.selectedModel = 'llama-3.2-1b';
+            home = 2;
             window.history.pushState({}, '', '/');
           " @touchend="
             if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
@@ -108,6 +162,19 @@
 </template>
 </div>
 </div>
+<button 
+    @click="
+        home = 0;
+        cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
+        time_till_first = 0;
+        tokens_per_second = 0;
+        total_tokens = 0;
+    " 
+    class="back-button"
+    x-show="home === 2">
+    <i class="fas fa-arrow-left"></i>
+    Back to Chats
+</button>
 <div class="messages" x-init="
       $watch('cstate', value =&gt; {
         $el.innerHTML = '';
@@ -209,27 +276,46 @@
 <i class="fas fa-times"></i>
 </button>
 </div>
-<textarea :disabled="generating" :placeholder="generating ? 'Generating...' : 'Say something'" @input="
-            home = (home === 0) ? 1 : home
-            if (cstate.messages.length === 0 &amp;&amp; $el.value === '') home = -1;
+<textarea 
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
+    :placeholder="
+        generating ? 'Generating...' : 
+        (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete)) ? 'Download in progress...' :
+        'Say something'
+    "
+    @input="
+        home = (home === 0) ? 1 : home
+        if (cstate.messages.length === 0 && $el.value === '') home = -1;
 
-            if ($el.value !== '') {
-              const messages = [...cstate.messages];
-              messages.push({ role: 'user', content: $el.value });
-              // updateTotalTokens(messages);
-            } else {
-              if (cstate.messages.length === 0) total_tokens = 0;
-              // else updateTotalTokens(cstate.messages);
-            }
-          " @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)" autofocus="" class="input-form" id="input-form" rows="1" x-autosize="" x-effect="
-            console.log(generating);
-            if (!generating) $nextTick(() =&gt; {
-              $el.focus();
-              setTimeout(() =&gt; $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
-            });
-          " x-ref="inputForm"></textarea>
-<button :disabled="generating" @click="await handleSend()" class="input-button">
-<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
+        if ($el.value !== '') {
+            const messages = [...cstate.messages];
+            messages.push({ role: 'user', content: $el.value });
+            // updateTotalTokens(messages);
+        } else {
+            if (cstate.messages.length === 0) total_tokens = 0;
+            // else updateTotalTokens(cstate.messages);
+        }
+    "
+    @keydown.enter="await handleEnter($event)"
+    @keydown.escape.window="$focus.focus($el)"
+    autofocus=""
+    class="input-form"
+    id="input-form"
+    rows="1"
+    x-autosize=""
+    x-effect="
+        console.log(generating);
+        if (!generating) $nextTick(() => {
+            $el.focus();
+            setTimeout(() => $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
+        });
+    "
+    x-ref="inputForm"></textarea>
+<button 
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
+    @click="await handleSend()" 
+    class="input-button">
+    <i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
 </button>
 </div>
 </div>

+ 173 - 96
exo/tinychat/index.js

@@ -13,6 +13,8 @@ document.addEventListener("alpine:init", () => {
     home: 0,
     generating: false,
     endpoint: `${window.location.origin}/v1`,
+    
+    // Initialize error message structure
     errorMessage: null,
     errorExpanded: false,
     errorTimeout: null,
@@ -32,12 +34,81 @@ document.addEventListener("alpine:init", () => {
     // Pending message storage
     pendingMessage: null,
 
+    modelPoolInterval: null,
+
+    // Add models state alongside existing state
+    models: {},
+
     init() {
       // Clean up any pending messages
       localStorage.removeItem("pendingMessage");
 
+      // Get initial model list
+      this.fetchInitialModels();
+
       // Start polling for download progress
       this.startDownloadProgressPolling();
+      
+      // Start model polling with the new pattern
+      this.startModelPolling();
+    },
+
+    async fetchInitialModels() {
+      try {
+        const response = await fetch(`${window.location.origin}/initial_models`);
+        if (response.ok) {
+          const initialModels = await response.json();
+          this.models = initialModels;
+        }
+      } catch (error) {
+        console.error('Error fetching initial models:', error);
+      }
+    },
+
+    async startModelPolling() {
+      while (true) {
+        try {
+          await this.populateSelector();
+          // Wait 5 seconds before next poll
+          await new Promise(resolve => setTimeout(resolve, 5000));
+        } catch (error) {
+          console.error('Model polling error:', error);
+          // If there's an error, wait before retrying
+          await new Promise(resolve => setTimeout(resolve, 5000));
+        }
+      }
+    },
+
+    async populateSelector() {
+      return new Promise((resolve, reject) => {
+        const evtSource = new EventSource(`${window.location.origin}/modelpool`);
+        
+        evtSource.onmessage = (event) => {
+          if (event.data === "[DONE]") {
+            evtSource.close();
+            resolve();
+            return;
+          }
+          
+          const modelData = JSON.parse(event.data);
+          // Update existing model data while preserving other properties
+          Object.entries(modelData).forEach(([modelName, data]) => {
+            if (this.models[modelName]) {
+              this.models[modelName] = {
+                ...this.models[modelName],
+                ...data,
+                loading: false
+              };
+            }
+          });
+        };
+        
+        evtSource.onerror = (error) => {
+          console.error('EventSource failed:', error);
+          evtSource.close();
+          reject(error);
+        };
+      });
     },
 
     removeHistory(cstate) {
@@ -74,56 +145,6 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
     },
 
-    async populateSelector() {
-      try {
-        const response = await fetch(`${window.location.origin}/modelpool`);
-        const responseText = await response.text(); // Get raw response text first
-        
-        if (!response.ok) {
-          throw new Error(`HTTP error! status: ${response.status}`);
-        }
-        
-        // Try to parse the response text
-        let responseJson;
-        try {
-          responseJson = JSON.parse(responseText);
-        } catch (parseError) {
-          console.error('Failed to parse JSON:', parseError);
-          throw new Error(`Invalid JSON response: ${responseText}`);
-        }
-
-        const sel = document.querySelector(".model-select");
-        if (!sel) {
-          throw new Error("Could not find model selector element");
-        }
-
-        // Clear the current options and add new ones
-        sel.innerHTML = '';
-          
-        const modelDict = responseJson["model pool"];
-        if (!modelDict) {
-          throw new Error("Response missing 'model pool' property");
-        }
-
-        Object.entries(modelDict).forEach(([key, value]) => {
-          const opt = document.createElement("option");
-          opt.value = key;
-          opt.textContent = value;
-          sel.appendChild(opt);
-        });
-
-        // Set initial value to the first model
-        const firstKey = Object.keys(modelDict)[0];
-        if (firstKey) {
-          sel.value = firstKey;
-          this.cstate.selectedModel = firstKey;
-        }
-      } catch (error) {
-        console.error("Error populating model selector:", error);
-        this.errorMessage = `Failed to load models: ${error.message}`;
-      }
-    },
-
     async handleImageUpload(event) {
       const file = event.target.files[0];
       if (file) {
@@ -169,29 +190,7 @@ document.addEventListener("alpine:init", () => {
         this.processMessage(value);
       } catch (error) {
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
         this.generating = false;
       }
     },
@@ -309,29 +308,7 @@ document.addEventListener("alpine:init", () => {
         }
       } catch (error) {
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
       } finally {
         this.generating = false;
       }
@@ -467,6 +444,106 @@ document.addEventListener("alpine:init", () => {
         this.fetchDownloadProgress();
       }, 1000); // Poll every second
     },
+
+    // Add a helper method to set errors consistently
+    setError(error) {
+      this.errorMessage = {
+        basic: error.message || "An unknown error occurred",
+        stack: error.stack || ""
+      };
+      this.errorExpanded = false;
+      
+      if (this.errorTimeout) {
+        clearTimeout(this.errorTimeout);
+      }
+
+      if (!this.errorExpanded) {
+        this.errorTimeout = setTimeout(() => {
+          this.errorMessage = null;
+          this.errorExpanded = false;
+        }, 30 * 1000);
+      }
+    },
+
+    async deleteModel(modelName, model) {
+      const downloadedSize = model.total_downloaded || 0;
+      const sizeMessage = downloadedSize > 0 ? 
+        `This will free up ${this.formatBytes(downloadedSize)} of space.` :
+        'This will remove any partially downloaded files.';
+      
+      if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
+        return;
+      }
+
+      try {
+        const response = await fetch(`${window.location.origin}/models/${modelName}`, {
+          method: 'DELETE',
+          headers: {
+            'Content-Type': 'application/json'
+          }
+        });
+
+        const data = await response.json();
+        
+        if (!response.ok) {
+          throw new Error(data.detail || 'Failed to delete model');
+        }
+
+        // Update the model status in the UI
+        if (this.models[modelName]) {
+          this.models[modelName].downloaded = false;
+          this.models[modelName].download_percentage = 0;
+          this.models[modelName].total_downloaded = 0;
+        }
+
+        // If this was the selected model, switch to a different one
+        if (this.cstate.selectedModel === modelName) {
+          const availableModel = Object.keys(this.models).find(key => this.models[key].downloaded);
+          this.cstate.selectedModel = availableModel || 'llama-3.2-1b';
+        }
+
+        // Show success message
+        console.log(`Model deleted successfully from: ${data.path}`);
+
+        // Refresh the model list
+        await this.populateSelector();
+      } catch (error) {
+        console.error('Error deleting model:', error);
+        this.setError(error.message || 'Failed to delete model');
+      }
+    },
+
+    async handleDownload(modelName) {
+      try {
+        const response = await fetch(`${window.location.origin}/download`, {
+          method: 'POST',
+          headers: {
+            'Content-Type': 'application/json'
+          },
+          body: JSON.stringify({
+            model: modelName
+          })
+        });
+
+        const data = await response.json();
+
+        if (!response.ok) {
+          throw new Error(data.error || 'Failed to start download');
+        }
+
+        // Update the model's status immediately when download starts
+        if (this.models[modelName]) {
+          this.models[modelName] = {
+            ...this.models[modelName],
+            loading: true
+          };
+        }
+
+      } catch (error) {
+        console.error('Error starting download:', error);
+        this.setError(error);
+      }
+    }
   }));
 });
 

+ 52 - 26
exo/topology/topology.py

@@ -1,11 +1,27 @@
 from .device_capabilities import DeviceCapabilities
 from typing import Dict, Set, Optional
+from dataclasses import dataclass
 
+@dataclass
+class PeerConnection:
+  from_id: str
+  to_id: str
+  description: Optional[str] = None
+
+  def __hash__(self):
+    # Use both from_id and to_id for uniqueness in sets
+    return hash((self.from_id, self.to_id))
+
+  def __eq__(self, other):
+    if not isinstance(other, PeerConnection):
+      return False
+    # Compare both from_id and to_id for equality
+    return self.from_id == other.from_id and self.to_id == other.to_id
 
 class Topology:
   def __init__(self):
-    self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
-    self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
+    self.nodes: Dict[str, DeviceCapabilities] = {}
+    self.peer_graph: Dict[str, Set[PeerConnection]] = {}
     self.active_node_id: Optional[str] = None
 
   def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
@@ -17,33 +33,43 @@ class Topology:
   def all_nodes(self):
     return self.nodes.items()
 
-  def add_edge(self, node1_id: str, node2_id: str):
-    if node1_id not in self.peer_graph:
-      self.peer_graph[node1_id] = set()
-    if node2_id not in self.peer_graph:
-      self.peer_graph[node2_id] = set()
-    self.peer_graph[node1_id].add(node2_id)
-    self.peer_graph[node2_id].add(node1_id)
-
-  def get_neighbors(self, node_id: str) -> Set[str]:
-    return self.peer_graph.get(node_id, set())
-
-  def all_edges(self):
-    edges = []
-    for node, neighbors in self.peer_graph.items():
-      for neighbor in neighbors:
-        if (neighbor, node) not in edges:  # Avoid duplicate edges
-          edges.append((node, neighbor))
-    return edges
-
-  def merge(self, other: "Topology"):
+  def add_edge(self, from_id: str, to_id: str, description: Optional[str] = None):
+    if from_id not in self.peer_graph:
+      self.peer_graph[from_id] = set()
+    conn = PeerConnection(from_id, to_id, description)
+    self.peer_graph[from_id].add(conn)
+
+  def merge(self, peer_node_id: str, other: "Topology"):
     for node_id, capabilities in other.nodes.items():
+      if node_id != peer_node_id: continue
       self.update_node(node_id, capabilities)
-    for node_id, neighbors in other.peer_graph.items():
-      for neighbor in neighbors:
-        self.add_edge(node_id, neighbor)
+    for node_id, connections in other.peer_graph.items():
+      for conn in connections:
+        if conn.from_id != peer_node_id: continue
+        self.add_edge(conn.from_id, conn.to_id, conn.description)
 
   def __str__(self):
     nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
-    edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
+    edges_str = ", ".join(f"{node}: {[f'{c.to_id}({c.description})' for c in conns]}"
+                         for node, conns in self.peer_graph.items())
     return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"
+
+  def to_json(self):
+    return {
+      "nodes": {
+        node_id: capabilities.to_dict()
+        for node_id, capabilities in self.nodes.items()
+      },
+      "peer_graph": {
+        node_id: [
+          {
+            "from_id": conn.from_id,
+            "to_id": conn.to_id,
+            "description": conn.description
+          }
+          for conn in connections
+        ]
+        for node_id, connections in self.peer_graph.items()
+      },
+      "active_node_id": self.active_node_id
+    }

+ 17 - 1
exo/viz/topology_viz.py

@@ -242,12 +242,19 @@ class TopologyViz:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
 
-      # Draw line to next node
+      # Draw line to next node and add connection description
       next_i = (i+1) % num_partitions
       next_angle = 2*math.pi*next_i/num_partitions
       next_x = int(center_x + radius_x*math.cos(next_angle))
       next_y = int(center_y + radius_y*math.sin(next_angle))
 
+      # Get connection descriptions
+      conn1 = self.topology.peer_graph.get(partition.node_id, set())
+      conn2 = self.topology.peer_graph.get(self.partitions[next_i].node_id, set())
+      description1 = next((c.description for c in conn1 if c.to_id == self.partitions[next_i].node_id), "")
+      description2 = next((c.description for c in conn2 if c.to_id == partition.node_id), "")
+      connection_description = f"{description1}/{description2}"
+
       # Simple line drawing
       steps = max(abs(next_x - x), abs(next_y - y))
       for step in range(1, steps):
@@ -256,6 +263,15 @@ class TopologyViz:
         if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
 
+      # Add connection description near the midpoint of the line
+      mid_x = (x + next_x) // 2
+      mid_y = (y + next_y) // 2
+      # Center the description text around the midpoint
+      desc_start_x = mid_x - len(connection_description) // 2
+      for j, char in enumerate(connection_description):
+        if 0 <= mid_y < 48 and 0 <= desc_start_x + j < 100:
+          visualization[mid_y][desc_start_x + j] = char
+
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
 

+ 1 - 1
scripts/compile_grpc.sh

@@ -2,6 +2,6 @@
 source ./install.sh
 pushd exo/networking/grpc
 python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
-sed -i "s/import\ node_service_pb2/from . &/" node_service_pb2_grpc.py
+sed -i '' "s/import\ node_service_pb2/from . &/" node_service_pb2_grpc.py
 popd
 

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor