Sfoglia il codice sorgente

changes to exo/main.py for manual config flags

Ian Paul 6 mesi fa
parent
commit
ad389363bc

+ 8 - 2
exo/main.py

@@ -5,7 +5,8 @@ import json
 import time
 import traceback
 import uuid
-import sys
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp.udp_discovery import UDPDiscovery
@@ -35,8 +36,9 @@ parser.add_argument("--download-quick-check", action="store_true", help="Quick c
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
-parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
+parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
+parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
@@ -78,6 +80,10 @@ if args.discovery_module == "udp":
   discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
 elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
+elif args.discovery_module == "manual":
+  if not args.discovery_config_path:
+    raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,

+ 1 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -56,7 +56,7 @@ class GRPCPeerHandle(PeerHandle):
       return response.is_healthy
     except asyncio.TimeoutError:
       return False
-    except:
+    except Exception:
       if DEBUG >= 4:
         print(f"Health check failed for {self._id}@{self.address}.")
         import traceback

+ 1 - 0
exo/networking/udp/test_udp_discovery.py

@@ -6,6 +6,7 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.orchestration.node import Node
 
+
 class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()

+ 1 - 1
exo/networking/udp/udp_discovery.py

@@ -205,4 +205,4 @@ class UDPDiscovery(Discovery):
       (current_time - last_seen > self.discovery_timeout) or
       (not health_ok)
     )
-    return should_remove
+    return should_remove