|
@@ -13,7 +13,6 @@ from exo.download.shard_download import ShardDownloader
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from .stateful_model import StatefulModel
|
|
|
import asyncio
|
|
|
-import re
|
|
|
|
|
|
Tensor.no_grad = True
|
|
|
# default settings
|
|
@@ -52,9 +51,6 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
|
|
|
weights = load(str(model_path), shard)
|
|
|
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
|
|
|
weights = fix_bf16(weights)
|
|
|
- for k in list(weights):
|
|
|
- if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
|
|
|
- del weights[k]
|
|
|
with Context(BEAM=0):
|
|
|
# replace weights in model
|
|
|
load_state_dict(model, weights, strict=False, consume=False) # consume=True
|