|
@@ -3,6 +3,9 @@ import asyncio
|
|
|
import signal
|
|
|
import json
|
|
|
import logging
|
|
|
+import platform
|
|
|
+import os
|
|
|
+import sys
|
|
|
import time
|
|
|
import traceback
|
|
|
import uuid
|
|
@@ -17,22 +20,24 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
|
|
|
from exo.api import ChatGPTAPI
|
|
|
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
|
|
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
-from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from exo.orchestration.node import Node
|
|
|
from exo.models import build_base_shard, get_repo
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
|
|
|
|
|
|
# parse args
|
|
|
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
|
|
parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
|
|
|
parser.add_argument("model_name", nargs="?", help="Model name to run")
|
|
|
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
|
|
|
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
|
|
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
|
|
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
|
|
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
|
|
|
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
|
|
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
|
|
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
|
@@ -42,7 +47,7 @@ parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale",
|
|
|
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
|
|
|
parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
|
|
|
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
|
|
-parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
|
|
+parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
|
|
|
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
|
|
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
|
|
|
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
|
|
@@ -121,20 +126,20 @@ api = ChatGPTAPI(
|
|
|
node,
|
|
|
inference_engine.__class__.__name__,
|
|
|
response_timeout=args.chatgpt_api_response_timeout,
|
|
|
- on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
|
|
|
+ on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
|
|
+ default_model=args.default_model
|
|
|
)
|
|
|
node.on_token.register("update_topology_viz").on_next(
|
|
|
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
|
|
)
|
|
|
|
|
|
-
|
|
|
def preemptively_start_download(request_id: str, opaque_status: str):
|
|
|
try:
|
|
|
status = json.loads(opaque_status)
|
|
|
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
|
|
|
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
|
|
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
|
|
- asyncio.create_task(shard_downloader.ensure_shard(current_shard))
|
|
|
+ asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2:
|
|
|
print(f"Failed to preemptively start download: {e}")
|
|
@@ -160,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
|
|
|
|
|
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
|
|
|
|
|
-
|
|
|
-async def shutdown(signal, loop):
|
|
|
- """Gracefully shutdown the server and close the asyncio loop."""
|
|
|
- print(f"Received exit signal {signal.name}...")
|
|
|
- print("Thank you for using exo.")
|
|
|
- print_yellow_exo()
|
|
|
- server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
|
- [task.cancel() for task in server_tasks]
|
|
|
- print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
|
|
- await asyncio.gather(*server_tasks, return_exceptions=True)
|
|
|
- await server.stop()
|
|
|
- loop.stop()
|
|
|
-
|
|
|
-
|
|
|
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
|
|
inference_class = inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|
|
@@ -206,12 +197,31 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
|
|
async def main():
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
|
|
+ # Check HuggingFace directory permissions
|
|
|
+ hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
|
|
|
+ if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
|
|
|
+ print(f"{has_read=}, {has_write=}")
|
|
|
+ if not has_read or not has_write:
|
|
|
+ print(f"""
|
|
|
+ WARNING: Limited permissions for model storage directory: {hf_home}.
|
|
|
+ This may prevent model downloads from working correctly.
|
|
|
+ {"❌ No read access" if not has_read else ""}
|
|
|
+ {"❌ No write access" if not has_write else ""}
|
|
|
+ """)
|
|
|
+
|
|
|
+ if not args.models_seed_dir is None:
|
|
|
+ try:
|
|
|
+ await move_models_to_hf(args.models_seed_dir)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error moving models to .cache/huggingface: {e}")
|
|
|
+
|
|
|
# Use a more direct approach to handle signals
|
|
|
def handle_exit():
|
|
|
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
|
|
|
|
|
|
- for s in [signal.SIGINT, signal.SIGTERM]:
|
|
|
- loop.add_signal_handler(s, handle_exit)
|
|
|
+ if platform.system() != "Windows":
|
|
|
+ for s in [signal.SIGINT, signal.SIGTERM]:
|
|
|
+ loop.add_signal_handler(s, handle_exit)
|
|
|
|
|
|
await node.start(wait_for_peers=args.wait_for_peers)
|
|
|
|