Преглед изворни кода

update mlx-lm to 0.17.0, use lru caches for kv_cache with RotatingKVCache to optimise memory fixes #158

Alex Cheema пре 8 месеци
родитељ
комит
cea9b48d24
2 измењених фајлова са 22 додато и 9 уклоњено
  1. 21 8
      exo/inference/mlx/sharded_model.py
  2. 1 1
      setup.py

+ 21 - 8
exo/inference/mlx/sharded_model.py

@@ -1,18 +1,21 @@
 from typing import Dict, Generator, Optional, Tuple
+from collections import OrderedDict
 
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.base import RotatingKVCache
 from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 
 
 class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module):
+  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
     self.shard = shard
     self.model = model
-    self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
+    self.max_kv_size = max_kv_size
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
 
   def step(
     self,
@@ -41,13 +44,17 @@ class StatefulShardedModel:
 
     y = x
 
-    if request_id not in self.request_cache:
+    if request_id not in self.caches:
       self.init_cache(request_id)
+    else:
+      self.caches.move_to_end(request_id)
+
+    cache = self.caches[request_id]
 
     if pixel_values is None:
-      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
+      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
     else:
-      output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id])
+      output = self.model(y, pixel_values=pixel_values, cache=cache)
 
     if self.shard.is_last_layer():
       logits = output[:, -1, :]
@@ -58,13 +65,19 @@ class StatefulShardedModel:
 
   def __call__(
     self,
+    request_id: str,
     x,
     temp: float = 0.0,
     top_p: float = 1.0,
     logit_bias: Optional[Dict[int, float]] = None,
   ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    return self.step(x, temp, top_p, logit_bias)
+    return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
 
   def init_cache(self, request_id: str):
     kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
-    self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size) for n in kv_heads]
+
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = new_cache

+ 1 - 1
setup.py

@@ -37,7 +37,7 @@ if sys.platform.startswith("darwin"):
     install_requires.extend(
         [
             "mlx==0.16.3",
-            "mlx-lm==0.16.1",
+            "mlx-lm==0.17.0",
         ]
     )