Преглед изворни кода

Revert "add max-caches option"

This reverts commit 9e3ae2b4b4e7ed9316ed10fda58df3a350c4d138.
Alex Cheema пре 6 месеци
родитељ
комит
ab91f20296

+ 2 - 2
exo/inference/inference_engine.py

@@ -17,13 +17,13 @@ class InferenceEngine(ABC):
     pass
 
 
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader', max_caches: int = 2):
+def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:
     print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
-    return MLXDynamicShardInferenceEngine(shard_downloader, max_caches=max_caches)
+    return MLXDynamicShardInferenceEngine(shard_downloader)
   elif inference_engine_name == "tinygrad":
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     import tinygrad.helpers

+ 2 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -12,11 +12,10 @@ from functools import partial
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, shard_downloader: ShardDownloader, max_caches: int = 2):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
-    self.max_caches = max_caches
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
@@ -51,5 +50,5 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
         return asyncio.run(load_shard(model_path, shard))
 
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard, self.max_caches)
+      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard

+ 8 - 1
exo/inference/mlx/sharded_model.py

@@ -11,9 +11,10 @@ from ..shard import Shard
 
 # TODO: support a speculative model so we can parallelise compute across devices
 class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module, max_caches: int = 2):
+  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
     self.shard = shard
     self.model = model
+    self.max_kv_size = max_kv_size
     self.max_caches = max_caches
     self.caches = OrderedDict()
 
@@ -74,6 +75,12 @@ class StatefulShardedModel:
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
 
   def init_cache(self, request_id: str):
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    # if self.max_kv_size is not None:
+      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    # else:
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
     cache = make_prompt_cache(self.model)
 
     if len(self.caches) >= self.max_caches:

+ 1 - 2
exo/main.py

@@ -36,7 +36,6 @@ parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
-parser.add_argument("--max-caches", type=int, default=2, help="Max caches to keep in memory at once.")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
@@ -65,7 +64,7 @@ shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
-inference_engine = get_inference_engine(inference_engine_name, shard_downloader, max_caches=args.max_caches)
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 if args.node_port is None: