|
@@ -3,6 +3,9 @@ import asyncio
|
|
|
import signal
|
|
|
import json
|
|
|
import logging
|
|
|
+import platform
|
|
|
+import os
|
|
|
+import sys
|
|
|
import time
|
|
|
import traceback
|
|
|
import uuid
|
|
@@ -17,14 +20,14 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
|
|
|
from exo.api import ChatGPTAPI
|
|
|
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
|
|
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from exo.orchestration.node import Node
|
|
|
from exo.models import build_base_shard, get_repo
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
-from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home
|
|
|
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
|
|
|
|
|
|
# parse args
|
|
|
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
|
@@ -34,6 +37,7 @@ parser.add_argument("--default-model", type=str, default=None, help="Default mod
|
|
|
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
|
|
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
|
|
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
|
|
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
|
|
|
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
|
|
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
|
|
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
|
@@ -129,7 +133,6 @@ node.on_token.register("update_topology_viz").on_next(
|
|
|
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
|
|
)
|
|
|
|
|
|
-
|
|
|
def preemptively_start_download(request_id: str, opaque_status: str):
|
|
|
try:
|
|
|
status = json.loads(opaque_status)
|
|
@@ -162,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
|
|
|
|
|
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
|
|
|
|
|
-
|
|
|
-async def shutdown(signal, loop):
|
|
|
- """Gracefully shutdown the server and close the asyncio loop."""
|
|
|
- print(f"Received exit signal {signal.name}...")
|
|
|
- print("Thank you for using exo.")
|
|
|
- print_yellow_exo()
|
|
|
- server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
|
- [task.cancel() for task in server_tasks]
|
|
|
- print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
|
|
- await asyncio.gather(*server_tasks, return_exceptions=True)
|
|
|
- await server.stop()
|
|
|
- loop.stop()
|
|
|
-
|
|
|
-
|
|
|
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
|
|
inference_class = inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|
|
@@ -219,13 +208,20 @@ async def main():
|
|
|
{"❌ No read access" if not has_read else ""}
|
|
|
{"❌ No write access" if not has_write else ""}
|
|
|
""")
|
|
|
+
|
|
|
+ if not args.models_seed_dir is None:
|
|
|
+ try:
|
|
|
+ await move_models_to_hf(args.models_seed_dir)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error moving models to .cache/huggingface: {e}")
|
|
|
|
|
|
# Use a more direct approach to handle signals
|
|
|
def handle_exit():
|
|
|
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
|
|
|
|
|
|
- for s in [signal.SIGINT, signal.SIGTERM]:
|
|
|
- loop.add_signal_handler(s, handle_exit)
|
|
|
+ if platform.system() != "Windows":
|
|
|
+ for s in [signal.SIGINT, signal.SIGTERM]:
|
|
|
+ loop.add_signal_handler(s, handle_exit)
|
|
|
|
|
|
await node.start(wait_for_peers=args.wait_for_peers)
|
|
|
|