Timothy Jaeryang Baek há 2 semanas atrás
pai
commit
f1bbf3a91e

+ 3 - 3
backend/open_webui/env.py

@@ -547,16 +547,16 @@ else:
 
 
 CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get(
-    "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "10"
+    "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30"
 )
 
 if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "":
-    CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 10
+    CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
 else:
     try:
         CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = int(CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES)
     except Exception:
-        CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 10
+        CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
 
 
 ####################################

+ 76 - 0
backend/open_webui/utils/files.py

@@ -3,6 +3,20 @@ from open_webui.routers.images import (
     upload_image,
 )
 
+from fastapi import (
+    APIRouter,
+    Depends,
+    HTTPException,
+    Request,
+    UploadFile,
+)
+
+from open_webui.routers.files import upload_file_handler
+
+import mimetypes
+import base64
+import io
+
 
 def get_image_url_from_base64(request, base64_image_string, metadata, user):
     if "data:image/png;base64" in base64_image_string:
@@ -19,3 +33,65 @@ def get_image_url_from_base64(request, base64_image_string, metadata, user):
             )
         return image_url
     return None
+
+
+def load_b64_audio_data(b64_str):
+    try:
+        if "," in b64_str:
+            header, b64_data = b64_str.split(",", 1)
+        else:
+            b64_data = b64_str
+            header = "data:audio/wav;base64"
+        audio_data = base64.b64decode(b64_data)
+        content_type = (
+            header.split(";")[0].split(":")[1] if ";" in header else "audio/wav"
+        )
+        return audio_data, content_type
+    except Exception as e:
+        print(f"Error decoding base64 audio data: {e}")
+        return None, None
+
+
+def upload_audio(request, audio_data, content_type, metadata, user):
+    audio_format = mimetypes.guess_extension(content_type)
+    file = UploadFile(
+        file=io.BytesIO(audio_data),
+        filename=f"generated-{audio_format}",  # will be converted to a unique ID on upload_file
+        headers={
+            "content-type": content_type,
+        },
+    )
+    file_item = upload_file_handler(
+        request,
+        file=file,
+        metadata=metadata,
+        process=False,
+        user=user,
+    )
+    url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
+    return url
+
+
+def get_audio_url_from_base64(request, base64_audio_string, metadata, user):
+    if "data:audio/wav;base64" in base64_audio_string:
+        audio_url = ""
+        # Extract base64 audio data from the line
+        audio_data, content_type = load_b64_audio_data(base64_audio_string)
+        if audio_data is not None:
+            audio_url = upload_audio(
+                request,
+                audio_data,
+                content_type,
+                metadata,
+                user,
+            )
+        return audio_url
+    return None
+
+
+def get_file_url_from_base64(request, base64_file_string, metadata, user):
+    if "data:image/png;base64" in base64_file_string:
+        return get_image_url_from_base64(request, base64_file_string, metadata, user)
+    elif "data:audio/wav;base64" in base64_file_string:
+        return get_audio_url_from_base64(request, base64_file_string, metadata, user)
+    return None

+ 35 - 29
backend/open_webui/utils/middleware.py

@@ -53,7 +53,11 @@ from open_webui.routers.pipelines import (
 from open_webui.routers.memories import query_memory, QueryMemoryForm
 
 from open_webui.utils.webhook import post_webhook
-from open_webui.utils.files import get_image_url_from_base64
+from open_webui.utils.files import (
+    get_audio_url_from_base64,
+    get_file_url_from_base64,
+    get_image_url_from_base64,
+)
 
 
 from open_webui.models.users import UserModel
@@ -2573,34 +2577,36 @@ async def process_chat_response(
                                     tool_result.remove(item)
 
                                 if tool.get("type") == "mcp":
-                                    if (
-                                        isinstance(item, dict)
-                                        and item.get("type") == "image"
-                                    ):
-                                        image_url = get_image_url_from_base64(
-                                            request,
-                                            f"data:{item.get('mimeType', 'image/png')};base64,{item.get('data', '')}",
-                                            {
-                                                "chat_id": metadata.get(
-                                                    "chat_id", None
-                                                ),
-                                                "message_id": metadata.get(
-                                                    "message_id", None
-                                                ),
-                                                "session_id": metadata.get(
-                                                    "session_id", None
-                                                ),
-                                            },
-                                            user,
-                                        )
+                                    if isinstance(item, dict):
+                                        if (
+                                            item.get("type") == "image"
+                                            or item.get("type") == "audio"
+                                        ):
+                                            file_url = get_file_url_from_base64(
+                                                request,
+                                                f"data:{item.get('mimeType')};base64,{item.get('data', '')}",
+                                                {
+                                                    "chat_id": metadata.get(
+                                                        "chat_id", None
+                                                    ),
+                                                    "message_id": metadata.get(
+                                                        "message_id", None
+                                                    ),
+                                                    "session_id": metadata.get(
+                                                        "session_id", None
+                                                    ),
+                                                    "result": item,
+                                                },
+                                                user,
+                                            )
 
-                                        tool_result_files.append(
-                                            {
-                                                "type": "image",
-                                                "url": image_url,
-                                            }
-                                        )
-                                        tool_result.remove(item)
+                                            tool_result_files.append(
+                                                {
+                                                    "type": item.get("type", "data"),
+                                                    "url": file_url,
+                                                }
+                                            )
+                                            tool_result.remove(item)
 
                         if tool_result_files:
                             if not isinstance(tool_result, list):
@@ -2612,7 +2618,7 @@ async def process_chat_response(
                                 tool_result.append(
                                     {
                                         "type": file.get("type", "data"),
-                                        "content": "Displayed",
+                                        "content": "Result is being displayed as a file.",
                                     }
                                 )