Explorar o código

add exo run command. usage: exo run <model-name> e.g. exo run llama-3.1-8b

Alex Cheema hai 7 meses
pai
achega
8ae59b7019
Modificáronse 1 ficheiros con 9 adicións e 2 borrados
  1. 9 2
      exo/main.py

+ 9 - 2
exo/main.py

@@ -5,6 +5,7 @@ import json
 import time
 import traceback
 import uuid
+import sys
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp.udp_discovery import UDPDiscovery
@@ -24,6 +25,8 @@ from exo.viz.topology_viz import TopologyViz
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
+parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
+parser.add_argument("model_name", nargs="?", help="Model name to run")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
@@ -179,8 +182,12 @@ async def main():
 
   await node.start(wait_for_peers=args.wait_for_peers)
 
-  if args.run_model:
-    await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+  if args.command == "run" or args.run_model:
+    model_name = args.model_name or args.run_model
+    if not model_name:
+      print("Error: Model name is required when using 'run' command or --run-model")
+      return
+    await run_model_cli(node, inference_engine, model_name, args.prompt)
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     await asyncio.Event().wait()