浏览代码

Merge pull request #229 from exo-explore/tailscale

Tailscale
Alex Cheema 7 月之前
父节点
当前提交
008247ab61
共有 5 个文件被更改,包括 271 次插入1 次删除
  1. 125 0
      exo/networking/tailscale_discovery.py
  2. 96 0
      exo/networking/tailscale_helpers.py
  3. 41 0
      exo/networking/test_tailscale_discovery.py
  4. 8 1
      main.py
  5. 1 0
      setup.py

+ 125 - 0
exo/networking/tailscale_discovery.py

@@ -0,0 +1,125 @@
+import asyncio
+import time
+import traceback
+from datetime import datetime, timezone
+from typing import List, Dict, Callable, Tuple
+from tailscale import Tailscale, Device
+from .discovery import Discovery
+from .peer_handle import PeerHandle
+from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
+from exo.helpers import DEBUG, DEBUG_DISCOVERY
+from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes
+
+class TailscaleDiscovery(Discovery):
+  def __init__(
+    self,
+    node_id: str,
+    node_port: int,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    discovery_interval: int = 10,
+    discovery_timeout: int = 30,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    tailscale_api_key: str = None,
+    tailnet: str = None
+  ):
+    self.node_id = node_id
+    self.node_port = node_port
+    self.create_peer_handle = create_peer_handle
+    self.discovery_interval = discovery_interval
+    self.discovery_timeout = discovery_timeout
+    self.device_capabilities = device_capabilities
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
+    self.discovery_task = None
+    self.cleanup_task = None
+    self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
+    self._device_id = None
+
+  async def start(self):
+    self.device_capabilities = device_capabilities()
+    await self.update_device_posture_attributes()  # Fetch and update device posture attributes
+    self.discovery_task = asyncio.create_task(self.task_discover_peers())
+    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+
+  async def get_device_id(self):
+    if self._device_id:
+      return self._device_id
+    self._device_id = await get_device_id()
+    return self._device_id
+
+  async def update_device_posture_attributes(self):
+    await update_device_attributes(await self.get_device_id(), self.tailscale.api_key, self.node_id, self.node_port, self.device_capabilities)
+
+  async def task_discover_peers(self):
+    while True:
+      try:
+        devices: dict[str, Device] = await self.tailscale.devices()
+        current_time = datetime.now(timezone.utc).timestamp()
+
+        active_devices = {
+          name: device for name, device in devices.items()
+          if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
+        }
+
+        if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
+        if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
+        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time  - device.last_seen.timestamp()) for device in devices.values()])
+
+        for device in active_devices.values():
+          if device.name != self.node_id:
+            peer_host = device.addresses[0]
+            peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale.api_key)
+            if not peer_id:
+              if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
+              continue
+            if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+              if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+              self.known_peers[peer_id] = (
+                self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
+                current_time,
+                current_time,
+              )
+            else:
+              self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], current_time)
+
+      except Exception as e:
+        print(f"Error in discover peers: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.discovery_interval)
+
+  async def stop(self):
+    if self.discovery_task:
+      self.discovery_task.cancel()
+    if self.cleanup_task:
+      self.cleanup_task.cancel()
+    if self.discovery_task or self.cleanup_task:
+      await asyncio.gather(self.discovery_task, self.cleanup_task, return_exceptions=True)
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+
+  async def task_cleanup_peers(self):
+    while True:
+      try:
+        current_time = time.time()
+        peers_to_remove = [
+          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
+        ]
+        if DEBUG_DISCOVERY >= 2:
+          print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        for peer_id in peers_to_remove:
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+          if DEBUG_DISCOVERY >= 2:
+            print(f"Removed peer {peer_id} due to inactivity.")
+      except Exception as e:
+        print(f"Error in cleanup peers: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.discovery_interval)

+ 96 - 0
exo/networking/tailscale_helpers.py

@@ -0,0 +1,96 @@
+import json
+import asyncio
+import aiohttp
+import re
+from typing import Dict, Any, Tuple
+from exo.helpers import DEBUG_DISCOVERY
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+
+async def get_device_id() -> str:
+  try:
+    process = await asyncio.create_subprocess_exec(
+      'tailscale', 'status', '--json',
+      stdout=asyncio.subprocess.PIPE,
+      stderr=asyncio.subprocess.PIPE
+    )
+    stdout, stderr = await process.communicate()
+    if process.returncode != 0:
+      raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
+    if DEBUG_DISCOVERY >= 4: print(f"tailscale status: {stdout.decode()}")
+    data = json.loads(stdout.decode())
+    return data['Self']['ID']
+  except Exception as e:
+    raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
+
+async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
+  async with aiohttp.ClientSession() as session:
+    base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
+    headers = {
+      'Authorization': f'Bearer {api_key}',
+      'Content-Type': 'application/json'
+    }
+
+    attributes = {
+      "custom:exo_node_id": node_id.replace('-', '_'),
+      "custom:exo_node_port": node_port,
+      "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
+      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model),
+      "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
+      "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
+    }
+
+    for attr_name, attr_value in attributes.items():
+      url = f"{base_url}/{attr_name}"
+      data = {"value": str(attr_value).replace(' ', '_')}  # Ensure all values are strings for JSON
+      async with session.post(url, headers=headers, json=data) as response:
+        if response.status == 200:
+          if DEBUG_DISCOVERY >= 1: print(f"Updated device posture attribute {attr_name} for device {device_id}")
+        else:
+          print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
+
+async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
+  async with aiohttp.ClientSession() as session:
+    url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
+    headers = {
+      'Authorization': f'Bearer {api_key}'
+    }
+    async with session.get(url, headers=headers) as response:
+      if response.status == 200:
+        data = await response.json()
+        attributes = data.get("attributes", {})
+        node_id = attributes.get("custom:exo_node_id", "").replace('_', '-')
+        node_port = int(attributes.get("custom:exo_node_port", 0))
+        device_capabilities = DeviceCapabilities(
+          model=attributes.get("custom:exo_device_capability_model", "").replace('_', ' '),
+          chip=attributes.get("custom:exo_device_capability_chip", "").replace('_', ' '),
+          memory=int(attributes.get("custom:exo_device_capability_memory", 0)),
+          flops=DeviceFlops(
+            fp16=float(attributes.get("custom:exo_device_capability_flops_fp16", 0)),
+            fp32=float(attributes.get("custom:exo_device_capability_flops_fp32", 0)),
+            int8=float(attributes.get("custom:exo_device_capability_flops_int8", 0))
+          )
+        )
+        return node_id, node_port, device_capabilities
+      else:
+        print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
+        return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
+
+def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
+  result = {}
+  prefix = "custom:exo_"
+  for key, value in data.items():
+    if key.startswith(prefix):
+      attr_name = key.replace(prefix, "")
+      if attr_name in ["node_id", "node_port", "device_capability_chip", "device_capability_model"]:
+        result[attr_name] = value.replace('_', ' ')
+      elif attr_name in ["device_capability_memory", "device_capability_flops_fp16", "device_capability_flops_fp32", "device_capability_flops_int8"]:
+        result[attr_name] = float(value)
+  return result
+
+def sanitize_attribute(value: str) -> str:
+  # Replace invalid characters with underscores
+  sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
+  # Truncate to 50 characters
+  return sanitized_value[:50]

+ 41 - 0
exo/networking/test_tailscale_discovery.py

@@ -0,0 +1,41 @@
+import os
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.tailscale_discovery import TailscaleDiscovery
+from exo.networking.peer_handle import PeerHandle
+
+class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
+    self.tailnet = os.environ.get("TAILSCALE_TAILNET", "")
+    self.discovery = TailscaleDiscovery(
+      node_id="test_node",
+      node_port=50051,
+      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
+      tailscale_api_key=self.tailscale_api_key,
+      tailnet=self.tailnet
+    )
+    await self.discovery.start()
+
+  async def asyncTearDown(self):
+    await self.discovery.stop()
+
+  async def test_discovery(self):
+    # Wait for a short period to allow discovery to happen
+    await asyncio.sleep(15)
+
+    # Get discovered peers
+    peers = await self.discovery.discover_peers()
+
+    # Check if any peers were discovered
+    self.assertGreater(len(peers), 0, "No peers were discovered")
+
+    # Print discovered peers for debugging
+    print(f"Discovered peers: {[peer.id() for peer in peers]}")
+
+    # Check if discovered peers are instances of GRPCPeerHandle
+    print(peers)
+
+if __name__ == '__main__':
+  unittest.main()

+ 8 - 1
main.py

@@ -8,6 +8,7 @@ import uuid
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp_discovery import UDPDiscovery
+from exo.networking.tailscale_discovery import TailscaleDiscovery
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
@@ -31,6 +32,7 @@ parser.add_argument("--download-quick-check", action="store_true", help="Quick c
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
+parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
@@ -40,6 +42,8 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
+parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
 
 print_yellow_exo()
@@ -67,7 +71,10 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
-discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+if args.discovery_module == "udp":
+  discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+elif args.discovery_module == "tailscale":
+  discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,

+ 1 - 0
setup.py

@@ -23,6 +23,7 @@ install_requires = [
   "requests==2.32.3",
   "rich==13.7.1",
   "safetensors==0.4.3",
+  "tailscale==0.6.1",
   "tenacity==9.0.0",
   "tiktoken==0.7.0",
   "tokenizers==0.19.1",