浏览代码

update mlx to 0.19.3, mlx-lm to 0.19.2

Alex Cheema 6 月之前
父节点
当前提交
bc7acfd37b

+ 1 - 1
exo/inference/mlx/models/base.py

@@ -1,7 +1,7 @@
 from typing import Optional
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
 
 class IdentityBlock(nn.Module):

+ 1 - 1
exo/inference/mlx/models/deepseek_v2.py

@@ -4,7 +4,7 @@ from typing import Optional
 import mlx.core as mx
 import mlx.nn as nn
 
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from .base import IdentityBlock
 from exo.inference.shard import Shard

+ 6 - 5
exo/inference/mlx/sharded_model.py

@@ -3,7 +3,7 @@ from collections import OrderedDict
 
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.models.base import KVCache, RotatingKVCache
+from mlx_lm.models.cache import make_prompt_cache
 from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
@@ -76,11 +76,12 @@ class StatefulShardedModel:
 
   def init_cache(self, request_id: str):
     kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
-    if self.max_kv_size is not None:
+    # if self.max_kv_size is not None:
       # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
-      cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-    else:
-      cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    # else:
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    cache = make_prompt_cache(self.model)
 
     if len(self.caches) >= self.max_caches:
       self.caches.popitem(last=False)

+ 2 - 4
exo/inference/mlx/sharded_utils.py

@@ -12,13 +12,13 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from io import BytesIO
 import base64
+import traceback
 
 import mlx.core as mx
 import mlx.nn as nn
 from transformers import AutoProcessor
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
-from mlx_lm.tuner.utils import apply_lora_layers
 
 from exo import DEBUG
 from ..shard import Shard
@@ -53,6 +53,7 @@ def _get_classes(config: dict):
   except ImportError:
     msg = f"Model type {model_type} not supported."
     logging.error(msg)
+    traceback.print_exc()
     raise ValueError(msg)
 
   return arch.Model, arch.ModelArgs
@@ -167,9 +168,6 @@ async def load_shard(
   lazy: bool = False,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   model = load_model_shard(model_path, shard, lazy, model_config)
-  if adapter_path is not None:
-    model = apply_lora_layers(model, adapter_path)
-    model.eval()
 
   # TODO: figure out a generic solution
   if model.model_type == "llava":

+ 1 - 1
exo/inference/mlx/test_sharded_llava.py

@@ -5,7 +5,7 @@ from PIL import Image
 from io import BytesIO
 
 import mlx.core as mx
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
 from exo.inference.mlx.sharded_model import StatefulShardedModel
 from exo.inference.mlx.sharded_utils import load_shard

+ 2 - 2
setup.py

@@ -34,8 +34,8 @@ extras_require = {
     "yapf==0.40.2",
   ],
   "apple_silicon": [
-    "mlx==0.18.0",
-    "mlx-lm==0.18.2",
+    "mlx==0.19.3",
+    "mlx-lm==0.19.2",
   ],
 }