Преглед на файлове

enh: very long audio transcription

Timothy Jaeryang Baek преди 4 месеца
родител
ревизия
b280f828b0
променени са 1 файла, в които са добавени 111 реда и са изтрити 23 реда
  1. 111 23
      backend/open_webui/routers/audio.py

+ 111 - 23
backend/open_webui/routers/audio.py

@@ -7,6 +7,7 @@ from functools import lru_cache
 from pathlib import Path
 from pydub import AudioSegment
 from pydub.silence import split_on_silence
+from concurrent.futures import ThreadPoolExecutor
 
 import aiohttp
 import aiofiles
@@ -50,7 +51,7 @@ from open_webui.env import (
 router = APIRouter()
 
 # Constants
-MAX_FILE_SIZE_MB = 25
+MAX_FILE_SIZE_MB = 20
 MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
 AZURE_MAX_FILE_SIZE_MB = 200
 AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
@@ -87,8 +88,6 @@ def get_audio_convert_format(file_path):
             and info.get("codec_tag_string") == "mp4a"
         ):
             return "mp4"
-        elif info.get("format_name") == "ogg":
-            return "ogg"
     except Exception as e:
         log.error(f"Error getting audio format: {e}")
         return False
@@ -511,8 +510,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         return FileResponse(file_path)
 
 
-def transcribe(request: Request, file_path):
-    log.info(f"transcribe: {file_path}")
+def transcription_handler(request, file_path):
     filename = os.path.basename(file_path)
     file_dir = os.path.dirname(file_path)
     id = filename.split(".")[0]
@@ -775,24 +773,119 @@ def transcribe(request: Request, file_path):
             )
 
 
+def transcribe(request: Request, file_path):
+    log.info(f"transcribe: {file_path}")
+
+    try:
+        file_path = compress_audio(file_path)
+    except Exception as e:
+        log.exception(e)
+
+    # Always produce a list of chunk paths (could be one entry if small)
+    try:
+        chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
+        print(f"Chunk paths: {chunk_paths}")
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+    results = []
+    try:
+        with ThreadPoolExecutor() as executor:
+            # Submit tasks for each chunk_path
+            futures = [
+                executor.submit(transcription_handler, request, chunk_path)
+                for chunk_path in chunk_paths
+            ]
+            # Gather results as they complete
+            for future in futures:
+                try:
+                    results.append(future.result())
+                except Exception as transcribe_exc:
+                    log.exception(f"Error transcribing chunk: {transcribe_exc}")
+                    raise HTTPException(
+                        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                        detail="Error during transcription.",
+                    )
+    finally:
+        # Clean up only the temporary chunks, never the original file
+        for chunk_path in chunk_paths:
+            if chunk_path != file_path and os.path.isfile(chunk_path):
+                try:
+                    os.remove(chunk_path)
+                except Exception:
+                    pass
+
+    return {
+        "text": " ".join([result["text"] for result in results]),
+    }
+
+
 def compress_audio(file_path):
     if os.path.getsize(file_path) > MAX_FILE_SIZE:
+        id = os.path.splitext(os.path.basename(file_path))[
+            0
+        ]  # Handles names with multiple dots
         file_dir = os.path.dirname(file_path)
+
         audio = AudioSegment.from_file(file_path)
         audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
-        compressed_path = f"{file_dir}/{id}_compressed.opus"
-        audio.export(compressed_path, format="opus", bitrate="32k")
-        log.debug(f"Compressed audio to {compressed_path}")
 
-        if (
-            os.path.getsize(compressed_path) > MAX_FILE_SIZE
-        ):  # Still larger than MAX_FILE_SIZE after compression
-            raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
+        compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
+        audio.export(compressed_path, format="mp3", bitrate="32k")
+        # log.debug(f"Compressed audio to {compressed_path}")  # Uncomment if log is defined
+
         return compressed_path
     else:
         return file_path
 
 
+def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
+    """
+    Splits audio into chunks not exceeding max_bytes.
+    Returns a list of chunk file paths. If audio fits, returns list with original path.
+    """
+    file_size = os.path.getsize(file_path)
+    if file_size <= max_bytes:
+        return [file_path]  # Nothing to split
+
+    audio = AudioSegment.from_file(file_path)
+    duration_ms = len(audio)
+    orig_size = file_size
+
+    approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
+    chunks = []
+    start = 0
+    i = 0
+
+    base, _ = os.path.splitext(file_path)
+
+    while start < duration_ms:
+        end = min(start + approx_chunk_ms, duration_ms)
+        chunk = audio[start:end]
+        chunk_path = f"{base}_chunk_{i}.{format}"
+        chunk.export(chunk_path, format=format, bitrate=bitrate)
+
+        # Reduce chunk duration if still too large
+        while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
+            end = start + ((end - start) // 2)
+            chunk = audio[start:end]
+            chunk.export(chunk_path, format=format, bitrate=bitrate)
+
+        if os.path.getsize(chunk_path) > max_bytes:
+            os.remove(chunk_path)
+            raise Exception("Audio chunk cannot be reduced below max file size.")
+
+        chunks.append(chunk_path)
+        start = end
+        i += 1
+
+    return chunks
+
+
 @router.post("/transcriptions")
 def transcription(
     request: Request,
@@ -807,6 +900,7 @@ def transcription(
         "audio/ogg",
         "audio/x-m4a",
         "audio/webm",
+        "video/webm",
     )
 
     if not file.content_type.startswith(supported_filetypes):
@@ -830,19 +924,13 @@ def transcription(
             f.write(contents)
 
         try:
-            try:
-                file_path = compress_audio(file_path)
-            except Exception as e:
-                log.exception(e)
+            result = transcribe(request, file_path)
 
-                raise HTTPException(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    detail=ERROR_MESSAGES.DEFAULT(e),
-                )
+            return {
+                **result,
+                "filename": os.path.basename(file_path),
+            }
 
-            data = transcribe(request, file_path)
-            file_path = file_path.split("/")[-1]
-            return {**data, "filename": file_path}
         except Exception as e:
             log.exception(e)