audio.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pydub import AudioSegment
  8. from pydub.silence import split_on_silence
  9. from concurrent.futures import ThreadPoolExecutor
  10. from typing import Optional
  11. from fnmatch import fnmatch
  12. import aiohttp
  13. import aiofiles
  14. import requests
  15. import mimetypes
  16. from urllib.parse import urljoin, quote
  17. from fastapi import (
  18. Depends,
  19. FastAPI,
  20. File,
  21. Form,
  22. HTTPException,
  23. Request,
  24. UploadFile,
  25. status,
  26. APIRouter,
  27. )
  28. from fastapi.middleware.cors import CORSMiddleware
  29. from fastapi.responses import FileResponse
  30. from pydantic import BaseModel
  31. from open_webui.utils.auth import get_admin_user, get_verified_user
  32. from open_webui.config import (
  33. WHISPER_MODEL_AUTO_UPDATE,
  34. WHISPER_MODEL_DIR,
  35. CACHE_DIR,
  36. WHISPER_LANGUAGE,
  37. )
  38. from open_webui.constants import ERROR_MESSAGES
  39. from open_webui.env import (
  40. AIOHTTP_CLIENT_SESSION_SSL,
  41. AIOHTTP_CLIENT_TIMEOUT,
  42. ENV,
  43. SRC_LOG_LEVELS,
  44. DEVICE_TYPE,
  45. ENABLE_FORWARD_USER_INFO_HEADERS,
  46. )
  47. router = APIRouter()
  48. # Constants
  49. MAX_FILE_SIZE_MB = 20
  50. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  51. AZURE_MAX_FILE_SIZE_MB = 200
  52. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  53. log = logging.getLogger(__name__)
  54. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  55. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  56. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  57. ##########################################
  58. #
  59. # Utility functions
  60. #
  61. ##########################################
  62. from pydub import AudioSegment
  63. from pydub.utils import mediainfo
  64. def is_audio_conversion_required(file_path):
  65. """
  66. Check if the given audio file needs conversion to mp3.
  67. """
  68. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  69. if not os.path.isfile(file_path):
  70. log.error(f"File not found: {file_path}")
  71. return False
  72. try:
  73. info = mediainfo(file_path)
  74. codec_name = info.get("codec_name", "").lower()
  75. codec_type = info.get("codec_type", "").lower()
  76. codec_tag_string = info.get("codec_tag_string", "").lower()
  77. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  78. # File is AAC/mp4a audio, recommend mp3 conversion
  79. return True
  80. # If the codec name is in the supported formats
  81. if codec_name in SUPPORTED_FORMATS:
  82. return False
  83. return True
  84. except Exception as e:
  85. log.error(f"Error getting audio format: {e}")
  86. return False
  87. def convert_audio_to_mp3(file_path):
  88. """Convert audio file to mp3 format."""
  89. try:
  90. output_path = os.path.splitext(file_path)[0] + ".mp3"
  91. audio = AudioSegment.from_file(file_path)
  92. audio.export(output_path, format="mp3")
  93. log.info(f"Converted {file_path} to {output_path}")
  94. return output_path
  95. except Exception as e:
  96. log.error(f"Error converting audio file: {e}")
  97. return None
  98. def set_faster_whisper_model(model: str, auto_update: bool = False):
  99. whisper_model = None
  100. if model:
  101. from faster_whisper import WhisperModel
  102. faster_whisper_kwargs = {
  103. "model_size_or_path": model,
  104. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  105. "compute_type": "int8",
  106. "download_root": WHISPER_MODEL_DIR,
  107. "local_files_only": not auto_update,
  108. }
  109. try:
  110. whisper_model = WhisperModel(**faster_whisper_kwargs)
  111. except Exception:
  112. log.warning(
  113. "WhisperModel initialization failed, attempting download with local_files_only=False"
  114. )
  115. faster_whisper_kwargs["local_files_only"] = False
  116. whisper_model = WhisperModel(**faster_whisper_kwargs)
  117. return whisper_model
  118. ##########################################
  119. #
  120. # Audio API
  121. #
  122. ##########################################
  123. class TTSConfigForm(BaseModel):
  124. OPENAI_API_BASE_URL: str
  125. OPENAI_API_KEY: str
  126. API_KEY: str
  127. ENGINE: str
  128. MODEL: str
  129. VOICE: str
  130. SPLIT_ON: str
  131. AZURE_SPEECH_REGION: str
  132. AZURE_SPEECH_BASE_URL: str
  133. AZURE_SPEECH_OUTPUT_FORMAT: str
  134. class STTConfigForm(BaseModel):
  135. OPENAI_API_BASE_URL: str
  136. OPENAI_API_KEY: str
  137. ENGINE: str
  138. MODEL: str
  139. SUPPORTED_CONTENT_TYPES: list[str] = []
  140. WHISPER_MODEL: str
  141. DEEPGRAM_API_KEY: str
  142. AZURE_API_KEY: str
  143. AZURE_REGION: str
  144. AZURE_LOCALES: str
  145. AZURE_BASE_URL: str
  146. AZURE_MAX_SPEAKERS: str
  147. class AudioConfigUpdateForm(BaseModel):
  148. tts: TTSConfigForm
  149. stt: STTConfigForm
  150. @router.get("/config")
  151. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  152. return {
  153. "tts": {
  154. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  155. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  156. "API_KEY": request.app.state.config.TTS_API_KEY,
  157. "ENGINE": request.app.state.config.TTS_ENGINE,
  158. "MODEL": request.app.state.config.TTS_MODEL,
  159. "VOICE": request.app.state.config.TTS_VOICE,
  160. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  161. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  162. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  163. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  164. },
  165. "stt": {
  166. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  167. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  168. "ENGINE": request.app.state.config.STT_ENGINE,
  169. "MODEL": request.app.state.config.STT_MODEL,
  170. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  171. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  172. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  173. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  174. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  175. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  176. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  177. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  178. },
  179. }
  180. @router.post("/config/update")
  181. async def update_audio_config(
  182. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  183. ):
  184. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  185. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  186. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  187. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  188. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  189. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  190. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  191. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  192. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  193. form_data.tts.AZURE_SPEECH_BASE_URL
  194. )
  195. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  196. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  197. )
  198. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  199. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  200. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  201. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  202. request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = (
  203. form_data.stt.SUPPORTED_CONTENT_TYPES
  204. )
  205. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  206. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  207. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  208. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  209. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  210. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  211. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  212. form_data.stt.AZURE_MAX_SPEAKERS
  213. )
  214. if request.app.state.config.STT_ENGINE == "":
  215. request.app.state.faster_whisper_model = set_faster_whisper_model(
  216. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  217. )
  218. else:
  219. request.app.state.faster_whisper_model = None
  220. return {
  221. "tts": {
  222. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  223. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  224. "API_KEY": request.app.state.config.TTS_API_KEY,
  225. "ENGINE": request.app.state.config.TTS_ENGINE,
  226. "MODEL": request.app.state.config.TTS_MODEL,
  227. "VOICE": request.app.state.config.TTS_VOICE,
  228. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  229. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  230. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  231. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  232. },
  233. "stt": {
  234. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  235. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  236. "ENGINE": request.app.state.config.STT_ENGINE,
  237. "MODEL": request.app.state.config.STT_MODEL,
  238. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  239. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  240. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  241. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  242. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  243. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  244. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  245. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  246. },
  247. }
  248. def load_speech_pipeline(request):
  249. from transformers import pipeline
  250. from datasets import load_dataset
  251. if request.app.state.speech_synthesiser is None:
  252. request.app.state.speech_synthesiser = pipeline(
  253. "text-to-speech", "microsoft/speecht5_tts"
  254. )
  255. if request.app.state.speech_speaker_embeddings_dataset is None:
  256. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  257. "Matthijs/cmu-arctic-xvectors", split="validation"
  258. )
  259. @router.post("/speech")
  260. async def speech(request: Request, user=Depends(get_verified_user)):
  261. body = await request.body()
  262. name = hashlib.sha256(
  263. body
  264. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  265. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  266. ).hexdigest()
  267. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  268. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  269. # Check if the file already exists in the cache
  270. if file_path.is_file():
  271. return FileResponse(file_path)
  272. payload = None
  273. try:
  274. payload = json.loads(body.decode("utf-8"))
  275. except Exception as e:
  276. log.exception(e)
  277. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  278. r = None
  279. if request.app.state.config.TTS_ENGINE == "openai":
  280. payload["model"] = request.app.state.config.TTS_MODEL
  281. try:
  282. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  283. async with aiohttp.ClientSession(
  284. timeout=timeout, trust_env=True
  285. ) as session:
  286. r = await session.post(
  287. url=urljoin(request.app.state.config.TTS_OPENAI_API_BASE_URL, "/audio/speech"),
  288. json=payload,
  289. headers={
  290. "Content-Type": "application/json",
  291. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  292. **(
  293. {
  294. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  295. "X-OpenWebUI-User-Id": user.id,
  296. "X-OpenWebUI-User-Email": user.email,
  297. "X-OpenWebUI-User-Role": user.role,
  298. }
  299. if ENABLE_FORWARD_USER_INFO_HEADERS
  300. else {}
  301. ),
  302. },
  303. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  304. )
  305. r.raise_for_status()
  306. async with aiofiles.open(file_path, "wb") as f:
  307. await f.write(await r.read())
  308. async with aiofiles.open(file_body_path, "w") as f:
  309. await f.write(json.dumps(payload))
  310. return FileResponse(file_path)
  311. except Exception as e:
  312. log.exception(e)
  313. detail = None
  314. status_code = 500
  315. detail = f"Open WebUI: Server Connection Error"
  316. if r is not None:
  317. status_code = r.status
  318. try:
  319. res = await r.json()
  320. if "error" in res:
  321. detail = f"External: {res['error']}"
  322. except Exception:
  323. detail = f"External: {e}"
  324. raise HTTPException(
  325. status_code=status_code,
  326. detail=detail,
  327. )
  328. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  329. voice_id = payload.get("voice", "")
  330. if voice_id not in get_available_voices(request):
  331. raise HTTPException(
  332. status_code=400,
  333. detail="Invalid voice id",
  334. )
  335. try:
  336. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  337. async with aiohttp.ClientSession(
  338. timeout=timeout, trust_env=True
  339. ) as session:
  340. async with session.post(
  341. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  342. json={
  343. "text": payload["input"],
  344. "model_id": request.app.state.config.TTS_MODEL,
  345. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  346. },
  347. headers={
  348. "Accept": "audio/mpeg",
  349. "Content-Type": "application/json",
  350. "xi-api-key": request.app.state.config.TTS_API_KEY,
  351. },
  352. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  353. ) as r:
  354. r.raise_for_status()
  355. async with aiofiles.open(file_path, "wb") as f:
  356. await f.write(await r.read())
  357. async with aiofiles.open(file_body_path, "w") as f:
  358. await f.write(json.dumps(payload))
  359. return FileResponse(file_path)
  360. except Exception as e:
  361. log.exception(e)
  362. detail = None
  363. try:
  364. if r.status != 200:
  365. res = await r.json()
  366. if "error" in res:
  367. detail = f"External: {res['error'].get('message', '')}"
  368. except Exception:
  369. detail = f"External: {e}"
  370. raise HTTPException(
  371. status_code=getattr(r, "status", 500) if r else 500,
  372. detail=detail if detail else "Open WebUI: Server Connection Error",
  373. )
  374. elif request.app.state.config.TTS_ENGINE == "azure":
  375. try:
  376. payload = json.loads(body.decode("utf-8"))
  377. except Exception as e:
  378. log.exception(e)
  379. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  380. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  381. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  382. language = request.app.state.config.TTS_VOICE
  383. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  384. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  385. try:
  386. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  387. <voice name="{language}">{payload["input"]}</voice>
  388. </speak>"""
  389. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  390. async with aiohttp.ClientSession(
  391. timeout=timeout, trust_env=True
  392. ) as session:
  393. async with session.post(
  394. urljoin(base_url or f"https://{region}.tts.speech.microsoft.com", "/cognitiveservices/v1"),
  395. headers={
  396. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  397. "Content-Type": "application/ssml+xml",
  398. "X-Microsoft-OutputFormat": output_format,
  399. },
  400. data=data,
  401. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  402. ) as r:
  403. r.raise_for_status()
  404. async with aiofiles.open(file_path, "wb") as f:
  405. await f.write(await r.read())
  406. async with aiofiles.open(file_body_path, "w") as f:
  407. await f.write(json.dumps(payload))
  408. return FileResponse(file_path)
  409. except Exception as e:
  410. log.exception(e)
  411. detail = None
  412. try:
  413. if r.status != 200:
  414. res = await r.json()
  415. if "error" in res:
  416. detail = f"External: {res['error'].get('message', '')}"
  417. except Exception:
  418. detail = f"External: {e}"
  419. raise HTTPException(
  420. status_code=getattr(r, "status", 500) if r else 500,
  421. detail=detail if detail else "Open WebUI: Server Connection Error",
  422. )
  423. elif request.app.state.config.TTS_ENGINE == "transformers":
  424. payload = None
  425. try:
  426. payload = json.loads(body.decode("utf-8"))
  427. except Exception as e:
  428. log.exception(e)
  429. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  430. import torch
  431. import soundfile as sf
  432. load_speech_pipeline(request)
  433. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  434. speaker_index = 6799
  435. try:
  436. speaker_index = embeddings_dataset["filename"].index(
  437. request.app.state.config.TTS_MODEL
  438. )
  439. except Exception:
  440. pass
  441. speaker_embedding = torch.tensor(
  442. embeddings_dataset[speaker_index]["xvector"]
  443. ).unsqueeze(0)
  444. speech = request.app.state.speech_synthesiser(
  445. payload["input"],
  446. forward_params={"speaker_embeddings": speaker_embedding},
  447. )
  448. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  449. async with aiofiles.open(file_body_path, "w") as f:
  450. await f.write(json.dumps(payload))
  451. return FileResponse(file_path)
  452. def transcription_handler(request, file_path, metadata):
  453. filename = os.path.basename(file_path)
  454. file_dir = os.path.dirname(file_path)
  455. id = filename.split(".")[0]
  456. metadata = metadata or {}
  457. languages = [
  458. metadata.get("language", None) if WHISPER_LANGUAGE == "" else WHISPER_LANGUAGE,
  459. None, # Always fallback to None in case transcription fails
  460. ]
  461. if request.app.state.config.STT_ENGINE == "":
  462. if request.app.state.faster_whisper_model is None:
  463. request.app.state.faster_whisper_model = set_faster_whisper_model(
  464. request.app.state.config.WHISPER_MODEL
  465. )
  466. model = request.app.state.faster_whisper_model
  467. segments, info = model.transcribe(
  468. file_path,
  469. beam_size=5,
  470. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  471. language=languages[0],
  472. )
  473. log.info(
  474. "Detected language '%s' with probability %f"
  475. % (info.language, info.language_probability)
  476. )
  477. transcript = "".join([segment.text for segment in list(segments)])
  478. data = {"text": transcript.strip()}
  479. # save the transcript to a json file
  480. transcript_file = f"{file_dir}/{id}.json"
  481. with open(transcript_file, "w") as f:
  482. json.dump(data, f)
  483. log.debug(data)
  484. return data
  485. elif request.app.state.config.STT_ENGINE == "openai":
  486. r = None
  487. try:
  488. for language in languages:
  489. payload = {
  490. "model": request.app.state.config.STT_MODEL,
  491. }
  492. if language:
  493. payload["language"] = language
  494. r = requests.post(
  495. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  496. headers={
  497. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  498. },
  499. files={"file": (filename, open(file_path, "rb"))},
  500. data=payload,
  501. )
  502. if r.status_code == 200:
  503. # Successful transcription
  504. break
  505. r.raise_for_status()
  506. data = r.json()
  507. # save the transcript to a json file
  508. transcript_file = f"{file_dir}/{id}.json"
  509. with open(transcript_file, "w") as f:
  510. json.dump(data, f)
  511. return data
  512. except Exception as e:
  513. log.exception(e)
  514. detail = None
  515. if r is not None:
  516. try:
  517. res = r.json()
  518. if "error" in res:
  519. detail = f"External: {res['error'].get('message', '')}"
  520. except Exception:
  521. detail = f"External: {e}"
  522. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  523. elif request.app.state.config.STT_ENGINE == "deepgram":
  524. try:
  525. # Determine the MIME type of the file
  526. mime, _ = mimetypes.guess_type(file_path)
  527. if not mime:
  528. mime = "audio/wav" # fallback to wav if undetectable
  529. # Read the audio file
  530. with open(file_path, "rb") as f:
  531. file_data = f.read()
  532. # Build headers and parameters
  533. headers = {
  534. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  535. "Content-Type": mime,
  536. }
  537. for language in languages:
  538. params = {}
  539. if request.app.state.config.STT_MODEL:
  540. params["model"] = request.app.state.config.STT_MODEL
  541. if language:
  542. params["language"] = language
  543. # Make request to Deepgram API
  544. r = requests.post(
  545. "https://api.deepgram.com/v1/listen?smart_format=true",
  546. headers=headers,
  547. params=params,
  548. data=file_data,
  549. )
  550. if r.status_code == 200:
  551. # Successful transcription
  552. break
  553. r.raise_for_status()
  554. response_data = r.json()
  555. # Extract transcript from Deepgram response
  556. try:
  557. transcript = response_data["results"]["channels"][0]["alternatives"][
  558. 0
  559. ].get("transcript", "")
  560. except (KeyError, IndexError) as e:
  561. log.error(f"Malformed response from Deepgram: {str(e)}")
  562. raise Exception(
  563. "Failed to parse Deepgram response - unexpected response format"
  564. )
  565. data = {"text": transcript.strip()}
  566. # Save transcript
  567. transcript_file = f"{file_dir}/{id}.json"
  568. with open(transcript_file, "w") as f:
  569. json.dump(data, f)
  570. return data
  571. except Exception as e:
  572. log.exception(e)
  573. detail = None
  574. if r is not None:
  575. try:
  576. res = r.json()
  577. if "error" in res:
  578. detail = f"External: {res['error'].get('message', '')}"
  579. except Exception:
  580. detail = f"External: {e}"
  581. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  582. elif request.app.state.config.STT_ENGINE == "azure":
  583. # Check file exists and size
  584. if not os.path.exists(file_path):
  585. raise HTTPException(status_code=400, detail="Audio file not found")
  586. # Check file size (Azure has a larger limit of 200MB)
  587. file_size = os.path.getsize(file_path)
  588. if file_size > AZURE_MAX_FILE_SIZE:
  589. raise HTTPException(
  590. status_code=400,
  591. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  592. )
  593. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  594. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  595. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  596. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  597. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  598. # IF NO LOCALES, USE DEFAULTS
  599. if len(locales) < 2:
  600. locales = [
  601. "en-US",
  602. "es-ES",
  603. "es-MX",
  604. "fr-FR",
  605. "hi-IN",
  606. "it-IT",
  607. "de-DE",
  608. "en-GB",
  609. "en-IN",
  610. "ja-JP",
  611. "ko-KR",
  612. "pt-BR",
  613. "zh-CN",
  614. ]
  615. locales = ",".join(locales)
  616. if not api_key or not region:
  617. raise HTTPException(
  618. status_code=400,
  619. detail="Azure API key is required for Azure STT",
  620. )
  621. r = None
  622. try:
  623. # Prepare the request
  624. data = {
  625. "definition": json.dumps(
  626. {
  627. "locales": locales.split(","),
  628. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  629. }
  630. if locales
  631. else {}
  632. )
  633. }
  634. url = (
  635. base_url or f"https://{region}.api.cognitive.microsoft.com"
  636. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  637. # Use context manager to ensure file is properly closed
  638. with open(file_path, "rb") as audio_file:
  639. r = requests.post(
  640. url=url,
  641. files={"audio": audio_file},
  642. data=data,
  643. headers={
  644. "Ocp-Apim-Subscription-Key": api_key,
  645. },
  646. )
  647. r.raise_for_status()
  648. response = r.json()
  649. # Extract transcript from response
  650. if not response.get("combinedPhrases"):
  651. raise ValueError("No transcription found in response")
  652. # Get the full transcript from combinedPhrases
  653. transcript = response["combinedPhrases"][0].get("text", "").strip()
  654. if not transcript:
  655. raise ValueError("Empty transcript in response")
  656. data = {"text": transcript}
  657. # Save transcript to json file (consistent with other providers)
  658. transcript_file = f"{file_dir}/{id}.json"
  659. with open(transcript_file, "w") as f:
  660. json.dump(data, f)
  661. log.debug(data)
  662. return data
  663. except (KeyError, IndexError, ValueError) as e:
  664. log.exception("Error parsing Azure response")
  665. raise HTTPException(
  666. status_code=500,
  667. detail=f"Failed to parse Azure response: {str(e)}",
  668. )
  669. except requests.exceptions.RequestException as e:
  670. log.exception(e)
  671. detail = None
  672. try:
  673. if r is not None and r.status_code != 200:
  674. res = r.json()
  675. if "error" in res:
  676. detail = f"External: {res['error'].get('message', '')}"
  677. except Exception:
  678. detail = f"External: {e}"
  679. raise HTTPException(
  680. status_code=getattr(r, "status_code", 500) if r else 500,
  681. detail=detail if detail else "Open WebUI: Server Connection Error",
  682. )
  683. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  684. log.info(f"transcribe: {file_path} {metadata}")
  685. if is_audio_conversion_required(file_path):
  686. file_path = convert_audio_to_mp3(file_path)
  687. try:
  688. file_path = compress_audio(file_path)
  689. except Exception as e:
  690. log.exception(e)
  691. # Always produce a list of chunk paths (could be one entry if small)
  692. try:
  693. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  694. print(f"Chunk paths: {chunk_paths}")
  695. except Exception as e:
  696. log.exception(e)
  697. raise HTTPException(
  698. status_code=status.HTTP_400_BAD_REQUEST,
  699. detail=ERROR_MESSAGES.DEFAULT(e),
  700. )
  701. results = []
  702. try:
  703. with ThreadPoolExecutor() as executor:
  704. # Submit tasks for each chunk_path
  705. futures = [
  706. executor.submit(transcription_handler, request, chunk_path, metadata)
  707. for chunk_path in chunk_paths
  708. ]
  709. # Gather results as they complete
  710. for future in futures:
  711. try:
  712. results.append(future.result())
  713. except Exception as transcribe_exc:
  714. raise HTTPException(
  715. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  716. detail=f"Error transcribing chunk: {transcribe_exc}",
  717. )
  718. finally:
  719. # Clean up only the temporary chunks, never the original file
  720. for chunk_path in chunk_paths:
  721. if chunk_path != file_path and os.path.isfile(chunk_path):
  722. try:
  723. os.remove(chunk_path)
  724. except Exception:
  725. pass
  726. return {
  727. "text": " ".join([result["text"] for result in results]),
  728. }
  729. def compress_audio(file_path):
  730. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  731. id = os.path.splitext(os.path.basename(file_path))[
  732. 0
  733. ] # Handles names with multiple dots
  734. file_dir = os.path.dirname(file_path)
  735. audio = AudioSegment.from_file(file_path)
  736. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  737. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  738. audio.export(compressed_path, format="mp3", bitrate="32k")
  739. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  740. return compressed_path
  741. else:
  742. return file_path
  743. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  744. """
  745. Splits audio into chunks not exceeding max_bytes.
  746. Returns a list of chunk file paths. If audio fits, returns list with original path.
  747. """
  748. file_size = os.path.getsize(file_path)
  749. if file_size <= max_bytes:
  750. return [file_path] # Nothing to split
  751. audio = AudioSegment.from_file(file_path)
  752. duration_ms = len(audio)
  753. orig_size = file_size
  754. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  755. chunks = []
  756. start = 0
  757. i = 0
  758. base, _ = os.path.splitext(file_path)
  759. while start < duration_ms:
  760. end = min(start + approx_chunk_ms, duration_ms)
  761. chunk = audio[start:end]
  762. chunk_path = f"{base}_chunk_{i}.{format}"
  763. chunk.export(chunk_path, format=format, bitrate=bitrate)
  764. # Reduce chunk duration if still too large
  765. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  766. end = start + ((end - start) // 2)
  767. chunk = audio[start:end]
  768. chunk.export(chunk_path, format=format, bitrate=bitrate)
  769. if os.path.getsize(chunk_path) > max_bytes:
  770. os.remove(chunk_path)
  771. raise Exception("Audio chunk cannot be reduced below max file size.")
  772. chunks.append(chunk_path)
  773. start = end
  774. i += 1
  775. return chunks
  776. @router.post("/transcriptions")
  777. def transcription(
  778. request: Request,
  779. file: UploadFile = File(...),
  780. language: Optional[str] = Form(None),
  781. user=Depends(get_verified_user),
  782. ):
  783. log.info(f"file.content_type: {file.content_type}")
  784. stt_supported_content_types = getattr(
  785. request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
  786. )
  787. if not any(
  788. fnmatch(file.content_type, content_type)
  789. for content_type in (
  790. stt_supported_content_types
  791. if stt_supported_content_types
  792. and any(t.strip() for t in stt_supported_content_types)
  793. else ["audio/*", "video/webm"]
  794. )
  795. ):
  796. raise HTTPException(
  797. status_code=status.HTTP_400_BAD_REQUEST,
  798. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  799. )
  800. try:
  801. ext = file.filename.split(".")[-1]
  802. id = uuid.uuid4()
  803. filename = f"{id}.{ext}"
  804. contents = file.file.read()
  805. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  806. os.makedirs(file_dir, exist_ok=True)
  807. file_path = f"{file_dir}/{filename}"
  808. with open(file_path, "wb") as f:
  809. f.write(contents)
  810. try:
  811. metadata = None
  812. if language:
  813. metadata = {"language": language}
  814. result = transcribe(request, file_path, metadata)
  815. return {
  816. **result,
  817. "filename": os.path.basename(file_path),
  818. }
  819. except Exception as e:
  820. log.exception(e)
  821. raise HTTPException(
  822. status_code=status.HTTP_400_BAD_REQUEST,
  823. detail=ERROR_MESSAGES.DEFAULT(e),
  824. )
  825. except Exception as e:
  826. log.exception(e)
  827. raise HTTPException(
  828. status_code=status.HTTP_400_BAD_REQUEST,
  829. detail=ERROR_MESSAGES.DEFAULT(e),
  830. )
  831. def get_available_models(request: Request) -> list[dict]:
  832. available_models = []
  833. if request.app.state.config.TTS_ENGINE == "openai":
  834. # Use custom endpoint if not using the official OpenAI API URL
  835. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  836. "https://api.openai.com"
  837. ):
  838. try:
  839. response = requests.get(
  840. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  841. )
  842. response.raise_for_status()
  843. data = response.json()
  844. available_models = data.get("models", [])
  845. except Exception as e:
  846. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  847. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  848. else:
  849. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  850. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  851. try:
  852. response = requests.get(
  853. "https://api.elevenlabs.io/v1/models",
  854. headers={
  855. "xi-api-key": request.app.state.config.TTS_API_KEY,
  856. "Content-Type": "application/json",
  857. },
  858. timeout=5,
  859. )
  860. response.raise_for_status()
  861. models = response.json()
  862. available_models = [
  863. {"name": model["name"], "id": model["model_id"]} for model in models
  864. ]
  865. except requests.RequestException as e:
  866. log.error(f"Error fetching voices: {str(e)}")
  867. return available_models
  868. @router.get("/models")
  869. async def get_models(request: Request, user=Depends(get_verified_user)):
  870. return {"models": get_available_models(request)}
  871. def get_available_voices(request) -> dict:
  872. """Returns {voice_id: voice_name} dict"""
  873. available_voices = {}
  874. if request.app.state.config.TTS_ENGINE == "openai":
  875. # Use custom endpoint if not using the official OpenAI API URL
  876. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  877. "https://api.openai.com"
  878. ):
  879. try:
  880. response = requests.get(
  881. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  882. )
  883. response.raise_for_status()
  884. data = response.json()
  885. voices_list = data.get("voices", [])
  886. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  887. except Exception as e:
  888. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  889. available_voices = {
  890. "alloy": "alloy",
  891. "echo": "echo",
  892. "fable": "fable",
  893. "onyx": "onyx",
  894. "nova": "nova",
  895. "shimmer": "shimmer",
  896. }
  897. else:
  898. available_voices = {
  899. "alloy": "alloy",
  900. "echo": "echo",
  901. "fable": "fable",
  902. "onyx": "onyx",
  903. "nova": "nova",
  904. "shimmer": "shimmer",
  905. }
  906. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  907. try:
  908. available_voices = get_elevenlabs_voices(
  909. api_key=request.app.state.config.TTS_API_KEY
  910. )
  911. except Exception:
  912. # Avoided @lru_cache with exception
  913. pass
  914. elif request.app.state.config.TTS_ENGINE == "azure":
  915. try:
  916. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  917. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  918. url = (
  919. base_url or f"https://{region}.tts.speech.microsoft.com"
  920. ) + "/cognitiveservices/voices/list"
  921. headers = {
  922. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  923. }
  924. response = requests.get(url, headers=headers)
  925. response.raise_for_status()
  926. voices = response.json()
  927. for voice in voices:
  928. available_voices[voice["ShortName"]] = (
  929. f"{voice['DisplayName']} ({voice['ShortName']})"
  930. )
  931. except requests.RequestException as e:
  932. log.error(f"Error fetching voices: {str(e)}")
  933. return available_voices
  934. @lru_cache
  935. def get_elevenlabs_voices(api_key: str) -> dict:
  936. """
  937. Note, set the following in your .env file to use Elevenlabs:
  938. AUDIO_TTS_ENGINE=elevenlabs
  939. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  940. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  941. AUDIO_TTS_MODEL=eleven_multilingual_v2
  942. """
  943. try:
  944. # TODO: Add retries
  945. response = requests.get(
  946. "https://api.elevenlabs.io/v1/voices",
  947. headers={
  948. "xi-api-key": api_key,
  949. "Content-Type": "application/json",
  950. },
  951. )
  952. response.raise_for_status()
  953. voices_data = response.json()
  954. voices = {}
  955. for voice in voices_data.get("voices", []):
  956. voices[voice["voice_id"]] = voice["name"]
  957. except requests.RequestException as e:
  958. # Avoid @lru_cache with exception
  959. log.error(f"Error fetching voices: {str(e)}")
  960. raise RuntimeError(f"Error fetching voices: {str(e)}")
  961. return voices
  962. @router.get("/voices")
  963. async def get_voices(request: Request, user=Depends(get_verified_user)):
  964. return {
  965. "voices": [
  966. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  967. ]
  968. }