Kaynağa Gözat

keep 4 in RotatingKVCache

Alex Cheema 11 ay önce
ebeveyn
işleme
5101f03369
1 değiştirilmiş dosya ile 1 ekleme ve 1 silme
  1. 1 1
      exo/inference/mlx/sharded_model.py

+ 1 - 1
exo/inference/mlx/sharded_model.py

@@ -75,7 +75,7 @@ class StatefulShardedModel:
 
   def init_cache(self, request_id: str):
     kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
-    new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size) for n in kv_heads]
+    new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size, keep=4) for n in kv_heads]
 
     if len(self.caches) >= self.max_caches:
       self.caches.popitem(last=False)