|
@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
|
|
import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
|
|
|
-from mlx_lm.models.base import create_additive_causal_mask
|
|
|
+from mlx_lm.models.base import create_attention_mask
|
|
|
from mlx_lm.models.llama import TransformerBlock, ModelArgs
|
|
|
|
|
|
from ...shard import Shard
|
|
@@ -24,7 +24,6 @@ class ModelArgs(ModelArgs):
|
|
|
|
|
|
self.shard = Shard(**self.shard)
|
|
|
|
|
|
-
|
|
|
class LlamaModel(nn.Module):
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
super().__init__()
|
|
@@ -40,7 +39,6 @@ class LlamaModel(nn.Module):
|
|
|
self.layers.append(TransformerBlock(args=args))
|
|
|
else:
|
|
|
self.layers.append(IdentityBlock())
|
|
|
-
|
|
|
if self.args.shard.is_last_layer():
|
|
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
|
|
|
@@ -56,8 +54,7 @@ class LlamaModel(nn.Module):
|
|
|
|
|
|
mask = None
|
|
|
if h.shape[1] > 1:
|
|
|
- mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
|
|
|
- mask = mask.astype(h.dtype)
|
|
|
+ mask = create_attention_mask(h, cache)
|
|
|
|
|
|
if cache is None:
|
|
|
cache = [None] * len(self.layers)
|
|
@@ -69,13 +66,11 @@ class LlamaModel(nn.Module):
|
|
|
h = self.norm(h)
|
|
|
return h
|
|
|
|
|
|
-
|
|
|
class Model(nn.Module):
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
super().__init__()
|
|
|
self.args = args
|
|
|
self.model_type = args.model_type
|
|
|
-
|
|
|
self.model = LlamaModel(args)
|
|
|
if self.args.shard.is_last_layer():
|
|
|
if not args.tie_word_embeddings:
|
|
@@ -121,7 +116,9 @@ class Model(nn.Module):
|
|
|
|
|
|
@property
|
|
|
def head_dim(self):
|
|
|
- return self.args.hidden_size // self.args.num_attention_heads
|
|
|
+ return (
|
|
|
+ self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
|
|
+ )
|
|
|
|
|
|
@property
|
|
|
def n_kv_heads(self):
|