|
@@ -1,18 +1,21 @@
|
|
|
from typing import Dict, Generator, Optional, Tuple
|
|
|
+from collections import OrderedDict
|
|
|
|
|
|
import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
-from mlx_lm.models.base import KVCache
|
|
|
+from mlx_lm.models.base import RotatingKVCache
|
|
|
from mlx_lm.sample_utils import top_p_sampling
|
|
|
|
|
|
from ..shard import Shard
|
|
|
|
|
|
|
|
|
class StatefulShardedModel:
|
|
|
- def __init__(self, shard: Shard, model: nn.Module):
|
|
|
+ def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
|
|
|
self.shard = shard
|
|
|
self.model = model
|
|
|
- self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
|
|
|
+ self.max_kv_size = max_kv_size
|
|
|
+ self.max_caches = max_caches
|
|
|
+ self.caches = OrderedDict()
|
|
|
|
|
|
def step(
|
|
|
self,
|
|
@@ -41,13 +44,17 @@ class StatefulShardedModel:
|
|
|
|
|
|
y = x
|
|
|
|
|
|
- if request_id not in self.request_cache:
|
|
|
+ if request_id not in self.caches:
|
|
|
self.init_cache(request_id)
|
|
|
+ else:
|
|
|
+ self.caches.move_to_end(request_id)
|
|
|
+
|
|
|
+ cache = self.caches[request_id]
|
|
|
|
|
|
if pixel_values is None:
|
|
|
- output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
|
|
|
+ output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
|
|
|
else:
|
|
|
- output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id])
|
|
|
+ output = self.model(y, pixel_values=pixel_values, cache=cache)
|
|
|
|
|
|
if self.shard.is_last_layer():
|
|
|
logits = output[:, -1, :]
|
|
@@ -58,13 +65,19 @@ class StatefulShardedModel:
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
+ request_id: str,
|
|
|
x,
|
|
|
temp: float = 0.0,
|
|
|
top_p: float = 1.0,
|
|
|
logit_bias: Optional[Dict[int, float]] = None,
|
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
|
- return self.step(x, temp, top_p, logit_bias)
|
|
|
+ return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
|
|
|
|
|
|
def init_cache(self, request_id: str):
|
|
|
kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
|
|
|
- self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
+ new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size) for n in kv_heads]
|
|
|
+
|
|
|
+ if len(self.caches) >= self.max_caches:
|
|
|
+ self.caches.popitem(last=False)
|
|
|
+
|
|
|
+ self.caches[request_id] = new_cache
|