|
@@ -391,25 +391,19 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
|
|
|
|
|
|
|
|
|
|
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
|
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
|
- default_patterns = [
|
|
|
|
- "*.json",
|
|
|
|
- "*.py",
|
|
|
|
- "tokenizer.model",
|
|
|
|
- "*.tiktoken",
|
|
|
|
- "*.txt",
|
|
|
|
- ]
|
|
|
|
- shard_specific_patterns = []
|
|
|
|
|
|
+ default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
|
|
|
|
+ shard_specific_patterns = set()
|
|
if weight_map:
|
|
if weight_map:
|
|
for tensor_name, filename in weight_map.items():
|
|
for tensor_name, filename in weight_map.items():
|
|
layer_num = extract_layer_num(tensor_name)
|
|
layer_num = extract_layer_num(tensor_name)
|
|
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
|
|
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
|
|
- shard_specific_patterns.append(filename)
|
|
|
|
|
|
+ shard_specific_patterns.add(filename)
|
|
sorted_file_names = sorted(weight_map.values())
|
|
sorted_file_names = sorted(weight_map.values())
|
|
if shard.is_first_layer():
|
|
if shard.is_first_layer():
|
|
- shard_specific_patterns.append(sorted_file_names[0])
|
|
|
|
|
|
+ shard_specific_patterns.add(sorted_file_names[0])
|
|
elif shard.is_last_layer():
|
|
elif shard.is_last_layer():
|
|
- shard_specific_patterns.append(sorted_file_names[-1])
|
|
|
|
|
|
+ shard_specific_patterns.add(sorted_file_names[-1])
|
|
else:
|
|
else:
|
|
shard_specific_patterns = ["*.safetensors"]
|
|
shard_specific_patterns = ["*.safetensors"]
|
|
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
|
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
|
- return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates
|
|
|
|
|
|
+ return list(default_patterns | shard_specific_patterns)
|