|
@@ -75,7 +75,7 @@ class StatefulShardedModel:
|
|
|
|
|
|
def init_cache(self, request_id: str):
|
|
|
kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
|
|
|
- new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size) for n in kv_heads]
|
|
|
+ new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size, keep=4) for n in kv_heads]
|
|
|
|
|
|
if len(self.caches) >= self.max_caches:
|
|
|
self.caches.popitem(last=False)
|