Просмотр исходного кода

Merge pull request #18537 from OAburub/patch

fix: prevent cancellation scope corruption by exitting in LIFO and ha…
Tim Baek 3 месяцев назад
Родитель
Сommit
a4d0bd1073
2 измененных файлов с 29 добавлено и 22 удалено
  1. 5 3
      backend/open_webui/main.py
  2. 24 19
      backend/open_webui/utils/mcp/client.py

+ 5 - 3
backend/open_webui/main.py

@@ -1556,11 +1556,13 @@ async def chat_completion(
             log.info("Chat processing was cancelled")
             try:
                 event_emitter = get_event_emitter(metadata)
-                await event_emitter(
+                await asyncio.shield(event_emitter(
                     {"type": "chat:tasks:cancel"},
-                )
+                ))
             except Exception as e:
                 pass
+            finally:
+                raise # re-raise to ensure proper task cancellation handling
         except Exception as e:
             log.debug(f"Error processing chat payload: {e}")
             if metadata.get("chat_id") and metadata.get("message_id"):
@@ -1591,7 +1593,7 @@ async def chat_completion(
         finally:
             try:
                 if mcp_clients := metadata.get("mcp_clients"):
-                    for client in mcp_clients.values():
+                    for client in reversed(mcp_clients.values()):
                         await client.disconnect()
             except Exception as e:
                 log.debug(f"Error cleaning up: {e}")

+ 24 - 19
backend/open_webui/utils/mcp/client.py

@@ -2,35 +2,40 @@ import asyncio
 from typing import Optional
 from contextlib import AsyncExitStack
 
+import anyio
+
 from mcp import ClientSession
 from mcp.client.auth import OAuthClientProvider, TokenStorage
 from mcp.client.streamable_http import streamablehttp_client
 from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
 
-
 class MCPClient:
     def __init__(self):
         self.session: Optional[ClientSession] = None
-        self.exit_stack = AsyncExitStack()
+        self.exit_stack = None
 
     async def connect(self, url: str, headers: Optional[dict] = None):
-        try:
-            self._streams_context = streamablehttp_client(url, headers=headers)
-
-            transport = await self.exit_stack.enter_async_context(self._streams_context)
-            read_stream, write_stream, _ = transport
-
-            self._session_context = ClientSession(
-                read_stream, write_stream
-            )  # pylint: disable=W0201
-
-            self.session = await self.exit_stack.enter_async_context(
-                self._session_context
-            )
-            await self.session.initialize()
-        except Exception as e:
-            await self.disconnect()
-            raise e
+        async with AsyncExitStack() as exit_stack:
+            try:
+                self._streams_context = streamablehttp_client(url, headers=headers)
+
+                transport = await exit_stack.enter_async_context(self._streams_context)
+                read_stream, write_stream, _ = transport
+
+                self._session_context = ClientSession(
+                    read_stream, write_stream
+                )  # pylint: disable=W0201
+
+                self.session = await exit_stack.enter_async_context(
+                    self._session_context
+                )
+                with anyio.fail_after(10):
+                    await self.session.initialize()
+                self.exit_stack = exit_stack.pop_all()
+            except Exception as e:
+                await asyncio.shield(self.disconnect())
+                raise e
+                
 
     async def list_tool_specs(self) -> Optional[dict]:
         if not self.session: