Pārlūkot izejas kodu

clean up unused, formatting

Alex Cheema 11 mēneši atpakaļ
vecāks
revīzija
581856897a
2 mainītis faili ar 3 papildinājumiem un 13 dzēšanām
  1. 0 3
      exo/orchestration/standard_node.py
  2. 3 10
      main.py

+ 0 - 3
exo/orchestration/standard_node.py

@@ -26,9 +26,6 @@ class StandardNode(Node):
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
-    chatgpt_api_endpoints: List[str] = [],
-    web_chat_urls: List[str] = [],
-    disable_tui: Optional[bool] = False,
     topology_viz: Optional[TopologyViz] = None,
   ):
     self.id = _id

+ 3 - 10
main.py

@@ -5,7 +5,6 @@ import json
 import time
 import traceback
 import uuid
-from asyncio import CancelledError
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -57,7 +56,6 @@ if args.node_port is None:
   if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 args.node_id = args.node_id or get_or_create_node_id()
-discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
 chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
 web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
 if DEBUG >= 0:
@@ -67,16 +65,15 @@ if DEBUG >= 0:
   print("ChatGPT API endpoint served at:")
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
+
+discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,
   None,
   inference_engine,
   discovery,
-  chatgpt_api_endpoints=chatgpt_api_endpoints,
-  web_chat_urls=web_chat_urls,
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
-  disable_tui=args.disable_tui,
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz
 )
@@ -91,8 +88,6 @@ api = ChatGPTAPI(
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
-
-
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
@@ -104,16 +99,14 @@ def preemptively_start_download(request_id: str, opaque_status: str):
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
       traceback.print_exc()
-
-
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
+
 if args.prometheus_client_port:
   from exo.stats.metrics import start_metrics_server
   start_metrics_server(node, args.prometheus_client_port)
 
 last_broadcast_time = 0
 
-
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
   global last_broadcast_time
   current_time = time.time()