Browse Source

smart prompt longest prefix matching to avoid sending the same text through the NN again. speeds up prefill significantly

Alex Cheema 9 months ago
parent
commit
5c67e24c35
2 changed files with 43 additions and 5 deletions
  1. 21 2
      exo/api/chatgpt_api.py
  2. 22 3
      exo/helpers.py

+ 21 - 2
exo/api/chatgpt_api.py

@@ -8,7 +8,7 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 from exo import DEBUG, VERSION
-from exo.helpers import terminal_link
+from exo.helpers import terminal_link, PrefixDict
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 
@@ -49,6 +49,7 @@ shard_mappings = {
 }
 
 
+
 class Message:
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
         self.role = role
@@ -234,6 +235,11 @@ def parse_chat_request(data: dict):
     data.get("temperature", 0.0),
   )
 
+class PromptSession:
+  def __init__(self, request_id: str, timestamp: int, prompt: str):
+    self.request_id = request_id
+    self.timestamp = timestamp
+    self.prompt = prompt
 
 class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
@@ -241,6 +247,7 @@ class ChatGPTAPI:
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
     self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
+    self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     cors = aiohttp_cors.setup(self.app)
@@ -293,12 +300,24 @@ class ChatGPTAPI:
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
       )
-    request_id = str(uuid.uuid4())
 
     tokenizer = await resolve_tokenizer(shard.model_id)
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
+    request_id = None
+    match = self.prompts.find_longest_prefix(prompt)
+    if match:
+        if DEBUG >= 2:
+            print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
+        request_id = match[1].request_id
+        self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
+        # remove the matching prefix from the prompt
+        prompt = prompt[len(match[1].prompt):]
+    else:
+      request_id = str(uuid.uuid4())
+      self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
+
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
 

+ 22 - 3
exo/helpers.py

@@ -1,6 +1,7 @@
 import os
 import asyncio
-from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
+from typing import Any, Callable, TypeVar, Optional, Dict, Generic, Tuple, List
+from collections import defaultdict
 import socket
 import random
 import platform
@@ -97,8 +98,6 @@ def terminal_link(uri, label=None):
 
 T = TypeVar("T")
 K = TypeVar("K")
-
-
 class AsyncCallback(Generic[T]):
   def __init__(self) -> None:
     self.condition: asyncio.Condition = asyncio.Condition()
@@ -147,3 +146,23 @@ class AsyncCallbackSystem(Generic[K, T]):
   def trigger_all(self, *args: T) -> None:
     for callback in self.callbacks.values():
       callback.set(*args)
+
+
+K = TypeVar('K', bound=str)
+V = TypeVar('V')
+class PrefixDict(Generic[K, V]):
+    def __init__(self):
+        self.items: Dict[K, V] = {}
+
+    def add(self, key: K, value: V) -> None:
+        self.items[key] = value
+
+    def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
+        return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
+
+    def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
+        matches = self.find_prefix(argument)
+        if len(matches) == 0:
+            return None
+
+        return max(matches, key=lambda x: len(x[0]))