Browse Source

ignore files that dont match allow patterns

Alex Cheema 8 months ago
parent
commit
3bd5a116df
1 changed files with 5 additions and 0 deletions
  1. 5 0
      exo/inference/tinygrad/tinygrad_helpers.py

+ 5 - 0
exo/inference/tinygrad/tinygrad_helpers.py

@@ -5,6 +5,8 @@ import json
 from typing import List
 from typing import List
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
+from exo.download.hf.hf_helpers import get_allow_patterns
+from fnmatch import fnmatch
 
 
 # **** helper functions ****
 # **** helper functions ****
 def concat_weights(models, device=None):
 def concat_weights(models, device=None):
@@ -22,7 +24,10 @@ def load(fn:str, shard: Shard):
     with open(fn) as fp: weight_map = json.load(fp)['weight_map']
     with open(fn) as fp: weight_map = json.load(fp)['weight_map']
     parts = {}
     parts = {}
     filtered_weight_map = {}
     filtered_weight_map = {}
+    allow_patterns = get_allow_patterns(weight_map, shard)
     for k, n in weight_map.items():
     for k, n in weight_map.items():
+      if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns):
+        continue
       if k.startswith("model.layers."):
       if k.startswith("model.layers."):
         layer_num = int(k.split('.')[2])
         layer_num = int(k.split('.')[2])
         if layer_num < shard.start_layer or layer_num > shard.end_layer:
         if layer_num < shard.start_layer or layer_num > shard.end_layer: