Jelajahi Sumber

switch to uvloop (faster asyncio event loop) and optimise grpc settings

Alex Cheema 4 bulan lalu
induk
melakukan
0a07223074

+ 4 - 4
exo/inference/mlx/sharded_inference_engine.py

@@ -5,7 +5,7 @@ from mlx_lm.sample_utils import top_p_sampling, make_sampler
 import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
 from .sharded_utils import load_shard, get_image_from_str
-from .losses import loss_fns 
+from .losses import loss_fns
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
@@ -56,7 +56,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
     self.model.load_weights(path)
-    
+
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     state = await self.poll_state(request_id)
@@ -102,7 +102,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
     score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
     #print(f"{score=}")
-      
+
     layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
     #print(layers[0])
 
@@ -117,7 +117,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard != shard:
       model_shard, self.tokenizer = await load_shard(model_path, shard)
       self.shard = shard
-      self.model = model_shard 
+      self.model = model_shard
       self.caches = OrderedDict()
       self.session = {}
 

+ 54 - 18
exo/main.py

@@ -13,7 +13,6 @@ import uuid
 import numpy as np
 from functools import partial
 from tqdm import tqdm
-from tqdm.asyncio import tqdm_asyncio
 from exo.train.dataset import load_dataset, iterate_batches, compose
 from exo.networking.manual.manual_discovery import ManualDiscovery
 from exo.networking.manual.network_topology_config import NetworkTopology
@@ -33,6 +32,41 @@ from exo.inference.tokenizers import resolve_tokenizer
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
+import uvloop
+from contextlib import asynccontextmanager
+import concurrent.futures
+import socket
+import resource
+import psutil
+
+# Configure uvloop for maximum performance
+def configure_uvloop():
+    # Install uvloop as event loop policy
+    uvloop.install()
+
+    # Create new event loop
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+
+    # Increase file descriptor limits on Unix systems
+    if not psutil.WINDOWS:
+      soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+      try:
+          resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
+      except ValueError:
+        try:
+          resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
+        except ValueError:
+          pass
+
+    # Configure thread pool for blocking operations
+    loop.set_default_executor(
+      concurrent.futures.ThreadPoolExecutor(
+        max_workers=min(32, (os.cpu_count() or 1) * 4)
+      )
+    )
+
+    return loop
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -223,7 +257,7 @@ def clean_path(path):
 async def hold_outstanding(node: Node):
   while node.outstanding_requests:
     await asyncio.sleep(.5)
-  return 
+  return
 
 async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
   losses = []
@@ -234,7 +268,7 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
     tokens.append(np.sum(lengths))
   total_tokens = np.sum(tokens)
   total_loss = np.sum(losses) / total_tokens
-  
+
   return total_loss, total_tokens
 
 async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
@@ -270,7 +304,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
       await hold_outstanding(node)
   await hold_outstanding(node)
 
-  
+
 async def main():
   loop = asyncio.get_running_loop()
 
@@ -285,7 +319,7 @@ async def main():
           {"❌ No read access" if not has_read else ""}
           {"❌ No write access" if not has_write else ""}
           """)
-    
+
   if not args.models_seed_dir is None:
     try:
       models_seed_dir = clean_path(args.models_seed_dir)
@@ -330,29 +364,31 @@ async def main():
         print("Error: This train ain't leaving the station without a model")
         return
       await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
-    
+
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     await asyncio.Event().wait()
-  
+
   if args.wait_for_peers > 0:
     print("Cooldown to allow peers to exit gracefully")
     for i in tqdm(range(50)):
       await asyncio.sleep(.1)
 
+@asynccontextmanager
+async def setup_node(args):
+    # Rest of setup_node implementation...
+    pass
 
 def run():
-  loop = asyncio.new_event_loop()
-  asyncio.set_event_loop(loop)
-  try:
-    loop.run_until_complete(main())
-      
-  except KeyboardInterrupt:
-    print("Received keyboard interrupt. Shutting down...")
-  finally:
-    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
-    loop.close()
-
+    loop = None
+    try:
+        loop = configure_uvloop()
+        loop.run_until_complete(main())
+    except KeyboardInterrupt:
+        print("\nShutdown requested... exiting")
+    finally:
+        if loop:
+            loop.close()
 
 if __name__ == "__main__":
   run()

+ 25 - 6
exo/networking/grpc/grpc_peer_handle.py

@@ -21,6 +21,19 @@ class GRPCPeerHandle(PeerHandle):
     self._device_capabilities = device_capabilities
     self.channel = None
     self.stub = None
+    self.channel_options = [
+      ("grpc.max_metadata_size", 64 * 1024 * 1024),
+      ("grpc.max_receive_message_length", 256 * 1024 * 1024),
+      ("grpc.max_send_message_length", 256 * 1024 * 1024),
+      ("grpc.max_concurrent_streams", 100),
+      ("grpc.http2.min_time_between_pings_ms", 10000),
+      ("grpc.keepalive_time_ms", 20000),
+      ("grpc.keepalive_timeout_ms", 10000),
+      ("grpc.keepalive_permit_without_calls", 1),
+      ("grpc.http2.max_pings_without_data", 0),
+      ("grpc.tcp_nodelay", 1),
+      ("grpc.optimization_target", "throughput"),
+    ]
 
   def id(self) -> str:
     return self._id
@@ -36,11 +49,11 @@ class GRPCPeerHandle(PeerHandle):
 
   async def connect(self):
     if self.channel is None:
-      self.channel = grpc.aio.insecure_channel(self.address, options=[
-        ("grpc.max_metadata_size", 32*1024*1024),
-        ('grpc.max_receive_message_length', 32*1024*1024),
-        ('grpc.max_send_message_length', 32*1024*1024)
-      ])
+      self.channel = grpc.aio.insecure_channel(
+        self.address,
+        options=self.channel_options,
+        compression=grpc.Compression.Gzip
+      )
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
 
@@ -54,7 +67,13 @@ class GRPCPeerHandle(PeerHandle):
     self.stub = None
 
   async def _ensure_connected(self):
-    if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
+    if not await self.is_connected():
+      try:
+        await asyncio.wait_for(self.connect(), timeout=10.0)
+      except asyncio.TimeoutError:
+        if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
+        await self.disconnect()
+        raise
 
   async def health_check(self) -> bool:
     try:

+ 13 - 10
exo/networking/udp/udp_discovery.py

@@ -31,7 +31,7 @@ class BroadcastProtocol(asyncio.DatagramProtocol):
   def connection_made(self, transport):
     sock = transport.get_extra_info("socket")
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-    transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
+    transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
 
 
 class UDPDiscovery(Discovery):
@@ -84,11 +84,7 @@ class UDPDiscovery(Discovery):
     return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
 
   async def task_broadcast_presence(self):
-    if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
-
     while True:
-      # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
-      # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
       for addr, interface_name in get_all_ip_addresses_and_interfaces():
         interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
         message = json.dumps({
@@ -96,16 +92,23 @@ class UDPDiscovery(Discovery):
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": interface_priority,
           "interface_name": interface_name,
           "interface_type": interface_type,
         })
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
 
         transport = None
         try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
-          if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
+          # Create socket with explicit broadcast permission
+          sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+          sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+          sock.bind((addr, 0))
+          
+          # Create transport with the pre-configured socket
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
+            lambda: BroadcastProtocol(message, self.broadcast_port),
+            sock=sock
+          )
         except Exception as e:
           print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
         finally:
@@ -113,7 +116,7 @@ class UDPDiscovery(Discovery):
             try: transport.close()
             except Exception as e:
               if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
-              if DEBUG_DISCOVERY >= 2: traceback.print_exc()
+
       await asyncio.sleep(self.broadcast_interval)
 
   async def on_listen_message(self, data, addr):

+ 1 - 0
setup.py

@@ -27,6 +27,7 @@ install_requires = [
   "tqdm==4.66.4",
   "transformers==4.46.3",
   "uuid==1.30",
+  "uvloop==0.21.0",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
 ]