Pārlūkot izejas kodu

Merge pull request #571 from exo-explore/function_calling

add chatgpt-api-compatible tools for function calling
Alex Cheema 7 mēneši atpakaļ
vecāks
revīzija
fdc3b5ac02
4 mainītis faili ar 141 papildinājumiem un 17 dzēšanām
  1. 111 0
      examples/function_calling.py
  2. 23 13
      exo/api/chatgpt_api.py
  3. 1 1
      exo/inference/tokenizers.py
  4. 6 3
      exo/models.py

+ 111 - 0
examples/function_calling.py

@@ -0,0 +1,111 @@
+import json
+import re
+import requests
+
+def get_current_weather(location: str, unit: str = "celsius"):
+  """Mock weather data function"""
+  # Hardcoded response for demo purposes
+  return {
+    "location": location,
+    "temperature": 22 if unit == "celsius" else 72,
+    "unit": unit,
+    "forecast": "Sunny with light clouds"
+  }
+
+def try_parse_tool_calls(content: str):
+  """Try parse the tool calls."""
+  tool_calls = []
+  offset = 0
+  for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
+    if i == 0:
+      offset = m.start()
+    try:
+      func = json.loads(m.group(1))
+      tool_calls.append({"type": "function", "function": func})
+      if isinstance(func["arguments"], str):
+        func["arguments"] = json.loads(func["arguments"])
+    except json.JSONDecodeError as e:
+      print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
+      pass
+  if tool_calls:
+    if offset > 0 and content[:offset].strip():
+      c = content[:offset]
+    else:
+      c = ""
+    return {"role": "assistant", "content": c, "tool_calls": tool_calls}
+  return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
+
+def chat_completion(messages):
+  """Send chat completion request to local server"""
+  response = requests.post(
+    "http://localhost:52415/v1/chat/completions",
+    json={
+      "model": "qwen-2.5-1.5b",
+      "messages": messages,
+      "tools": [{
+        "type": "function",
+        "function": {
+          "name": "get_current_weather",
+          "description": "Get the current weather in a given location",
+          "parameters": {
+            "type": "object",
+            "properties": {
+              "location": {
+                "type": "string",
+                "description": "The city and state, e.g. San Francisco, CA"
+              },
+              "unit": {
+                "type": "string",
+                "enum": ["celsius", "fahrenheit"]
+              }
+            },
+            "required": ["location"]
+          }
+        }
+      }],
+      "tool_choice": "auto"
+    }
+  )
+  return response.json()
+
+def main():
+  # Initial conversation
+  messages = [{
+    "role": "user",
+    "content": "Hi there, what's the weather in Boston?"
+  }]
+  
+  # Get initial response
+  response = chat_completion(messages)
+  print(f"First response: {response}")
+  assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
+  messages.append(assistant_message)
+  
+  # If there are tool calls, execute them and continue conversation
+  if "tool_calls" in assistant_message:
+    for tool_call in assistant_message["tool_calls"]:
+      if tool_call["function"]["name"] == "get_current_weather":
+        args = tool_call["function"]["arguments"]
+        weather_data = get_current_weather(**args)
+        
+        # Add tool response to messages
+        messages.append({
+          "role": "tool",
+          "content": json.dumps(weather_data),
+          "name": tool_call["function"]["name"]
+        })
+    
+    # Get final response with weather data
+    response = chat_completion(messages)
+    print(f"Final response: {response}")
+    messages.append({
+      "role": "assistant",
+      "content": response["choices"][0]["message"]["content"]
+    })
+  
+  # Print full conversation
+  for msg in messages:
+    print(f"\n{msg['role'].upper()}: {msg['content']}")
+
+if __name__ == "__main__":
+  main()

+ 23 - 13
exo/api/chatgpt_api.py

@@ -5,7 +5,7 @@ import json
 import os
 from pathlib import Path
 from transformers import AutoTokenizer
-from typing import List, Literal, Union, Dict
+from typing import List, Literal, Union, Dict, Optional
 from aiohttp import web
 import aiohttp_cors
 import traceback
@@ -23,23 +23,28 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
 
 class Message:
-  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
     self.role = role
     self.content = content
+    self.tools = tools
 
   def to_dict(self):
-    return {"role": self.role, "content": self.content}
+    data = {"role": self.role, "content": self.content}
+    if self.tools:
+      data["tools"] = self.tools
+    return data
 
 
 
 class ChatCompletionRequest:
-  def __init__(self, model: str, messages: List[Message], temperature: float):
+  def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
     self.model = model
     self.messages = messages
     self.temperature = temperature
+    self.tools = tools
 
   def to_dict(self):
-    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
+    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
 
 
 def generate_completion(
@@ -119,20 +124,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
   return remapped_messages
 
 
-def build_prompt(tokenizer, _messages: List[Message]):
+def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
   messages = remap_messages(_messages)
-  prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  for message in messages:
-    if not isinstance(message.content, list):
-      continue
+  chat_template_args = {
+    "conversation": [m.to_dict() for m in messages],
+    "tokenize": False,
+    "add_generation_prompt": True
+  }
+  if tools: chat_template_args["tools"] = tools
 
+  prompt = tokenizer.apply_chat_template(**chat_template_args)
+  print(f"!!! Prompt: {prompt}")
   return prompt
 
 
 def parse_message(data: dict):
   if "role" not in data or "content" not in data:
     raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
-  return Message(data["role"], data["content"])
+  return Message(data["role"], data["content"], data.get("tools"))
 
 
 def parse_chat_request(data: dict, default_model: str):
@@ -140,6 +149,7 @@ def parse_chat_request(data: dict, default_model: str):
     data.get("model", default_model),
     [parse_message(msg) for msg in data["messages"]],
     data.get("temperature", 0.0),
+    data.get("tools", None),
   )
 
 
@@ -287,7 +297,7 @@ class ChatGPTAPI:
     shard = build_base_shard(model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    prompt = build_prompt(tokenizer, messages)
+    prompt = build_prompt(tokenizer, messages, data.get("tools", None))
     tokens = tokenizer.encode(prompt)
     return web.json_response({
       "length": len(prompt),
@@ -326,7 +336,7 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-    prompt = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
       try:

+ 1 - 1
exo/inference/tokenizers.py

@@ -14,7 +14,7 @@ class DummyTokenizer:
     self.eos_token_id = 69
     self.vocab_size = 1000
 
-  def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
+  def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
     return "dummy_tokenized_prompt"
 
   def encode(self, text):

+ 6 - 3
exo/models.py

@@ -136,14 +136,17 @@ pretty_name = {
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-3b": "Qwen 2.5 3B",
   "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
-  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
-  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
-  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
   "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
   "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
   "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-32b": "Qwen 2.5 32B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
   "qwen-2.5-72b": "Qwen 2.5 72B",
   "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
   "llama-3-8b": "Llama 3 8B",