|
@@ -166,7 +166,7 @@ class LlamaModel(nn.Module):
|
|
|
assert self.vocab_size > 0
|
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
|
self.layers = [
|
|
|
- TransformerBlock(args=args) for _ in range(args.shard.n_layers)
|
|
|
+ TransformerBlock(args=args) for _ in range(args.shard.end_layer - args.shard.start_layer + 1)
|
|
|
]
|
|
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
|
|