Jelajahi Sumber

replace tailscale.devices with good old http, removing the need for tailscale dependency

Alex Cheema 7 bulan lalu
induk
melakukan
e8a8702377

+ 6 - 6
exo/networking/tailscale/tailscale_discovery.py

@@ -2,12 +2,11 @@ import asyncio
 import time
 import traceback
 from typing import List, Dict, Callable, Tuple
-from tailscale import Tailscale, Device
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
-from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes
+from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes, get_tailscale_devices, Device
 
 class TailscaleDiscovery(Discovery):
   def __init__(
@@ -32,7 +31,8 @@ class TailscaleDiscovery(Discovery):
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
     self.discovery_task = None
     self.cleanup_task = None
-    self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
+    self.tailscale_api_key = tailscale_api_key
+    self.tailnet = tailnet
     self._device_id = None
     self.update_task = None
 
@@ -61,12 +61,12 @@ class TailscaleDiscovery(Discovery):
     return self._device_id
 
   async def update_device_posture_attributes(self):
-    await update_device_attributes(await self.get_device_id(), self.tailscale.api_key, self.node_id, self.node_port, self.device_capabilities)
+    await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)
 
   async def task_discover_peers(self):
     while True:
       try:
-        devices: dict[str, Device] = await self.tailscale.devices()
+        devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
         current_time = time.time()
 
         active_devices = {
@@ -81,7 +81,7 @@ class TailscaleDiscovery(Discovery):
         for device in active_devices.values():
           if device.name == self.node_id: continue
           peer_host = device.addresses[0]
-          peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale.api_key)
+          peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
           if not peer_id:
             if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
             continue

+ 41 - 1
exo/networking/tailscale/tailscale_helpers.py

@@ -2,9 +2,32 @@ import json
 import asyncio
 import aiohttp
 import re
-from typing import Dict, Any, Tuple
+from typing import Dict, Any, Tuple, List, Optional
 from exo.helpers import DEBUG_DISCOVERY
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+from datetime import datetime, timezone
+
+class Device:
+  def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
+    self.device_id = device_id
+    self.name = name
+    self.addresses = addresses
+    self.last_seen = last_seen
+
+  @classmethod
+  def from_dict(cls, data: Dict[str, Any]) -> 'Device':
+    return cls(
+      device_id=data.get('id', ''),
+      name=data.get('name', ''),
+      addresses=data.get('addresses', []),
+      last_seen=cls.parse_datetime(data.get('lastSeen'))
+    )
+
+  @staticmethod
+  def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
+    if not date_string:
+      return None
+    return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
 
 async def get_device_id() -> str:
   try:
@@ -94,3 +117,20 @@ def sanitize_attribute(value: str) -> str:
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   # Truncate to 50 characters
   return sanitized_value[:50]
+
+async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
+  async with aiohttp.ClientSession() as session:
+    url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
+    headers = {"Authorization": f"Bearer {api_key}"}
+
+    async with session.get(url, headers=headers) as response:
+      response.raise_for_status()
+      data = await response.json()
+
+      devices = {}
+      for device_data in data.get("devices", []):
+        print("Device data: ", device_data)
+        device = Device.from_dict(device_data)
+        devices[device.name] = device
+
+      return devices

+ 0 - 1
setup.py

@@ -20,7 +20,6 @@ install_requires = [
   "requests==2.32.3",
   "rich==13.7.1",
   "safetensors==0.4.3",
-  "tailscale==0.6.1",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
   "transformers==4.43.3",