|
@@ -20,7 +20,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
|
|
|
from exo.api import ChatGPTAPI
|
|
|
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
|
|
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
|
@@ -163,20 +163,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
|
|
|
|
|
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
|
|
|
|
|
-
|
|
|
-async def shutdown(signal, loop):
|
|
|
- """Gracefully shutdown the server and close the asyncio loop."""
|
|
|
- print(f"Received exit signal {signal.name}...")
|
|
|
- print("Thank you for using exo.")
|
|
|
- print_yellow_exo()
|
|
|
- server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
|
- [task.cancel() for task in server_tasks]
|
|
|
- print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
|
|
- await asyncio.gather(*server_tasks, return_exceptions=True)
|
|
|
- await server.stop()
|
|
|
- loop.stop()
|
|
|
-
|
|
|
-
|
|
|
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
|
|
inference_class = inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|