mglo 3 месяцев назад
Родитель
Сommit
3561c7eedd

+ 18 - 0
backend/open_webui/config.py

@@ -3403,6 +3403,24 @@ AUDIO_STT_AZURE_MAX_SPEAKERS = PersistentConfig(
     os.getenv("AUDIO_STT_AZURE_MAX_SPEAKERS", ""),
 )
 
+AUDIO_STT_MISTRAL_API_KEY = PersistentConfig(
+    "AUDIO_STT_MISTRAL_API_KEY",
+    "audio.stt.mistral.api_key",
+    os.getenv("AUDIO_STT_MISTRAL_API_KEY", ""),
+)
+
+AUDIO_STT_MISTRAL_API_BASE_URL = PersistentConfig(
+    "AUDIO_STT_MISTRAL_API_BASE_URL",
+    "audio.stt.mistral.api_base_url",
+    os.getenv("AUDIO_STT_MISTRAL_API_BASE_URL", "https://api.mistral.ai/v1"),
+)
+
+AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = PersistentConfig(
+    "AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS",
+    "audio.stt.mistral.use_chat_completions",
+    os.getenv("AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS", "false").lower() == "true",
+)
+
 AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
     "AUDIO_TTS_OPENAI_API_BASE_URL",
     "audio.tts.openai.api_base_url",

+ 7 - 0
backend/open_webui/main.py

@@ -175,6 +175,9 @@ from open_webui.config import (
     AUDIO_STT_AZURE_LOCALES,
     AUDIO_STT_AZURE_BASE_URL,
     AUDIO_STT_AZURE_MAX_SPEAKERS,
+    AUDIO_STT_MISTRAL_API_KEY,
+    AUDIO_STT_MISTRAL_API_BASE_URL,
+    AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
     AUDIO_TTS_ENGINE,
     AUDIO_TTS_MODEL,
     AUDIO_TTS_VOICE,
@@ -1108,6 +1111,10 @@ app.state.config.AUDIO_STT_AZURE_LOCALES = AUDIO_STT_AZURE_LOCALES
 app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL
 app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS
 
+app.state.config.AUDIO_STT_MISTRAL_API_KEY = AUDIO_STT_MISTRAL_API_KEY
+app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = AUDIO_STT_MISTRAL_API_BASE_URL
+app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS
+
 app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
 
 app.state.config.TTS_MODEL = AUDIO_TTS_MODEL

+ 195 - 0
backend/open_webui/routers/audio.py

@@ -4,6 +4,7 @@ import logging
 import os
 import uuid
 import html
+import base64
 from functools import lru_cache
 from pydub import AudioSegment
 from pydub.silence import split_on_silence
@@ -178,6 +179,9 @@ class STTConfigForm(BaseModel):
     AZURE_LOCALES: str
     AZURE_BASE_URL: str
     AZURE_MAX_SPEAKERS: str
+    MISTRAL_API_KEY: str
+    MISTRAL_API_BASE_URL: str
+    MISTRAL_USE_CHAT_COMPLETIONS: bool
 
 
 class AudioConfigUpdateForm(BaseModel):
@@ -214,6 +218,9 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
             "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
             "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
             "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
+            "MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
+            "MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
+            "MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
         },
     }
 
@@ -255,6 +262,13 @@ async def update_audio_config(
     request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
         form_data.stt.AZURE_MAX_SPEAKERS
     )
+    request.app.state.config.AUDIO_STT_MISTRAL_API_KEY = form_data.stt.MISTRAL_API_KEY
+    request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = (
+        form_data.stt.MISTRAL_API_BASE_URL
+    )
+    request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = (
+        form_data.stt.MISTRAL_USE_CHAT_COMPLETIONS
+    )
 
     if request.app.state.config.STT_ENGINE == "":
         request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -290,6 +304,9 @@ async def update_audio_config(
             "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
             "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
             "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
+            "MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
+            "MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
+            "MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
         },
     }
 
@@ -828,6 +845,184 @@ def transcription_handler(request, file_path, metadata):
                 detail=detail if detail else "Open WebUI: Server Connection Error",
             )
 
