Browse Source

make a --max-kv-size cli arg

Alex Cheema 10 months ago
parent
commit
e06c506e1a
3 changed files with 8 additions and 6 deletions
  1. 2 2
      exo/inference/inference_engine.py
  2. 3 2
      exo/inference/mlx/sharded_inference_engine.py
  3. 3 2
      main.py

+ 2 - 2
exo/inference/inference_engine.py

@@ -16,11 +16,11 @@ class InferenceEngine(ABC):
     pass
 
 
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader', max_kv_size: int = 1024):
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
-    return MLXDynamicShardInferenceEngine(shard_downloader)
+    return MLXDynamicShardInferenceEngine(shard_downloader, max_kv_size=max_kv_size)
   elif inference_engine_name == "tinygrad":
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     import tinygrad.helpers

+ 3 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -9,9 +9,10 @@ from exo.download.shard_download import ShardDownloader
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, shard_downloader: ShardDownloader):
+  def __init__(self, shard_downloader: ShardDownloader, max_kv_size: int = 1024):
     self.shard = None
     self.shard_downloader = shard_downloader
+    self.max_kv_size = max_kv_size
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
@@ -36,5 +37,5 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
     model_path = await self.shard_downloader.ensure_shard(shard)
     model_shard, self.tokenizer = await load_shard(model_path, shard)
-    self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
+    self.stateful_sharded_model = StatefulShardedModel(shard, model_shard, max_kv_size=self.max_kv_size)
     self.shard = shard

+ 3 - 2
main.py

@@ -39,6 +39,7 @@ parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT
 parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
+parser.add_argument("--max-kv-size", type=int, default=1024, help="Max KV size for inference engine")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
@@ -52,7 +53,7 @@ print(f"Detected system: {system_info}")
 
 shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
-inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader, max_kv_size=args.max_kv_size)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 if args.node_port is None:
@@ -156,7 +157,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
   prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
 
   try:
-    print(f"Processing prompt (len=${len(prompt)}): {prompt}")
+    print(f"Processing prompt (len={len(prompt)}): {prompt}")
     await node.process_prompt(shard, prompt, None, request_id=request_id)
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)