|
|
@@ -4,6 +4,7 @@ import logging
|
|
|
import os
|
|
|
import uuid
|
|
|
import html
|
|
|
+import base64
|
|
|
from functools import lru_cache
|
|
|
from pydub import AudioSegment
|
|
|
from pydub.silence import split_on_silence
|
|
|
@@ -178,6 +179,9 @@ class STTConfigForm(BaseModel):
|
|
|
AZURE_LOCALES: str
|
|
|
AZURE_BASE_URL: str
|
|
|
AZURE_MAX_SPEAKERS: str
|
|
|
+ MISTRAL_API_KEY: str
|
|
|
+ MISTRAL_API_BASE_URL: str
|
|
|
+ MISTRAL_USE_CHAT_COMPLETIONS: bool
|
|
|
|
|
|
|
|
|
class AudioConfigUpdateForm(BaseModel):
|
|
|
@@ -214,6 +218,9 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
|
|
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
|
|
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
|
|
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
|
|
+ "MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
|
|
|
+ "MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
|
|
|
+ "MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
@@ -255,6 +262,13 @@ async def update_audio_config(
|
|
|
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
|
|
|
form_data.stt.AZURE_MAX_SPEAKERS
|
|
|
)
|
|
|
+ request.app.state.config.AUDIO_STT_MISTRAL_API_KEY = form_data.stt.MISTRAL_API_KEY
|
|
|
+ request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = (
|
|
|
+ form_data.stt.MISTRAL_API_BASE_URL
|
|
|
+ )
|
|
|
+ request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = (
|
|
|
+ form_data.stt.MISTRAL_USE_CHAT_COMPLETIONS
|
|
|
+ )
|
|
|
|
|
|
if request.app.state.config.STT_ENGINE == "":
|
|
|
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
|
|
@@ -290,6 +304,9 @@ async def update_audio_config(
|
|
|
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
|
|
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
|
|
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
|
|
+ "MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
|
|
|
+ "MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
|
|
|
+ "MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
@@ -828,6 +845,184 @@ def transcription_handler(request, file_path, metadata):
|
|
|
detail=detail if detail else "Open WebUI: Server Connection Error",
|
|
|
)
|
|
|
|
|
|
+ elif request.app.state.config.STT_ENGINE == "mistral":
|
|
|
+ # Check file exists
|
|
|
+ if not os.path.exists(file_path):
|
|
|
+ raise HTTPException(status_code=400, detail="Audio file not found")
|
|
|
+
|
|
|
+ # Check file size
|
|
|
+ file_size = os.path.getsize(file_path)
|
|
|
+ if file_size > MAX_FILE_SIZE:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=400,
|
|
|
+ detail=f"File size exceeds limit of {MAX_FILE_SIZE_MB}MB",
|
|
|
+ )
|
|
|
+
|
|
|
+ api_key = request.app.state.config.AUDIO_STT_MISTRAL_API_KEY
|
|
|
+ api_base_url = (
|
|
|
+ request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL
|
|
|
+ or "https://api.mistral.ai/v1"
|
|
|
+ )
|
|
|
+ use_chat_completions = (
|
|
|
+ request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS
|
|
|
+ )
|
|
|
+
|
|
|
+ if not api_key:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=400,
|
|
|
+ detail="Mistral API key is required for Mistral STT",
|
|
|
+ )
|
|
|
+
|
|
|
+ r = None
|
|
|
+ try:
|
|
|
+ # Use voxtral-mini-latest as the default model for transcription
|
|
|
+ model = request.app.state.config.STT_MODEL or "voxtral-mini-latest"
|
|
|
+
|
|
|
+ log.info(
|
|
|
+ f"Mistral STT - model: {model}, "
|
|
|
+ f"method: {'chat_completions' if use_chat_completions else 'transcriptions'}"
|
|
|
+ )
|
|
|
+
|
|
|
+ if use_chat_completions:
|
|
|
+ # Use chat completions API with audio input
|
|
|
+ # This method requires mp3 or wav format
|
|
|
+ audio_file_to_use = file_path
|
|
|
+
|
|
|
+ if is_audio_conversion_required(file_path):
|
|
|
+ log.debug("Converting audio to mp3 for chat completions API")
|
|
|
+ converted_path = convert_audio_to_mp3(file_path)
|
|
|
+ if converted_path:
|
|
|
+ audio_file_to_use = converted_path
|
|
|
+ else:
|
|
|
+ log.error("Audio conversion failed")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=500,
|
|
|
+ detail="Audio conversion failed. Chat completions API requires mp3 or wav format.",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Read and encode audio file as base64
|
|
|
+ with open(audio_file_to_use, "rb") as audio_file:
|
|
|
+ audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
|
|
|
+
|
|
|
+ # Prepare chat completions request
|
|
|
+ url = f"{api_base_url}/chat/completions"
|
|
|
+
|
|
|
+ # Add language instruction if specified
|
|
|
+ language = metadata.get("language", None) if metadata else None
|
|
|
+ if language:
|
|
|
+ text_instruction = f"Transcribe this audio exactly as spoken in {language}. Do not translate it."
|
|
|
+ else:
|
|
|
+ text_instruction = "Transcribe this audio exactly as spoken in its original language. Do not translate it to another language."
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": model,
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "input_audio",
|
|
|
+ "input_audio": audio_base64,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": text_instruction
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+ r = requests.post(
|
|
|
+ url=url,
|
|
|
+ json=payload,
|
|
|
+ headers={
|
|
|
+ "Authorization": f"Bearer {api_key}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ response = r.json()
|
|
|
+
|
|
|
+ # Extract transcript from chat completion response
|
|
|
+ transcript = response.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
|
|
|
+ if not transcript:
|
|
|
+ raise ValueError("Empty transcript in response")
|
|
|
+
|
|
|
+ data = {"text": transcript}
|
|
|
+
|
|
|
+ else:
|
|
|
+ # Use dedicated transcriptions API
|
|
|
+ url = f"{api_base_url}/audio/transcriptions"
|
|
|
+
|
|
|
+ # Determine the MIME type
|
|
|
+ mime_type, _ = mimetypes.guess_type(file_path)
|
|
|
+ if not mime_type:
|
|
|
+ mime_type = "audio/webm"
|
|
|
+
|
|
|
+ # Use context manager to ensure file is properly closed
|
|
|
+ with open(file_path, "rb") as audio_file:
|
|
|
+ files = {"file": (filename, audio_file, mime_type)}
|
|
|
+ data_form = {"model": model}
|
|
|
+
|
|
|
+ # Add language if specified in metadata
|
|
|
+ language = metadata.get("language", None) if metadata else None
|
|
|
+ if language:
|
|
|
+ data_form["language"] = language
|
|
|
+
|
|
|
+ r = requests.post(
|
|
|
+ url=url,
|
|
|
+ files=files,
|
|
|
+ data=data_form,
|
|
|
+ headers={
|
|
|
+ "Authorization": f"Bearer {api_key}",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ response = r.json()
|
|
|
+
|
|
|
+ # Extract transcript from response
|
|
|
+ transcript = response.get("text", "").strip()
|
|
|
+ if not transcript:
|
|
|
+ raise ValueError("Empty transcript in response")
|
|
|
+
|
|
|
+ data = {"text": transcript}
|
|
|
+
|
|
|
+ # Save transcript to json file (consistent with other providers)
|
|
|
+ transcript_file = f"{file_dir}/{id}.json"
|
|
|
+ with open(transcript_file, "w") as f:
|
|
|
+ json.dump(data, f)
|
|
|
+
|
|
|
+ log.debug(data)
|
|
|
+ return data
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ log.exception("Error parsing Mistral response")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=500,
|
|
|
+ detail=f"Failed to parse Mistral response: {str(e)}",
|
|
|
+ )
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
+ log.exception(e)
|
|
|
+ detail = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ if r is not None and r.status_code != 200:
|
|
|
+ res = r.json()
|
|
|
+ if "error" in res:
|
|
|
+ detail = f"External: {res['error'].get('message', '')}"
|
|
|
+ else:
|
|
|
+ detail = f"External: {r.text}"
|
|
|
+ except Exception:
|
|
|
+ detail = f"External: {e}"
|
|
|
+
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=getattr(r, "status_code", 500) if r else 500,
|
|
|
+ detail=detail if detail else "Open WebUI: Server Connection Error",
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
|
|
log.info(f"transcribe: {file_path} {metadata}")
|