Bläddra i källkod

remove unused layers

Rory Clear 8 månader sedan
förälder
incheckning
2fdda5174d
1 ändrade filer med 4 tillägg och 1 borttagningar
  1. 4 1
      exo/inference/tinygrad/inference.py

+ 4 - 1
exo/inference/tinygrad/inference.py

@@ -13,6 +13,7 @@ from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from .stateful_model import StatefulModel
 import asyncio
+import re
 
 Tensor.no_grad = True
 # default settings
@@ -51,7 +52,9 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     weights = load(str(model_path), shard)
   weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = fix_bf16(weights)
-
+  for k in list(weights):
+    if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
+        del weights[k]
   with Context(BEAM=0):
     # replace weights in model
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True