Kaynağa Gözat

feat: experimental mcp support

Timothy Jaeryang Baek 2 hafta önce
ebeveyn
işleme
777e81f7a8

+ 8 - 0
backend/open_webui/main.py

@@ -1531,6 +1531,14 @@ async def chat_completion(
 
                 except:
                     pass
+        finally:
+            try:
+                if mcp_clients := metadata.get("mcp_clients"):
+                    for client in mcp_clients:
+                        await client.disconnect()
+            except Exception as e:
+                log.debug(f"Error cleaning up: {e}")
+                pass
 
     if (
         metadata.get("session_id")

+ 56 - 10
backend/open_webui/routers/configs.py

@@ -1,3 +1,4 @@
+from cmath import log
 from fastapi import APIRouter, Depends, Request, HTTPException
 from pydantic import BaseModel, ConfigDict
 
@@ -12,7 +13,7 @@ from open_webui.utils.tools import (
     get_tool_server_url,
     set_tool_servers,
 )
-
+from open_webui.utils.mcp.client import MCPClient
 
 router = APIRouter()
 
@@ -87,6 +88,7 @@ async def set_connections_config(
 class ToolServerConnection(BaseModel):
     url: str
     path: str
+    type: Optional[str] = "openapi"  # openapi, mcp
     auth_type: Optional[str]
     key: Optional[str]
     config: Optional[dict]
@@ -129,15 +131,59 @@ async def verify_tool_servers_config(
     Verify the connection to the tool server.
     """
     try:
-
-        token = None
-        if form_data.auth_type == "bearer":
-            token = form_data.key
-        elif form_data.auth_type == "session":
-            token = request.state.token.credentials
-
-        url = get_tool_server_url(form_data.url, form_data.path)
-        return await get_tool_server_data(token, url)
+        if form_data.type == "mcp":
+            try:
+                async with MCPClient() as client:
+                    auth = None
+                    headers = None
+
+                    token = None
+                    if form_data.auth_type == "bearer":
+                        token = form_data.key
+                    elif form_data.auth_type == "session":
+                        token = request.state.token.credentials
+                    elif form_data.auth_type == "system_oauth":
+                        try:
+                            if request.cookies.get("oauth_session_id", None):
+                                token = await request.app.state.oauth_manager.get_oauth_token(
+                                    user.id,
+                                    request.cookies.get("oauth_session_id", None),
+                                )
+                        except Exception as e:
+                            pass
+
+                    if token:
+                        headers = {"Authorization": f"Bearer {token}"}
+
+                    await client.connect(form_data.url, auth=auth, headers=headers)
+                    specs = await client.list_tool_specs()
+                    return {
+                        "status": True,
+                        "specs": specs,
+                    }
+            except Exception as e:
+                raise HTTPException(
+                    status_code=400,
+                    detail=f"Failed to create MCP client: {str(e)}",
+                )
+        else:  # openapi
+            token = None
+            if form_data.auth_type == "bearer":
+                token = form_data.key
+            elif form_data.auth_type == "session":
+                token = request.state.token.credentials
+            elif form_data.auth_type == "system_oauth":
+                try:
+                    if request.cookies.get("oauth_session_id", None):
+                        token = await request.app.state.oauth_manager.get_oauth_token(
+                            user.id,
+                            request.cookies.get("oauth_session_id", None),
+                        )
+                except Exception as e:
+                    pass
+
+            url = get_tool_server_url(form_data.url, form_data.path)
+            return await get_tool_server_data(token, url)
     except Exception as e:
         raise HTTPException(
             status_code=400,

+ 24 - 0
backend/open_webui/routers/tools.py

@@ -43,6 +43,7 @@ router = APIRouter()
 async def get_tools(request: Request, user=Depends(get_verified_user)):
     tools = Tools.get_tools()
 
+    # OpenAPI Tool Servers
     for server in await get_tool_servers(request):
         tools.append(
             ToolUserResponse(
@@ -68,6 +69,29 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
             )
         )
 
+    # MCP Tool Servers
+    for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+        if server.get("type", "openapi") == "mcp":
+            tools.append(
+                ToolUserResponse(
+                    **{
+                        "id": f"server:mcp:{server.get('info', {}).get('id')}",
+                        "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
+                        "name": server.get("info", {}).get("name", "MCP Tool Server"),
+                        "meta": {
+                            "description": server.get("info", {}).get(
+                                "description", ""
+                            ),
+                        },
+                        "access_control": server.get("config", {}).get(
+                            "access_control", None
+                        ),
+                        "updated_at": int(time.time()),
+                        "created_at": int(time.time()),
+                    }
+                )
+            )
+
     if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
         # Admin can see all tools
         return tools

+ 83 - 0
backend/open_webui/utils/mcp/client.py

@@ -0,0 +1,83 @@
+import asyncio
+from typing import Optional
+from contextlib import AsyncExitStack
+
+from mcp import ClientSession
+from mcp.client.auth import OAuthClientProvider, TokenStorage
+from mcp.client.streamable_http import streamablehttp_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+
+
+class MCPClient:
+    def __init__(self):
+        self.session: Optional[ClientSession] = None
+        self.exit_stack = AsyncExitStack()
+
+    async def connect(
+        self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
+    ):
+        self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
+        read_stream, write_stream, _ = (
+            await self._streams_context.__aenter__()
+        )  # pylint: disable=E1101
+
+        self._session_context = ClientSession(
+            read_stream, write_stream
+        )  # pylint: disable=W0201
+        self.session: ClientSession = (
+            await self._session_context.__aenter__()
+        )  # pylint: disable=C2801
+
+        await self.session.initialize()
+
+    async def list_tool_specs(self) -> Optional[dict]:
+        if not self.session:
+            raise RuntimeError("MCP client is not connected.")
+
+        result = await self.session.list_tools()
+        tools = result.tools
+
+        tool_specs = []
+        for tool in tools:
+            name = tool.name
+            description = tool.description
+
+            inputSchema = tool.inputSchema
+
+            # TODO: handle outputSchema if needed
+            outputSchema = getattr(tool, "outputSchema", None)
+
+            tool_specs.append(
+                {"name": name, "description": description, "parameters": inputSchema}
+            )
+
+        return tool_specs
+
+    async def call_tool(
+        self, function_name: str, function_args: dict
+    ) -> Optional[dict]:
+        if not self.session:
+            raise RuntimeError("MCP client is not connected.")
+
+        result = await self.session.call_tool(function_name, function_args)
+        return result.model_dump()
+
+    async def disconnect(self):
+        # Clean up and close the session
+        if self.session:
+            await self._session_context.__aexit__(
+                None, None, None
+            )  # pylint: disable=E1101
+        if self._streams_context:
+            await self._streams_context.__aexit__(
+                None, None, None
+            )  # pylint: disable=E1101
+        self.session = None
+
+    async def __aenter__(self):
+        await self.exit_stack.__aenter__()
+        return self
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        await self.exit_stack.__aexit__(exc_type, exc_value, traceback)
+        await self.disconnect()

+ 99 - 4
backend/open_webui/utils/middleware.py

@@ -87,6 +87,7 @@ from open_webui.utils.filter import (
 )
 from open_webui.utils.code_interpreter import execute_code_jupyter
 from open_webui.utils.payload import apply_system_prompt_to_body
+from open_webui.utils.mcp.client import MCPClient
 
 
 from open_webui.config import (
@@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     # Server side tools
     tool_ids = metadata.get("tool_ids", None)
     # Client side tools
-    tool_servers = metadata.get("tool_servers", None)
+    direct_tool_servers = metadata.get("tool_servers", None)
 
     log.debug(f"{tool_ids=}")
-    log.debug(f"{tool_servers=}")
+    log.debug(f"{direct_tool_servers=}")
 
     tools_dict = {}
 
+    mcp_clients = []
+    mcp_tools_dict = {}
+
     if tool_ids:
+        for tool_id in tool_ids:
+            if tool_id.startswith("server:mcp:"):
+                try:
+                    server_id = tool_id[len("server:mcp:") :]
+
+                    mcp_server_connection = None
+                    for (
+                        server_connection
+                    ) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+                        if (
+                            server_connection.get("type", "") == "mcp"
+                            and server_connection.get("info", {}).get("id") == server_id
+                        ):
+                            mcp_server_connection = server_connection
+                            break
+
+                    if not mcp_server_connection:
+                        log.error(f"MCP server with id {server_id} not found")
+                        continue
+
+                    auth_type = mcp_server_connection.get("auth_type", "")
+
+                    headers = {}
+                    if auth_type == "bearer":
+                        headers["Authorization"] = (
+                            f"Bearer {mcp_server_connection.get('key', '')}"
+                        )
+                    elif auth_type == "none":
+                        # No authentication
+                        pass
+                    elif auth_type == "session":
+                        headers["Authorization"] = (
+                            f"Bearer {request.state.token.credentials}"
+                        )
+                    elif auth_type == "system_oauth":
+                        oauth_token = extra_params.get("__oauth_token__", None)
+                        if oauth_token:
+                            headers["Authorization"] = (
+                                f"Bearer {oauth_token.get('access_token', '')}"
+                            )
+
+                    mcp_client = MCPClient()
+                    await mcp_client.connect(
+                        url=mcp_server_connection.get("url", ""),
+                        headers=headers if headers else None,
+                    )
+
+                    tool_specs = await mcp_client.list_tool_specs()
+                    for tool_spec in tool_specs:
+
+                        def make_tool_function(function_name):
+                            async def tool_function(**kwargs):
+                                print(
+                                    f"Calling MCP tool {function_name} with args {kwargs}"
+                                )
+                                return await mcp_client.call_tool(
+                                    function_name,
+                                    function_args=kwargs,
+                                )
+
+                            return tool_function
+
+                        tool_function = make_tool_function(tool_spec["name"])
+
+                        mcp_tools_dict[tool_spec["name"]] = {
+                            "spec": tool_spec,
+                            "callable": tool_function,
+                            "type": "mcp",
+                            "client": mcp_client,
+                            "direct": False,
+                        }
+
+                    mcp_clients.append(mcp_client)
+                except Exception as e:
+                    log.debug(e)
+                    continue
+
         tools_dict = await get_tools(
             request,
             tool_ids,
@@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                 "__files__": metadata.get("files", []),
             },
         )
+        if mcp_tools_dict:
+            tools_dict = {**tools_dict, **mcp_tools_dict}
 
-    if tool_servers:
-        for tool_server in tool_servers:
+    if direct_tool_servers:
+        for tool_server in direct_tool_servers:
             tool_specs = tool_server.pop("specs", [])
 
             for tool in tool_specs:
@@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                     "server": tool_server,
                 }
 
+    if mcp_clients:
+        metadata["mcp_clients"] = mcp_clients
+
     if tools_dict:
+        log.info(f"tools_dict: {tools_dict}")
         if metadata.get("params", {}).get("function_calling") == "native":
             # If the function calling is native, then call the tools function calling handler
             metadata["tools"] = tools_dict
@@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                 {"type": "function", "function": tool.get("spec", {})}
                 for tool in tools_dict.values()
             ]
+
         else:
             # If the function calling is not native, then call the tools function calling handler
             try:
@@ -2330,6 +2418,8 @@ async def process_chat_response(
                     results = []
 
                     for tool_call in response_tool_calls:
+
+                        print("tool_call", tool_call)
                         tool_call_id = tool_call.get("id", "")
                         tool_name = tool_call.get("function", {}).get("name", "")
                         tool_args = tool_call.get("function", {}).get("arguments", "{}")
@@ -2397,9 +2487,14 @@ async def process_chat_response(
 
                                 else:
                                     tool_function = tool["callable"]
+
+                                    print("tool_name", tool_name)
+                                    print("tool_function", tool_function)
+                                    print("tool_function_params", tool_function_params)
                                     tool_result = await tool_function(
                                         **tool_function_params
                                     )
+                                    print("tool_result", tool_result)
 
                             except Exception as e:
                                 tool_result = str(e)

+ 101 - 74
backend/open_webui/utils/tools.py

@@ -96,94 +96,118 @@ async def get_tools(
     for tool_id in tool_ids:
         tool = Tools.get_tool_by_id(tool_id)
         if tool is None:
+
             if tool_id.startswith("server:"):
-                server_id = tool_id.split(":")[1]
+                splits = tool_id.split(":")
+
+                if len(splits) == 2:
+                    type = "openapi"
+                    server_id = splits[1]
+                elif len(splits) == 3:
+                    type = splits[1]
+                    server_id = splits[2]
+
+                server_id_splits = server_id.split("|")
+                if len(server_id_splits) == 2:
+                    server_id = server_id_splits[0]
+                    function_names = server_id_splits[1].split(",")
+
+                if type == "openapi":
+
+                    tool_server_data = None
+                    for server in await get_tool_servers(request):
+                        if server["id"] == server_id:
+                            tool_server_data = server
+                            break
+
+                    if tool_server_data is None:
+                        log.warning(f"Tool server data not found for {server_id}")
+                        continue
+
+                    tool_server_idx = tool_server_data.get("idx", 0)
+                    tool_server_connection = (
+                        request.app.state.config.TOOL_SERVER_CONNECTIONS[
+                            tool_server_idx
+                        ]
+                    )
 
-                tool_server_data = None
-                for server in await get_tool_servers(request):
-                    if server["id"] == server_id:
-                        tool_server_data = server
-                        break
+                    specs = tool_server_data.get("specs", [])
+                    for spec in specs:
+                        function_name = spec["name"]
 
-                if tool_server_data is None:
-                    log.warning(f"Tool server data not found for {server_id}")
-                    continue
+                        auth_type = tool_server_connection.get("auth_type", "bearer")
 
-                tool_server_idx = tool_server_data.get("idx", 0)
-                tool_server_connection = (
-                    request.app.state.config.TOOL_SERVER_CONNECTIONS[tool_server_idx]
-                )
+                        cookies = {}
+                        headers = {}
 
-                specs = tool_server_data.get("specs", [])
-                for spec in specs:
-                    function_name = spec["name"]
+                        if auth_type == "bearer":
+                            headers["Authorization"] = (
+                                f"Bearer {tool_server_connection.get('key', '')}"
+                            )
+                        elif auth_type == "none":
+                            # No authentication
+                            pass
+                        elif auth_type == "session":
+                            cookies = request.cookies
+                            headers["Authorization"] = (
+                                f"Bearer {request.state.token.credentials}"
+                            )
+                        elif auth_type == "system_oauth":
+                            cookies = request.cookies
+                            oauth_token = extra_params.get("__oauth_token__", None)
+                            if oauth_token:
+                                headers["Authorization"] = (
+                                    f"Bearer {oauth_token.get('access_token', '')}"
+                                )
 
-                    auth_type = tool_server_connection.get("auth_type", "bearer")
+                        headers["Content-Type"] = "application/json"
+
+                        def make_tool_function(
+                            function_name, tool_server_data, headers
+                        ):
+                            async def tool_function(**kwargs):
+                                return await execute_tool_server(
+                                    url=tool_server_data["url"],
+                                    headers=headers,
+                                    cookies=cookies,
+                                    name=function_name,
+                                    params=kwargs,
+                                    server_data=tool_server_data,
+                                )
 
-                    cookies = {}
-                    headers = {}
+                            return tool_function
 
-                    if auth_type == "bearer":
-                        headers["Authorization"] = (
-                            f"Bearer {tool_server_connection.get('key', '')}"
+                        tool_function = make_tool_function(
+                            function_name, tool_server_data, headers
                         )
-                    elif auth_type == "none":
-                        # No authentication
-                        pass
-                    elif auth_type == "session":
-                        cookies = request.cookies
-                        headers["Authorization"] = (
-                            f"Bearer {request.state.token.credentials}"
+
+                        callable = get_async_tool_function_and_apply_extra_params(
+                            tool_function,
+                            {},
                         )
-                    elif auth_type == "system_oauth":
-                        cookies = request.cookies
-                        oauth_token = extra_params.get("__oauth_token__", None)
-                        if oauth_token:
-                            headers["Authorization"] = (
-                                f"Bearer {oauth_token.get('access_token', '')}"
-                            )
 
-                    headers["Content-Type"] = "application/json"
-
-                    def make_tool_function(function_name, tool_server_data, headers):
-                        async def tool_function(**kwargs):
-                            return await execute_tool_server(
-                                url=tool_server_data["url"],
-                                headers=headers,
-                                cookies=cookies,
-                                name=function_name,
-                                params=kwargs,
-                                server_data=tool_server_data,
+                        tool_dict = {
+                            "tool_id": tool_id,
+                            "callable": callable,
+                            "spec": spec,
+                            # Misc info
+                            "type": "external",
+                        }
+
+                        # Handle function name collisions
+                        while function_name in tools_dict:
+                            log.warning(
+                                f"Tool {function_name} already exists in another tools!"
                             )
+                            # Prepend server ID to function name
+                            function_name = f"{server_id}_{function_name}"
 
-                        return tool_function
-
-                    tool_function = make_tool_function(
-                        function_name, tool_server_data, headers
-                    )
-
-                    callable = get_async_tool_function_and_apply_extra_params(
-                        tool_function,
-                        {},
-                    )
+                        tools_dict[function_name] = tool_dict
 
-                    tool_dict = {
-                        "tool_id": tool_id,
-                        "callable": callable,
-                        "spec": spec,
-                        # Misc info
-                        "type": "external",
-                    }
-
-                    # Handle function name collisions
-                    while function_name in tools_dict:
-                        log.warning(
-                            f"Tool {function_name} already exists in another tools!"
-                        )
-                        # Prepend server ID to function name
-                        function_name = f"{server_id}_{function_name}"
+                else:
+                    log.warning(f"Unsupported tool server type: {type}")
+                    continue
 
-                    tools_dict[function_name] = tool_dict
             else:
                 continue
         else:
@@ -579,7 +603,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
     # Prepare list of enabled servers along with their original index
     server_entries = []
     for idx, server in enumerate(servers):
-        if server.get("config", {}).get("enable"):
+        if (
+            server.get("config", {}).get("enable")
+            and server.get("type", "openapi") == "openapi"
+        ):
             # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
             openapi_path = server.get("path", "openapi.json")
             full_url = get_tool_server_url(server.get("url"), openapi_path)

+ 32 - 3
src/lib/components/AddServerModal.svelte → src/lib/components/AddToolServerModal.svelte

@@ -100,6 +100,11 @@
 
 		// remove trailing slash from url
 		url = url.replace(/\/$/, '');
+		if (id.includes(':') || id.includes('|')) {
+			toast.error($i18n.t('ID cannot contain ":" or "|" characters'));
+			loading = false;
+			return;
+		}
 
 		const connection = {
 			url,
@@ -214,6 +219,7 @@
 												{$i18n.t('OpenAPI')}
 											{:else if type === 'mcp'}
 												{$i18n.t('MCP')}
+												<span class="text-gray-500">{$i18n.t('Streamable HTTP')}</span>
 											{/if}
 										</button>
 									</div>
@@ -221,6 +227,25 @@
 							</div>
 						{/if}
 
+						{#if type === 'mcp'}
+							<div
+								class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-2xl text-xs px-4 py-3 mb-2"
+							>
+								<span class="font-medium">
+									{$i18n.t('Warning')}:
+								</span>
+								{$i18n.t(
+									'MCP support is experimental and its specification changes often, which can lead to incompatibilities. OpenAPI specification support is directly maintained by the Open WebUI team, making it the more reliable option for compatibility.'
+								)}
+
+								<a
+									class="font-medium underline"
+									href="https://docs.openwebui.com/features/mcp"
+									target="_blank">{$i18n.t('Read more →')}</a
+								>
+							</div>
+						{/if}
+
 						<div class="flex gap-2">
 							<div class="flex flex-col w-full">
 								<div class="flex justify-between mb-0.5">
@@ -372,9 +397,12 @@
 										for="enter-id"
 										class={`mb-0.5 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
 										>{$i18n.t('ID')}
-										<span class="text-xs text-gray-200 dark:text-gray-800 ml-0.5"
-											>{$i18n.t('Optional')}</span
-										>
+
+										{#if type !== 'mcp'}
+											<span class="text-xs text-gray-200 dark:text-gray-800 ml-0.5"
+												>{$i18n.t('Optional')}</span
+											>
+										{/if}
 									</label>
 
 									<div class="flex-1">
@@ -385,6 +413,7 @@
 											bind:value={id}
 											placeholder={$i18n.t('Enter ID')}
 											autocomplete="off"
+											required={type === 'mcp'}
 										/>
 									</div>
 								</div>

+ 2 - 2
src/lib/components/admin/Settings/Tools.svelte

@@ -14,7 +14,7 @@
 	import Plus from '$lib/components/icons/Plus.svelte';
 	import Connection from '$lib/components/chat/Settings/Tools/Connection.svelte';
 
-	import AddServerModal from '$lib/components/AddServerModal.svelte';
+	import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
 	import { getToolServerConnections, setToolServerConnections } from '$lib/apis/configs';
 
 	export let saveSettings: Function;
@@ -47,7 +47,7 @@
 	});
 </script>
 
-<AddServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} />
+<AddToolServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} />
 
 <form
 	class="flex flex-col h-full justify-between text-sm"

+ 2 - 2
src/lib/components/chat/Settings/Tools.svelte

@@ -14,7 +14,7 @@
 	import Plus from '$lib/components/icons/Plus.svelte';
 	import Connection from './Tools/Connection.svelte';
 
-	import AddServerModal from '$lib/components/AddServerModal.svelte';
+	import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
 
 	export let saveSettings: Function;
 
@@ -52,7 +52,7 @@
 	});
 </script>
 
-<AddServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} direct />
+<AddToolServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} direct />
 
 <form
 	id="tab-tools"

+ 2 - 2
src/lib/components/chat/Settings/Tools/Connection.svelte

@@ -6,7 +6,7 @@
 	import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
 	import Cog6 from '$lib/components/icons/Cog6.svelte';
 	import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
-	import AddServerModal from '$lib/components/AddServerModal.svelte';
+	import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
 
 	export let onDelete = () => {};
 	export let onSubmit = () => {};
@@ -18,7 +18,7 @@
 	let showDeleteConfirmDialog = false;
 </script>
 
-<AddServerModal
+<AddToolServerModal
 	edit
 	{direct}
 	bind:show={showConfigModal}