Browse Source

fix embed_tokens for last layer in qwen models

Alex Cheema 3 months ago
parent
commit
9c1bea97e8
1 changed files with 1 additions and 1 deletions
  1. 1 1
      exo/inference/mlx/models/qwen2.py

+ 1 - 1
exo/inference/mlx/models/qwen2.py

@@ -31,7 +31,7 @@ class Qwen2Model(nn.Module):
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
 
-    if self.args.shard.is_first_layer():
+    if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
 
     self.layers = []