Timothy Jaeryang Baek hai 1 semana
pai
achega
742e2ff193
Modificáronse 2 ficheiros con 14 adicións e 11 borrados
  1. 1 1
      backend/open_webui/main.py
  2. 13 10
      backend/open_webui/utils/middleware.py

+ 1 - 1
backend/open_webui/main.py

@@ -1552,7 +1552,7 @@ async def chat_completion(
         finally:
             try:
                 if mcp_clients := metadata.get("mcp_clients"):
-                    for client in mcp_clients:
+                    for client in mcp_clients.values():
                         await client.disconnect()
             except Exception as e:
                 log.debug(f"Error cleaning up: {e}")

+ 13 - 10
backend/open_webui/utils/middleware.py

@@ -1096,7 +1096,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
 
     tools_dict = {}
 
-    mcp_clients = []
+    mcp_clients = {}
     mcp_tools_dict = {}
 
     if tool_ids:
@@ -1157,25 +1157,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                             log.error(f"Error getting OAuth token: {e}")
                             oauth_token = None
 
-                    mcp_client = MCPClient()
-                    await mcp_client.connect(
+                    mcp_clients[server_id] = MCPClient()
+                    await mcp_clients[server_id].connect(
                         url=mcp_server_connection.get("url", ""),
                         headers=headers if headers else None,
                     )
 
-                    tool_specs = await mcp_client.list_tool_specs()
+                    tool_specs = await mcp_clients[server_id].list_tool_specs()
                     for tool_spec in tool_specs:
 
-                        def make_tool_function(function_name):
+                        def make_tool_function(client, function_name):
                             async def tool_function(**kwargs):
-                                return await mcp_client.call_tool(
+                                print(kwargs)
+                                print(client)
+                                print(await client.list_tool_specs())
+                                return await client.call_tool(
                                     function_name,
                                     function_args=kwargs,
                                 )
 
                             return tool_function
 
-                        tool_function = make_tool_function(tool_spec["name"])
+                        tool_function = make_tool_function(
+                            mcp_clients[server_id], tool_spec["name"]
+                        )
 
                         mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
                             "spec": {
@@ -1184,11 +1189,9 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                             },
                             "callable": tool_function,
                             "type": "mcp",
-                            "client": mcp_client,
+                            "client": mcp_clients[server_id],
                             "direct": False,
                         }
-
-                    mcp_clients.append(mcp_client)
                 except Exception as e:
                     log.debug(e)
                     continue