Selaa lähdekoodia

Merge branch 'main' into downloadedModelsV2

Caden MacKenzie 8 kuukautta sitten
vanhempi
commit
2cdd55d297

+ 1 - 0
.circleci/config.yml

@@ -126,6 +126,7 @@ jobs:
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             echo "Running tokenizer tests..."
             python3 ./test/test_tokenizers.py
+            python3 ./test/test_model_helpers.py
 
   discovery_integration_test:
     macos:

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_node_id
 .idea
+.DS_Store
 
 # Byte-compiled / optimized / DLL files
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 # Distribution / packaging
 /.Python
-/build/
 /develop-eggs/
 /dist/
 /downloads/

+ 5 - 5
README.md

@@ -121,14 +121,14 @@ exo
 
 That's it! No configuration required - exo will automatically discover the other device(s).
 
-exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
+exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:52415
 
-For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Examples with curl:
+For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:52415/v1/chat/completions. Examples with curl:
 
 #### Llama 3.2 3B:
 
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
      "model": "llama-3.2-3b",
@@ -140,7 +140,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llama 3.1 405B:
 
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
      "model": "llama-3.1-405b",
@@ -152,7 +152,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llava 1.5 7B (Vision Language Model):
 
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
      "model": "llava-1.5-7b-hf",

+ 18 - 3
configure_mlx.sh

@@ -1,3 +1,18 @@
-# this needs to be dynamic
-# sudo sysctl iogpu.wired_lwm_mb=400000
-# sudo sysctl iogpu.wired_limit_mb=180000
+#!/bin/bash
+
+# Get the total memory in MB
+TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
+
+# Set WIRED_LIMIT_MB to 80%
+WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
+# Set  WIRED_LWM_MB to 70%
+WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
+
+# Display the calculated values
+echo "Total memory: $TOTAL_MEM_MB MB"
+echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
+echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
+
+# Apply the values with sysctl
+sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
+sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB

BIN
docs/exo-rounded.png


+ 1 - 1
examples/astra/astra/ContentView.swift

@@ -148,7 +148,7 @@ struct ContentView: View {
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var silenceTimeThreshold = 1.0
     @State private var debugText = ""
-    @State private var apiEndpoint = "http://192.168.212.74:8000/v1/chat/completions"
+    @State private var apiEndpoint = "http://192.168.212.74:52415/v1/chat/completions"
     @State private var audioBuffer: [Float] = []
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var isInitialTranscription = true

+ 1 - 1
examples/chatgpt_api.sh

@@ -3,7 +3,7 @@
 # This works the same in a single-node set up and in a multi-node setup.
 # You need to start exo before running this by running `python3 main.py`.
 
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 MODEL="llama-3.1-8b"
 PROMPT="What is the meaning of exo?"
 TEMPERATURE=0.7

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 25 - 13
exo/api/chatgpt_api.py

@@ -9,18 +9,17 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
+import os
+import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
-from exo.inference.inference_engine import inference_engine_classes
-from exo.inference.shard import Shard
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
-
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     self.role = role
@@ -30,6 +29,7 @@ class Message:
     return {"role": self.role, "content": self.content}
 
 
+
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
@@ -147,9 +147,8 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
-
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -158,7 +157,7 @@ class ChatGPTAPI:
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
-    self.default_model = "llama-3.2-1b"
+    self.default_model = default_model or "llama-3.2-1b"
 
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
@@ -175,13 +174,23 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
+    cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop())
 
   async def timeout_middleware(self, app, handler):
     async def middleware(request):
@@ -202,6 +211,9 @@ class ChatGPTAPI:
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
 
+  async def handle_healthcheck(self, request):
+    return web.json_response({"status": "ok"})
+
   async def handle_model_support(self, request):
     try:
       model_pool = {}
@@ -274,8 +286,8 @@ class ChatGPTAPI:
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     chat_request = parse_chat_request(data, self.default_model)
-    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-      chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.2-1b"
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
+      chat_request.model = self.default_model
     if not chat_request.model or chat_request.model not in model_cards:
       if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model
@@ -399,7 +411,7 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
-  async def run(self, host: str = "0.0.0.0", port: int = 8000):
+  async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()
     site = web.TCPSite(runner, host, port)

