浏览代码

Merge branch 'main' into downloadedModelsV2

Caden MacKenzie 8 月之前
父节点
当前提交
2cdd55d297

+ 1 - 0
.circleci/config.yml

@@ -126,6 +126,7 @@ jobs:
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             echo "Running tokenizer tests..."
             echo "Running tokenizer tests..."
             python3 ./test/test_tokenizers.py
             python3 ./test/test_tokenizers.py
+            python3 ./test/test_model_helpers.py
 
 
   discovery_integration_test:
   discovery_integration_test:
     macos:
     macos:

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_used_ports
 .exo_node_id
 .exo_node_id
 .idea
 .idea
+.DS_Store
 
 
 # Byte-compiled / optimized / DLL files
 # Byte-compiled / optimized / DLL files
 __pycache__/
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 
 # Distribution / packaging
 # Distribution / packaging
 /.Python
 /.Python
-/build/
 /develop-eggs/
 /develop-eggs/
 /dist/
 /dist/
 /downloads/
 /downloads/

+ 5 - 5
README.md

@@ -121,14 +121,14 @@ exo
 
 
 That's it! No configuration required - exo will automatically discover the other device(s).
 That's it! No configuration required - exo will automatically discover the other device(s).
 
 
-exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
+exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:52415
 
 
-For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Examples with curl:
+For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:52415/v1/chat/completions. Examples with curl:
 
 
 #### Llama 3.2 3B:
 #### Llama 3.2 3B:
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llama-3.2-3b",
      "model": "llama-3.2-3b",
@@ -140,7 +140,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llama 3.1 405B:
 #### Llama 3.1 405B:
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llama-3.1-405b",
      "model": "llama-3.1-405b",
@@ -152,7 +152,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llava 1.5 7B (Vision Language Model):
 #### Llava 1.5 7B (Vision Language Model):
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llava-1.5-7b-hf",
      "model": "llava-1.5-7b-hf",

+ 18 - 3
configure_mlx.sh

@@ -1,3 +1,18 @@
-# this needs to be dynamic
-# sudo sysctl iogpu.wired_lwm_mb=400000
-# sudo sysctl iogpu.wired_limit_mb=180000
+#!/bin/bash
+
+# Get the total memory in MB
+TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
+
+# Set WIRED_LIMIT_MB to 80%
+WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
+# Set  WIRED_LWM_MB to 70%
+WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
+
+# Display the calculated values
+echo "Total memory: $TOTAL_MEM_MB MB"
+echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
+echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
+
+# Apply the values with sysctl
+sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
+sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB

二进制
docs/exo-rounded.png


+ 1 - 1
examples/astra/astra/ContentView.swift

