Przeglądaj źródła

implement tailscale discovery module

Alex Cheema 7 miesięcy temu
rodzic
commit
93224799e5

+ 104 - 0
exo/networking/tailscale_discovery.py

@@ -0,0 +1,104 @@
+import asyncio
+import time
+import traceback
+from typing import List, Dict, Callable, Tuple
+from tailscale import Tailscale, Device
+from .discovery import Discovery
+from .peer_handle import PeerHandle
+from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
+from exo.helpers import DEBUG, DEBUG_DISCOVERY
+
+class TailscaleDiscovery(Discovery):
+  def __init__(
+    self,
+    node_id: str,
+    node_port: int,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    discovery_interval: int = 10,
+    discovery_timeout: int = 30,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    tailscale_api_key: str = None,
+    tailnet: str = None
+  ):
+    self.node_id = node_id
+    self.node_port = node_port
+    self.create_peer_handle = create_peer_handle
+    self.discovery_interval = discovery_interval
+    self.discovery_timeout = discovery_timeout
+    self.device_capabilities = device_capabilities
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
+    self.discovery_task = None
+    self.cleanup_task = None
+    self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
+
+  async def start(self):
+    self.device_capabilities = device_capabilities()
+    self.discovery_task = asyncio.create_task(self.task_discover_peers())
+    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+
+  async def stop(self):
+    if self.discovery_task:
+      self.discovery_task.cancel()
+    if self.cleanup_task:
+      self.cleanup_task.cancel()
+    if self.discovery_task or self.cleanup_task:
+      await asyncio.gather(self.discovery_task, self.cleanup_task, return_exceptions=True)
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+
+  async def task_discover_peers(self):
+    while True:
+      try:
+        devices: dict[str, Device] = await self.tailscale.devices()
+        print("Devices:", devices)
+        current_time = time.time()
+
+        for device in devices.values():
+          if device.name != self.node_id:
+            peer_id = device.name
+            peer_host = device.addresses[0]  # Assuming the first address is the one we want
+            peer_port = self.node_port  # Assuming all peers use the same port
+
+            if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+              if DEBUG >= 1:
+                print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+              self.known_peers[peer_id] = (
+                self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", self.device_capabilities),
+                current_time,
+                current_time,
+              )
+            else:
+              self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], current_time)
+
+      except Exception as e:
+        print(f"Error in discover peers: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.discovery_interval)
+
+  async def task_cleanup_peers(self):
+    while True:
+      try:
+        current_time = time.time()
+        peers_to_remove = [
+          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
+        ]
+        if DEBUG_DISCOVERY >= 2:
+          print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        for peer_id in peers_to_remove:
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+          if DEBUG_DISCOVERY >= 2:
+            print(f"Removed peer {peer_id} due to inactivity.")
+      except Exception as e:
+        print(f"Error in cleanup peers: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.discovery_interval)

+ 41 - 0
exo/networking/test_tailscale_discovery.py

@@ -0,0 +1,41 @@
+import os
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.tailscale_discovery import TailscaleDiscovery
+from exo.networking.peer_handle import PeerHandle
+
+class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
+    self.tailnet = os.environ.get("TAILSCALE_TAILNET", "")
+    self.discovery = TailscaleDiscovery(
+      node_id="test_node",
+      node_port=50051,
+      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
+      tailscale_api_key=self.tailscale_api_key,
+      tailnet=self.tailnet
+    )
+    await self.discovery.start()
+
+  async def asyncTearDown(self):
+    await self.discovery.stop()
+
+  async def test_discovery(self):
+    # Wait for a short period to allow discovery to happen
+    await asyncio.sleep(15)
+
+    # Get discovered peers
+    peers = await self.discovery.discover_peers()
+
+    # Check if any peers were discovered
+    self.assertGreater(len(peers), 0, "No peers were discovered")
+
+    # Print discovered peers for debugging
+    print(f"Discovered peers: {[peer.id() for peer in peers]}")
+
+    # Check if discovered peers are instances of GRPCPeerHandle
+    print(peers)
+
+if __name__ == '__main__':
+  unittest.main()

+ 8 - 1
main.py

@@ -8,6 +8,7 @@ import uuid
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp_discovery import UDPDiscovery
+from exo.networking.tailscale_discovery import TailscaleDiscovery
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
@@ -31,6 +32,7 @@ parser.add_argument("--download-quick-check", action="store_true", help="Quick c
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
+parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
@@ -40,6 +42,8 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
+parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
 
 print_yellow_exo()
@@ -67,7 +71,10 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
-discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+if args.discovery_module == "udp":
+  discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+elif args.discovery_module == "tailscale":
+  discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,

+ 1 - 0
setup.py

@@ -23,6 +23,7 @@ install_requires = [
   "requests==2.32.3",
   "rich==13.7.1",
   "safetensors==0.4.3",
+  "tailscale==0.6.1",
   "tenacity==9.0.0",
   "tiktoken==0.7.0",
   "tokenizers==0.19.1",