1
0
Эх сурвалжийг харах

by default find an ephemeral node port fixes #35, more robust topology updates. both fix #15 and #14

Alex Cheema 1 жил өмнө
parent
commit
35177690bd

+ 1 - 0
.gitignore

@@ -1,6 +1,7 @@
 __pycache__/
 __pycache__/
 .venv
 .venv
 test_weights.npz
 test_weights.npz
+.exo_used_ports
 
 
 # Byte-compiled / optimized / DLL files
 # Byte-compiled / optimized / DLL files
 __pycache__/
 __pycache__/

+ 0 - 3
exo/api/chatgpt_api.py

@@ -187,9 +187,6 @@ class ChatGPTAPI:
                     headers={
                     headers={
                         "Content-Type": "application/json",
                         "Content-Type": "application/json",
                         "Cache-Control": "no-cache",
                         "Cache-Control": "no-cache",
-                        # "Access-Control-Allow-Origin": "*",
-                        # "Access-Control-Allow-Methods": "*",
-                        # "Access-Control-Allow-Headers": "*",
                     }
                     }
                 )
                 )
                 await response.prepare(request)
                 await response.prepare(request)

+ 33 - 1
exo/helpers.py

@@ -1,6 +1,8 @@
 import os
 import os
 import asyncio
 import asyncio
 from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
 from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
