Explorar o código

fix embed_tokens for last layer in qwen models

Alex Cheema hai 7 meses
pai
achega
9c1bea97e8
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      exo/inference/mlx/models/qwen2.py

+ 1 - 1
exo/inference/mlx/models/qwen2.py

@@ -31,7 +31,7 @@ class Qwen2Model(nn.Module):
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
 
-    if self.args.shard.is_first_layer():
+    if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
 
     self.layers = []