ソースを参照

fix layer calculation for sharded llama

Alex Cheema 1 年間 前
コミット
6ee0547eff
1 ファイル変更1 行追加1 行削除
  1. 1 1
      inference/mlx/models/sharded_llama.py

+ 1 - 1
inference/mlx/models/sharded_llama.py

@@ -166,7 +166,7 @@ class LlamaModel(nn.Module):
         assert self.vocab_size > 0
         self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
         self.layers = [
-            TransformerBlock(args=args) for _ in range(args.shard.n_layers)
+            TransformerBlock(args=args) for _ in range(args.shard.end_layer - args.shard.start_layer + 1)
         ]
         self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)