+    elif request.app.state.config.STT_ENGINE == "mistral":
+        # Check file exists
+        if not os.path.exists(file_path):
+            raise HTTPException(status_code=400, detail="Audio file not found")
+
+        # Check file size
+        file_size = os.path.getsize(file_path)
+        if file_size > MAX_FILE_SIZE:
+            raise HTTPException(
+                status_code=400,
+                detail=f"File size exceeds limit of {MAX_FILE_SIZE_MB}MB",
+            )
+
+        api_key = request.app.state.config.AUDIO_STT_MISTRAL_API_KEY
+        api_base_url = (
+            request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL
+            or "https://api.mistral.ai/v1"
+        )
+        use_chat_completions = (
+            request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS
+        )
+
+        if not api_key:
+            raise HTTPException(
+                status_code=400,
+                detail="Mistral API key is required for Mistral STT",
+            )
+
+        r = None
+        try:
+            # Use voxtral-mini-latest as the default model for transcription
+            model = request.app.state.config.STT_MODEL or "voxtral-mini-latest"
+            
+            log.info(
+                f"Mistral STT - model: {model}, "
+                f"method: {'chat_completions' if use_chat_completions else 'transcriptions'}"
+            )
+
+            if use_chat_completions:
+                # Use chat completions API with audio input
+                # This method requires mp3 or wav format
+                audio_file_to_use = file_path
+                
+                if is_audio_conversion_required(file_path):
+                    log.debug("Converting audio to mp3 for chat completions API")
+                    converted_path = convert_audio_to_mp3(file_path)
+                    if converted_path:
+                        audio_file_to_use = converted_path
+                    else:
+                        log.error("Audio conversion failed")
+                        raise HTTPException(
+                            status_code=500,
+                            detail="Audio conversion failed. Chat completions API requires mp3 or wav format.",
+                        )
+                
+                # Read and encode audio file as base64
+                with open(audio_file_to_use, "rb") as audio_file:
+                    audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
+                
+                # Prepare chat completions request
+                url = f"{api_base_url}/chat/completions"
+                
+                # Add language instruction if specified
+                language = metadata.get("language", None) if metadata else None
+                if language:
+                    text_instruction = f"Transcribe this audio exactly as spoken in {language}. Do not translate it."
+                else:
+                    text_instruction = "Transcribe this audio exactly as spoken in its original language. Do not translate it to another language."
+                
+                payload = {
+                    "model": model,
+                    "messages": [
+                        {
+                            "role": "user",
+                            "content": [
+                                {
+                                    "type": "input_audio",
+                                    "input_audio": audio_base64,
+                                },
+                                {
+                                    "type": "text",
+                                    "text": text_instruction
+                                }
+                            ]
+                        }
+                    ]
+                }
+                
+                r = requests.post(
+                    url=url,
+                    json=payload,
+                    headers={
+                        "Authorization": f"Bearer {api_key}",
+                        "Content-Type": "application/json",
+                    },
+                )
+                
+                r.raise_for_status()
+                response = r.json()
+                
+                # Extract transcript from chat completion response
+                transcript = response.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
+                if not transcript:
+                    raise ValueError("Empty transcript in response")
+                
+                data = {"text": transcript}
+                
+            else:
+                # Use dedicated transcriptions API
+                url = f"{api_base_url}/audio/transcriptions"
+
+                # Determine the MIME type
+                mime_type, _ = mimetypes.guess_type(file_path)
+                if not mime_type:
+                    mime_type = "audio/webm"
+
+                # Use context manager to ensure file is properly closed
+                with open(file_path, "rb") as audio_file:
+                    files = {"file": (filename, audio_file, mime_type)}
+                    data_form = {"model": model}
+
+                    # Add language if specified in metadata
+                    language = metadata.get("language", None) if metadata else None
+                    if language:
+                        data_form["language"] = language
+                    
+                    r = requests.post(
+                        url=url,
+                        files=files,
+                        data=data_form,
+                        headers={
+                            "Authorization": f"Bearer {api_key}",
+                        },
+                    )
+                
+                r.raise_for_status()
+                response = r.json()
+
+                # Extract transcript from response
+                transcript = response.get("text", "").strip()
+                if not transcript:
+                    raise ValueError("Empty transcript in response")
+
+                data = {"text": transcript}
+
+            # Save transcript to json file (consistent with other providers)
+            transcript_file = f"{file_dir}/{id}.json"
+            with open(transcript_file, "w") as f:
+                json.dump(data, f)
+
+            log.debug(data)
+            return data
+
+        except ValueError as e:
+            log.exception("Error parsing Mistral response")
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to parse Mistral response: {str(e)}",
+            )
+        except requests.exceptions.RequestException as e:
+            log.exception(e)
+            detail = None
+
+            try:
+                if r is not None and r.status_code != 200:
+                    res = r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+                    else:
+                        detail = f"External: {r.text}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status_code", 500) if r else 500,
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
 
 def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
     log.info(f"transcribe: {file_path} {metadata}")

+ 72 - 1
src/lib/components/admin/Settings/Audio.svelte

@@ -50,6 +50,9 @@
 	let STT_AZURE_BASE_URL = '';
 	let STT_AZURE_MAX_SPEAKERS = '';
 	let STT_DEEPGRAM_API_KEY = '';
+	let STT_MISTRAL_API_KEY = '';
+	let STT_MISTRAL_API_BASE_URL = '';
+	let STT_MISTRAL_USE_CHAT_COMPLETIONS = false;
 
 	let STT_WHISPER_MODEL_LOADING = false;
 
@@ -135,7 +138,10 @@
 				AZURE_REGION: STT_AZURE_REGION,
 				AZURE_LOCALES: STT_AZURE_LOCALES,
 				AZURE_BASE_URL: STT_AZURE_BASE_URL,
-				AZURE_MAX_SPEAKERS: STT_AZURE_MAX_SPEAKERS
+				AZURE_MAX_SPEAKERS: STT_AZURE_MAX_SPEAKERS,
+				MISTRAL_API_KEY: STT_MISTRAL_API_KEY,
+				MISTRAL_API_BASE_URL: STT_MISTRAL_API_BASE_URL,
+				MISTRAL_USE_CHAT_COMPLETIONS: STT_MISTRAL_USE_CHAT_COMPLETIONS
 			}
 		});
 
@@ -184,6 +190,9 @@
 			STT_AZURE_BASE_URL = res.stt.AZURE_BASE_URL;
 			STT_AZURE_MAX_SPEAKERS = res.stt.AZURE_MAX_SPEAKERS;
 			STT_DEEPGRAM_API_KEY = res.stt.DEEPGRAM_API_KEY;
+			STT_MISTRAL_API_KEY = res.stt.MISTRAL_API_KEY;
+			STT_MISTRAL_API_BASE_URL = res.stt.MISTRAL_API_BASE_URL;
+			STT_MISTRAL_USE_CHAT_COMPLETIONS = res.stt.MISTRAL_USE_CHAT_COMPLETIONS;
 		}
 
 		await getVoices();
@@ -235,6 +244,7 @@
 							<option value="web">{$i18n.t('Web API')}</option>
 							<option value="deepgram">{$i18n.t('Deepgram')}</option>
 							<option value="azure">{$i18n.t('Azure AI Speech')}</option>
+							<option value="mistral">{$i18n.t('MistralAI')}</option>
 						</select>
 					</div>
 				</div>
@@ -367,6 +377,67 @@
 							</div>
 						</div>
 					</div>
+				{:else if STT_ENGINE === 'mistral'}
+					<div>
+						<div class="mt-1 flex gap-2 mb-1">
+							<input
+								class="flex-1 w-full bg-transparent outline-hidden"
+								placeholder={$i18n.t('API Base URL')}
+								bind:value={STT_MISTRAL_API_BASE_URL}
+								required
+							/>
+
+							<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={STT_MISTRAL_API_KEY} />
+						</div>
+					</div>
+
+					<hr class="border-gray-100 dark:border-gray-850 my-2" />
+
+					<div>
+						<div class=" mb-1.5 text-xs font-medium">{$i18n.t('STT Model')}</div>
+						<div class="flex w-full">
+							<div class="flex-1">
+								<input
+									class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden"
+									bind:value={STT_MODEL}
+									placeholder="voxtral-mini-latest"
+								/>
+							</div>
+						</div>
+						<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+							{$i18n.t('Leave empty to use the default model (voxtral-mini-latest).')}
+							<a
+								class=" hover:underline dark:text-gray-200 text-gray-800"
+								href="https://docs.mistral.ai/capabilities/audio_transcription"
+								target="_blank"
+							>
+								{$i18n.t('Learn more about Voxtral transcription.')}
+							</a>
+						</div>
+					</div>
+
+					<hr class="border-gray-100 dark:border-gray-850 my-2" />
+
+					<div>
+						<div class="flex items-center justify-between mb-2">
+							<div class="text-xs font-medium">{$i18n.t('Use Chat Completions API')}</div>
+							<label class="relative inline-flex items-center cursor-pointer">
+								<input
+									type="checkbox"
+									bind:checked={STT_MISTRAL_USE_CHAT_COMPLETIONS}
+									class="sr-only peer"
+								/>
+								<div
+									class="w-9 h-5 bg-gray-200 peer-focus:outline-none peer-focus:ring-2 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 rounded-full peer dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:content-[''] after:absolute after:top-[2px] after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-4 after:w-4 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"
+								></div>
+							</label>
+						</div>
+						<div class="text-xs text-gray-400 dark:text-gray-500">
+							{$i18n.t(
+								'Use /v1/chat/completions endpoint instead of /v1/audio/transcriptions for potentially better accuracy.'
+							)}
+						</div>
+					</div>
 				{:else if STT_ENGINE === ''}
 					<div>
 						<div class=" mb-1.5 text-xs font-medium">{$i18n.t('STT Model')}</div>

+ 1 - 0
src/lib/i18n/locales/ca-ES/translation.json

@@ -179,6 +179,7 @@
 	"Away": "Absent",
 	"Awful": "Terrible",
 	"Azure AI Speech": "Azure AI Speech",
+	"MistralAI": "MistralAI",
 	"Azure OpenAI": "Azure OpenAI",
 	"Azure Region": "Regió d'Azure",
 	"Back": "Enrere",