Преглед изворни кода

support node_id, node_port and device_capabilities with tailscale attributes

Alex Cheema пре 8 месеци
родитељ
комит
1798fc073f
2 измењених фајлова са 119 додато и 19 уклоњено
  1. 30 19
      exo/networking/tailscale_discovery.py
  2. 89 0
      exo/networking/tailscale_helpers.py

+ 30 - 19
exo/networking/tailscale_discovery.py

@@ -8,6 +8,7 @@ from .discovery import Discovery
 from .peer_handle import PeerHandle
 from .peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
+from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes
 
 
 class TailscaleDiscovery(Discovery):
 class TailscaleDiscovery(Discovery):
   def __init__(
   def __init__(
@@ -31,27 +32,22 @@ class TailscaleDiscovery(Discovery):
     self.discovery_task = None
     self.discovery_task = None
     self.cleanup_task = None
     self.cleanup_task = None
     self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
     self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
+    self._device_id = None
 
 
   async def start(self):
   async def start(self):
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
+    await self.update_device_posture_attributes()  # Fetch and update device posture attributes
     self.discovery_task = asyncio.create_task(self.task_discover_peers())
     self.discovery_task = asyncio.create_task(self.task_discover_peers())
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
 
 
-  async def stop(self):
-    if self.discovery_task:
-      self.discovery_task.cancel()
-    if self.cleanup_task:
-      self.cleanup_task.cancel()
-    if self.discovery_task or self.cleanup_task:
-      await asyncio.gather(self.discovery_task, self.cleanup_task, return_exceptions=True)
+  async def get_device_id(self):
+    if self._device_id:
+      return self._device_id
+    self._device_id = await get_device_id()
+    return self._device_id
 
 
-  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if wait_for_peers > 0:
-      while len(self.known_peers) < wait_for_peers:
-        if DEBUG_DISCOVERY >= 2:
-          print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
-        await asyncio.sleep(0.1)
-    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+  async def update_device_posture_attributes(self):
+    await update_device_attributes(await self.get_device_id(), self.tailscale.api_key, self.node_id, self.node_port, self.device_capabilities)
 
 
   async def task_discover_peers(self):
   async def task_discover_peers(self):
     while True:
     while True:
@@ -59,7 +55,6 @@ class TailscaleDiscovery(Discovery):
         devices: dict[str, Device] = await self.tailscale.devices()
         devices: dict[str, Device] = await self.tailscale.devices()
         current_time = datetime.now(timezone.utc).timestamp()
         current_time = datetime.now(timezone.utc).timestamp()
 
 
-        # Filter out devices last seen more than 1 minute ago
         active_devices = {
         active_devices = {
           name: device for name, device in devices.items()
           name: device for name, device in devices.items()
           if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
           if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
@@ -71,14 +66,14 @@ class TailscaleDiscovery(Discovery):
 
 
         for device in active_devices.values():
         for device in active_devices.values():
           if device.name != self.node_id:
           if device.name != self.node_id:
-            peer_id = device.name
-            peer_host = device.addresses[0]  # Assuming the first address is the one we want
-            peer_port = self.node_port  # Assuming all peers use the same port
+            peer_host = device.addresses[0]
+            peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale.api_key)
+            print("retrieved attributes", peer_id, peer_host, peer_port, device_capabilities)
 
 
             if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
             if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
               if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
               if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
               self.known_peers[peer_id] = (
               self.known_peers[peer_id] = (
-                self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", self.device_capabilities),
+                self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
                 current_time,
                 current_time,
                 current_time,
                 current_time,
               )
               )
@@ -91,6 +86,22 @@ class TailscaleDiscovery(Discovery):
       finally:
       finally:
         await asyncio.sleep(self.discovery_interval)
         await asyncio.sleep(self.discovery_interval)
 
 
+  async def stop(self):
+    if self.discovery_task:
+      self.discovery_task.cancel()
+    if self.cleanup_task:
+      self.cleanup_task.cancel()
+    if self.discovery_task or self.cleanup_task:
+      await asyncio.gather(self.discovery_task, self.cleanup_task, return_exceptions=True)
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+
   async def task_cleanup_peers(self):
   async def task_cleanup_peers(self):
     while True:
     while True:
       try:
       try:

+ 89 - 0
exo/networking/tailscale_helpers.py

@@ -0,0 +1,89 @@
+import json
+import asyncio
+import aiohttp
+from typing import Dict, Any, Tuple
+from exo.helpers import DEBUG_DISCOVERY
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+
+async def get_device_id() -> str:
+  try:
+    process = await asyncio.create_subprocess_exec(
+      'tailscale', 'status', '--json',
+      stdout=asyncio.subprocess.PIPE,
+      stderr=asyncio.subprocess.PIPE
+    )
+    stdout, stderr = await process.communicate()
+    if process.returncode != 0:
+      raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
+    if DEBUG_DISCOVERY >= 4: print(f"tailscale status: {stdout.decode()}")
+    data = json.loads(stdout.decode())
+    return data['Self']['ID']
+  except Exception as e:
+    raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
+
+async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
+  async with aiohttp.ClientSession() as session:
+    base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
+    headers = {
+      'Authorization': f'Bearer {api_key}',
+      'Content-Type': 'application/json'
+    }
+
+    attributes = {
+      "custom:exo_node_id": node_id.replace('-', '_'),
+      "custom:exo_node_port": node_port,
+      "custom:exo_device_capability_chip": device_capabilities.chip.replace(' ', '_'),
+      "custom:exo_device_capability_model": device_capabilities.model.replace(' ', '_'),
+      "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
+      "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
+    }
+
+    for attr_name, attr_value in attributes.items():
+      url = f"{base_url}/{attr_name}"
+      data = {"value": str(attr_value).replace(' ', '_')}  # Ensure all values are strings for JSON
+      async with session.post(url, headers=headers, json=data) as response:
+        if response.status == 200:
+          if DEBUG_DISCOVERY >= 1: print(f"Updated device posture attribute {attr_name} for device {device_id}")
+        else:
+          print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
+
+async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
+  async with aiohttp.ClientSession() as session:
+    url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
+    headers = {
+      'Authorization': f'Bearer {api_key}'
+    }
+    async with session.get(url, headers=headers) as response:
+      if response.status == 200:
+        data = await response.json()
+        attributes = data.get("attributes", {})
+        node_id = attributes.get("custom:exo_node_id", "").replace('_', '-')
+        node_port = int(attributes.get("custom:exo_node_port", 0))
+        device_capabilities = DeviceCapabilities(
+          model=attributes.get("custom:exo_device_capability_model", "").replace('_', ' '),
+          chip=attributes.get("custom:exo_device_capability_chip", "").replace('_', ' '),
+          memory=int(attributes.get("custom:exo_device_capability_memory", 0)),
+          flops=DeviceFlops(
+            fp16=float(attributes.get("custom:exo_device_capability_flops_fp16", 0)),
+            fp32=float(attributes.get("custom:exo_device_capability_flops_fp32", 0)),
+            int8=float(attributes.get("custom:exo_device_capability_flops_int8", 0))
+          )
+        )
+        return node_id, node_port, device_capabilities
+      else:
+        print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
+        return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
+
+def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
+  result = {}
+  prefix = "custom:exo_"
+  for key, value in data.items():
+    if key.startswith(prefix):
+      attr_name = key.replace(prefix, "")
+      if attr_name in ["node_id", "node_port", "device_capability_chip", "device_capability_model"]:
+        result[attr_name] = value.replace('_', ' ')
+      elif attr_name in ["device_capability_memory", "device_capability_flops_fp16", "device_capability_flops_fp32", "device_capability_flops_int8"]:
+        result[attr_name] = float(value)
+  return result