Browse Source

use set for shard specific patterns

Alex Cheema 9 months ago
parent
commit
11dd952d26
1 changed files with 6 additions and 12 deletions
  1. 6 12
      exo/download/hf/hf_helpers.py

+ 6 - 12
exo/download/hf/hf_helpers.py

@@ -391,25 +391,19 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-  default_patterns = [
-    "*.json",
-    "*.py",
-    "tokenizer.model",
-    "*.tiktoken",
-    "*.txt",
-  ]
-  shard_specific_patterns = []
+  default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
+  shard_specific_patterns = set()
   if weight_map:
     for tensor_name, filename in weight_map.items():
       layer_num = extract_layer_num(tensor_name)
       if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
-        shard_specific_patterns.append(filename)
+        shard_specific_patterns.add(filename)
     sorted_file_names = sorted(weight_map.values())
     if shard.is_first_layer():
-      shard_specific_patterns.append(sorted_file_names[0])
+      shard_specific_patterns.add(sorted_file_names[0])
     elif shard.is_last_layer():
-      shard_specific_patterns.append(sorted_file_names[-1])
+      shard_specific_patterns.add(sorted_file_names[-1])
   else:
     shard_specific_patterns = ["*.safetensors"]
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
-  return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
+  return list(default_patterns | shard_specific_patterns)