|
@@ -8,7 +8,7 @@ from typing import List, Literal, Union, Dict
|
|
from aiohttp import web
|
|
from aiohttp import web
|
|
import aiohttp_cors
|
|
import aiohttp_cors
|
|
from exo import DEBUG, VERSION
|
|
from exo import DEBUG, VERSION
|
|
-from exo.helpers import terminal_link
|
|
|
|
|
|
+from exo.helpers import terminal_link, PrefixDict
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
from exo.orchestration import Node
|
|
from exo.orchestration import Node
|
|
|
|
|
|
@@ -49,6 +49,7 @@ shard_mappings = {
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
class Message:
|
|
class Message:
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
self.role = role
|
|
self.role = role
|
|
@@ -234,6 +235,11 @@ def parse_chat_request(data: dict):
|
|
data.get("temperature", 0.0),
|
|
data.get("temperature", 0.0),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+class PromptSession:
|
|
|
|
+ def __init__(self, request_id: str, timestamp: int, prompt: str):
|
|
|
|
+ self.request_id = request_id
|
|
|
|
+ self.timestamp = timestamp
|
|
|
|
+ self.prompt = prompt
|
|
|
|
|
|
class ChatGPTAPI:
|
|
class ChatGPTAPI:
|
|
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
|
|
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
|
|
@@ -241,6 +247,7 @@ class ChatGPTAPI:
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.response_timeout_secs = response_timeout_secs
|
|
self.response_timeout_secs = response_timeout_secs
|
|
self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
|
|
self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
|
|
|
|
+ self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
|
|
self.prev_token_lens: Dict[str, int] = {}
|
|
self.prev_token_lens: Dict[str, int] = {}
|
|
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
|
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
|
cors = aiohttp_cors.setup(self.app)
|
|
cors = aiohttp_cors.setup(self.app)
|
|
@@ -293,12 +300,24 @@ class ChatGPTAPI:
|
|
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
|
|
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
|
|
status=400,
|
|
status=400,
|
|
)
|
|
)
|
|
- request_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
|
|
|
prompt, image_str = build_prompt(tokenizer, chat_request.messages)
|
|
prompt, image_str = build_prompt(tokenizer, chat_request.messages)
|
|
|
|
+ request_id = None
|
|
|
|
+ match = self.prompts.find_longest_prefix(prompt)
|
|
|
|
+ if match:
|
|
|
|
+ if DEBUG >= 2:
|
|
|
|
+ print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
|
|
|
|
+ request_id = match[1].request_id
|
|
|
|
+ self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
|
|
|
+ # remove the matching prefix from the prompt
|
|
|
|
+ prompt = prompt[len(match[1].prompt):]
|
|
|
|
+ else:
|
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
|
+ self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
|
|
|
+
|
|
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
callback = self.node.on_token.register(callback_id)
|
|
callback = self.node.on_token.register(callback_id)
|
|
|
|
|