瀏覽代碼

add chatgpt-api-compatible tools for function calling

Alex Cheema 4 月之前
父節點
當前提交
456fbdd2b0
共有 1 個文件被更改,包括 23 次插入13 次删除
  1. 23 13
      exo/api/chatgpt_api.py

+ 23 - 13
exo/api/chatgpt_api.py

@@ -5,7 +5,7 @@ import json
 import os
 from pathlib import Path
 from transformers import AutoTokenizer
-from typing import List, Literal, Union, Dict
+from typing import List, Literal, Union, Dict, Optional
 from aiohttp import web
 import aiohttp_cors
 import traceback
@@ -23,23 +23,28 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
 
 class Message:
-  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
     self.role = role
     self.content = content
+    self.tools = tools
 
   def to_dict(self):
-    return {"role": self.role, "content": self.content}
+    data = {"role": self.role, "content": self.content}
+    if self.tools:
+      data["tools"] = self.tools
+    return data
 
 
 
 class ChatCompletionRequest:
-  def __init__(self, model: str, messages: List[Message], temperature: float):
+  def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
     self.model = model
     self.messages = messages
     self.temperature = temperature
+    self.tools = tools
 
   def to_dict(self):
-    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
+    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
 
 
 def generate_completion(
@@ -119,20 +124,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
   return remapped_messages
 
 
-def build_prompt(tokenizer, _messages: List[Message]):
+def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
   messages = remap_messages(_messages)
-  prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  for message in messages:
-    if not isinstance(message.content, list):
-      continue
+  chat_template_args = {
+    "conversation": [m.to_dict() for m in messages],
+    "tokenize": False,
+    "add_generation_prompt": True
+  }
+  if tools: chat_template_args["tools"] = tools
 
+  prompt = tokenizer.apply_chat_template(**chat_template_args)
+  print(f"!!! Prompt: {prompt}")
   return prompt
 
 
 def parse_message(data: dict):
   if "role" not in data or "content" not in data:
     raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
-  return Message(data["role"], data["content"])
+  return Message(data["role"], data["content"], data.get("tools"))
 
 
 def parse_chat_request(data: dict, default_model: str):
@@ -140,6 +149,7 @@ def parse_chat_request(data: dict, default_model: str):
     data.get("model", default_model),
     [parse_message(msg) for msg in data["messages"]],
     data.get("temperature", 0.0),
+    data.get("tools", None),
   )
 
 
@@ -287,7 +297,7 @@ class ChatGPTAPI:
     shard = build_base_shard(model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    prompt = build_prompt(tokenizer, messages)
+    prompt = build_prompt(tokenizer, messages, data.get("tools", None))
     tokens = tokenizer.encode(prompt)
     return web.json_response({
       "length": len(prompt),
@@ -326,7 +336,7 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-    prompt = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
       try: