Browse Source

feat: native tool calling support

Timothy Jaeryang Baek 5 months ago
parent
commit
314b674f32
2 changed files with 185 additions and 6 deletions
  1. 172 6
      backend/open_webui/utils/middleware.py
  2. 13 0
      backend/open_webui/utils/misc.py

+ 172 - 6
backend/open_webui/utils/middleware.py

@@ -57,6 +57,7 @@ from open_webui.utils.task import (
     tools_function_calling_generation_template,
 )
 from open_webui.utils.misc import (
+    deep_update,
     get_message_list,
     add_or_update_system_message,
     add_or_update_user_message,
@@ -1126,8 +1127,18 @@ async def process_chat_response(
                 for block in content_blocks:
                     if block["type"] == "text":
                         content = f"{content}{block['content'].strip()}\n"
-                    elif block["type"] == "tool":
-                        pass
+                    elif block["type"] == "tool_calls":
+                        attributes = block.get("attributes", {})
+
+                        block_content = block.get("content", [])
+                        results = block.get("results", [])
+
+                        if results:
+                            if not raw:
+                                content = f'{content}\n<details type="tool_calls" done="true" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n```json\n{block_content}\n```\n```json\n{results}\n```\n</details>\n'
+                        else:
+                            if not raw:
+                                content = f'{content}\n<details type="tool_calls" done="false">\n<summary>Tool Executing...</summary>\n```json\n{block_content}\n```\n</details>\n'
 
                     elif block["type"] == "reasoning":
                         reasoning_display_content = "\n".join(
@@ -1254,6 +1265,7 @@ async def process_chat_response(
                 metadata["chat_id"], metadata["message_id"]
             )
 
+            tool_calls = []
             content = message.get("content", "") if message else ""
             content_blocks = [
                 {
@@ -1293,6 +1305,8 @@ async def process_chat_response(
                     nonlocal content
                     nonlocal content_blocks
 
+                    response_tool_calls = []
+
                     async for line in response.body_iterator:
                         line = line.decode("utf-8") if isinstance(line, bytes) else line
                         data = line
@@ -1326,7 +1340,42 @@ async def process_chat_response(
                                 if not choices:
                                     continue
 
-                                value = choices[0].get("delta", {}).get("content")
+                                delta = choices[0].get("delta", {})
+                                delta_tool_calls = delta.get("tool_calls", None)
+
+                                if delta_tool_calls:
+                                    for delta_tool_call in delta_tool_calls:
+                                        tool_call_index = delta_tool_call.get("index")
+
+                                        if tool_call_index is not None:
+                                            if (
+                                                len(response_tool_calls)
+                                                <= tool_call_index
+                                            ):
+                                                response_tool_calls.append(
+                                                    delta_tool_call
+                                                )
+                                            else:
+                                                delta_name = delta_tool_call.get(
+                                                    "function", {}
+                                                ).get("name")
+                                                delta_arguments = delta_tool_call.get(
+                                                    "function", {}
+                                                ).get("arguments")
+
+                                                if delta_name:
+                                                    response_tool_calls[
+                                                        tool_call_index
+                                                    ]["function"]["name"] += delta_name
+
+                                                if delta_arguments:
+                                                    response_tool_calls[
+                                                        tool_call_index
+                                                    ]["function"][
+                                                        "arguments"
+                                                    ] += delta_arguments
+
+                                value = delta.get("content")
 
                                 if value:
                                     content = f"{content}{value}"
@@ -1398,6 +1447,29 @@ async def process_chat_response(
                         if not content_blocks[-1]["content"]:
                             content_blocks.pop()
 
+                    if response_tool_calls:
+                        tool_calls.append(response_tool_calls)
+
+                    if response.background:
+                        await response.background()
+
+                await stream_body_handler(response)
+
+                MAX_TOOL_CALL_RETRIES = 5
+                tool_call_retries = 0
+
+                while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
+                    tool_call_retries += 1
+
+                    response_tool_calls = tool_calls.pop(0)
+
+                    content_blocks.append(
+                        {
+                            "type": "tool_calls",
+                            "content": response_tool_calls,
+                        }
+                    )
+
                     await event_emitter(
                         {
                             "type": "chat:completion",
@@ -1407,10 +1479,103 @@ async def process_chat_response(
                         }
                     )
 
-                    if response.background:
-                        await response.background()
+                    tools = metadata.get("tools", {})
 
-                await stream_body_handler(response)
+                    results = []
+                    for tool_call in response_tool_calls:
+                        tool_call_id = tool_call.get("id", "")
+                        tool_name = tool_call.get("function", {}).get("name", "")
+
+                        tool_function_params = {}
+                        try:
+                            tool_function_params = json.loads(
+                                tool_call.get("function", {}).get("arguments", "{}")
+                            )
+                        except Exception as e:
+                            log.debug(e)
+
+                        tool_result = None
+
+                        if tool_name in tools:
+                            tool = tools[tool_name]
+                            spec = tool.get("spec", {})
+
+                            try:
+                                required_params = spec.get("parameters", {}).get(
+                                    "required", []
+                                )
+                                tool_function = tool["callable"]
+                                tool_function_params = {
+                                    k: v
+                                    for k, v in tool_function_params.items()
+                                    if k in required_params
+                                }
+                                tool_result = await tool_function(
+                                    **tool_function_params
+                                )
+                            except Exception as e:
+                                tool_result = str(e)
+
+                        results.append(
+                            {
+                                "tool_call_id": tool_call_id,
+                                "content": tool_result,
+                            }
+                        )
+
+                    content_blocks[-1]["results"] = results
+
+                    content_blocks.append(
+                        {
+                            "type": "text",
+                            "content": "",
+                        }
+                    )
+
+                    await event_emitter(
+                        {
+                            "type": "chat:completion",
+                            "data": {
+                                "content": serialize_content_blocks(content_blocks),
+                            },
+                        }
+                    )
+
+                    try:
+                        res = await generate_chat_completion(
+                            request,
+                            {
+                                "model": model_id,
+                                "stream": True,
+                                "messages": [
+                                    *form_data["messages"],
+                                    {
+                                        "role": "assistant",
+                                        "content": serialize_content_blocks(
+                                            content_blocks, raw=True
+                                        ),
+                                        "tool_calls": response_tool_calls,
+                                    },
+                                    *[
+                                        {
+                                            "role": "tool",
+                                            "tool_call_id": result["tool_call_id"],
+                                            "content": result["content"],
+                                        }
+                                        for result in results
+                                    ],
+                                ],
+                            },
+                            user,
+                        )
+
+                        if isinstance(res, StreamingResponse):
+                            await stream_body_handler(res)
+                        else:
+                            break
+                    except Exception as e:
+                        log.debug(e)
+                        break
 
                 if DETECT_CODE_INTERPRETER:
                     MAX_RETRIES = 5
@@ -1472,6 +1637,7 @@ async def process_chat_response(
                             output = str(e)
 
                         content_blocks[-1]["output"] = output
+
                         content_blocks.append(
                             {
                                 "type": "text",

+ 13 - 0
backend/open_webui/utils/misc.py

@@ -7,6 +7,18 @@ from pathlib import Path
 from typing import Callable, Optional
 
 
+import collections.abc
+
+
+def deep_update(d, u):
+    for k, v in u.items():
+        if isinstance(v, collections.abc.Mapping):
+            d[k] = deep_update(d.get(k, {}), v)
+        else:
+            d[k] = v
+    return d
+
+
 def get_message_list(messages, message_id):
     """
     Reconstructs a list of messages in order up to the specified message_id.
@@ -187,6 +199,7 @@ def openai_chat_chunk_message_template(
     template = openai_chat_message_template(model)
     template["object"] = "chat.completion.chunk"
 
+    template["choices"][0]["index"] = 0
     template["choices"][0]["delta"] = {}
 
     if content: