|
@@ -51,6 +51,7 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
|
|
|
weights = load(str(model_path), shard)
|
|
|
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
|
|
|
weights = fix_bf16(weights)
|
|
|
+
|
|
|
with Context(BEAM=0):
|
|
|
# replace weights in model
|
|
|
load_state_dict(model, weights, strict=False, consume=False) # consume=True
|