|
@@ -2,15 +2,28 @@ import asyncio
|
|
|
import json
|
|
|
import socket
|
|
|
import time
|
|
|
-from typing import List, Dict
|
|
|
+from typing import List, Dict, Callable, Tuple, Coroutine
|
|
|
from ..discovery import Discovery
|
|
|
from ..peer_handle import PeerHandle
|
|
|
from .grpc_peer_handle import GRPCPeerHandle
|
|
|
-from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
|
|
|
+from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
|
|
|
from exo import DEBUG_DISCOVERY
|
|
|
|
|
|
+class ListenProtocol(asyncio.DatagramProtocol):
|
|
|
+ def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
|
|
|
+ super().__init__()
|
|
|
+ self.on_message = on_message
|
|
|
+ self.loop = asyncio.get_event_loop()
|
|
|
+
|
|
|
+ def connection_made(self, transport):
|
|
|
+ self.transport = transport
|
|
|
+
|
|
|
+ def datagram_received(self, data, addr):
|
|
|
+ asyncio.create_task(self.on_message(data, addr))
|
|
|
+
|
|
|
+
|
|
|
class GRPCDiscovery(Discovery):
|
|
|
- def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
|
|
|
+ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
|
|
|
self.node_id = node_id
|
|
|
self.node_port = node_port
|
|
|
self.device_capabilities = device_capabilities
|
|
@@ -24,9 +37,10 @@ class GRPCDiscovery(Discovery):
|
|
|
self.cleanup_task = None
|
|
|
|
|
|
async def start(self):
|
|
|
- self.broadcast_task = asyncio.create_task(self._broadcast_presence())
|
|
|
- self.listen_task = asyncio.create_task(self._listen_for_peers())
|
|
|
- self.cleanup_task = asyncio.create_task(self._cleanup_peers())
|
|
|
+ self.device_capabilities = device_capabilities()
|
|
|
+ self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
|
|
|
+ self.listen_task = asyncio.create_task(self.task_listen_for_peers())
|
|
|
+ self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
|
|
|
|
|
|
async def stop(self):
|
|
|
if self.broadcast_task:
|
|
@@ -62,54 +76,49 @@ class GRPCDiscovery(Discovery):
|
|
|
|
|
|
return list(self.known_peers.values())
|
|
|
|
|
|
- async def _broadcast_presence(self):
|
|
|
- if not self.device_capabilities:
|
|
|
- self.device_capabilities = device_capabilities()
|
|
|
-
|
|
|
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
|
|
+ async def task_broadcast_presence(self):
|
|
|
+ transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
|
|
|
+ lambda: asyncio.DatagramProtocol(),
|
|
|
+ local_addr=('0.0.0.0', 0),
|
|
|
+ family=socket.AF_INET)
|
|
|
+ sock = transport.get_extra_info('socket')
|
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
|
|
- sock.settimeout(0.5)
|
|
|
+
|
|
|
message = json.dumps({
|
|
|
"type": "discovery",
|
|
|
"node_id": self.node_id,
|
|
|
"grpc_port": self.node_port,
|
|
|
- "device_capabilities": {
|
|
|
- "model": self.device_capabilities.model,
|
|
|
- "chip": self.device_capabilities.chip,
|
|
|
- "memory": self.device_capabilities.memory
|
|
|
- }
|
|
|
+ "device_capabilities": self.device_capabilities.to_dict()
|
|
|
}).encode('utf-8')
|
|
|
|
|
|
- while True:
|
|
|
- sock.sendto(message, ('<broadcast>', self.broadcast_port))
|
|
|
- await asyncio.sleep(self.broadcast_interval)
|
|
|
-
|
|
|
- async def _listen_for_peers(self):
|
|
|
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
|
- sock.bind(('', self.listen_port))
|
|
|
- sock.setblocking(False)
|
|
|
-
|
|
|
while True:
|
|
|
try:
|
|
|
- data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
|
|
|
- message = json.loads(data.decode('utf-8'))
|
|
|
- if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
|
|
|
- if message['type'] == 'discovery' and message['node_id'] != self.node_id:
|
|
|
- peer_id = message['node_id']
|
|
|
- peer_host = addr[0]
|
|
|
- peer_port = message['grpc_port']
|
|
|
- device_capabilities = DeviceCapabilities(**message['device_capabilities'])
|
|
|
- if peer_id not in self.known_peers:
|
|
|
- self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
|
|
|
- if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
|
|
|
- self.peer_last_seen[peer_id] = time.time()
|
|
|
+ if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
|
|
|
+ transport.sendto(message, ('<broadcast>', self.broadcast_port))
|
|
|
+ await asyncio.sleep(self.broadcast_interval)
|
|
|
except Exception as e:
|
|
|
- print(f"Error in peer discovery: {e}")
|
|
|
+ print(f"Error in broadcast presence: {e}")
|
|
|
import traceback
|
|
|
print(traceback.format_exc())
|
|
|
- await asyncio.sleep(self.broadcast_interval / 2)
|
|
|
|
|
|
- async def _cleanup_peers(self):
|
|
|
+ async def on_listen_message(self, data, addr):
|
|
|
+ message = json.loads(data.decode('utf-8'))
|
|
|
+ if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
|
|
|
+ if message['type'] == 'discovery' and message['node_id'] != self.node_id:
|
|
|
+ peer_id = message['node_id']
|
|
|
+ peer_host = addr[0]
|
|
|
+ peer_port = message['grpc_port']
|
|
|
+ device_capabilities = DeviceCapabilities(**message['device_capabilities'])
|
|
|
+ if peer_id not in self.known_peers:
|
|
|
+ self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
|
|
|
+ if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
|
|
|
+ self.peer_last_seen[peer_id] = time.time()
|
|
|
+
|
|
|
+ async def task_listen_for_peers(self):
|
|
|
+ await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
|
|
|
+ if DEBUG_DISCOVERY >= 2: print("Started listen task")
|
|
|
+
|
|
|
+ async def task_cleanup_peers(self):
|
|
|
while True:
|
|
|
current_time = time.time()
|
|
|
timeout = 15 * self.broadcast_interval
|