瀏覽代碼

keep 4 in RotatingKVCache

Alex Cheema 11 月之前
父節點
當前提交
5101f03369
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      exo/inference/mlx/sharded_model.py

+ 1 - 1
exo/inference/mlx/sharded_model.py

@@ -75,7 +75,7 @@ class StatefulShardedModel:
 
   def init_cache(self, request_id: str):
     kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
-    new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size) for n in kv_heads]
+    new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size, keep=4) for n in kv_heads]
 
     if len(self.caches) >= self.max_caches:
       self.caches.popitem(last=False)