|
@@ -3,7 +3,7 @@ from collections import OrderedDict
|
|
|
|
|
|
import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
-from mlx_lm.models.base import KVCache, RotatingKVCache
|
|
|
+from mlx_lm.models.cache import make_prompt_cache
|
|
|
from mlx_lm.sample_utils import top_p_sampling
|
|
|
|
|
|
from ..shard import Shard
|
|
@@ -76,11 +76,12 @@ class StatefulShardedModel:
|
|
|
|
|
|
def init_cache(self, request_id: str):
|
|
|
kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
|
|
|
- if self.max_kv_size is not None:
|
|
|
+ # if self.max_kv_size is not None:
|
|
|
# cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
|
|
|
- cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
- else:
|
|
|
- cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
+ # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
+ # else:
|
|
|
+ # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
+ cache = make_prompt_cache(self.model)
|
|
|
|
|
|
if len(self.caches) >= self.max_caches:
|
|
|
self.caches.popitem(last=False)
|