Просмотр исходного кода

Merge pull request #18884 from ShirasawaSama/feature/handle-large-stream-chunks

feat: handle large stream chunks responses to support Nano Banana [Test Needed]
Tim Baek 3 месяцев назад
Родитель
Сommit
27df461abd
3 измененных файлов с 82 добавлено и 2 удалено
  1. 13 0
      backend/open_webui/env.py
  2. 2 1
      backend/open_webui/routers/openai.py
  3. 67 1
      backend/open_webui/utils/misc.py

+ 13 - 0
backend/open_webui/env.py

@@ -569,6 +569,19 @@ else:
         CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
 
 
+CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get(
+    "CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", "10485760" # 10MB
+)
+
+if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "":
+    CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = 1024 * 1024 * 10
+else:
+    try:
+        CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE)
+    except Exception:
+        CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = 1024 * 1024 * 10
+
+
 ####################################
 # WEBSOCKET SUPPORT
 ####################################

+ 2 - 1
backend/open_webui/routers/openai.py

@@ -45,6 +45,7 @@ from open_webui.utils.payload import (
 )
 from open_webui.utils.misc import (
     convert_logit_bias_input_to_json,
+    handle_large_stream_chunks,
 )
 
 from open_webui.utils.auth import get_admin_user, get_verified_user
@@ -952,7 +953,7 @@ async def generate_chat_completion(
         if "text/event-stream" in r.headers.get("Content-Type", ""):
             streaming = True
             return StreamingResponse(
-                r.content,
+                handle_large_stream_chunks(r.content),
                 status_code=r.status,
                 headers=dict(r.headers),
                 background=BackgroundTask(

+ 67 - 1
backend/open_webui/utils/misc.py

@@ -8,10 +8,11 @@ from datetime import timedelta
 from pathlib import Path
 from typing import Callable, Optional
 import json
+import aiohttp
 
 
 import collections.abc
-from open_webui.env import SRC_LOG_LEVELS
+from open_webui.env import SRC_LOG_LEVELS, CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
@@ -539,3 +540,68 @@ def extract_urls(text: str) -> list[str]:
         r"(https?://[^\s]+)", re.IGNORECASE
     )  # Matches http and https URLs
     return url_pattern.findall(text)
+
+
+def handle_large_stream_chunks(stream: aiohttp.StreamReader, max_buffer_size: int = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE):
+    """
+    Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit.
+    When a single line exceeds max_buffer_size, returns an empty JSON string {} and skips subsequent data
+    until encountering normally sized data.
+
+    :param stream: The stream reader to handle.
+    :param max_buffer_size: The maximum buffer size in bytes, -1 means not handle large chunks, default is 10MB.
+    :return: An async generator that yields the stream data.
+    """
+
+    if max_buffer_size <= 0:
+        return stream
+
+    async def handle_stream_chunks():
+        buffer = b""
+        skip_mode = False
+
+        async for data, _ in stream.iter_chunks():
+            if not data:
+                continue
+
+            # In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
+            if skip_mode and len(buffer) > max_buffer_size:
+                buffer = b""
+
+            lines = (buffer + data).split(b"\n")
+
+            # Process complete lines (except the last possibly incomplete fragment)
+            for i in range(len(lines) - 1):
+                line = lines[i]
+
+                if skip_mode:
+                    # Skip mode: check if current line is small enough to exit skip mode
+                    if len(line) <= max_buffer_size:
+                        skip_mode = False
+                        yield line
+                    else:
+                        yield b"data: {}"
+                else:
+                    # Normal mode: check if line exceeds limit
+                    if len(line) > max_buffer_size:
+                        skip_mode = True
+                        yield b"data: {}"
+                        log.info(f"Skip mode triggered, line size: {len(line)}")
+                    else:
+                        yield line
+
+            # Save the last incomplete fragment
+            buffer = lines[-1]
+
+            # Check if buffer exceeds limit
+            if not skip_mode and len(buffer) > max_buffer_size:
+                skip_mode = True
+                log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
+                # Clear oversized buffer to prevent unlimited growth
+                buffer = b""
+
+        # Process remaining buffer data
+        if buffer and not skip_mode:
+            yield buffer
+    
+    return handle_stream_chunks()