Quellcode durchsuchen

feat: user stt language

Timothy Jaeryang Baek vor 4 Monaten
Ursprung
Commit
baaa285534

+ 26 - 9
backend/open_webui/routers/audio.py

@@ -8,6 +8,8 @@ from pathlib import Path
 from pydub import AudioSegment
 from pydub.silence import split_on_silence
 from concurrent.futures import ThreadPoolExecutor
+from typing import Optional
+
 
 import aiohttp
 import aiofiles
@@ -18,6 +20,7 @@ from fastapi import (
     Depends,
     FastAPI,
     File,
+    Form,
     HTTPException,
     Request,
     UploadFile,
@@ -527,11 +530,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         return FileResponse(file_path)
 
 
-def transcription_handler(request, file_path):
+def transcription_handler(request, file_path, metadata):
     filename = os.path.basename(file_path)
     file_dir = os.path.dirname(file_path)
     id = filename.split(".")[0]
 
+    metadata = metadata or {}
+
     if request.app.state.config.STT_ENGINE == "":
         if request.app.state.faster_whisper_model is None:
             request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -543,7 +548,7 @@ def transcription_handler(request, file_path):
             file_path,
             beam_size=5,
             vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
-            language=WHISPER_LANGUAGE,
+            language=metadata.get("language") or WHISPER_LANGUAGE,
         )
         log.info(
             "Detected language '%s' with probability %f"
@@ -569,7 +574,14 @@ def transcription_handler(request, file_path):
                     "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
                 },
                 files={"file": (filename, open(file_path, "rb"))},
-                data={"model": request.app.state.config.STT_MODEL},
+                data={
+                    "model": request.app.state.config.STT_MODEL,
+                    **(
+                        {"language": metadata.get("language")}
+                        if metadata.get("language")
+                        else {}
+                    ),
+                },
             )
 
             r.raise_for_status()
@@ -777,8 +789,8 @@ def transcription_handler(request, file_path):
             )
 
 
-def transcribe(request: Request, file_path):
-    log.info(f"transcribe: {file_path}")
+def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
+    log.info(f"transcribe: {file_path} {metadata}")
 
     if is_audio_conversion_required(file_path):
         file_path = convert_audio_to_mp3(file_path)
@@ -804,7 +816,7 @@ def transcribe(request: Request, file_path):
         with ThreadPoolExecutor() as executor:
             # Submit tasks for each chunk_path
             futures = [
-                executor.submit(transcription_handler, request, chunk_path)
+                executor.submit(transcription_handler, request, chunk_path, metadata)
                 for chunk_path in chunk_paths
             ]
             # Gather results as they complete
@@ -812,10 +824,9 @@ def transcribe(request: Request, file_path):
                 try:
                     results.append(future.result())
                 except Exception as transcribe_exc:
-                    log.exception(f"Error transcribing chunk: {transcribe_exc}")
                     raise HTTPException(
                         status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                        detail="Error during transcription.",
+                        detail=f"Error transcribing chunk: {transcribe_exc}",
                     )
     finally:
         # Clean up only the temporary chunks, never the original file
@@ -897,6 +908,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
 def transcription(
     request: Request,
     file: UploadFile = File(...),
+    language: Optional[str] = Form(None),
     user=Depends(get_verified_user),
 ):
     log.info(f"file.content_type: {file.content_type}")
