|
@@ -404,6 +404,6 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
|
elif shard.is_last_layer():
|
|
elif shard.is_last_layer():
|
|
shard_specific_patterns.add(sorted_file_names[-1])
|
|
shard_specific_patterns.add(sorted_file_names[-1])
|
|
else:
|
|
else:
|
|
- shard_specific_patterns = ["*.safetensors"]
|
|
|
|
|
|
+ shard_specific_patterns = set("*.safetensors")
|
|
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
|
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
|
return list(default_patterns | shard_specific_patterns)
|
|
return list(default_patterns | shard_specific_patterns)
|