|
@@ -31,7 +31,7 @@ class Qwen2Model(nn.Module):
|
|
|
self.num_hidden_layers = args.num_hidden_layers
|
|
|
assert self.vocab_size > 0
|
|
|
|
|
|
- if self.args.shard.is_first_layer():
|
|
|
+ if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
|
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
|
|
|
|
self.layers = []
|