Browse Source

move to tinygrad_helpers

Rory Clear 7 months ago
parent
commit
1d1fa8c608
2 changed files with 6 additions and 5 deletions
  1. 0 4
      exo/inference/tinygrad/inference.py
  2. 6 1
      exo/inference/tinygrad/tinygrad_helpers.py

+ 0 - 4
exo/inference/tinygrad/inference.py

@@ -13,7 +13,6 @@ from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from .stateful_model import StatefulModel
 import asyncio
-import re
 
 Tensor.no_grad = True
 # default settings
@@ -52,9 +51,6 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     weights = load(str(model_path), shard)
   weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = fix_bf16(weights)
-  for k in list(weights):
-    if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
-        del weights[k]
   with Context(BEAM=0):
     # replace weights in model
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True

+ 6 - 1
exo/inference/tinygrad/tinygrad_helpers.py

@@ -7,6 +7,7 @@ from exo.inference.shard import Shard
 from exo.helpers import DEBUG
 from exo.download.hf.hf_helpers import get_allow_patterns
 from fnmatch import fnmatch
+import re
 
 
 # **** helper functions ****
@@ -42,6 +43,10 @@ def load(fn: str, shard: Shard):
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}
   elif fn.endswith(".safetensors"):
-    return safe_load(fn)
+    weight_map = safe_load(fn)
+    for k in list(weight_map):
+      if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
+          del weight_map[k]
+    return weight_map
   else:
     return torch_load(fn)