Browse Source

refac: audio upload handling

Timothy Jaeryang Baek 1 month ago
parent
commit
73e64fe7fb
2 changed files with 41 additions and 46 deletions
  1. 38 36
      backend/open_webui/routers/audio.py
  2. 3 10
      backend/open_webui/routers/files.py

+ 38 - 36
backend/open_webui/routers/audio.py

@@ -73,33 +73,50 @@ from pydub import AudioSegment
 from pydub.utils import mediainfo
 
 
-def get_audio_convert_format(file_path):
-    """Check if the given file needs to be converted to a different format."""
+def is_audio_conversion_required(file_path):
+    """
+    Check if the given audio file needs conversion to mp3.
+    """
+    SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
+
     if not os.path.isfile(file_path):
         log.error(f"File not found: {file_path}")
         return False
 
     try:
         info = mediainfo(file_path)
+        codec_name = info.get("codec_name", "").lower()
+        codec_type = info.get("codec_type", "").lower()
+        codec_tag_string = info.get("codec_tag_string", "").lower()
+
+        if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
+            # File is AAC/mp4a audio, recommend mp3 conversion
+            return True
 
+        # If the codec name or file extension is in the supported formats
         if (
-            info.get("codec_name") == "aac"
-            and info.get("codec_type") == "audio"
-            and info.get("codec_tag_string") == "mp4a"
+            codec_name in SUPPORTED_FORMATS
+            or os.path.splitext(file_path)[1][1:].lower() in SUPPORTED_FORMATS
         ):
-            return "mp4"
+            return False  # Already supported
+
+        return True
     except Exception as e:
         log.error(f"Error getting audio format: {e}")
         return False
 
-    return None
-
 
-def convert_audio_to_wav(file_path, output_path, conversion_type):
-    """Convert MP4/OGG audio file to WAV format."""
-    audio = AudioSegment.from_file(file_path, format=conversion_type)
-    audio.export(output_path, format="wav")
-    log.info(f"Converted {file_path} to {output_path}")
+def convert_audio_to_mp3(file_path):
+    """Convert audio file to mp3 format."""
+    try:
+        output_path = os.path.splitext(file_path)[0] + ".mp3"
+        audio = AudioSegment.from_file(file_path)
+        audio.export(output_path, format="mp3")
+        log.info(f"Converted {file_path} to {output_path}")
+        return output_path
+    except Exception as e:
+        log.error(f"Error converting audio file: {e}")
+        return None
 
 
 def set_faster_whisper_model(model: str, auto_update: bool = False):
@@ -544,19 +561,6 @@ def transcription_handler(request, file_path):
         log.debug(data)
         return data
     elif request.app.state.config.STT_ENGINE == "openai":
-        convert_format = get_audio_convert_format(file_path)
-
-        if convert_format:
-            ext = convert_format.split(".")[-1]
-
-            os.rename(file_path, file_path.replace(".{ext}", f".{convert_format}"))
-            # Convert unsupported audio file to WAV format
-            convert_audio_to_wav(
-                file_path.replace(".{ext}", f".{convert_format}"),
-                file_path,
-                convert_format,
-            )
-
         r = None
         try:
             r = requests.post(
@@ -776,6 +780,9 @@ def transcription_handler(request, file_path):
 def transcribe(request: Request, file_path):
     log.info(f"transcribe: {file_path}")
 
+    if is_audio_conversion_required(file_path):
+        file_path = convert_audio_to_mp3(file_path)
+
     try:
         file_path = compress_audio(file_path)
     except Exception as e:
@@ -894,16 +901,11 @@ def transcription(
 ):
     log.info(f"file.content_type: {file.content_type}")
 
-    supported_filetypes = (
-        "audio/mpeg",
-        "audio/wav",
-        "audio/ogg",
-        "audio/x-m4a",
-        "audio/webm",
-        "video/webm",
-    )
-
-    if not file.content_type.startswith(supported_filetypes):
+    SUPPORTED_CONTENT_TYPES = {"video/webm"}  # Extend if you add more video types!
+    if not (
+        file.content_type.startswith("audio/")
+        or file.content_type in SUPPORTED_CONTENT_TYPES
+    ):
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,

+ 3 - 10
backend/open_webui/routers/files.py

@@ -140,16 +140,9 @@ def upload_file(
         if process:
             try:
                 if file.content_type:
-                    if file.content_type.startswith(
-                        (
-                            "audio/mpeg",
-                            "audio/wav",
-                            "audio/ogg",
-                            "audio/x-m4a",
-                            "audio/webm",
-                            "video/webm",
-                        )
-                    ):
+                    if file.content_type.startswith("audio/") or file.content_type in {
+                        "video/webm"
+                    }:
                         file_path = Storage.get_file(file_path)
                         result = transcribe(request, file_path)