Explorar o código

logs for file filtering, grpc_discovery -> udp_discovery

Alex Cheema hai 9 meses
pai
achega
65e0488ebe

+ 2 - 0
exo/download/hf/hf_helpers.py

@@ -235,6 +235,7 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
             if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
             if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
 
 
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
+        if DEBUG >= 2: print(f"Filtered file list {allow_patterns=} {ignore_patterns=}\noriginal: {file_list}\nfiltered: {filtered_file_list}")
         total_files = len(filtered_file_list)
         total_files = len(filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
         file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
@@ -353,4 +354,5 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
             shard_specific_patterns.append(sorted_file_names[-1])
             shard_specific_patterns.append(sorted_file_names[-1])
     else:
     else:
         shard_specific_patterns = ["*.safetensors"]
         shard_specific_patterns = ["*.safetensors"]
+    if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
     return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
     return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 1 - 1
exo/download/hf/hf_shard_download.py

@@ -41,7 +41,7 @@ class HFShardDownloader(ShardDownloader):
             try:
             try:
                 await task
                 await task
             except asyncio.CancelledError:
             except asyncio.CancelledError:
-                pass  # This is expected when cancelling a task
+                pass
             except Exception as e:
             except Exception as e:
                 if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
                 if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
                 traceback.print_exc()
                 traceback.print_exc()

+ 4 - 4
exo/networking/grpc/test_grpc_discovery.py

@@ -1,12 +1,12 @@
 import asyncio
 import asyncio
 import unittest
 import unittest
-from .grpc_discovery import GRPCDiscovery
+from ..udp_discovery import UDPDiscovery
 
 
 
 
-class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
+class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
   async def asyncSetUp(self):
-    self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
-    self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
+    self.node1 = UDPDiscovery("node1", 50051, 5678, 5679)
+    self.node2 = UDPDiscovery("node2", 50052, 5679, 5678)
     await self.node1.start()
     await self.node1.start()
     await self.node2.start()
     await self.node2.start()
 
 

+ 3 - 4
exo/networking/peer_handle.py

@@ -5,7 +5,6 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 
 
-
 class PeerHandle(ABC):
 class PeerHandle(ABC):
   @abstractmethod
   @abstractmethod
   def id(self) -> str:
   def id(self) -> str:
@@ -36,13 +35,13 @@ class PeerHandle(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
     pass
     pass

+ 5 - 8
exo/networking/grpc/grpc_discovery.py → exo/networking/udp_discovery.py

@@ -2,10 +2,11 @@ import asyncio
 import json
 import json
 import socket
 import socket
 import time
 import time
+import traceback
 from typing import List, Dict, Callable, Tuple, Coroutine
 from typing import List, Dict, Callable, Tuple, Coroutine
-from ..discovery import Discovery
-from ..peer_handle import PeerHandle
-from .grpc_peer_handle import GRPCPeerHandle
+from .discovery import Discovery
+from .peer_handle import PeerHandle
+from .grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo import DEBUG_DISCOVERY
 from exo import DEBUG_DISCOVERY
 
 
@@ -23,7 +24,7 @@ class ListenProtocol(asyncio.DatagramProtocol):
     asyncio.create_task(self.on_message(data, addr))
     asyncio.create_task(self.on_message(data, addr))
 
 
 
 
-class GRPCDiscovery(Discovery):
+class UDPDiscovery(Discovery):
   def __init__(
   def __init__(
     self,
     self,
     node_id: str,
     node_id: str,
@@ -114,8 +115,6 @@ class GRPCDiscovery(Discovery):
         await asyncio.sleep(self.broadcast_interval)
         await asyncio.sleep(self.broadcast_interval)
       except Exception as e:
       except Exception as e:
         print(f"Error in broadcast presence: {e}")
         print(f"Error in broadcast presence: {e}")
-        import traceback
-
         print(traceback.format_exc())
         print(traceback.format_exc())
 
 
   async def on_listen_message(self, data, addr):
   async def on_listen_message(self, data, addr):
@@ -185,6 +184,4 @@ class GRPCDiscovery(Discovery):
         await asyncio.sleep(self.broadcast_interval)
         await asyncio.sleep(self.broadcast_interval)
       except Exception as e:
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(f"Error in cleanup peers: {e}")
-        import traceback
-
         print(traceback.format_exc())
         print(traceback.format_exc())

+ 2 - 2
main.py

@@ -6,7 +6,7 @@ import time
 import traceback
 import traceback
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
-from exo.networking.grpc.grpc_discovery import GRPCDiscovery
+from exo.networking.udp_discovery import UDPDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
@@ -48,7 +48,7 @@ if args.node_port is None:
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 
 args.node_id = args.node_id or get_or_create_node_id()
 args.node_id = args.node_id or get_or_create_node_id()
-discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
+discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
 chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
 chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
 web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
 web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
 if DEBUG >= 0:
 if DEBUG >= 0: