|
@@ -5,6 +5,7 @@ import json
|
|
|
import time
|
|
|
import traceback
|
|
|
import uuid
|
|
|
+import sys
|
|
|
from exo.orchestration.standard_node import StandardNode
|
|
|
from exo.networking.grpc.grpc_server import GRPCServer
|
|
|
from exo.networking.udp.udp_discovery import UDPDiscovery
|
|
@@ -24,6 +25,8 @@ from exo.viz.topology_viz import TopologyViz
|
|
|
|
|
|
# parse args
|
|
|
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
|
|
+parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
|
|
|
+parser.add_argument("model_name", nargs="?", help="Model name to run")
|
|
|
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
|
|
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
|
|
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
|
@@ -179,8 +182,12 @@ async def main():
|
|
|
|
|
|
await node.start(wait_for_peers=args.wait_for_peers)
|
|
|
|
|
|
- if args.run_model:
|
|
|
- await run_model_cli(node, inference_engine, args.run_model, args.prompt)
|
|
|
+ if args.command == "run" or args.run_model:
|
|
|
+ model_name = args.model_name or args.run_model
|
|
|
+ if not model_name:
|
|
|
+ print("Error: Model name is required when using 'run' command or --run-model")
|
|
|
+ return
|
|
|
+ await run_model_cli(node, inference_engine, model_name, args.prompt)
|
|
|
else:
|
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
|
|
await asyncio.Event().wait()
|