+ 31 - 5
exo/download/hf/hf_helpers.py

@@ -1,7 +1,11 @@
+import aiofiles.os as aios
+from typing import Union
 import asyncio
 import aiohttp
 import json
 import os
+import sys
+import shutil
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 import aiofiles
@@ -17,7 +21,6 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
@@ -99,9 +102,22 @@ async def get_auth_headers():
 
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  async for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if dest_path.exists():
+        if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
+      else:
+        await aios.rename(str(path), str(dest_path))
+        
 
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
@@ -417,11 +433,10 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     elif shard.is_last_layer():
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
-    shard_specific_patterns = set("*.safetensors")
+    shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
 
-
 async def get_file_download_percentage(
     session: aiohttp.ClientSession,
     repo_id: str,
@@ -472,3 +487,14 @@ async def get_file_download_percentage(
     if DEBUG >= 2:
       print(f"Error checking file download status for {file_path}: {e}")
     return 0
+
+async def has_hf_home_read_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.R_OK)
+  except OSError: return False
+
+async def has_hf_home_write_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.W_OK)
+  except OSError: return False
+

+ 20 - 0
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
+import sys
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
+
+
+async def shutdown(signal, loop):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 2 - 1
exo/inference/inference_engine.py

@@ -26,7 +26,8 @@ class InferenceEngine(ABC):
   
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
-    output_data = await self.infer_tensor(request_id, shard, tokens)
+    x = tokens.reshape(1, -1)
+    output_data = await self.infer_tensor(request_id, shard, x)
     return output_data 
 
 inference_engine_classes = {

+ 3 - 2
exo/inference/mlx/sharded_utils.py

@@ -21,6 +21,7 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 
 
@@ -136,7 +137,7 @@ def load_model_shard(
       self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
 
     def __call__(self, x, *args, **kwargs):
-      y = super().__call__(x[None] if self.shard.is_first_layer() else x, *args, **kwargs)
+      y = super().__call__(x, *args, **kwargs)
       return y
 
   model_args = model_args_class.from_dict(config)
@@ -183,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     return model, processor
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
 
 

+ 5 - 9
exo/inference/tinygrad/inference.py

@@ -7,7 +7,6 @@ from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad import Tensor, nn, Context
 from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
@@ -68,24 +67,21 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
     logits = x[:, -1, :]
     def sample_wrapper():
-      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
-    out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
-    return out.numpy().astype(int)
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return np.array(tokens)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
   
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-    return tokens
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
-    return output_data.numpy()
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 32 - 22
exo/main.py

@@ -3,6 +3,9 @@ import asyncio
 import signal
 import json
 import logging
+import platform
+import os
+import sys
 import time
 import traceback
 import uuid
@@ -17,22 +20,24 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
@@ -42,7 +47,7 @@ parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale",
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
-parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
+parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
@@ -121,20 +126,20 @@ api = ChatGPTAPI(
   node,
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
-  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
+  default_model=args.default_model
 )
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
 
-
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
   except Exception as e:
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
@@ -160,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
-
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
-
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
@@ -206,12 +197,31 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 async def main():
   loop = asyncio.get_running_loop()
 
+  # Check HuggingFace directory permissions
+  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
+  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  print(f"{has_read=}, {has_write=}")
+  if not has_read or not has_write:
+    print(f"""
+          WARNING: Limited permissions for model storage directory: {hf_home}.
+          This may prevent model downloads from working correctly.
+          {"❌ No read access" if not has_read else ""}
+          {"❌ No write access" if not has_write else ""}
+          """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      await move_models_to_hf(args.models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
+
   # Use a more direct approach to handle signals
   def handle_exit():
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
   await node.start(wait_for_peers=args.wait_for_peers)
 

+ 25 - 3
exo/models.py

@@ -1,11 +1,11 @@
 from exo.inference.shard import Shard
-from typing import Optional
+from typing import Optional, List
 
 model_cards = {
   ### llama
   "llama-3.2-1b": {
     "layers": 16,
-    "repo": { 
+    "repo": {
       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
     },
@@ -63,6 +63,7 @@ model_cards = {
   ### llava
   "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
+  "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
   "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
   "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
   "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
@@ -123,4 +124,25 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
   if repo is None or n_layers < 1:
     return None
   return Shard(model_id, 0, 0, n_layers)
-  
+
+def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+  if not supported_inference_engine_lists:
+    return list(model_cards.keys())
+
+  from exo.inference.inference_engine import inference_engine_classes
+  supported_inference_engine_lists = [
+    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+    for engine_list in supported_inference_engine_lists
+  ]
+
+  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
+    return any(engine in model_info.get("repo", {}) for engine in engine_list)
+
+  def supports_all_engine_lists(model_info: dict) -> bool:
+    return all(has_any_engine(model_info, engine_list)
+              for engine_list in supported_inference_engine_lists)
+
+  return [
+    model_id for model_id, model_info in model_cards.items()
+    if supports_all_engine_lists(model_info)
+  ]

+ 55 - 56
exo/orchestration/standard_node.py

@@ -102,11 +102,7 @@ class StandardNode(Node):
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
   
-  async def encode_prompt(self, shard: Shard, prompt):
-    toks = await self.inference_engine.encode(shard, prompt)
-    return toks
-  
-  async def process_result(
+  async def process_inference_result(
     self,
     shard,
     result: np.ndarray,
@@ -114,32 +110,24 @@ class StandardNode(Node):
   ):
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
-    
-    if request_id not in self.buffered_logits:
-      self.buffered_logits[request_id] = []
-
-    self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
-
-    if shard.is_last_layer():
-      result = await self.inference_engine.sample(result)
-    
-    await self.inference_engine.ensure_shard(shard)
-    is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:  # we got a new token out
-      self.buffered_token_output[request_id][0].append(result.item())
+    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if shard.is_last_layer() and not is_finished:
+      token = await self.inference_engine.sample(result)
+      await self.inference_engine.ensure_shard(shard)
+      self.buffered_token_output[request_id][0].append(token.item())
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-    
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      forward = token.reshape(1, -1)
+    else:
+      forward = result
 
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
     else:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
 
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    return np.array(self.buffered_token_output[request_id][0])
 
   async def process_prompt(
     self,
@@ -190,13 +178,13 @@ class StandardNode(Node):
     shard = self.get_current_shard(base_shard)
 
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
-    if shard.start_layer != 0:
+    if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      await self.forward_to_next_shard(shard, prompt, request_id)
+      resp = await self.forward_prompt(shard, prompt, request_id, 0)
       return None
     else:
       result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return result
 
   async def process_tensor(
@@ -255,46 +243,57 @@ class StandardNode(Node):
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return ret
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       return None
 
-  async def forward_to_next_shard(
+  async def forward_prompt(
     self,
     base_shard: Shard,
-    tensor_or_prompt: Union[np.ndarray, str],
+    prompt: str,
     request_id: str,
+    target_index: int,
   ) -> None:
-    if not self.partitioning_strategy:
-      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return
-
-    next_partition_index = self.get_partition_index(offset = 1)
-    if DEBUG >= 1: print(f"Next partition index: {next_partition_index}")
-    if next_partition_index is not None:
-      target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
-      next_shard = self.get_current_shard(base_shard, next_partition_index)
-      if DEBUG >= 2: print(f"Computed next from: {base_shard} {next_partition_index}, {self.topology}. Next shard: {next_shard}")
-      is_tensor = isinstance(tensor_or_prompt, np.ndarray)
-      if target_id == self.id:
-        if is_tensor:
-          await self.process_tensor(next_shard, tensor_or_prompt, request_id)
-        else:
-          await self.process_prompt(next_shard, tensor_or_prompt, request_id)
-      else:
-        target_peer = next((p for p in self.peers if p.id() == target_id), None)
-        if not target_peer:
-          raise ValueError(f"Peer for {next_partition_index} not found")
-        if is_tensor:
-          if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
-          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id)
-        else:
-          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id)
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
 
   def get_partition_index(self, offset: int = 0):
+    if not self.partitioning_strategy:
+      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+      return None
     partitions = self.partitioning_strategy.partition(self.topology)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if current_partition_index is None:

+ 2 - 0
exo/tinychat/index.css

@@ -167,6 +167,8 @@ main {
 .download-progress {
   margin-bottom: 12em;
   overflow-y: auto;
+  min-height: 350px;
+  padding: 2rem;
 }
 .message > pre {
   white-space: pre-wrap;

+ 1 - 1
extra/start_openwebui.sh

@@ -1,3 +1,3 @@
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main

+ 60 - 0
scripts/build_exo.py

@@ -0,0 +1,60 @@
+import site
+import subprocess
+import sys
+import os 
+import pkgutil
+
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--macos-signed-app-name=com.exolabs.exo",
+            "--macos-sign-identity=auto",
+            "--macos-sign-notarization",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 5 - 5
setup.py

@@ -5,14 +5,15 @@ from setuptools import find_packages, setup
 
 # Base requirements for all platforms
 install_requires = [
-  "aiohttp==3.10.2",
+  "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
-  "grpcio==1.64.1",
-  "grpcio-tools==1.64.1",
+  "grpcio==1.68.0",
+  "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
@@ -21,10 +22,9 @@ install_requires = [
   "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",
-  "safetensors==0.4.3",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
-  "transformers==4.43.3",
+  "transformers==4.46.3",
   "uuid==1.30",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
 ]

+ 1 - 1
test/reconnect.sh

@@ -1,7 +1,7 @@
 #!/bin/bash
 
 echo "Starting node 1"
-DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
 PID1=$!
 echo "Started node 1 PID: $PID1"
 echo "Starting node 2"

+ 121 - 0
test/test_model_helpers.py

@@ -0,0 +1,121 @@
+import unittest
+from exo.models import get_supported_models, model_cards
+from exo.inference.inference_engine import inference_engine_classes
+from typing import NamedTuple
+
+class TestCase(NamedTuple):
+  name: str
+  engine_lists: list  # Will contain short names, will be mapped to class names
+  expected_models_contains: list
+  min_count: int | None
+  exact_count: int | None
+  max_count: int | None
+
+# Helper function to map short names to class names
+def expand_engine_lists(engine_lists):
+  def map_engine(engine):
+    return inference_engine_classes.get(engine, engine)  # Return original name if not found
+
+  return [[map_engine(engine) for engine in sublist]
+          for sublist in engine_lists]
+
+test_cases = [
+  TestCase(
+    name="single_mlx_engine",
+    engine_lists=[["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="single_tinygrad_engine",
+    engine_lists=[["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="multiple_engines_or",
+    engine_lists=[["mlx", "tinygrad"], ["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="multiple_engines_all",
+    engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="distinct_engine_lists",
+    engine_lists=[["mlx"], ["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="no_engines",
+    engine_lists=[],
+    expected_models_contains=None,
+    min_count=None,
+    exact_count=len(model_cards),
+    max_count=None
+  ),
+  TestCase(
+    name="nonexistent_engine",
+    engine_lists=[["NonexistentEngine"]],
+    expected_models_contains=[],
+    min_count=None,
+    exact_count=0,
+    max_count=None
+  ),
+  TestCase(
+    name="dummy_engine",
+    engine_lists=[["dummy"]],
+    expected_models_contains=["dummy"],
+    min_count=None,
+    exact_count=1,
+    max_count=None
+  ),
+]
+
+class TestModelHelpers(unittest.TestCase):
+  def test_get_supported_models(self):
+    for case in test_cases:
+      with self.subTest(f"{case.name}_short_names"):
+        result = get_supported_models(case.engine_lists)
+        self._verify_results(case, result)
+
+      with self.subTest(f"{case.name}_class_names"):
+        class_name_lists = expand_engine_lists(case.engine_lists)
+        result = get_supported_models(class_name_lists)
+        self._verify_results(case, result)
+
+  def _verify_results(self, case, result):
+    if case.expected_models_contains:
+      for model in case.expected_models_contains:
+        self.assertIn(model, result)
+
+    if case.min_count:
+      self.assertGreater(len(result), case.min_count)
+
+    if case.exact_count is not None:
+      self.assertEqual(len(result), case.exact_count)
+
+    # Special case for distinct lists test
+    if case.name == "distinct_engine_lists":
+      self.assertLess(len(result), 10)
+      self.assertNotIn("mistral-nemo", result)
+
+    if case.max_count:
+      self.assertLess(len(result), case.max_count)
+
+if __name__ == '__main__':
+  unittest.main()