+import socket
+import random
 
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -13,6 +15,36 @@ exo_text = """
  \___/_/\_\___/ 
  \___/_/\_\___/ 
     """
     """
 
 
+def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
+    used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
+
+    def read_used_ports():
+        if os.path.exists(used_ports_file):
+            with open(used_ports_file, 'r') as f:
+                return [int(line.strip()) for line in f if line.strip().isdigit()]
+        return []
+
+    def write_used_port(port, used_ports):
+        with open(used_ports_file, 'w') as f:
+            print(used_ports[-19:])
+            for p in used_ports[-19:] + [port]:
+                f.write(f"{p}\n")
+
+    used_ports = read_used_ports()
+    available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
+
+    while available_ports:
+        port = random.choice(list(available_ports))
+        if DEBUG >= 2: print(f"Trying to find available port {port=}")
+        try:
+            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+                s.bind((host, port))
+            write_used_port(port, used_ports)
+            return port
+        except socket.error:
+            available_ports.remove(port)
+
+    raise RuntimeError("No available ports in the specified range")
 
 
 def print_exo():
 def print_exo():
     print(exo_text)
     print(exo_text)
@@ -81,4 +113,4 @@ class AsyncCallbackSystem(Generic[K, T]):
 
 
     def trigger_all(self, *args: T) -> None:
     def trigger_all(self, *args: T) -> None:
         for callback in self.callbacks.values():
         for callback in self.callbacks.values():
-            callback.set(*args)
+            callback.set(*args)

+ 18 - 13
exo/networking/grpc/grpc_discovery.py

@@ -30,8 +30,7 @@ class GRPCDiscovery(Discovery):
         self.listen_port = listen_port
         self.listen_port = listen_port
         self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
         self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
         self.broadcast_interval = broadcast_interval
         self.broadcast_interval = broadcast_interval
-        self.known_peers: Dict[str, GRPCPeerHandle] = {}
-        self.peer_last_seen: Dict[str, float] = {}
+        self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float]] = {}
         self.broadcast_task = None
         self.broadcast_task = None
         self.listen_task = None
         self.listen_task = None
         self.cleanup_task = None
         self.cleanup_task = None
@@ -74,7 +73,7 @@ class GRPCDiscovery(Discovery):
                     if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
                     if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
                     break  # No new peers found in the grace period, we are done
                     break  # No new peers found in the grace period, we are done
 
 
-        return list(self.known_peers.values())
+        return [peer_handle for peer_handle, _ in self.known_peers.values()]
 
 
     async def task_broadcast_presence(self):
     async def task_broadcast_presence(self):
         transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
         transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
@@ -110,9 +109,9 @@ class GRPCDiscovery(Discovery):
             peer_port = message['grpc_port']
             peer_port = message['grpc_port']
             device_capabilities = DeviceCapabilities(**message['device_capabilities'])
             device_capabilities = DeviceCapabilities(**message['device_capabilities'])
             if peer_id not in self.known_peers:
             if peer_id not in self.known_peers:
-                self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+                self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time())
                 if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
                 if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
-            self.peer_last_seen[peer_id] = time.time()
+            self.known_peers[peer_id] = (self.known_peers[peer_id][0], time.time())
 
 
     async def task_listen_for_peers(self):
     async def task_listen_for_peers(self):
         await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
         await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
@@ -120,11 +119,17 @@ class GRPCDiscovery(Discovery):
 
 
     async def task_cleanup_peers(self):
     async def task_cleanup_peers(self):
         while True:
         while True:
-            current_time = time.time()
-            timeout = 15 * self.broadcast_interval
-            peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
-            for peer_id in peers_to_remove:
-                del self.known_peers[peer_id]
-                del self.peer_last_seen[peer_id]
-                if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
-            await asyncio.sleep(self.broadcast_interval)
+            try:
+                current_time = time.time()
+                timeout = 15 * self.broadcast_interval
+                peers_to_remove = [peer_handle.id() for peer_handle, last_seen in self.known_peers.values() if not await peer_handle.is_connected() or current_time - last_seen > timeout]
+                if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, last_seen={last_seen}" for peer_handle, last_seen in self.known_peers.values()})
+                if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
+                for peer_id in peers_to_remove:
+                    if peer_id in self.known_peers: del self.known_peers[peer_id]
+                    if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
+                await asyncio.sleep(self.broadcast_interval)
+            except Exception as e:
+                print(f"Error in cleanup peers: {e}")
+                import traceback
+                print(traceback.format_exc())

+ 3 - 1
exo/networking/grpc/grpc_server.py

@@ -17,7 +17,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
 
     async def start(self) -> None:
     async def start(self) -> None:
         self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
         self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
-            ('grpc.max_metadata_size', 32*1024*1024)
+            ('grpc.max_metadata_size', 32*1024*1024),
+            ('grpc.max_send_message_length', 128 * 1024 * 1024),
+            ('grpc.max_receive_message_length', 128 * 1024 * 1024),
         ])
         ])
         node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
         node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
         listen_addr = f'{self.host}:{self.port}'
         listen_addr = f'{self.host}:{self.port}'

+ 7 - 5
exo/orchestration/standard_node.py

@@ -84,7 +84,7 @@ class StandardNode(Node):
             if result.size == 1:  # we got a new token out
             if result.size == 1:  # we got a new token out
                 self.buffered_token_output[request_id][0].append(result.item())
                 self.buffered_token_output[request_id][0].append(result.item())
                 self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
                 self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-            if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
+            if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
 
 
             if not is_finished:
             if not is_finished:
                 asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
                 asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
@@ -179,7 +179,8 @@ class StandardNode(Node):
         return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
         return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
 
 
     async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
     async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
-        self.topology.update_node(self.id, self.device_capabilities)
+        next_topology = Topology()
+        next_topology.update_node(self.id, self.device_capabilities)
 
 
         if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
         if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
 
 
@@ -187,8 +188,8 @@ class StandardNode(Node):
         visited.update(p.id() for p in self.peers)
         visited.update(p.id() for p in self.peers)
 
 
         for peer in self.peers:
         for peer in self.peers:
-            self.topology.update_node(peer.id(), peer.device_capabilities())
-            self.topology.add_edge(self.id, peer.id())
+            next_topology.update_node(peer.id(), peer.device_capabilities())
+            next_topology.add_edge(self.id, peer.id())
 
 
             if peer.id() in prev_visited:
             if peer.id() in prev_visited:
                 if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
                 if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
@@ -205,7 +206,8 @@ class StandardNode(Node):
             except Exception as e:
             except Exception as e:
                 print(f"Error collecting topology from {peer.id()}: {e}")
                 print(f"Error collecting topology from {peer.id()}: {e}")
 
 
-        return self.topology
+        self.topology = next_topology
+        return next_topology
 
 
     # TODO: unify this and collect_topology as global actions
     # TODO: unify this and collect_topology as global actions
     async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
     async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:

+ 6 - 2
main.py

@@ -11,13 +11,13 @@ from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
-from exo.helpers import print_yellow_exo
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
 parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
-parser.add_argument("--node-port", type=int, default=8080, help="Node port")
+parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
@@ -49,6 +49,10 @@ else:
         raise ValueError(f"Inference engine {args.inference_engine} not supported")
         raise ValueError(f"Inference engine {args.inference_engine} not supported")
 print(f"Using inference engine {inference_engine.__class__.__name__}")
 print(f"Using inference engine {inference_engine.__class__.__name__}")
 
 
+
+if args.node_port is None:
+    args.node_port = find_available_port(args.node_host)
+    if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)