|
@@ -39,6 +39,7 @@ parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT
|
|
|
parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
|
|
parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
|
|
|
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
|
|
|
+parser.add_argument("--max-kv-size", type=int, default=1024, help="Max KV size for inference engine")
|
|
|
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
|
|
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
|
|
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
|
@@ -52,7 +53,7 @@ print(f"Detected system: {system_info}")
|
|
|
|
|
|
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
|
|
|
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
|
|
-inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
|
|
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader, max_kv_size=args.max_kv_size)
|
|
|
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
|
|
|
|
|
if args.node_port is None:
|
|
@@ -156,7 +157,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
|
|
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
|
|
|
|
|
|
try:
|
|
|
- print(f"Processing prompt (len=${len(prompt)}): {prompt}")
|
|
|
+ print(f"Processing prompt (len={len(prompt)}): {prompt}")
|
|
|
await node.process_prompt(shard, prompt, None, request_id=request_id)
|
|
|
|
|
|
_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
|