Explorar o código

refac: audio lang fallback logic

Timothy Jaeryang Baek hai 1 mes
pai
achega
f23eb2a31c
Modificáronse 1 ficheiros con 45 adicións e 31 borrados
  1. 45 31
      backend/open_webui/routers/audio.py

+ 45 - 31
backend/open_webui/routers/audio.py

@@ -550,6 +550,11 @@ def transcription_handler(request, file_path, metadata):
 
     metadata = metadata or {}
 
+    languages = [
+        metadata.get("language", None) if WHISPER_LANGUAGE == "" else WHISPER_LANGUAGE,
+        None,  # Always fallback to None in case transcription fails
+    ]
+
     if request.app.state.config.STT_ENGINE == "":
         if request.app.state.faster_whisper_model is None:
             request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -561,11 +566,7 @@ def transcription_handler(request, file_path, metadata):
             file_path,
             beam_size=5,
             vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
-            language=(
-                metadata.get("language", None)
-                if WHISPER_LANGUAGE == ""
-                else WHISPER_LANGUAGE
-            ),
+            language=languages[0],
         )
         log.info(
             "Detected language '%s' with probability %f"
@@ -585,21 +586,26 @@ def transcription_handler(request, file_path, metadata):
     elif request.app.state.config.STT_ENGINE == "openai":
         r = None
         try:
-            r = requests.post(
-                url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
-                headers={
-                    "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
-                },
-                files={"file": (filename, open(file_path, "rb"))},
-                data={
+            for language in languages:
+                payload = {
                     "model": request.app.state.config.STT_MODEL,
-                    **(
-                        {"language": metadata.get("language")}
-                        if metadata.get("language")
-                        else {}
-                    ),
-                },
-            )
+                }
+
+                if language:
+                    payload["language"] = language
+
+                r = requests.post(
+                    url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                    headers={
+                        "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
+                    },
+                    files={"file": (filename, open(file_path, "rb"))},
+                    data=payload,
+                )
+
+                if r.status_code == 200:
+                    # Successful transcription
+                    break
 
             r.raise_for_status()
             data = r.json()
@@ -641,18 +647,26 @@ def transcription_handler(request, file_path, metadata):
                 "Content-Type": mime,
             }
 
-            # Add model if specified
-            params = {}
-            if request.app.state.config.STT_MODEL:
-                params["model"] = request.app.state.config.STT_MODEL
-
-            # Make request to Deepgram API
-            r = requests.post(
-                "https://api.deepgram.com/v1/listen?smart_format=true",
-                headers=headers,
-                params=params,
-                data=file_data,
-            )
+            for language in languages:
+                params = {}
+                if request.app.state.config.STT_MODEL:
+                    params["model"] = request.app.state.config.STT_MODEL
+
+                if language:
+                    params["language"] = language
+
+                # Make request to Deepgram API
+                r = requests.post(
+                    "https://api.deepgram.com/v1/listen?smart_format=true",
+                    headers=headers,
+                    params=params,
+                    data=file_data,
+                )
+
+                if r.status_code == 200:
+                    # Successful transcription
+                    break
+
             r.raise_for_status()
             response_data = r.json()