|
@@ -77,7 +77,8 @@ class StatefulShardedModel:
|
|
|
def init_cache(self, request_id: str):
|
|
|
kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
|
|
|
if self.max_kv_size is not None:
|
|
|
- cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
|
|
|
+ # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
|
|
|
+ cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
else:
|
|
|
cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
|