Browse Source

periodically update exo_updated_at attribute for tailscale

Alex Cheema 9 months ago
parent
commit
15a2165d78
2 changed files with 33 additions and 10 deletions
  1. 24 5
      exo/networking/tailscale_discovery.py
  2. 9 5
      exo/networking/tailscale_helpers.py

+ 24 - 5
exo/networking/tailscale_discovery.py

@@ -20,7 +20,7 @@ class TailscaleDiscovery(Discovery):
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     tailscale_api_key: str = None,
-    tailnet: str = None
+    tailnet: str = None,
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -33,12 +33,25 @@ class TailscaleDiscovery(Discovery):
     self.cleanup_task = None
     self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
     self._device_id = None
+    self.update_task = None
 
   async def start(self):
     self.device_capabilities = device_capabilities()
-    await self.update_device_posture_attributes()  # Fetch and update device posture attributes
     self.discovery_task = asyncio.create_task(self.task_discover_peers())
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+    self.update_task = asyncio.create_task(self.task_update_device_posture_attributes())
+
+  async def task_update_device_posture_attributes(self):
+    while True:
+      try:
+        await self.update_device_posture_attributes()
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Updated device posture attributes")
+      except Exception as e:
+        print(f"Error updating device posture attributes: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.discovery_interval)
 
   async def get_device_id(self):
     if self._device_id:
@@ -67,10 +80,14 @@ class TailscaleDiscovery(Discovery):
         for device in active_devices.values():
           if device.name != self.node_id:
             peer_host = device.addresses[0]
-            peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale.api_key)
+            peer_id, peer_port, device_capabilities, updated_at = await get_device_attributes(device.device_id, self.tailscale.api_key)
             if not peer_id:
               if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
               continue
+            if current_time - updated_at > self.discovery_timeout:
+              if DEBUG_DISCOVERY >= 3: print(f"{device.device_id} has outdated exo node attributes. skipping.")
+              continue
+
             if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
               if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
               self.known_peers[peer_id] = (
@@ -92,8 +109,10 @@ class TailscaleDiscovery(Discovery):
       self.discovery_task.cancel()
     if self.cleanup_task:
       self.cleanup_task.cancel()
-    if self.discovery_task or self.cleanup_task:
-      await asyncio.gather(self.discovery_task, self.cleanup_task, return_exceptions=True)
+    if self.update_task:
+      self.update_task.cancel()
+    if self.discovery_task or self.cleanup_task or self.update_task:
+      await asyncio.gather(self.discovery_task, self.cleanup_task, self.update_task, return_exceptions=True)
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
     if wait_for_peers > 0:

+ 9 - 5
exo/networking/tailscale_helpers.py

@@ -2,9 +2,10 @@ import json
 import asyncio
 import aiohttp
 import re
-from typing import Dict, Any, Tuple
+from typing import Dict, Any, Tuple, Optional
 from exo.helpers import DEBUG_DISCOVERY
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+from datetime import datetime, timezone
 
 async def get_device_id() -> str:
   try:
@@ -38,7 +39,8 @@ async def update_device_attributes(device_id: str, api_key: str, node_id: str, n
       "custom:exo_device_capability_memory": str(device_capabilities.memory),
       "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
       "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
-      "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
+      "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8),
+      "custom:exo_updated_at": str(int(datetime.now(timezone.utc).timestamp()))
     }
 
     for attr_name, attr_value in attributes.items():
@@ -50,7 +52,7 @@ async def update_device_attributes(device_id: str, api_key: str, node_id: str, n
         else:
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
 
-async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
+async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities, int]:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
     headers = {
@@ -72,10 +74,12 @@ async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int,
             int8=float(attributes.get("custom:exo_device_capability_flops_int8", 0))
           )
         )
-        return node_id, node_port, device_capabilities
+        updated_at_str = attributes.get("custom:exo_updated_at")
+        updated_at = int(updated_at_str) if updated_at_str else 0
+        return node_id, node_port, device_capabilities, updated_at
       else:
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
-        return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
+        return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0)), 0
 
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
   result = {}