@@ -148,7 +148,7 @@ struct ContentView: View {
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var silenceTimeThreshold = 1.0
     @State private var silenceTimeThreshold = 1.0
     @State private var debugText = ""
     @State private var debugText = ""
-    @State private var apiEndpoint = "http://192.168.212.74:8000/v1/chat/completions"
+    @State private var apiEndpoint = "http://192.168.212.74:52415/v1/chat/completions"
     @State private var audioBuffer: [Float] = []
     @State private var audioBuffer: [Float] = []
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var isInitialTranscription = true
     @State private var isInitialTranscription = true

+ 1 - 1
examples/chatgpt_api.sh

@@ -3,7 +3,7 @@
 # This works the same in a single-node set up and in a multi-node setup.
 # This works the same in a single-node set up and in a multi-node setup.
 # You need to start exo before running this by running `python3 main.py`.
 # You need to start exo before running this by running `python3 main.py`.
 
 
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 MODEL="llama-3.1-8b"
 MODEL="llama-3.1-8b"
 PROMPT="What is the meaning of exo?"
 PROMPT="What is the meaning of exo?"
 TEMPERATURE=0.7
 TEMPERATURE=0.7

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 25 - 13
exo/api/chatgpt_api.py

@@ -9,18 +9,17 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
 import traceback
 import traceback
+import os
+import sys
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
-from exo.inference.inference_engine import inference_engine_classes
-from exo.inference.shard import Shard
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
 from typing import Callable
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
 
-
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     self.role = role
     self.role = role
@@ -30,6 +29,7 @@ class Message:
     return {"role": self.role, "content": self.content}
     return {"role": self.role, "content": self.content}
 
 
 
 
+
 class ChatCompletionRequest:
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
     self.model = model
@@ -147,9 +147,8 @@ class PromptSession:
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
-
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
     self.response_timeout = response_timeout
@@ -158,7 +157,7 @@ class ChatGPTAPI:
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
-    self.default_model = "llama-3.2-1b"
+    self.default_model = default_model or "llama-3.2-1b"
 
 
     cors = aiohttp_cors.setup(self.app)
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
     cors_options = aiohttp_cors.ResourceOptions(
@@ -175,13 +174,23 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
+    cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop())
 
 
   async def timeout_middleware(self, app, handler):
   async def timeout_middleware(self, app, handler):
     async def middleware(request):
     async def middleware(request):
@@ -202,6 +211,9 @@ class ChatGPTAPI:
   async def handle_root(self, request):
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
     return web.FileResponse(self.static_dir/"index.html")
 
 
+  async def handle_healthcheck(self, request):
+    return web.json_response({"status": "ok"})
+
   async def handle_model_support(self, request):
   async def handle_model_support(self, request):
     try:
     try:
       model_pool = {}
       model_pool = {}
@@ -274,8 +286,8 @@ class ChatGPTAPI:
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     stream = data.get("stream", False)
     chat_request = parse_chat_request(data, self.default_model)
     chat_request = parse_chat_request(data, self.default_model)
-    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-      chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.2-1b"
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
+      chat_request.model = self.default_model
     if not chat_request.model or chat_request.model not in model_cards:
     if not chat_request.model or chat_request.model not in model_cards:
       if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model
       chat_request.model = self.default_model
@@ -399,7 +411,7 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
 
-  async def run(self, host: str = "0.0.0.0", port: int = 8000):
+  async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     runner = web.AppRunner(self.app)
     await runner.setup()
     await runner.setup()
     site = web.TCPSite(runner, host, port)
     site = web.TCPSite(runner, host, port)

+ 31 - 5
exo/download/hf/hf_helpers.py

@@ -1,7 +1,11 @@
+import aiofiles.os as aios
+from typing import Union
 import asyncio
 import asyncio
 import aiohttp
 import aiohttp
 import json
 import json
 import os
 import os
+import sys
+import shutil
 from urllib.parse import urljoin
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 import aiofiles
 import aiofiles
@@ -17,7 +21,6 @@ from aiofiles import os as aios
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
   refs_file = refs_dir/revision
@@ -99,9 +102,22 @@ async def get_auth_headers():
 
 
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  async for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if dest_path.exists():
+        if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
+      else:
+        await aios.rename(str(path), str(dest_path))
+        
 
 
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
@@ -417,11 +433,10 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     elif shard.is_last_layer():
     elif shard.is_last_layer():
       shard_specific_patterns.add(sorted_file_names[-1])
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
   else:
-    shard_specific_patterns = set("*.safetensors")
+    shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
   return list(default_patterns | shard_specific_patterns)
 
 
-
 async def get_file_download_percentage(
 async def get_file_download_percentage(
     session: aiohttp.ClientSession,
     session: aiohttp.ClientSession,
     repo_id: str,
     repo_id: str,
@@ -472,3 +487,14 @@ async def get_file_download_percentage(
     if DEBUG >= 2:
     if DEBUG >= 2:
       print(f"Error checking file download status for {file_path}: {e}")
       print(f"Error checking file download status for {file_path}: {e}")
     return 0
     return 0
+
+async def has_hf_home_read_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.R_OK)
+  except OSError: return False
+
+async def has_hf_home_write_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.W_OK)
+  except OSError: return False
+

+ 20 - 0
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
 import os
+import sys
 import asyncio
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
 import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
     return ["localhost"]
+
+
+async def shutdown(signal, loop):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 2 - 1
exo/inference/inference_engine.py

@@ -26,7 +26,8 @@ class InferenceEngine(ABC):
   
   
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
     tokens = await self.encode(shard, prompt)
-    output_data = await self.infer_tensor(request_id, shard, tokens)
+    x = tokens.reshape(1, -1)
+    output_data = await self.infer_tensor(request_id, shard, x)
     return output_data 
     return output_data 
 
 
 inference_engine_classes = {
 inference_engine_classes = {

+ 3 - 2
exo/inference/mlx/sharded_utils.py

@@ -21,6 +21,7 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 
 
 from exo import DEBUG
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -136,7 +137,7 @@ def load_model_shard(
       self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
       self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
 
 
     def __call__(self, x, *args, **kwargs):
     def __call__(self, x, *args, **kwargs):
-      y = super().__call__(x[None] if self.shard.is_first_layer() else x, *args, **kwargs)
+      y = super().__call__(x, *args, **kwargs)
       return y
       return y
 
 
   model_args = model_args_class.from_dict(config)
   model_args = model_args_class.from_dict(config)
@@ -183,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     processor.encode = processor.tokenizer.encode
     return model, processor
     return model, processor
   else:
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
     return model, tokenizer
 
 
 
 

+ 5 - 9
exo/inference/tinygrad/inference.py

@@ -7,7 +7,6 @@ from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad.nn.state import load_state_dict
 from tinygrad import Tensor, nn, Context
 from tinygrad import Tensor, nn, Context
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
@@ -68,24 +67,21 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
     logits = x[:, -1, :]
     logits = x[:, -1, :]
     def sample_wrapper():
     def sample_wrapper():
-      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
-    out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
-    return out.numpy().astype(int)
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return np.array(tokens)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
   
   
   async def decode(self, shard: Shard, tokens) -> str:
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-    return tokens
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
 
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
-    return output_data.numpy()
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:

+ 32 - 22
exo/main.py

@@ -3,6 +3,9 @@ import asyncio
 import signal
 import signal
 import json
 import json
 import logging
 import logging
+import platform
+import os
+import sys
 import time
 import time
 import traceback
 import traceback
 import uuid
 import uuid
@@ -17,22 +20,24 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.orchestration.node import Node
 from exo.models import build_base_shard, get_repo
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
@@ -42,7 +47,7 @@ parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale",
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
-parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
+parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
@@ -121,20 +126,20 @@ api = ChatGPTAPI(
   node,
   node,
   inference_engine.__class__.__name__,
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   response_timeout=args.chatgpt_api_response_timeout,
-  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
+  default_model=args.default_model
 )
 )
 node.on_token.register("update_topology_viz").on_next(
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
 )
 
 
-
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
   try:
     status = json.loads(opaque_status)
     status = json.loads(opaque_status)
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
   except Exception as e:
   except Exception as e:
     if DEBUG >= 2:
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
       print(f"Failed to preemptively start download: {e}")
@@ -160,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
-
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
-
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   inference_class = inference_engine.__class__.__name__
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   shard = build_base_shard(model_name, inference_class)
@@ -206,12 +197,31 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 async def main():
 async def main():
   loop = asyncio.get_running_loop()
   loop = asyncio.get_running_loop()
 
 
+  # Check HuggingFace directory permissions
+  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
+  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  print(f"{has_read=}, {has_write=}")
+  if not has_read or not has_write:
+    print(f"""
+          WARNING: Limited permissions for model storage directory: {hf_home}.
+          This may prevent model downloads from working correctly.
+          {"❌ No read access" if not has_read else ""}
+          {"❌ No write access" if not has_write else ""}
+          """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      await move_models_to_hf(args.models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
+
   # Use a more direct approach to handle signals
   # Use a more direct approach to handle signals
   def handle_exit():
   def handle_exit():
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
 
   await node.start(wait_for_peers=args.wait_for_peers)
   await node.start(wait_for_peers=args.wait_for_peers)
 
 

+ 25 - 3
exo/models.py

@@ -1,11 +1,11 @@
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
-from typing import Optional
+from typing import Optional, List
 
 
 model_cards = {
 model_cards = {
   ### llama
   ### llama
   "llama-3.2-1b": {
   "llama-3.2-1b": {
     "layers": 16,
     "layers": 16,
-    "repo": { 
+    "repo": {
       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
     },
     },
@@ -63,6 +63,7 @@ model_cards = {
   ### llava
   ### llava
   "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
   ### qwen
+  "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
   "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
   "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
   "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
   "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
   "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
   "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
@@ -123,4 +124,25 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
   if repo is None or n_layers < 1:
   if repo is None or n_layers < 1:
     return None
     return None
   return Shard(model_id, 0, 0, n_layers)
   return Shard(model_id, 0, 0, n_layers)
-  
+
+def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+  if not supported_inference_engine_lists:
+    return list(model_cards.keys())
+
+  from exo.inference.inference_engine import inference_engine_classes
+  supported_inference_engine_lists = [
+    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+    for engine_list in supported_inference_engine_lists
+  ]
+
+  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
+    return any(engine in model_info.get("repo", {}) for engine in engine_list)
+
+  def supports_all_engine_lists(model_info: dict) -> bool:
+    return all(has_any_engine(model_info, engine_list)
+              for engine_list in supported_inference_engine_lists)
+
+  return [
+    model_id for model_id, model_info in model_cards.items()
+    if supports_all_engine_lists(model_info)
+  ]

+ 55 - 56
exo/orchestration/standard_node.py

@@ -102,11 +102,7 @@ class StandardNode(Node):
   def get_topology_inference_engines(self) -> List[List[str]]:
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
     return self.topology_inference_engines_pool
   
   
-  async def encode_prompt(self, shard: Shard, prompt):
-    toks = await self.inference_engine.encode(shard, prompt)
-    return toks
-  
-  async def process_result(
+  async def process_inference_result(
     self,
     self,
     shard,
     shard,
     result: np.ndarray,
     result: np.ndarray,
@@ -114,32 +110,24 @@ class StandardNode(Node):
   ):
   ):
     if request_id not in self.buffered_token_output:
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
       self.buffered_token_output[request_id] = ([], False)
-    
-    if request_id not in self.buffered_logits:
-      self.buffered_logits[request_id] = []
-
-    self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
-
-    if shard.is_last_layer():
-      result = await self.inference_engine.sample(result)
-    
-    await self.inference_engine.ensure_shard(shard)
-    is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:  # we got a new token out
-      self.buffered_token_output[request_id][0].append(result.item())
+    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if shard.is_last_layer() and not is_finished:
+      token = await self.inference_engine.sample(result)
+      await self.inference_engine.ensure_shard(shard)
+      self.buffered_token_output[request_id][0].append(token.item())
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-    
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      forward = token.reshape(1, -1)
+    else:
+      forward = result
 
 
     if is_finished:
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
     else:
     else:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
 
 
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    return np.array(self.buffered_token_output[request_id][0])
 
 
   async def process_prompt(
   async def process_prompt(
     self,
     self,
@@ -190,13 +178,13 @@ class StandardNode(Node):
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
 
 
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
-    if shard.start_layer != 0:
+    if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      await self.forward_to_next_shard(shard, prompt, request_id)
+      resp = await self.forward_prompt(shard, prompt, request_id, 0)
       return None
       return None
     else:
     else:
       result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
       result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return result
       return result
 
 
   async def process_tensor(
   async def process_tensor(
@@ -255,46 +243,57 @@ class StandardNode(Node):
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
     try:
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return ret
       return ret
     except Exception as e:
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       traceback.print_exc()
       return None
       return None
 
 
-  async def forward_to_next_shard(
+  async def forward_prompt(
     self,
     self,
     base_shard: Shard,
     base_shard: Shard,
-    tensor_or_prompt: Union[np.ndarray, str],
+    prompt: str,
     request_id: str,
     request_id: str,
+    target_index: int,
   ) -> None:
   ) -> None:
-    if not self.partitioning_strategy:
-      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return
-
-    next_partition_index = self.get_partition_index(offset = 1)
-    if DEBUG >= 1: print(f"Next partition index: {next_partition_index}")
-    if next_partition_index is not None:
-      target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
-      next_shard = self.get_current_shard(base_shard, next_partition_index)
-      if DEBUG >= 2: print(f"Computed next from: {base_shard} {next_partition_index}, {self.topology}. Next shard: {next_shard}")
-      is_tensor = isinstance(tensor_or_prompt, np.ndarray)
-      if target_id == self.id:
-        if is_tensor:
-          await self.process_tensor(next_shard, tensor_or_prompt, request_id)
-        else:
-          await self.process_prompt(next_shard, tensor_or_prompt, request_id)
-      else:
-        target_peer = next((p for p in self.peers if p.id() == target_id), None)
-        if not target_peer:
-          raise ValueError(f"Peer for {next_partition_index} not found")
-        if is_tensor:
-          if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
-          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id)
-        else:
-          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id)
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
 
 
   def get_partition_index(self, offset: int = 0):
   def get_partition_index(self, offset: int = 0):
+    if not self.partitioning_strategy:
+      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+      return None
     partitions = self.partitioning_strategy.partition(self.topology)
     partitions = self.partitioning_strategy.partition(self.topology)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if current_partition_index is None:
     if current_partition_index is None:

+ 2 - 0
exo/tinychat/index.css

@@ -167,6 +167,8 @@ main {
 .download-progress {
 .download-progress {
   margin-bottom: 12em;
   margin-bottom: 12em;
   overflow-y: auto;
   overflow-y: auto;
+  min-height: 350px;
+  padding: 2rem;
 }
 }
 .message > pre {
 .message > pre {
   white-space: pre-wrap;
   white-space: pre-wrap;

+ 1 - 1
extra/start_openwebui.sh

@@ -1,3 +1,3 @@
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main

+ 60 - 0
scripts/build_exo.py

@@ -0,0 +1,60 @@
+import site
+import subprocess
+import sys
+import os 
+import pkgutil
+
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--macos-signed-app-name=com.exolabs.exo",
+            "--macos-sign-identity=auto",
+            "--macos-sign-notarization",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 5 - 5
setup.py

@@ -5,14 +5,15 @@ from setuptools import find_packages, setup
 
 
 # Base requirements for all platforms
 # Base requirements for all platforms
 install_requires = [
 install_requires = [
-  "aiohttp==3.10.2",
+  "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
   "aiofiles==24.1.0",
-  "grpcio==1.64.1",
-  "grpcio-tools==1.64.1",
+  "grpcio==1.68.0",
+  "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "netifaces==0.11.0",
   "numpy==2.0.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "prometheus-client==0.20.0",
@@ -21,10 +22,9 @@ install_requires = [
   "pydantic==2.9.2",
   "pydantic==2.9.2",
   "requests==2.32.3",
   "requests==2.32.3",
   "rich==13.7.1",
   "rich==13.7.1",
-  "safetensors==0.4.3",
   "tenacity==9.0.0",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
   "tqdm==4.66.4",
-  "transformers==4.43.3",
+  "transformers==4.46.3",
   "uuid==1.30",
   "uuid==1.30",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
 ]
 ]

+ 1 - 1
test/reconnect.sh

@@ -1,7 +1,7 @@
 #!/bin/bash
 #!/bin/bash
 
 
 echo "Starting node 1"
 echo "Starting node 1"
-DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
 PID1=$!
 PID1=$!
 echo "Started node 1 PID: $PID1"
 echo "Started node 1 PID: $PID1"
 echo "Starting node 2"
 echo "Starting node 2"

+ 121 - 0
test/test_model_helpers.py

@@ -0,0 +1,121 @@
+import unittest
+from exo.models import get_supported_models, model_cards
+from exo.inference.inference_engine import inference_engine_classes
+from typing import NamedTuple
+
+class TestCase(NamedTuple):
+  name: str
+  engine_lists: list  # Will contain short names, will be mapped to class names
+  expected_models_contains: list
+  min_count: int | None
+  exact_count: int | None
+  max_count: int | None
+
+# Helper function to map short names to class names
+def expand_engine_lists(engine_lists):
+  def map_engine(engine):
+    return inference_engine_classes.get(engine, engine)  # Return original name if not found
+
+  return [[map_engine(engine) for engine in sublist]
+          for sublist in engine_lists]
+
+test_cases = [
+  TestCase(
+    name="single_mlx_engine",
+    engine_lists=[["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="single_tinygrad_engine",
+    engine_lists=[["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="multiple_engines_or",
+    engine_lists=[["mlx", "tinygrad"], ["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="multiple_engines_all",
+    engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="distinct_engine_lists",
+    engine_lists=[["mlx"], ["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="no_engines",
+    engine_lists=[],
+    expected_models_contains=None,
+    min_count=None,
+    exact_count=len(model_cards),
+    max_count=None
+  ),
+  TestCase(
+    name="nonexistent_engine",
+    engine_lists=[["NonexistentEngine"]],
+    expected_models_contains=[],
+    min_count=None,
+    exact_count=0,
+    max_count=None
+  ),
+  TestCase(
+    name="dummy_engine",
+    engine_lists=[["dummy"]],
+    expected_models_contains=["dummy"],
+    min_count=None,
+    exact_count=1,
+    max_count=None
+  ),
+]
+
+class TestModelHelpers(unittest.TestCase):
+  def test_get_supported_models(self):
+    for case in test_cases:
+      with self.subTest(f"{case.name}_short_names"):
+        result = get_supported_models(case.engine_lists)
+        self._verify_results(case, result)
+
+      with self.subTest(f"{case.name}_class_names"):
+        class_name_lists = expand_engine_lists(case.engine_lists)
+        result = get_supported_models(class_name_lists)
+        self._verify_results(case, result)
+
+  def _verify_results(self, case, result):
+    if case.expected_models_contains:
+      for model in case.expected_models_contains:
+        self.assertIn(model, result)
+
+    if case.min_count:
+      self.assertGreater(len(result), case.min_count)
+
+    if case.exact_count is not None:
+      self.assertEqual(len(result), case.exact_count)
+
+    # Special case for distinct lists test
+    if case.name == "distinct_engine_lists":
+      self.assertLess(len(result), 10)
+      self.assertNotIn("mistral-nemo", result)
+
+    if case.max_count:
+      self.assertLess(len(result), case.max_count)
+
+if __name__ == '__main__':
+  unittest.main()