|
@@ -32,15 +32,15 @@ class LlamaModel(nn.Module):
|
|
|
self.vocab_size = args.vocab_size
|
|
|
self.num_hidden_layers = args.num_hidden_layers
|
|
|
assert self.vocab_size > 0
|
|
|
- if self.args.shard.is_first_layer():
|
|
|
+ if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
|
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
|
self.layers = []
|
|
|
for i in range(self.num_hidden_layers):
|
|
|
- if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
|
|
+ if args.shard.start_layer <= i <= args.shard.end_layer:
|
|
|
self.layers.append(TransformerBlock(args=args))
|
|
|
else:
|
|
|
self.layers.append(IdentityBlock())
|
|
|
- if self.args.shard.is_last_layer():
|
|
|
+ if args.shard.is_last_layer():
|
|
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
|
|
|
|
def __call__(
|
|
@@ -74,7 +74,7 @@ class Model(nn.Module):
|
|
|
self.args = args
|
|
|
self.model_type = args.model_type
|
|
|
self.model = LlamaModel(args)
|
|
|
- if self.args.shard.is_last_layer():
|
|
|
+ if args.shard.is_last_layer():
|
|
|
if not args.tie_word_embeddings:
|
|
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
|
|