|
@@ -404,8 +404,6 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
|
|
elif shard.is_last_layer():
|
|
|
shard_specific_patterns.add(sorted_file_names[-1])
|
|
|
else:
|
|
|
- shard_specific_patterns = ["*.safetensors"]
|
|
|
+ shard_specific_patterns = set("*.safetensors")
|
|
|
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
|
|
- allowed_patterns = list(default_patterns)
|
|
|
- allowed_patterns.extend(shard_specific_patterns)
|
|
|
- return list(set(allowed_patterns))
|
|
|
+ return list(default_patterns | shard_specific_patterns)
|