ソースを参照

fix mistral nemo

Alex Cheema 11 ヶ月 前
コミット
417114fae4
2 ファイル変更19 行追加12 行削除
  1. 5 8
      exo/inference/mlx/models/llama.py
  2. 14 4
      exo/inference/mlx/sharded_model.py

+ 5 - 8
exo/inference/mlx/models/llama.py

@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
 import mlx.core as mx
 import mlx.nn as nn
 
-from mlx_lm.models.base import create_additive_causal_mask
+from mlx_lm.models.base import create_attention_mask
 from mlx_lm.models.llama import TransformerBlock, ModelArgs
 
 from ...shard import Shard
@@ -24,7 +24,6 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
-
 class LlamaModel(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
@@ -40,7 +39,6 @@ class LlamaModel(nn.Module):
         self.layers.append(TransformerBlock(args=args))
       else:
         self.layers.append(IdentityBlock())
-
     if self.args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 
@@ -56,8 +54,7 @@ class LlamaModel(nn.Module):
 
     mask = None
     if h.shape[1] > 1:
-      mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
-      mask = mask.astype(h.dtype)
+      mask = create_attention_mask(h, cache)
 
     if cache is None:
       cache = [None] * len(self.layers)
@@ -69,13 +66,11 @@ class LlamaModel(nn.Module):
       h = self.norm(h)
     return h
 
-
 class Model(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
     self.args = args
     self.model_type = args.model_type
-
     self.model = LlamaModel(args)
     if self.args.shard.is_last_layer():
       if not args.tie_word_embeddings:
@@ -121,7 +116,9 @@ class Model(nn.Module):
 
   @property
   def head_dim(self):
-    return self.args.hidden_size // self.args.num_attention_heads
+    return (
+      self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
+    )
 
   @property
   def n_kv_heads(self):

+ 14 - 4
exo/inference/mlx/sharded_model.py

@@ -3,7 +3,7 @@ from collections import OrderedDict
 
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.models.base import RotatingKVCache
+from mlx_lm.models.base import KVCache, RotatingKVCache
 from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
@@ -74,10 +74,20 @@ class StatefulShardedModel:
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
 
   def init_cache(self, request_id: str):
-    kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
-    new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size, keep=4) for n in kv_heads]
+    kv_heads = (
+      [self.model.n_kv_heads] * len(self.model.layers)
+      if isinstance(self.model.n_kv_heads, int)
+      else self.model.n_kv_heads
+    )
+    if self.max_kv_size is not None:
+      cache = [
+        RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4)
+        for n in kv_heads
+      ]
+    else:
+      cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
 
     if len(self.caches) >= self.max_caches:
       self.caches.popitem(last=False)
 
-    self.caches[request_id] = new_cache
+    self.caches[request_id] = cache