Quellcode durchsuchen

Merge pull request #368 from ianpaul10/feat/manual-disc-0

Manual networking with configuration files
Alex Cheema vor 6 Monaten
Ursprung
Commit
03621b9962

+ 1 - 0
.gitignore

@@ -170,3 +170,4 @@ cython_debug/
 #.idea/
 
 **/*.xcodeproj/*
+.aider*

+ 8 - 2
exo/main.py

@@ -6,7 +6,8 @@ import logging
 import time
 import traceback
 import uuid
-import sys
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp.udp_discovery import UDPDiscovery
@@ -36,8 +37,9 @@ parser.add_argument("--download-quick-check", action="store_true", help="Quick c
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
-parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
+parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
+parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
@@ -80,6 +82,10 @@ if args.discovery_module == "udp":
   discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
 elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
+elif args.discovery_module == "manual":
+  if not args.discovery_config_path:
+    raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,

+ 1 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -56,7 +56,7 @@ class GRPCPeerHandle(PeerHandle):
       return response.is_healthy
     except asyncio.TimeoutError:
       return False
-    except:
+    except Exception:
       if DEBUG >= 4:
         print(f"Health check failed for {self._id}@{self.address}.")
         import traceback

+ 0 - 0
exo/networking/manual/__init__.py


+ 81 - 0
exo/networking/manual/manual_discovery.py

@@ -0,0 +1,81 @@
+import asyncio
+from exo.networking.discovery import Discovery
+from typing import Dict, List, Callable
+
+from exo.topology.device_capabilities import DeviceCapabilities
+from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
+from exo.helpers import DEBUG_DISCOVERY
+from exo.networking.peer_handle import PeerHandle
+
+
+class ManualDiscovery(Discovery):
+  def __init__(
+    self,
+    network_config_path: str,
+    node_id: str,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    discovery_timeout: int = 30,
+  ):
+    self.topology = NetworkTopology.from_path(network_config_path)
+    self.node_id = node_id
+    self.create_peer_handle = create_peer_handle
+    self.discovery_timeout = discovery_timeout
+
+    try:
+      self.node = self.topology.peers[node_id]
+    except KeyError as e:
+      print(f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}")
+      raise e
+
+    self.node_port = self.node.port
+
+    self.listen_task = None
+    self.cleanup_task = None
+
+    self.known_peers: Dict[str, PeerHandle] = {}
+    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
+    self.node_config = self.peers_in_network.pop(node_id)
+
+  async def start(self) -> None:
+    self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
+
+  async def stop(self) -> None:
+    if self.listen_task:
+      self.listen_task.cancel()
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if DEBUG_DISCOVERY >= 2: print("Starting discovery...")
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    return list(self.known_peers.values())
+
+
+  async def task_find_peers_from_config(self):
+    if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
+    while True:
+      for peer_id, peer_config in self.peers_in_network.items():
+        try:
+          if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
+          peer = self.known_peers.get(peer_id)
+          if not peer:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)  
+          is_healthy = await peer.health_check()
+          if is_healthy:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
+            self.known_peers[peer_id] = peer
+          else:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
+            try:
+              del self.known_peers[peer_id]
+            except KeyError:
+              pass  # peer was never added, so nothing to delete
+        except Exception as e:
+            if DEBUG_DISCOVERY >=2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+        finally:
+          await asyncio.sleep(self.discovery_timeout)
+
+      if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
+

+ 32 - 0
exo/networking/manual/network_topology_config.py

@@ -0,0 +1,32 @@
+from typing import Dict
+from pydantic import BaseModel, ValidationError
+
+from exo.topology.device_capabilities import DeviceCapabilities
+
+
+class PeerConfig(BaseModel):
+  address: str
+  port: int
+  device_capabilities: DeviceCapabilities
+
+
+class NetworkTopology(BaseModel):
+  """Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
+
+  peers: Dict[str, PeerConfig]
+  """
+  node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
+  """
+
+  @classmethod
+  def from_path(cls, path: str) -> "NetworkTopology":
+    try:
+      with open(path, "r") as f:
+        config_data = f.read()
+    except FileNotFoundError as e:
+      raise FileNotFoundError(f"Config file not found at {path}") from e
+
+    try:
+      return cls.model_validate_json(config_data)
+    except ValidationError as e:
+      raise ValueError(f"Error validating network topology config from {path}: {e}") from e

+ 17 - 0
exo/networking/manual/test_data/invalid_config.json

@@ -0,0 +1,17 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 0 - 0
exo/networking/manual/test_data/invalid_json.json


+ 32 - 0
exo/networking/manual/test_data/test_config.json

@@ -0,0 +1,32 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "port": 50051,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    },
+    "node2": {
+      "address": "localhost",
+      "port": 50052,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 18 - 0
exo/networking/manual/test_data/test_config_single_node.json

@@ -0,0 +1,18 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "port": 50051,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 103 - 0
exo/networking/manual/test_manual_discovery.py

@@ -0,0 +1,103 @@
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.orchestration.node import Node
+
+root_path = "./exo/networking/manual/test_data/test_config.json"
+
+
+class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    _ = self.discovery1.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
+    assert len(peers1) == 0
+
+    self.peer1.connect.assert_not_called()
+
+
+class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
+
+
+class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    config = NetworkTopology.from_path(root_path)
+
+    self.node1 = mock.AsyncMock(spec=Node)
+    self.node2 = mock.AsyncMock(spec=Node)
+    self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
+    self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
+    await self.server1.start()
+    await self.server2.start()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+    await self.server1.stop()
+    await self.server2.stop()
+
+  async def test_grpc_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # Connect
+    await peers1[0].connect()
+    await peers2[0].connect()
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertTrue(await peers2[0].is_connected())
+
+    # Kill server1
+    await self.server1.stop()
+
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+    # Kill server2
+    await self.server2.stop()
+
+    self.assertFalse(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+
+if __name__ == "__main__":
+  asyncio.run(unittest.main())

+ 49 - 0
exo/networking/manual/test_network_topology_config.py

@@ -0,0 +1,49 @@
+import unittest
+
+from exo.networking.manual.network_topology_config import NetworkTopology
+
+root_path = "./exo/networking/manual/test_data/"
+
+
+class TestNetworkTopologyConfig(unittest.TestCase):
+  def test_from_path_invalid_path(self):
+    with self.assertRaises(FileNotFoundError) as e:
+      NetworkTopology.from_path("invalid_path")
+    self.assertEqual(str(e.exception), "Config file not found at invalid_path")
+
+  def test_from_path_invalid_json(self):
+    with self.assertRaises(ValueError) as e:
+      NetworkTopology.from_path(root_path + "invalid_json.json")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("1 validation error for NetworkTopology\n  Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception))
+
+  def test_from_path_invalid_config(self):
+    with self.assertRaises(ValueError) as e:
+      NetworkTopology.from_path(root_path + "invalid_config.json")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("port\n  Field required", str(e.exception))
+
+  def test_from_path_valid(self):
+    config = NetworkTopology.from_path(root_path + "test_config.json")
+
+    self.assertEqual(config.peers["node1"].port, 50051)
+    self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node1"].address, "localhost")
+    self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node1"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0)
+
+    self.assertEqual(config.peers["node2"].port, 50052)
+    self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node2"].address, "localhost")
+    self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node2"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0)
+
+
+if __name__ == "__main__":
+  unittest.main()

+ 1 - 0
exo/networking/udp/test_udp_discovery.py

@@ -6,6 +6,7 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.orchestration.node import Node
 
+
 class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()

+ 1 - 1
exo/networking/udp/udp_discovery.py

@@ -205,4 +205,4 @@ class UDPDiscovery(Discovery):
       (current_time - last_seen > self.discovery_timeout) or
       (not health_ok)
     )
-    return should_remove
+    return should_remove

+ 6 - 7
exo/topology/device_capabilities.py

@@ -1,13 +1,13 @@
+from typing import Any
+from pydantic import BaseModel
 from exo import DEBUG
-from dataclasses import dataclass, asdict
 import subprocess
 import psutil
 
 TFLOPS = 1.00
 
 
-@dataclass
-class DeviceFlops:
+class DeviceFlops(BaseModel):
   # units of TFLOPS
   fp32: float
   fp16: float
@@ -17,11 +17,10 @@ class DeviceFlops:
     return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
 
   def to_dict(self):
-    return asdict(self)
+    return self.model_dump()
 
 
-@dataclass
-class DeviceCapabilities:
+class DeviceCapabilities(BaseModel):
   model: str
   chip: str
   memory: int
@@ -30,7 +29,7 @@ class DeviceCapabilities:
   def __str__(self):
     return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
 
-  def __post_init__(self):
+  def model_post_init(self, __context: Any) -> None:
     if isinstance(self.flops, dict):
       self.flops = DeviceFlops(**self.flops)
 

+ 1 - 0
setup.py

@@ -18,6 +18,7 @@ install_requires = [
   "prometheus-client==0.20.0",
   "protobuf==5.27.1",
   "psutil==6.0.0",
+  "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",
   "safetensors==0.4.3",