Explorar o código

fix layer calculation for sharded llama

Alex Cheema hai 1 ano
pai
achega
6ee0547eff
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      inference/mlx/models/sharded_llama.py

+ 1 - 1
inference/mlx/models/sharded_llama.py

@@ -166,7 +166,7 @@ class LlamaModel(nn.Module):
         assert self.vocab_size > 0
         self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
         self.layers = [
-            TransformerBlock(args=args) for _ in range(args.shard.n_layers)
+            TransformerBlock(args=args) for _ in range(args.shard.end_layer - args.shard.start_layer + 1)
         ]
         self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)