@@ -926,7 +938,12 @@ def transcription(
             f.write(contents)
 
         try:
-            result = transcribe(request, file_path)
+            metadata = None
+
+            if language:
+                metadata = {"language": language}
+
+            result = transcribe(request, file_path, metadata)
 
             return {
                 **result,

+ 16 - 4
backend/open_webui/routers/files.py

@@ -1,6 +1,7 @@
 import logging
 import os
 import uuid
+import json
 from fnmatch import fnmatch
 from pathlib import Path
 from typing import Optional
@@ -10,6 +11,7 @@ from fastapi import (
     APIRouter,
     Depends,
     File,
+    Form,
     HTTPException,
     Request,
     UploadFile,
@@ -84,13 +86,23 @@ def has_access_to_file(
 def upload_file(
     request: Request,
     file: UploadFile = File(...),
-    user=Depends(get_verified_user),
-    metadata: dict = None,
+    metadata: Optional[dict | str] = Form(None),
     process: bool = Query(True),
+    internal: bool = False,
+    user=Depends(get_verified_user),
 ):
     log.info(f"file.content_type: {file.content_type}")
 
+    if isinstance(metadata, str):
+        try:
+            metadata = json.loads(metadata)
+        except json.JSONDecodeError:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
+            )
     file_metadata = metadata if metadata else {}
+
     try:
         unsanitized_filename = file.filename
         filename = os.path.basename(unsanitized_filename)
@@ -99,7 +111,7 @@ def upload_file(
         # Remove the leading dot from the file extension
         file_extension = file_extension[1:] if file_extension else ""
 
-        if not file_metadata and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
+        if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
             request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
                 ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
             ]
@@ -147,7 +159,7 @@ def upload_file(
                         "video/webm"
                     }:
                         file_path = Storage.get_file(file_path)
-                        result = transcribe(request, file_path)
+                        result = transcribe(request, file_path, file_metadata)
 
                         process_file(
                             request,

+ 1 - 1
backend/open_webui/routers/images.py

@@ -460,7 +460,7 @@ def upload_image(request, image_data, content_type, metadata, user):
             "content-type": content_type,
         },
     )
-    file_item = upload_file(request, file, user, metadata=metadata)
+    file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
     url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
     return url
 

+ 4 - 1
src/lib/apis/audio/index.ts

@@ -64,9 +64,12 @@ export const updateAudioConfig = async (token: string, payload: OpenAIConfigForm
 	return res;
 };
 
-export const transcribeAudio = async (token: string, file: File) => {
+export const transcribeAudio = async (token: string, file: File, language?: string) => {
 	const data = new FormData();
 	data.append('file', file);
+	if (language) {
+		data.append('language', language);
+	}
 
 	let error = null;
 	const res = await fetch(`${AUDIO_API_BASE_URL}/transcriptions`, {

+ 5 - 1
src/lib/apis/files/index.ts

@@ -1,8 +1,12 @@
 import { WEBUI_API_BASE_URL } from '$lib/constants';
 
-export const uploadFile = async (token: string, file: File) => {
+export const uploadFile = async (token: string, file: File, metadata?: object | null) => {
 	const data = new FormData();
 	data.append('file', file);
+	if (metadata) {
+		data.append('metadata', JSON.stringify(metadata));
+	}
+
 	let error = null;
 
 	const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, {

+ 13 - 2
src/lib/components/channel/MessageInput.svelte

@@ -17,7 +17,6 @@
 	import { WEBUI_API_BASE_URL } from '$lib/constants';
 	import FileItem from '../common/FileItem.svelte';
 	import Image from '../common/Image.svelte';
-	import { transcribeAudio } from '$lib/apis/audio';
 	import FilesOverlay from '../chat/MessageInput/FilesOverlay.svelte';
 
 	export let placeholder = $i18n.t('Send a Message');
@@ -160,7 +159,19 @@
 
 		try {
 			// During the file upload, file content is automatically extracted.
-			const uploadedFile = await uploadFile(localStorage.token, file);
+
+			// If the file is an audio file, provide the language for STT.
+			let metadata = null;
+			if (
+				(file.type.startsWith('audio/') || file.type.startsWith('video/')) &&
+				$settings?.audio?.stt?.language
+			) {
+				metadata = {
+					language: $settings?.audio?.stt?.language
+				};
+			}
+
+			const uploadedFile = await uploadFile(localStorage.token, file, metadata);
 
 			if (uploadedFile) {
 				console.info('File upload completed:', {

+ 12 - 1
src/lib/components/chat/Chat.svelte

@@ -591,9 +591,20 @@
 				throw new Error('Created file is empty');
 			}
 
+			// If the file is an audio file, provide the language for STT.
+			let metadata = null;
+			if (
+				(file.type.startsWith('audio/') || file.type.startsWith('video/')) &&
+				$settings?.audio?.stt?.language
+			) {
+				metadata = {
+					language: $settings?.audio?.stt?.language
+				};
+			}
+
 			// Upload file to server
 			console.log('Uploading file to server...');
-			const uploadedFile = await uploadFile(localStorage.token, file);
+			const uploadedFile = await uploadFile(localStorage.token, file, metadata);
 
 			if (!uploadedFile) {
 				throw new Error('Server returned null response for file upload');

+ 12 - 2
src/lib/components/chat/MessageInput.svelte

@@ -27,7 +27,6 @@
 		createMessagesList,
 		extractCurlyBraceWords
 	} from '$lib/utils';
-	import { transcribeAudio } from '$lib/apis/audio';
 	import { uploadFile } from '$lib/apis/files';
 	import { generateAutoCompletion } from '$lib/apis';
 	import { deleteFileById } from '$lib/apis/files';
@@ -249,8 +248,19 @@
 		files = [...files, fileItem];
 
 		try {
+			// If the file is an audio file, provide the language for STT.
+			let metadata = null;
+			if (
+				(file.type.startsWith('audio/') || file.type.startsWith('video/')) &&
+				$settings?.audio?.stt?.language
+			) {
+				metadata = {
+					language: $settings?.audio?.stt?.language
+				};
+			}
+
 			// During the file upload, file content is automatically extracted.
-			const uploadedFile = await uploadFile(localStorage.token, file);
+			const uploadedFile = await uploadFile(localStorage.token, file, metadata);
 
 			if (uploadedFile) {
 				console.log('File upload completed:', {

+ 5 - 1
src/lib/components/chat/MessageInput/CallOverlay.svelte

@@ -153,7 +153,11 @@
 		await tick();
 		const file = blobToFile(audioBlob, 'recording.wav');
 
-		const res = await transcribeAudio(localStorage.token, file).catch((error) => {
+		const res = await transcribeAudio(
+			localStorage.token,
+			file,
+			$settings?.audio?.stt?.language
+		).catch((error) => {
 			toast.error(`${error}`);
 			return null;
 		});

+ 5 - 1
src/lib/components/chat/MessageInput/VoiceRecording.svelte

@@ -150,7 +150,11 @@
 				return;
 			}
 
-			const res = await transcribeAudio(localStorage.token, file).catch((error) => {
+			const res = await transcribeAudio(
+				localStorage.token,
+				file,
+				$settings?.audio?.stt?.language
+			).catch((error) => {
 				toast.error(`${error}`);
 				return null;
 			});

+ 26 - 2
src/lib/components/chat/Settings/Audio.svelte

@@ -9,6 +9,7 @@
 	import Switch from '$lib/components/common/Switch.svelte';
 	import { round } from '@huggingface/transformers';
 	import Spinner from '$lib/components/common/Spinner.svelte';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
 	const dispatch = createEventDispatcher();
 
 	const i18n = getContext('i18n');
@@ -22,6 +23,7 @@
 	let nonLocalVoices = false;
 
 	let STTEngine = '';
+	let STTLanguage = '';
 
 	let TTSEngine = '';
 	let TTSEngineConfig = {};
@@ -89,6 +91,7 @@
 		responseAutoPlayback = $settings.responseAutoPlayback ?? false;
 
 		STTEngine = $settings?.audio?.stt?.engine ?? '';
+		STTLanguage = $settings?.audio?.stt?.language ?? '';
 
 		TTSEngine = $settings?.audio?.tts?.engine ?? '';
 		TTSEngineConfig = $settings?.audio?.tts?.engineConfig ?? {};
@@ -156,7 +159,8 @@
 		saveSettings({
 			audio: {
 				stt: {
-					engine: STTEngine !== '' ? STTEngine : undefined
+					engine: STTEngine !== '' ? STTEngine : undefined,
+					language: STTLanguage !== '' ? STTLanguage : undefined
 				},
 				tts: {
 					engine: TTSEngine !== '' ? TTSEngine : undefined,
@@ -189,6 +193,26 @@
 						</select>
 					</div>
 				</div>
+
+				<div class=" py-0.5 flex w-full justify-between">
+					<div class=" self-center text-xs font-medium">{$i18n.t('Language')}</div>
+
+					<div class="flex items-center relative text-xs px-3">
+						<Tooltip
+							content={$i18n.t(
+								'The language of the input audio. Supplying the input language in ISO-639-1 (e.g. en) format will improve accuracy and latency. Leave blank to automatically detect the language.'
+							)}
+							placement="top"
+						>
+							<input
+								type="text"
+								bind:value={STTLanguage}
+								placeholder={$i18n.t('e.g. en')}
+								class=" text-sm text-right bg-transparent dark:text-gray-300 outline-hidden"
+							/>
+						</Tooltip>
+					</div>
+				</div>
 			{/if}
 
 			<div class=" py-0.5 flex w-full justify-between">
@@ -269,7 +293,7 @@
 			<div class=" py-0.5 flex w-full justify-between">
 				<div class=" self-center text-xs font-medium">{$i18n.t('Speech Playback Speed')}</div>
 
-				<div class="flex items-center relative text-xs">
+				<div class="flex items-center relative text-xs px-3">
 					<input
 						type="number"
 						min="0"

+ 0 - 73
src/lib/components/layout/Help/HelpMenu.svelte

@@ -1,73 +0,0 @@
-<script lang="ts">
-	import { DropdownMenu } from 'bits-ui';
-	import { getContext } from 'svelte';
-
-	import { showSettings } from '$lib/stores';
-	import { flyAndScale } from '$lib/utils/transitions';
-
-	import Dropdown from '$lib/components/common/Dropdown.svelte';
-	import QuestionMarkCircle from '$lib/components/icons/QuestionMarkCircle.svelte';
-	import Keyboard from '$lib/components/icons/Keyboard.svelte';
-	import Map from '$lib/components/icons/Map.svelte';
-
-	const i18n = getContext('i18n');
-
-	export let showDocsHandler: Function;
-	export let showShortcutsHandler: Function;
-
-	export let onClose: Function = () => {};
-</script>
-
-<Dropdown
-	on:change={(e) => {
-		if (e.detail === false) {
-			onClose();
-		}
-	}}
->
-	<slot />
-
-	<div slot="content">
-		<DropdownMenu.Content
-			class="w-full max-w-[200px] rounded-xl px-1 py-1.5 border border-gray-300/30 dark:border-gray-700/50 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-lg"
-			sideOffset={4}
-			side="top"
-			align="end"
-			transition={flyAndScale}
-		>
-			<DropdownMenu.Item
-				class="flex gap-2 items-center px-3 py-2 text-sm  cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
-				id="chat-share-button"
-				on:click={() => {
-					window.open('https://docs.openwebui.com', '_blank');
-				}}
-			>
-				<QuestionMarkCircle className="size-5" />
-				<div class="flex items-center">{$i18n.t('Documentation')}</div>
-			</DropdownMenu.Item>
-
-			<!-- Releases -->
-			<DropdownMenu.Item
-				class="flex gap-2 items-center px-3 py-2 text-sm cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
-				id="menu-item-releases"
-				on:click={() => {
-					window.open('https://github.com/open-webui/open-webui/releases', '_blank');
-				}}
-			>
-				<Map className="size-5" />
-				<div class="flex items-center">{$i18n.t('Releases')}</div>
-			</DropdownMenu.Item>
-
-			<DropdownMenu.Item
-				class="flex gap-2 items-center px-3 py-2 text-sm  cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
-				id="chat-share-button"
-				on:click={() => {
-					showShortcutsHandler();
-				}}
-			>
-				<Keyboard className="size-5" />
-				<div class="flex items-center">{$i18n.t('Keyboard shortcuts')}</div>
-			</DropdownMenu.Item>
-		</DropdownMenu.Content>
-	</div>
-</Dropdown>

+ 12 - 1
src/lib/components/notes/NoteEditor.svelte

@@ -276,8 +276,19 @@
 		files = [...files, fileItem];
 
 		try {
+			// If the file is an audio file, provide the language for STT.
+			let metadata = null;
+			if (
+				(file.type.startsWith('audio/') || file.type.startsWith('video/')) &&
+				$settings?.audio?.stt?.language
+			) {
+				metadata = {
+					language: $settings?.audio?.stt?.language
+				};
+			}
+
 			// During the file upload, file content is automatically extracted.
-			const uploadedFile = await uploadFile(localStorage.token, file);
+			const uploadedFile = await uploadFile(localStorage.token, file, metadata);
 
 			if (uploadedFile) {
 				console.log('File upload completed:', {

+ 12 - 4
src/lib/components/workspace/Knowledge/KnowledgeBase.svelte

@@ -26,10 +26,7 @@
 		updateFileFromKnowledgeById,
 		updateKnowledgeById
 	} from '$lib/apis/knowledge';
-
-	import { transcribeAudio } from '$lib/apis/audio';
 	import { blobToFile } from '$lib/utils';
-	import { processFile } from '$lib/apis/retrieval';
 
 	import Spinner from '$lib/components/common/Spinner.svelte';
 	import Files from './KnowledgeBase/Files.svelte';
@@ -158,7 +155,18 @@
 		knowledge.files = [...(knowledge.files ?? []), fileItem];
 
 		try {
-			const uploadedFile = await uploadFile(localStorage.token, file).catch((e) => {
+			// If the file is an audio file, provide the language for STT.
+			let metadata = null;
+			if (
+				(file.type.startsWith('audio/') || file.type.startsWith('video/')) &&
+				$settings?.audio?.stt?.language
+			) {
+				metadata = {
+					language: $settings?.audio?.stt?.language
+				};
+			}
+
+			const uploadedFile = await uploadFile(localStorage.token, file, metadata).catch((e) => {
 				toast.error(`${e}`);
 				return null;
 			});