Browse Source

remove some logs, make get_allow_patterns out of class

Alex Cheema 1 year ago
parent
commit
357331c55f

+ 25 - 0
exo/download/hf/hf_helpers.py

@@ -11,6 +11,7 @@ from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from exo.helpers import DEBUG
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
+from exo.inference.shard import Shard
 
 T = TypeVar("T")
 def filter_repo_objects(
@@ -278,3 +279,27 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
         if part.isdigit():
             return int(part)
     return None
+
+
+def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
+    default_patterns = [
+        "*.json",
+        "*.py",
+        "tokenizer.model",
+        "*.tiktoken",
+        "*.txt",
+    ]
+    shard_specific_patterns = []
+    if weight_map:
+        for tensor_name, filename in weight_map.items():
+            layer_num = extract_layer_num(tensor_name)
+            if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
+                shard_specific_patterns.append(filename)
+        sorted_file_names = sorted(weight_map.values())
+        if shard.is_first_layer():
+            shard_specific_patterns.append(sorted_file_names[0])
+        elif shard.is_last_layer():
+            shard_specific_patterns.append(sorted_file_names[-1])
+    else:
+        shard_specific_patterns = ["*.safetensors"]
+    return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 2 - 21
exo/download/hf/hf_shard_download.py

@@ -5,7 +5,7 @@ from typing import Dict, List, Tuple
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns
 from exo.helpers import AsyncCallbackSystem, DEBUG
 
 class HFShardDownloader(ShardDownloader):
@@ -57,7 +57,7 @@ class HFShardDownloader(ShardDownloader):
             self._on_progress.trigger_all(shard, event)
 
         weight_map = await get_weight_map(shard.model_id)
-        allow_patterns = self._get_allow_patterns(weight_map, shard.start_layer, shard.end_layer)
+        allow_patterns = get_allow_patterns(weight_map, shard)
 
         return await download_repo_files(
             repo_id=shard.model_id,
@@ -65,25 +65,6 @@ class HFShardDownloader(ShardDownloader):
             allow_patterns=allow_patterns
         )
 
-    @staticmethod
-    def _get_allow_patterns(weight_map: Dict[str, str], start_layer: int, end_layer: int) -> List[str]:
-        default_patterns = [
-            "*.json",
-            "*.py",
-            "tokenizer.model",
-            "*.tiktoken",
-            "*.txt",
-        ]
-        shard_specific_patterns = []
-        if weight_map:
-            for tensor_name, filename in weight_map.items():
-                layer_num = extract_layer_num(tensor_name)
-                if layer_num is not None and start_layer <= layer_num <= end_layer:
-                    shard_specific_patterns.append(filename)
-        else:
-            shard_specific_patterns = ["*.safetensors"]
-        return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
-
     @property
     def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
         return self._on_progress

+ 2 - 7
exo/orchestration/standard_node.py

@@ -289,15 +289,14 @@ class StandardNode(Node):
 
   async def update_peers(self, wait_for_peers: int = 0) -> None:
     self.peers = await self.discovery.discover_peers(wait_for_peers)
-    if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
-    if DEBUG >= 2: print("Connecting to new peers...")
     for peer in self.peers:
       is_connected = await peer.is_connected()
       if DEBUG >= 2 and is_connected:
         print(f"Already connected to {peer.id()}: {is_connected}")
       if not is_connected:
+        if DEBUG >= 2: print(f"Connecting to {peer.id()}...")
         await peer.connect()
-        if DEBUG >= 0: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
+        if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
 
   async def periodic_topology_collection(self, interval: int):
     while True:
@@ -308,9 +307,6 @@ class StandardNode(Node):
       except Exception as e:
         print(f"Error collecting topology: {e}")
 
-      if DEBUG >= 2: print("Topology collection task executed.")
-      if DEBUG >= 2: print(f"Current topology: {self.topology}")
-
   async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     if request_id not in self.buffered_token_output:
       return None, False
@@ -330,7 +326,6 @@ class StandardNode(Node):
       next_topology.add_edge(self.id, peer.id())
 
       if peer.id() in prev_visited:
-        if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
         continue
 
       if max_depth <= 0: