audio.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from pydub import AudioSegment
  9. from pydub.silence import split_on_silence
  10. from concurrent.futures import ThreadPoolExecutor
  11. from typing import Optional
  12. from fnmatch import fnmatch
  13. import aiohttp
  14. import aiofiles
  15. import requests
  16. import mimetypes
  17. from fastapi import (
  18. Depends,
  19. FastAPI,
  20. File,
  21. Form,
  22. HTTPException,
  23. Request,
  24. UploadFile,
  25. status,
  26. APIRouter,
  27. )
  28. from fastapi.middleware.cors import CORSMiddleware
  29. from fastapi.responses import FileResponse
  30. from pydantic import BaseModel
  31. from open_webui.utils.auth import get_admin_user, get_verified_user
  32. from open_webui.config import (
  33. WHISPER_MODEL_AUTO_UPDATE,
  34. WHISPER_MODEL_DIR,
  35. CACHE_DIR,
  36. WHISPER_LANGUAGE,
  37. )
  38. from open_webui.constants import ERROR_MESSAGES
  39. from open_webui.env import (
  40. AIOHTTP_CLIENT_SESSION_SSL,
  41. AIOHTTP_CLIENT_TIMEOUT,
  42. ENV,
  43. SRC_LOG_LEVELS,
  44. DEVICE_TYPE,
  45. ENABLE_FORWARD_USER_INFO_HEADERS,
  46. )
  47. router = APIRouter()
  48. # Constants
  49. MAX_FILE_SIZE_MB = 20
  50. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  51. AZURE_MAX_FILE_SIZE_MB = 200
  52. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  53. log = logging.getLogger(__name__)
  54. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  55. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  56. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  57. ##########################################
  58. #
  59. # Utility functions
  60. #
  61. ##########################################
  62. from pydub import AudioSegment
  63. from pydub.utils import mediainfo
  64. def is_audio_conversion_required(file_path):
  65. """
  66. Check if the given audio file needs conversion to mp3.
  67. """
  68. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  69. if not os.path.isfile(file_path):
  70. log.error(f"File not found: {file_path}")
  71. return False
  72. try:
  73. info = mediainfo(file_path)
  74. codec_name = info.get("codec_name", "").lower()
  75. codec_type = info.get("codec_type", "").lower()
  76. codec_tag_string = info.get("codec_tag_string", "").lower()
  77. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  78. # File is AAC/mp4a audio, recommend mp3 conversion
  79. return True
  80. # If the codec name is in the supported formats
  81. if codec_name in SUPPORTED_FORMATS:
  82. return False
  83. return True
  84. except Exception as e:
  85. log.error(f"Error getting audio format: {e}")
  86. return False
  87. def convert_audio_to_mp3(file_path):
  88. """Convert audio file to mp3 format."""
  89. try:
  90. output_path = os.path.splitext(file_path)[0] + ".mp3"
  91. audio = AudioSegment.from_file(file_path)
  92. audio.export(output_path, format="mp3")
  93. log.info(f"Converted {file_path} to {output_path}")
  94. return output_path
  95. except Exception as e:
  96. log.error(f"Error converting audio file: {e}")
  97. return None
  98. def set_faster_whisper_model(model: str, auto_update: bool = False):
  99. whisper_model = None
  100. if model:
  101. from faster_whisper import WhisperModel
  102. faster_whisper_kwargs = {
  103. "model_size_or_path": model,
  104. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  105. "compute_type": "int8",
  106. "download_root": WHISPER_MODEL_DIR,
  107. "local_files_only": not auto_update,
  108. }
  109. try:
  110. whisper_model = WhisperModel(**faster_whisper_kwargs)
  111. except Exception:
  112. log.warning(
  113. "WhisperModel initialization failed, attempting download with local_files_only=False"
  114. )
  115. faster_whisper_kwargs["local_files_only"] = False
  116. whisper_model = WhisperModel(**faster_whisper_kwargs)
  117. return whisper_model
  118. ##########################################
  119. #
  120. # Audio API
  121. #
  122. ##########################################
  123. class TTSConfigForm(BaseModel):
  124. OPENAI_API_BASE_URL: str
  125. OPENAI_API_KEY: str
  126. API_KEY: str
  127. ENGINE: str
  128. MODEL: str
  129. VOICE: str
  130. SPLIT_ON: str
  131. AZURE_SPEECH_REGION: str
  132. AZURE_SPEECH_BASE_URL: str
  133. AZURE_SPEECH_OUTPUT_FORMAT: str
  134. class STTConfigForm(BaseModel):
  135. OPENAI_API_BASE_URL: str
  136. OPENAI_API_KEY: str
  137. ENGINE: str
  138. MODEL: str
  139. SUPPORTED_CONTENT_TYPES: list[str] = []
  140. WHISPER_MODEL: str
  141. DEEPGRAM_API_KEY: str
  142. AZURE_API_KEY: str
  143. AZURE_REGION: str
  144. AZURE_LOCALES: str
  145. AZURE_BASE_URL: str
  146. AZURE_MAX_SPEAKERS: str
  147. class AudioConfigUpdateForm(BaseModel):
  148. tts: TTSConfigForm
  149. stt: STTConfigForm
  150. @router.get("/config")
  151. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  152. return {
  153. "tts": {
  154. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  155. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  156. "API_KEY": request.app.state.config.TTS_API_KEY,
  157. "ENGINE": request.app.state.config.TTS_ENGINE,
  158. "MODEL": request.app.state.config.TTS_MODEL,
  159. "VOICE": request.app.state.config.TTS_VOICE,
  160. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  161. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  162. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  163. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  164. },
  165. "stt": {
  166. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  167. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  168. "ENGINE": request.app.state.config.STT_ENGINE,
  169. "MODEL": request.app.state.config.STT_MODEL,
  170. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  171. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  172. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  173. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  174. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  175. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  176. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  177. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  178. },
  179. }
  180. @router.post("/config/update")
  181. async def update_audio_config(
  182. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  183. ):
  184. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  185. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  186. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  187. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  188. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  189. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  190. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  191. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  192. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  193. form_data.tts.AZURE_SPEECH_BASE_URL
  194. )
  195. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  196. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  197. )
  198. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  199. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  200. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  201. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  202. request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = (
  203. form_data.stt.SUPPORTED_CONTENT_TYPES
  204. )
  205. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  206. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  207. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  208. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  209. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  210. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  211. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  212. form_data.stt.AZURE_MAX_SPEAKERS
  213. )
  214. if request.app.state.config.STT_ENGINE == "":
  215. request.app.state.faster_whisper_model = set_faster_whisper_model(
  216. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  217. )
  218. else:
  219. request.app.state.faster_whisper_model = None
  220. return {
  221. "tts": {
  222. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  223. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  224. "API_KEY": request.app.state.config.TTS_API_KEY,
  225. "ENGINE": request.app.state.config.TTS_ENGINE,
  226. "MODEL": request.app.state.config.TTS_MODEL,
  227. "VOICE": request.app.state.config.TTS_VOICE,
  228. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  229. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  230. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  231. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  232. },
  233. "stt": {
  234. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  235. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  236. "ENGINE": request.app.state.config.STT_ENGINE,
  237. "MODEL": request.app.state.config.STT_MODEL,
  238. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  239. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  240. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  241. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  242. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  243. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  244. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  245. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  246. },
  247. }
  248. def load_speech_pipeline(request):
  249. from transformers import pipeline
  250. from datasets import load_dataset
  251. if request.app.state.speech_synthesiser is None:
  252. request.app.state.speech_synthesiser = pipeline(
  253. "text-to-speech", "microsoft/speecht5_tts"
  254. )
  255. if request.app.state.speech_speaker_embeddings_dataset is None:
  256. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  257. "Matthijs/cmu-arctic-xvectors", split="validation"
  258. )
  259. @router.post("/speech")
  260. async def speech(request: Request, user=Depends(get_verified_user)):
  261. body = await request.body()
  262. name = hashlib.sha256(
  263. body
  264. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  265. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  266. ).hexdigest()
  267. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  268. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  269. # Check if the file already exists in the cache
  270. if file_path.is_file():
  271. return FileResponse(file_path)
  272. payload = None
  273. try:
  274. payload = json.loads(body.decode("utf-8"))
  275. except Exception as e:
  276. log.exception(e)
  277. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  278. if request.app.state.config.TTS_ENGINE == "openai":
  279. payload["model"] = request.app.state.config.TTS_MODEL
  280. try:
  281. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  282. async with aiohttp.ClientSession(
  283. timeout=timeout, trust_env=True
  284. ) as session:
  285. async with session.post(
  286. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  287. json=payload,
  288. headers={
  289. "Content-Type": "application/json",
  290. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  291. **(
  292. {
  293. "X-OpenWebUI-User-Name": user.name,
  294. "X-OpenWebUI-User-Id": user.id,
  295. "X-OpenWebUI-User-Email": user.email,
  296. "X-OpenWebUI-User-Role": user.role,
  297. }
  298. if ENABLE_FORWARD_USER_INFO_HEADERS
  299. else {}
  300. ),
  301. },
  302. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  303. ) as r:
  304. r.raise_for_status()
  305. async with aiofiles.open(file_path, "wb") as f:
  306. await f.write(await r.read())
  307. async with aiofiles.open(file_body_path, "w") as f:
  308. await f.write(json.dumps(payload))
  309. return FileResponse(file_path)
  310. except Exception as e:
  311. log.exception(e)
  312. detail = None
  313. try:
  314. if r.status != 200:
  315. res = await r.json()
  316. if "error" in res:
  317. detail = f"External: {res['error'].get('message', '')}"
  318. except Exception:
  319. detail = f"External: {e}"
  320. raise HTTPException(
  321. status_code=getattr(r, "status", 500) if r else 500,
  322. detail=detail if detail else "Open WebUI: Server Connection Error",
  323. )
  324. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  325. voice_id = payload.get("voice", "")
  326. if voice_id not in get_available_voices(request):
  327. raise HTTPException(
  328. status_code=400,
  329. detail="Invalid voice id",
  330. )
  331. try:
  332. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  333. async with aiohttp.ClientSession(
  334. timeout=timeout, trust_env=True
  335. ) as session:
  336. async with session.post(
  337. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  338. json={
  339. "text": payload["input"],
  340. "model_id": request.app.state.config.TTS_MODEL,
  341. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  342. },
  343. headers={
  344. "Accept": "audio/mpeg",
  345. "Content-Type": "application/json",
  346. "xi-api-key": request.app.state.config.TTS_API_KEY,
  347. },
  348. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  349. ) as r:
  350. r.raise_for_status()
  351. async with aiofiles.open(file_path, "wb") as f:
  352. await f.write(await r.read())
  353. async with aiofiles.open(file_body_path, "w") as f:
  354. await f.write(json.dumps(payload))
  355. return FileResponse(file_path)
  356. except Exception as e:
  357. log.exception(e)
  358. detail = None
  359. try:
  360. if r.status != 200:
  361. res = await r.json()
  362. if "error" in res:
  363. detail = f"External: {res['error'].get('message', '')}"
  364. except Exception:
  365. detail = f"External: {e}"
  366. raise HTTPException(
  367. status_code=getattr(r, "status", 500) if r else 500,
  368. detail=detail if detail else "Open WebUI: Server Connection Error",
  369. )
  370. elif request.app.state.config.TTS_ENGINE == "azure":
  371. try:
  372. payload = json.loads(body.decode("utf-8"))
  373. except Exception as e:
  374. log.exception(e)
  375. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  376. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  377. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  378. language = request.app.state.config.TTS_VOICE
  379. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  380. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  381. try:
  382. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  383. <voice name="{language}">{payload["input"]}</voice>
  384. </speak>"""
  385. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  386. async with aiohttp.ClientSession(
  387. timeout=timeout, trust_env=True
  388. ) as session:
  389. async with session.post(
  390. (base_url or f"https://{region}.tts.speech.microsoft.com")
  391. + "/cognitiveservices/v1",
  392. headers={
  393. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  394. "Content-Type": "application/ssml+xml",
  395. "X-Microsoft-OutputFormat": output_format,
  396. },
  397. data=data,
  398. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  399. ) as r:
  400. r.raise_for_status()
  401. async with aiofiles.open(file_path, "wb") as f:
  402. await f.write(await r.read())
  403. async with aiofiles.open(file_body_path, "w") as f:
  404. await f.write(json.dumps(payload))
  405. return FileResponse(file_path)
  406. except Exception as e:
  407. log.exception(e)
  408. detail = None
  409. try:
  410. if r.status != 200:
  411. res = await r.json()
  412. if "error" in res:
  413. detail = f"External: {res['error'].get('message', '')}"
  414. except Exception:
  415. detail = f"External: {e}"
  416. raise HTTPException(
  417. status_code=getattr(r, "status", 500) if r else 500,
  418. detail=detail if detail else "Open WebUI: Server Connection Error",
  419. )
  420. elif request.app.state.config.TTS_ENGINE == "transformers":
  421. payload = None
  422. try:
  423. payload = json.loads(body.decode("utf-8"))
  424. except Exception as e:
  425. log.exception(e)
  426. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  427. import torch
  428. import soundfile as sf
  429. load_speech_pipeline(request)
  430. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  431. speaker_index = 6799
  432. try:
  433. speaker_index = embeddings_dataset["filename"].index(
  434. request.app.state.config.TTS_MODEL
  435. )
  436. except Exception:
  437. pass
  438. speaker_embedding = torch.tensor(
  439. embeddings_dataset[speaker_index]["xvector"]
  440. ).unsqueeze(0)
  441. speech = request.app.state.speech_synthesiser(
  442. payload["input"],
  443. forward_params={"speaker_embeddings": speaker_embedding},
  444. )
  445. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  446. async with aiofiles.open(file_body_path, "w") as f:
  447. await f.write(json.dumps(payload))
  448. return FileResponse(file_path)
  449. def transcription_handler(request, file_path, metadata):
  450. filename = os.path.basename(file_path)
  451. file_dir = os.path.dirname(file_path)
  452. id = filename.split(".")[0]
  453. metadata = metadata or {}
  454. if request.app.state.config.STT_ENGINE == "":
  455. if request.app.state.faster_whisper_model is None:
  456. request.app.state.faster_whisper_model = set_faster_whisper_model(
  457. request.app.state.config.WHISPER_MODEL
  458. )
  459. model = request.app.state.faster_whisper_model
  460. segments, info = model.transcribe(
  461. file_path,
  462. beam_size=5,
  463. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  464. language=metadata.get("language") or WHISPER_LANGUAGE,
  465. )
  466. log.info(
  467. "Detected language '%s' with probability %f"
  468. % (info.language, info.language_probability)
  469. )
  470. transcript = "".join([segment.text for segment in list(segments)])
  471. data = {"text": transcript.strip()}
  472. # save the transcript to a json file
  473. transcript_file = f"{file_dir}/{id}.json"
  474. with open(transcript_file, "w") as f:
  475. json.dump(data, f)
  476. log.debug(data)
  477. return data
  478. elif request.app.state.config.STT_ENGINE == "openai":
  479. r = None
  480. try:
  481. r = requests.post(
  482. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  483. headers={
  484. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  485. },
  486. files={"file": (filename, open(file_path, "rb"))},
  487. data={
  488. "model": request.app.state.config.STT_MODEL,
  489. **(
  490. {"language": metadata.get("language")}
  491. if metadata.get("language")
  492. else {}
  493. ),
  494. },
  495. )
  496. r.raise_for_status()
  497. data = r.json()
  498. # save the transcript to a json file
  499. transcript_file = f"{file_dir}/{id}.json"
  500. with open(transcript_file, "w") as f:
  501. json.dump(data, f)
  502. return data
  503. except Exception as e:
  504. log.exception(e)
  505. detail = None
  506. if r is not None:
  507. try:
  508. res = r.json()
  509. if "error" in res:
  510. detail = f"External: {res['error'].get('message', '')}"
  511. except Exception:
  512. detail = f"External: {e}"
  513. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  514. elif request.app.state.config.STT_ENGINE == "deepgram":
  515. try:
  516. # Determine the MIME type of the file
  517. mime, _ = mimetypes.guess_type(file_path)
  518. if not mime:
  519. mime = "audio/wav" # fallback to wav if undetectable
  520. # Read the audio file
  521. with open(file_path, "rb") as f:
  522. file_data = f.read()
  523. # Build headers and parameters
  524. headers = {
  525. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  526. "Content-Type": mime,
  527. }
  528. # Add model if specified
  529. params = {}
  530. if request.app.state.config.STT_MODEL:
  531. params["model"] = request.app.state.config.STT_MODEL
  532. # Make request to Deepgram API
  533. r = requests.post(
  534. "https://api.deepgram.com/v1/listen?smart_format=true",
  535. headers=headers,
  536. params=params,
  537. data=file_data,
  538. )
  539. r.raise_for_status()
  540. response_data = r.json()
  541. # Extract transcript from Deepgram response
  542. try:
  543. transcript = response_data["results"]["channels"][0]["alternatives"][
  544. 0
  545. ].get("transcript", "")
  546. except (KeyError, IndexError) as e:
  547. log.error(f"Malformed response from Deepgram: {str(e)}")
  548. raise Exception(
  549. "Failed to parse Deepgram response - unexpected response format"
  550. )
  551. data = {"text": transcript.strip()}
  552. # Save transcript
  553. transcript_file = f"{file_dir}/{id}.json"
  554. with open(transcript_file, "w") as f:
  555. json.dump(data, f)
  556. return data
  557. except Exception as e:
  558. log.exception(e)
  559. detail = None
  560. if r is not None:
  561. try:
  562. res = r.json()
  563. if "error" in res:
  564. detail = f"External: {res['error'].get('message', '')}"
  565. except Exception:
  566. detail = f"External: {e}"
  567. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  568. elif request.app.state.config.STT_ENGINE == "azure":
  569. # Check file exists and size
  570. if not os.path.exists(file_path):
  571. raise HTTPException(status_code=400, detail="Audio file not found")
  572. # Check file size (Azure has a larger limit of 200MB)
  573. file_size = os.path.getsize(file_path)
  574. if file_size > AZURE_MAX_FILE_SIZE:
  575. raise HTTPException(
  576. status_code=400,
  577. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  578. )
  579. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  580. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  581. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  582. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  583. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  584. # IF NO LOCALES, USE DEFAULTS
  585. if len(locales) < 2:
  586. locales = [
  587. "en-US",
  588. "es-ES",
  589. "es-MX",
  590. "fr-FR",
  591. "hi-IN",
  592. "it-IT",
  593. "de-DE",
  594. "en-GB",
  595. "en-IN",
  596. "ja-JP",
  597. "ko-KR",
  598. "pt-BR",
  599. "zh-CN",
  600. ]
  601. locales = ",".join(locales)
  602. if not api_key or not region:
  603. raise HTTPException(
  604. status_code=400,
  605. detail="Azure API key is required for Azure STT",
  606. )
  607. r = None
  608. try:
  609. # Prepare the request
  610. data = {
  611. "definition": json.dumps(
  612. {
  613. "locales": locales.split(","),
  614. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  615. }
  616. if locales
  617. else {}
  618. )
  619. }
  620. url = (
  621. base_url or f"https://{region}.api.cognitive.microsoft.com"
  622. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  623. # Use context manager to ensure file is properly closed
  624. with open(file_path, "rb") as audio_file:
  625. r = requests.post(
  626. url=url,
  627. files={"audio": audio_file},
  628. data=data,
  629. headers={
  630. "Ocp-Apim-Subscription-Key": api_key,
  631. },
  632. )
  633. r.raise_for_status()
  634. response = r.json()
  635. # Extract transcript from response
  636. if not response.get("combinedPhrases"):
  637. raise ValueError("No transcription found in response")
  638. # Get the full transcript from combinedPhrases
  639. transcript = response["combinedPhrases"][0].get("text", "").strip()
  640. if not transcript:
  641. raise ValueError("Empty transcript in response")
  642. data = {"text": transcript}
  643. # Save transcript to json file (consistent with other providers)
  644. transcript_file = f"{file_dir}/{id}.json"
  645. with open(transcript_file, "w") as f:
  646. json.dump(data, f)
  647. log.debug(data)
  648. return data
  649. except (KeyError, IndexError, ValueError) as e:
  650. log.exception("Error parsing Azure response")
  651. raise HTTPException(
  652. status_code=500,
  653. detail=f"Failed to parse Azure response: {str(e)}",
  654. )
  655. except requests.exceptions.RequestException as e:
  656. log.exception(e)
  657. detail = None
  658. try:
  659. if r is not None and r.status_code != 200:
  660. res = r.json()
  661. if "error" in res:
  662. detail = f"External: {res['error'].get('message', '')}"
  663. except Exception:
  664. detail = f"External: {e}"
  665. raise HTTPException(
  666. status_code=getattr(r, "status_code", 500) if r else 500,
  667. detail=detail if detail else "Open WebUI: Server Connection Error",
  668. )
  669. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  670. log.info(f"transcribe: {file_path} {metadata}")
  671. if is_audio_conversion_required(file_path):
  672. file_path = convert_audio_to_mp3(file_path)
  673. try:
  674. file_path = compress_audio(file_path)
  675. except Exception as e:
  676. log.exception(e)
  677. # Always produce a list of chunk paths (could be one entry if small)
  678. try:
  679. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  680. print(f"Chunk paths: {chunk_paths}")
  681. except Exception as e:
  682. log.exception(e)
  683. raise HTTPException(
  684. status_code=status.HTTP_400_BAD_REQUEST,
  685. detail=ERROR_MESSAGES.DEFAULT(e),
  686. )
  687. results = []
  688. try:
  689. with ThreadPoolExecutor() as executor:
  690. # Submit tasks for each chunk_path
  691. futures = [
  692. executor.submit(transcription_handler, request, chunk_path, metadata)
  693. for chunk_path in chunk_paths
  694. ]
  695. # Gather results as they complete
  696. for future in futures:
  697. try:
  698. results.append(future.result())
  699. except Exception as transcribe_exc:
  700. raise HTTPException(
  701. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  702. detail=f"Error transcribing chunk: {transcribe_exc}",
  703. )
  704. finally:
  705. # Clean up only the temporary chunks, never the original file
  706. for chunk_path in chunk_paths:
  707. if chunk_path != file_path and os.path.isfile(chunk_path):
  708. try:
  709. os.remove(chunk_path)
  710. except Exception:
  711. pass
  712. return {
  713. "text": " ".join([result["text"] for result in results]),
  714. }
  715. def compress_audio(file_path):
  716. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  717. id = os.path.splitext(os.path.basename(file_path))[
  718. 0
  719. ] # Handles names with multiple dots
  720. file_dir = os.path.dirname(file_path)
  721. audio = AudioSegment.from_file(file_path)
  722. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  723. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  724. audio.export(compressed_path, format="mp3", bitrate="32k")
  725. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  726. return compressed_path
  727. else:
  728. return file_path
  729. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  730. """
  731. Splits audio into chunks not exceeding max_bytes.
  732. Returns a list of chunk file paths. If audio fits, returns list with original path.
  733. """
  734. file_size = os.path.getsize(file_path)
  735. if file_size <= max_bytes:
  736. return [file_path] # Nothing to split
  737. audio = AudioSegment.from_file(file_path)
  738. duration_ms = len(audio)
  739. orig_size = file_size
  740. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  741. chunks = []
  742. start = 0
  743. i = 0
  744. base, _ = os.path.splitext(file_path)
  745. while start < duration_ms:
  746. end = min(start + approx_chunk_ms, duration_ms)
  747. chunk = audio[start:end]
  748. chunk_path = f"{base}_chunk_{i}.{format}"
  749. chunk.export(chunk_path, format=format, bitrate=bitrate)
  750. # Reduce chunk duration if still too large
  751. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  752. end = start + ((end - start) // 2)
  753. chunk = audio[start:end]
  754. chunk.export(chunk_path, format=format, bitrate=bitrate)
  755. if os.path.getsize(chunk_path) > max_bytes:
  756. os.remove(chunk_path)
  757. raise Exception("Audio chunk cannot be reduced below max file size.")
  758. chunks.append(chunk_path)
  759. start = end
  760. i += 1
  761. return chunks
  762. @router.post("/transcriptions")
  763. def transcription(
  764. request: Request,
  765. file: UploadFile = File(...),
  766. language: Optional[str] = Form(None),
  767. user=Depends(get_verified_user),
  768. ):
  769. log.info(f"file.content_type: {file.content_type}")
  770. supported_content_types = request.app.state.config.STT_SUPPORTED_CONTENT_TYPES or [
  771. "audio/*",
  772. "video/webm",
  773. ]
  774. if not any(
  775. fnmatch(file.content_type, content_type)
  776. for content_type in supported_content_types
  777. ):
  778. raise HTTPException(
  779. status_code=status.HTTP_400_BAD_REQUEST,
  780. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  781. )
  782. try:
  783. ext = file.filename.split(".")[-1]
  784. id = uuid.uuid4()
  785. filename = f"{id}.{ext}"
  786. contents = file.file.read()
  787. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  788. os.makedirs(file_dir, exist_ok=True)
  789. file_path = f"{file_dir}/{filename}"
  790. with open(file_path, "wb") as f:
  791. f.write(contents)
  792. try:
  793. metadata = None
  794. if language:
  795. metadata = {"language": language}
  796. result = transcribe(request, file_path, metadata)
  797. return {
  798. **result,
  799. "filename": os.path.basename(file_path),
  800. }
  801. except Exception as e:
  802. log.exception(e)
  803. raise HTTPException(
  804. status_code=status.HTTP_400_BAD_REQUEST,
  805. detail=ERROR_MESSAGES.DEFAULT(e),
  806. )
  807. except Exception as e:
  808. log.exception(e)
  809. raise HTTPException(
  810. status_code=status.HTTP_400_BAD_REQUEST,
  811. detail=ERROR_MESSAGES.DEFAULT(e),
  812. )
  813. def get_available_models(request: Request) -> list[dict]:
  814. available_models = []
  815. if request.app.state.config.TTS_ENGINE == "openai":
  816. # Use custom endpoint if not using the official OpenAI API URL
  817. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  818. "https://api.openai.com"
  819. ):
  820. try:
  821. response = requests.get(
  822. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  823. )
  824. response.raise_for_status()
  825. data = response.json()
  826. available_models = data.get("models", [])
  827. except Exception as e:
  828. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  829. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  830. else:
  831. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  832. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  833. try:
  834. response = requests.get(
  835. "https://api.elevenlabs.io/v1/models",
  836. headers={
  837. "xi-api-key": request.app.state.config.TTS_API_KEY,
  838. "Content-Type": "application/json",
  839. },
  840. timeout=5,
  841. )
  842. response.raise_for_status()
  843. models = response.json()
  844. available_models = [
  845. {"name": model["name"], "id": model["model_id"]} for model in models
  846. ]
  847. except requests.RequestException as e:
  848. log.error(f"Error fetching voices: {str(e)}")
  849. return available_models
  850. @router.get("/models")
  851. async def get_models(request: Request, user=Depends(get_verified_user)):
  852. return {"models": get_available_models(request)}
  853. def get_available_voices(request) -> dict:
  854. """Returns {voice_id: voice_name} dict"""
  855. available_voices = {}
  856. if request.app.state.config.TTS_ENGINE == "openai":
  857. # Use custom endpoint if not using the official OpenAI API URL
  858. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  859. "https://api.openai.com"
  860. ):
  861. try:
  862. response = requests.get(
  863. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  864. )
  865. response.raise_for_status()
  866. data = response.json()
  867. voices_list = data.get("voices", [])
  868. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  869. except Exception as e:
  870. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  871. available_voices = {
  872. "alloy": "alloy",
  873. "echo": "echo",
  874. "fable": "fable",
  875. "onyx": "onyx",
  876. "nova": "nova",
  877. "shimmer": "shimmer",
  878. }
  879. else:
  880. available_voices = {
  881. "alloy": "alloy",
  882. "echo": "echo",
  883. "fable": "fable",
  884. "onyx": "onyx",
  885. "nova": "nova",
  886. "shimmer": "shimmer",
  887. }
  888. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  889. try:
  890. available_voices = get_elevenlabs_voices(
  891. api_key=request.app.state.config.TTS_API_KEY
  892. )
  893. except Exception:
  894. # Avoided @lru_cache with exception
  895. pass
  896. elif request.app.state.config.TTS_ENGINE == "azure":
  897. try:
  898. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  899. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  900. url = (
  901. base_url or f"https://{region}.tts.speech.microsoft.com"
  902. ) + "/cognitiveservices/voices/list"
  903. headers = {
  904. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  905. }
  906. response = requests.get(url, headers=headers)
  907. response.raise_for_status()
  908. voices = response.json()
  909. for voice in voices:
  910. available_voices[voice["ShortName"]] = (
  911. f"{voice['DisplayName']} ({voice['ShortName']})"
  912. )
  913. except requests.RequestException as e:
  914. log.error(f"Error fetching voices: {str(e)}")
  915. return available_voices
  916. @lru_cache
  917. def get_elevenlabs_voices(api_key: str) -> dict:
  918. """
  919. Note, set the following in your .env file to use Elevenlabs:
  920. AUDIO_TTS_ENGINE=elevenlabs
  921. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  922. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  923. AUDIO_TTS_MODEL=eleven_multilingual_v2
  924. """
  925. try:
  926. # TODO: Add retries
  927. response = requests.get(
  928. "https://api.elevenlabs.io/v1/voices",
  929. headers={
  930. "xi-api-key": api_key,
  931. "Content-Type": "application/json",
  932. },
  933. )
  934. response.raise_for_status()
  935. voices_data = response.json()
  936. voices = {}
  937. for voice in voices_data.get("voices", []):
  938. voices[voice["voice_id"]] = voice["name"]
  939. except requests.RequestException as e:
  940. # Avoid @lru_cache with exception
  941. log.error(f"Error fetching voices: {str(e)}")
  942. raise RuntimeError(f"Error fetching voices: {str(e)}")
  943. return voices
  944. @router.get("/voices")
  945. async def get_voices(request: Request, user=Depends(get_verified_user)):
  946. return {
  947. "voices": [
  948. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  949. ]
  950. }