|
@@ -87,6 +87,7 @@ from open_webui.utils.filter import (
|
|
)
|
|
)
|
|
from open_webui.utils.code_interpreter import execute_code_jupyter
|
|
from open_webui.utils.code_interpreter import execute_code_jupyter
|
|
from open_webui.utils.payload import apply_system_prompt_to_body
|
|
from open_webui.utils.payload import apply_system_prompt_to_body
|
|
|
|
+from open_webui.utils.mcp.client import MCPClient
|
|
|
|
|
|
|
|
|
|
from open_webui.config import (
|
|
from open_webui.config import (
|
|
@@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|
# Server side tools
|
|
# Server side tools
|
|
tool_ids = metadata.get("tool_ids", None)
|
|
tool_ids = metadata.get("tool_ids", None)
|
|
# Client side tools
|
|
# Client side tools
|
|
- tool_servers = metadata.get("tool_servers", None)
|
|
|
|
|
|
+ direct_tool_servers = metadata.get("tool_servers", None)
|
|
|
|
|
|
log.debug(f"{tool_ids=}")
|
|
log.debug(f"{tool_ids=}")
|
|
- log.debug(f"{tool_servers=}")
|
|
|
|
|
|
+ log.debug(f"{direct_tool_servers=}")
|
|
|
|
|
|
tools_dict = {}
|
|
tools_dict = {}
|
|
|
|
|
|
|
|
+ mcp_clients = []
|
|
|
|
+ mcp_tools_dict = {}
|
|
|
|
+
|
|
if tool_ids:
|
|
if tool_ids:
|
|
|
|
+ for tool_id in tool_ids:
|
|
|
|
+ if tool_id.startswith("server:mcp:"):
|
|
|
|
+ try:
|
|
|
|
+ server_id = tool_id[len("server:mcp:") :]
|
|
|
|
+
|
|
|
|
+ mcp_server_connection = None
|
|
|
|
+ for (
|
|
|
|
+ server_connection
|
|
|
|
+ ) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
|
|
|
+ if (
|
|
|
|
+ server_connection.get("type", "") == "mcp"
|
|
|
|
+ and server_connection.get("info", {}).get("id") == server_id
|
|
|
|
+ ):
|
|
|
|
+ mcp_server_connection = server_connection
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ if not mcp_server_connection:
|
|
|
|
+ log.error(f"MCP server with id {server_id} not found")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ auth_type = mcp_server_connection.get("auth_type", "")
|
|
|
|
+
|
|
|
|
+ headers = {}
|
|
|
|
+ if auth_type == "bearer":
|
|
|
|
+ headers["Authorization"] = (
|
|
|
|
+ f"Bearer {mcp_server_connection.get('key', '')}"
|
|
|
|
+ )
|
|
|
|
+ elif auth_type == "none":
|
|
|
|
+ # No authentication
|
|
|
|
+ pass
|
|
|
|
+ elif auth_type == "session":
|
|
|
|
+ headers["Authorization"] = (
|
|
|
|
+ f"Bearer {request.state.token.credentials}"
|
|
|
|
+ )
|
|
|
|
+ elif auth_type == "system_oauth":
|
|
|
|
+ oauth_token = extra_params.get("__oauth_token__", None)
|
|
|
|
+ if oauth_token:
|
|
|
|
+ headers["Authorization"] = (
|
|
|
|
+ f"Bearer {oauth_token.get('access_token', '')}"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ mcp_client = MCPClient()
|
|
|
|
+ await mcp_client.connect(
|
|
|
|
+ url=mcp_server_connection.get("url", ""),
|
|
|
|
+ headers=headers if headers else None,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ tool_specs = await mcp_client.list_tool_specs()
|
|
|
|
+ for tool_spec in tool_specs:
|
|
|
|
+
|
|
|
|
+ def make_tool_function(function_name):
|
|
|
|
+ async def tool_function(**kwargs):
|
|
|
|
+ print(
|
|
|
|
+ f"Calling MCP tool {function_name} with args {kwargs}"
|
|
|
|
+ )
|
|
|
|
+ return await mcp_client.call_tool(
|
|
|
|
+ function_name,
|
|
|
|
+ function_args=kwargs,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return tool_function
|
|
|
|
+
|
|
|
|
+ tool_function = make_tool_function(tool_spec["name"])
|
|
|
|
+
|
|
|
|
+ mcp_tools_dict[tool_spec["name"]] = {
|
|
|
|
+ "spec": tool_spec,
|
|
|
|
+ "callable": tool_function,
|
|
|
|
+ "type": "mcp",
|
|
|
|
+ "client": mcp_client,
|
|
|
|
+ "direct": False,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ mcp_clients.append(mcp_client)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.debug(e)
|
|
|
|
+ continue
|
|
|
|
+
|
|
tools_dict = await get_tools(
|
|
tools_dict = await get_tools(
|
|
request,
|
|
request,
|
|
tool_ids,
|
|
tool_ids,
|
|
@@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|
"__files__": metadata.get("files", []),
|
|
"__files__": metadata.get("files", []),
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
+ if mcp_tools_dict:
|
|
|
|
+ tools_dict = {**tools_dict, **mcp_tools_dict}
|
|
|
|
|
|
- if tool_servers:
|
|
|
|
- for tool_server in tool_servers:
|
|
|
|
|
|
+ if direct_tool_servers:
|
|
|
|
+ for tool_server in direct_tool_servers:
|
|
tool_specs = tool_server.pop("specs", [])
|
|
tool_specs = tool_server.pop("specs", [])
|
|
|
|
|
|
for tool in tool_specs:
|
|
for tool in tool_specs:
|
|
@@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|
"server": tool_server,
|
|
"server": tool_server,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if mcp_clients:
|
|
|
|
+ metadata["mcp_clients"] = mcp_clients
|
|
|
|
+
|
|
if tools_dict:
|
|
if tools_dict:
|
|
|
|
+ log.info(f"tools_dict: {tools_dict}")
|
|
if metadata.get("params", {}).get("function_calling") == "native":
|
|
if metadata.get("params", {}).get("function_calling") == "native":
|
|
# If the function calling is native, then call the tools function calling handler
|
|
# If the function calling is native, then call the tools function calling handler
|
|
metadata["tools"] = tools_dict
|
|
metadata["tools"] = tools_dict
|
|
@@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|
{"type": "function", "function": tool.get("spec", {})}
|
|
{"type": "function", "function": tool.get("spec", {})}
|
|
for tool in tools_dict.values()
|
|
for tool in tools_dict.values()
|
|
]
|
|
]
|
|
|
|
+
|
|
else:
|
|
else:
|
|
# If the function calling is not native, then call the tools function calling handler
|
|
# If the function calling is not native, then call the tools function calling handler
|
|
try:
|
|
try:
|
|
@@ -2330,6 +2418,8 @@ async def process_chat_response(
|
|
results = []
|
|
results = []
|
|
|
|
|
|
for tool_call in response_tool_calls:
|
|
for tool_call in response_tool_calls:
|
|
|
|
+
|
|
|
|
+ print("tool_call", tool_call)
|
|
tool_call_id = tool_call.get("id", "")
|
|
tool_call_id = tool_call.get("id", "")
|
|
tool_name = tool_call.get("function", {}).get("name", "")
|
|
tool_name = tool_call.get("function", {}).get("name", "")
|
|
tool_args = tool_call.get("function", {}).get("arguments", "{}")
|
|
tool_args = tool_call.get("function", {}).get("arguments", "{}")
|
|
@@ -2397,9 +2487,14 @@ async def process_chat_response(
|
|
|
|
|
|
else:
|
|
else:
|
|
tool_function = tool["callable"]
|
|
tool_function = tool["callable"]
|
|
|
|
+
|
|
|
|
+ print("tool_name", tool_name)
|
|
|
|
+ print("tool_function", tool_function)
|
|
|
|
+ print("tool_function_params", tool_function_params)
|
|
tool_result = await tool_function(
|
|
tool_result = await tool_function(
|
|
**tool_function_params
|
|
**tool_function_params
|
|
)
|
|
)
|
|
|
|
+ print("tool_result", tool_result)
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
tool_result = str(e)
|
|
tool_result = str(e)
|