|
@@ -5,7 +5,7 @@ import json
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
from transformers import AutoTokenizer
|
|
|
-from typing import List, Literal, Union, Dict
|
|
|
+from typing import List, Literal, Union, Dict, Optional
|
|
|
from aiohttp import web
|
|
|
import aiohttp_cors
|
|
|
import traceback
|
|
@@ -23,23 +23,28 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
|
|
from exo.apputil import create_animation_mp4
|
|
|
|
|
|
class Message:
|
|
|
- def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
|
+ def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
|
|
self.role = role
|
|
|
self.content = content
|
|
|
+ self.tools = tools
|
|
|
|
|
|
def to_dict(self):
|
|
|
- return {"role": self.role, "content": self.content}
|
|
|
+ data = {"role": self.role, "content": self.content}
|
|
|
+ if self.tools:
|
|
|
+ data["tools"] = self.tools
|
|
|
+ return data
|
|
|
|
|
|
|
|
|
|
|
|
class ChatCompletionRequest:
|
|
|
- def __init__(self, model: str, messages: List[Message], temperature: float):
|
|
|
+ def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
|
|
self.model = model
|
|
|
self.messages = messages
|
|
|
self.temperature = temperature
|
|
|
+ self.tools = tools
|
|
|
|
|
|
def to_dict(self):
|
|
|
- return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
|
|
|
+ return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
|
|
|
|
|
|
|
|
|
def generate_completion(
|
|
@@ -119,20 +124,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
|
|
return remapped_messages
|
|
|
|
|
|
|
|
|
-def build_prompt(tokenizer, _messages: List[Message]):
|
|
|
+def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
|
|
messages = remap_messages(_messages)
|
|
|
- prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
|
|
|
- for message in messages:
|
|
|
- if not isinstance(message.content, list):
|
|
|
- continue
|
|
|
+ chat_template_args = {
|
|
|
+ "conversation": [m.to_dict() for m in messages],
|
|
|
+ "tokenize": False,
|
|
|
+ "add_generation_prompt": True
|
|
|
+ }
|
|
|
+ if tools: chat_template_args["tools"] = tools
|
|
|
|
|
|
+ prompt = tokenizer.apply_chat_template(**chat_template_args)
|
|
|
+ print(f"!!! Prompt: {prompt}")
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
def parse_message(data: dict):
|
|
|
if "role" not in data or "content" not in data:
|
|
|
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
|
|
- return Message(data["role"], data["content"])
|
|
|
+ return Message(data["role"], data["content"], data.get("tools"))
|
|
|
|
|
|
|
|
|
def parse_chat_request(data: dict, default_model: str):
|
|
@@ -140,6 +149,7 @@ def parse_chat_request(data: dict, default_model: str):
|
|
|
data.get("model", default_model),
|
|
|
[parse_message(msg) for msg in data["messages"]],
|
|
|
data.get("temperature", 0.0),
|
|
|
+ data.get("tools", None),
|
|
|
)
|
|
|
|
|
|
|
|
@@ -287,7 +297,7 @@ class ChatGPTAPI:
|
|
|
shard = build_base_shard(model, self.inference_engine_classname)
|
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
|
- prompt = build_prompt(tokenizer, messages)
|
|
|
+ prompt = build_prompt(tokenizer, messages, data.get("tools", None))
|
|
|
tokens = tokenizer.encode(prompt)
|
|
|
return web.json_response({
|
|
|
"length": len(prompt),
|
|
@@ -326,7 +336,7 @@ class ChatGPTAPI:
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
|
|
|
- prompt = build_prompt(tokenizer, chat_request.messages)
|
|
|
+ prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
|
|
request_id = str(uuid.uuid4())
|
|
|
if self.on_chat_completion_request:
|
|
|
try:
|