Browse Source

also initialize embed_tokens if last layer and tie_word_embeddings true

Alex Cheema 9 months ago
parent
commit
ad09b4b3d9
1 changed files with 4 additions and 4 deletions
  1. 4 4
      exo/inference/mlx/models/llama.py

+ 4 - 4
exo/inference/mlx/models/llama.py

@@ -32,15 +32,15 @@ class LlamaModel(nn.Module):
     self.vocab_size = args.vocab_size
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
-    if self.args.shard.is_first_layer():
+    if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
     self.layers = []
     for i in range(self.num_hidden_layers):
-      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+      if args.shard.start_layer <= i <= args.shard.end_layer:
         self.layers.append(TransformerBlock(args=args))
       else:
         self.layers.append(IdentityBlock())
-    if self.args.shard.is_last_layer():
+    if args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 
   def __call__(
@@ -74,7 +74,7 @@ class Model(nn.Module):
     self.args = args
     self.model_type = args.model_type
     self.model = LlamaModel(args)
-    if self.args.shard.is_last_layer():
+    if args.shard.is_last_layer():
       if not args.tie_word_embeddings:
         self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)