|
@@ -5,6 +5,8 @@ import json
|
|
|
from typing import List
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.helpers import DEBUG
|
|
|
+from exo.download.hf.hf_helpers import get_allow_patterns
|
|
|
+from fnmatch import fnmatch
|
|
|
|
|
|
# **** helper functions ****
|
|
|
def concat_weights(models, device=None):
|
|
@@ -22,7 +24,10 @@ def load(fn:str, shard: Shard):
|
|
|
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
|
|
parts = {}
|
|
|
filtered_weight_map = {}
|
|
|
+ allow_patterns = get_allow_patterns(weight_map, shard)
|
|
|
for k, n in weight_map.items():
|
|
|
+ if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns):
|
|
|
+ continue
|
|
|
if k.startswith("model.layers."):
|
|
|
layer_num = int(k.split('.')[2])
|
|
|
if layer_num < shard.start_layer or layer_num > shard.end_layer:
|