Sfoglia il codice sorgente

prefill in batches to prevent oom on very long prompts

Alex Cheema 8 mesi fa
parent
commit
f5764f3756
3 ha cambiato i file con 30 aggiunte e 10 eliminazioni
  1. 13 2
      exo/inference/mlx/sharded_model.py
  2. 14 3
      main.py
  3. 3 5
      setup.py

+ 13 - 2
exo/inference/mlx/sharded_model.py

@@ -4,7 +4,7 @@ from collections import OrderedDict
 import mlx.core as mx
 import mlx.nn as nn
 from mlx_lm.models.base import KVCache, RotatingKVCache
-from mlx_lm.sample_utils import top_p_sampling
+from mlx_lm.sample_utils import top_p_sampling, min_p_sampling, categorical_sampling
 
 from ..shard import Shard
 
@@ -24,7 +24,10 @@ class StatefulShardedModel:
     pixel_values=None,
     temp: float = 0.0,
     top_p: float = 1.0,
+    min_p: float = 0.0,
+    min_tokens_to_keep: int = 0,
     logit_bias: Optional[Dict[int, float]] = None,
+    prefill_step_size: int = 512,
   ) -> Generator[Tuple[mx.array, mx.array], None, None]:
     def sample(logits: mx.array) -> Tuple[mx.array, float]:
       if logit_bias:
@@ -37,8 +40,10 @@ class StatefulShardedModel:
       else:
         if top_p > 0 and top_p < 1.0:
           token = top_p_sampling(logits, top_p, temp)
+        elif min_p != 0.0:
+          token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
         else:
-          token = mx.random.categorical(logits*(1/temp))
+          token = categorical_sampling(logits, temp)
 
       return token
 
@@ -52,6 +57,12 @@ class StatefulShardedModel:
     cache = self.caches[request_id]
 
     if pixel_values is None:
+      if self.shard.is_first_layer():
+        while y.size > prefill_step_size:
+          self.model(y[:prefill_step_size][None], cache=cache)
+          mx.eval([c.state for c in cache])
+          y = y[prefill_step_size:]
+
       output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
     else:
       output = self.model(y, pixel_values=pixel_values, cache=cache)

+ 14 - 3
main.py

@@ -1,10 +1,13 @@
 import argparse
 import asyncio
+import aiofiles
 import signal
 import json
 import time
 import traceback
 import uuid
+from typing import Optional
+from pathlib import Path
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -39,6 +42,7 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--file", type=str, help="File to use for the model when using --run-model", default=None)
 args = parser.parse_args()
 
 print_yellow_exo()
@@ -131,7 +135,14 @@ async def shutdown(signal, loop):
   loop.stop()
 
 
-async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
+async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str, file_path: Optional[str] = None):
+  if file_path:
+    try:
+      import textract
+      prompt = "Input file: " + textract.process(file_path).decode('utf-8') + "\n\n---\n\n" + prompt
+    except Exception as e:
+      print(f"Error reading file {file_path}: {str(e)}")
+      return
   shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
@@ -145,7 +156,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
   prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
 
   try:
-    print(f"Processing prompt: {prompt}")
+    print(f"Processing prompt (len=${len(prompt)}): {prompt}")
     await node.process_prompt(shard, prompt, None, request_id=request_id)
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
@@ -172,7 +183,7 @@ async def main():
   await node.start(wait_for_peers=args.wait_for_peers)
 
   if args.run_model:
-    await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+    await run_model_cli(node, inference_engine, args.run_model, args.prompt, args.file)
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     await asyncio.Event().wait()

+ 3 - 5
setup.py

@@ -10,8 +10,6 @@ install_requires = [
   "blobfile==2.1.1",
   "grpcio==1.64.1",
   "grpcio-tools==1.64.1",
-  "hf-transfer==0.1.8",
-  "huggingface-hub==0.24.5",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "numpy==2.0.0",
@@ -24,9 +22,9 @@ install_requires = [
   "rich==13.7.1",
   "safetensors==0.4.3",
   "tenacity==9.0.0",
+  "textract==1.6.5",
   "tiktoken==0.7.0",
   "tokenizers==0.19.1",
-  "tqdm==4.66.4",
   "transformers==4.43.3",
   "uuid==1.30",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@639af3f823cf242a1945dc24183e52a9df0af2b7",
@@ -35,8 +33,8 @@ install_requires = [
 # Add macOS-specific packages if on Darwin (macOS)
 if sys.platform.startswith("darwin"):
   install_requires.extend([
-    "mlx==0.17.1",
-    "mlx-lm==0.17.0",
+    "mlx==0.17.2",
+    "mlx-lm==0.18.1",
   ])
 
 extras_require = {