|
@@ -1,7 +1,5 @@
|
|
|
-from typing import Dict, Tuple
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
-import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
from mlx_lm.models.cache import make_prompt_cache
|
|
|
|
|
@@ -16,12 +14,6 @@ class StatefulModel(nn.Module):
|
|
|
self.caches = OrderedDict()
|
|
|
|
|
|
def init_cache(self, request_id: str):
|
|
|
- kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
|
|
|
- # if self.max_kv_size is not None:
|
|
|
- # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
|
|
|
- # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
- # else:
|
|
|
- # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
cache = make_prompt_cache(self.model)
|
|
|
|
|
|
if len(self.caches) >= self.max_caches:
|
|
@@ -39,4 +31,3 @@ class StatefulModel(nn.Module):
|
|
|
|
|
|
y = self.model(x, cache=cache)
|
|
|
return y
|
|
|
-
|