audio.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. import html
  7. from functools import lru_cache
  8. from pydub import AudioSegment
  9. from pydub.silence import split_on_silence
  10. from concurrent.futures import ThreadPoolExecutor
  11. from typing import Optional
  12. from fnmatch import fnmatch
  13. import aiohttp
  14. import aiofiles
  15. import requests
  16. import mimetypes
  17. from urllib.parse import urljoin, quote
  18. from fastapi import (
  19. Depends,
  20. FastAPI,
  21. File,
  22. Form,
  23. HTTPException,
  24. Request,
  25. UploadFile,
  26. status,
  27. APIRouter,
  28. )
  29. from fastapi.middleware.cors import CORSMiddleware
  30. from fastapi.responses import FileResponse
  31. from pydantic import BaseModel
  32. from open_webui.utils.auth import get_admin_user, get_verified_user
  33. from open_webui.config import (
  34. WHISPER_MODEL_AUTO_UPDATE,
  35. WHISPER_MODEL_DIR,
  36. CACHE_DIR,
  37. WHISPER_LANGUAGE,
  38. )
  39. from open_webui.constants import ERROR_MESSAGES
  40. from open_webui.env import (
  41. AIOHTTP_CLIENT_SESSION_SSL,
  42. AIOHTTP_CLIENT_TIMEOUT,
  43. ENV,
  44. SRC_LOG_LEVELS,
  45. DEVICE_TYPE,
  46. ENABLE_FORWARD_USER_INFO_HEADERS,
  47. )
  48. router = APIRouter()
  49. # Constants
  50. MAX_FILE_SIZE_MB = 20
  51. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  52. AZURE_MAX_FILE_SIZE_MB = 200
  53. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  54. log = logging.getLogger(__name__)
  55. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  56. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  57. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  58. ##########################################
  59. #
  60. # Utility functions
  61. #
  62. ##########################################
  63. from pydub import AudioSegment
  64. from pydub.utils import mediainfo
  65. def is_audio_conversion_required(file_path):
  66. """
  67. Check if the given audio file needs conversion to mp3.
  68. """
  69. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  70. if not os.path.isfile(file_path):
  71. log.error(f"File not found: {file_path}")
  72. return False
  73. try:
  74. info = mediainfo(file_path)
  75. codec_name = info.get("codec_name", "").lower()
  76. codec_type = info.get("codec_type", "").lower()
  77. codec_tag_string = info.get("codec_tag_string", "").lower()
  78. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  79. # File is AAC/mp4a audio, recommend mp3 conversion
  80. return True
  81. # If the codec name is in the supported formats
  82. if codec_name in SUPPORTED_FORMATS:
  83. return False
  84. return True
  85. except Exception as e:
  86. log.error(f"Error getting audio format: {e}")
  87. return False
  88. def convert_audio_to_mp3(file_path):
  89. """Convert audio file to mp3 format."""
  90. try:
  91. output_path = os.path.splitext(file_path)[0] + ".mp3"
  92. audio = AudioSegment.from_file(file_path)
  93. audio.export(output_path, format="mp3")
  94. log.info(f"Converted {file_path} to {output_path}")
  95. return output_path
  96. except Exception as e:
  97. log.error(f"Error converting audio file: {e}")
  98. return None
  99. def set_faster_whisper_model(model: str, auto_update: bool = False):
  100. whisper_model = None
  101. if model:
  102. from faster_whisper import WhisperModel
  103. faster_whisper_kwargs = {
  104. "model_size_or_path": model,
  105. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  106. "compute_type": "int8",
  107. "download_root": WHISPER_MODEL_DIR,
  108. "local_files_only": not auto_update,
  109. }
  110. try:
  111. whisper_model = WhisperModel(**faster_whisper_kwargs)
  112. except Exception:
  113. log.warning(
  114. "WhisperModel initialization failed, attempting download with local_files_only=False"
  115. )
  116. faster_whisper_kwargs["local_files_only"] = False
  117. whisper_model = WhisperModel(**faster_whisper_kwargs)
  118. return whisper_model
  119. ##########################################
  120. #
  121. # Audio API
  122. #
  123. ##########################################
  124. class TTSConfigForm(BaseModel):
  125. OPENAI_API_BASE_URL: str
  126. OPENAI_API_KEY: str
  127. API_KEY: str
  128. ENGINE: str
  129. MODEL: str
  130. VOICE: str
  131. SPLIT_ON: str
  132. AZURE_SPEECH_REGION: str
  133. AZURE_SPEECH_BASE_URL: str
  134. AZURE_SPEECH_OUTPUT_FORMAT: str
  135. class STTConfigForm(BaseModel):
  136. OPENAI_API_BASE_URL: str
  137. OPENAI_API_KEY: str
  138. ENGINE: str
  139. MODEL: str
  140. SUPPORTED_CONTENT_TYPES: list[str] = []
  141. WHISPER_MODEL: str
  142. DEEPGRAM_API_KEY: str
  143. AZURE_API_KEY: str
  144. AZURE_REGION: str
  145. AZURE_LOCALES: str
  146. AZURE_BASE_URL: str
  147. AZURE_MAX_SPEAKERS: str
  148. class AudioConfigUpdateForm(BaseModel):
  149. tts: TTSConfigForm
  150. stt: STTConfigForm
  151. @router.get("/config")
  152. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  153. return {
  154. "tts": {
  155. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  156. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  157. "API_KEY": request.app.state.config.TTS_API_KEY,
  158. "ENGINE": request.app.state.config.TTS_ENGINE,
  159. "MODEL": request.app.state.config.TTS_MODEL,
  160. "VOICE": request.app.state.config.TTS_VOICE,
  161. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  162. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  163. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  164. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  165. },
  166. "stt": {
  167. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  168. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  169. "ENGINE": request.app.state.config.STT_ENGINE,
  170. "MODEL": request.app.state.config.STT_MODEL,
  171. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  172. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  173. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  174. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  175. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  176. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  177. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  178. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  179. },
  180. }
  181. @router.post("/config/update")
  182. async def update_audio_config(
  183. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  184. ):
  185. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  186. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  187. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  188. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  189. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  190. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  191. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  192. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  193. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  194. form_data.tts.AZURE_SPEECH_BASE_URL
  195. )
  196. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  197. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  198. )
  199. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  200. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  201. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  202. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  203. request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = (
  204. form_data.stt.SUPPORTED_CONTENT_TYPES
  205. )
  206. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  207. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  208. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  209. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  210. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  211. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  212. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  213. form_data.stt.AZURE_MAX_SPEAKERS
  214. )
  215. if request.app.state.config.STT_ENGINE == "":
  216. request.app.state.faster_whisper_model = set_faster_whisper_model(
  217. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  218. )
  219. else:
  220. request.app.state.faster_whisper_model = None
  221. return {
  222. "tts": {
  223. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  224. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  225. "API_KEY": request.app.state.config.TTS_API_KEY,
  226. "ENGINE": request.app.state.config.TTS_ENGINE,
  227. "MODEL": request.app.state.config.TTS_MODEL,
  228. "VOICE": request.app.state.config.TTS_VOICE,
  229. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  230. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  231. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  232. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  233. },
  234. "stt": {
  235. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  236. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  237. "ENGINE": request.app.state.config.STT_ENGINE,
  238. "MODEL": request.app.state.config.STT_MODEL,
  239. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  240. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  241. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  242. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  243. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  244. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  245. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  246. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  247. },
  248. }
  249. def load_speech_pipeline(request):
  250. from transformers import pipeline
  251. from datasets import load_dataset
  252. if request.app.state.speech_synthesiser is None:
  253. request.app.state.speech_synthesiser = pipeline(
  254. "text-to-speech", "microsoft/speecht5_tts"
  255. )
  256. if request.app.state.speech_speaker_embeddings_dataset is None:
  257. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  258. "Matthijs/cmu-arctic-xvectors", split="validation"
  259. )
  260. @router.post("/speech")
  261. async def speech(request: Request, user=Depends(get_verified_user)):
  262. body = await request.body()
  263. name = hashlib.sha256(
  264. body
  265. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  266. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  267. ).hexdigest()
  268. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  269. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  270. # Check if the file already exists in the cache
  271. if file_path.is_file():
  272. return FileResponse(file_path)
  273. payload = None
  274. try:
  275. payload = json.loads(body.decode("utf-8"))
  276. except Exception as e:
  277. log.exception(e)
  278. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  279. r = None
  280. if request.app.state.config.TTS_ENGINE == "openai":
  281. payload["model"] = request.app.state.config.TTS_MODEL
  282. try:
  283. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  284. async with aiohttp.ClientSession(
  285. timeout=timeout, trust_env=True
  286. ) as session:
  287. r = await session.post(
  288. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  289. json=payload,
  290. headers={
  291. "Content-Type": "application/json",
  292. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  293. **(
  294. {
  295. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  296. "X-OpenWebUI-User-Id": user.id,
  297. "X-OpenWebUI-User-Email": user.email,
  298. "X-OpenWebUI-User-Role": user.role,
  299. }
  300. if ENABLE_FORWARD_USER_INFO_HEADERS
  301. else {}
  302. ),
  303. },
  304. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  305. )
  306. r.raise_for_status()
  307. async with aiofiles.open(file_path, "wb") as f:
  308. await f.write(await r.read())
  309. async with aiofiles.open(file_body_path, "w") as f:
  310. await f.write(json.dumps(payload))
  311. return FileResponse(file_path)
  312. except Exception as e:
  313. log.exception(e)
  314. detail = None
  315. status_code = 500
  316. detail = f"Open WebUI: Server Connection Error"
  317. if r is not None:
  318. status_code = r.status
  319. try:
  320. res = await r.json()
  321. if "error" in res:
  322. detail = f"External: {res['error']}"
  323. except Exception:
  324. detail = f"External: {e}"
  325. raise HTTPException(
  326. status_code=status_code,
  327. detail=detail,
  328. )
  329. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  330. voice_id = payload.get("voice", "")
  331. if voice_id not in get_available_voices(request):
  332. raise HTTPException(
  333. status_code=400,
  334. detail="Invalid voice id",
  335. )
  336. try:
  337. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  338. async with aiohttp.ClientSession(
  339. timeout=timeout, trust_env=True
  340. ) as session:
  341. async with session.post(
  342. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  343. json={
  344. "text": payload["input"],
  345. "model_id": request.app.state.config.TTS_MODEL,
  346. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  347. },
  348. headers={
  349. "Accept": "audio/mpeg",
  350. "Content-Type": "application/json",
  351. "xi-api-key": request.app.state.config.TTS_API_KEY,
  352. },
  353. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  354. ) as r:
  355. r.raise_for_status()
  356. async with aiofiles.open(file_path, "wb") as f:
  357. await f.write(await r.read())
  358. async with aiofiles.open(file_body_path, "w") as f:
  359. await f.write(json.dumps(payload))
  360. return FileResponse(file_path)
  361. except Exception as e:
  362. log.exception(e)
  363. detail = None
  364. try:
  365. if r.status != 200:
  366. res = await r.json()
  367. if "error" in res:
  368. detail = f"External: {res['error'].get('message', '')}"
  369. except Exception:
  370. detail = f"External: {e}"
  371. raise HTTPException(
  372. status_code=getattr(r, "status", 500) if r else 500,
  373. detail=detail if detail else "Open WebUI: Server Connection Error",
  374. )
  375. elif request.app.state.config.TTS_ENGINE == "azure":
  376. try:
  377. payload = json.loads(body.decode("utf-8"))
  378. except Exception as e:
  379. log.exception(e)
  380. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  381. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  382. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  383. language = request.app.state.config.TTS_VOICE
  384. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  385. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  386. try:
  387. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  388. <voice name="{language}">{html.escape(payload["input"])}</voice>
  389. </speak>"""
  390. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  391. async with aiohttp.ClientSession(
  392. timeout=timeout, trust_env=True
  393. ) as session:
  394. async with session.post(
  395. (base_url or f"https://{region}.tts.speech.microsoft.com")
  396. + "/cognitiveservices/v1",
  397. headers={
  398. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  399. "Content-Type": "application/ssml+xml",
  400. "X-Microsoft-OutputFormat": output_format,
  401. },
  402. data=data,
  403. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  404. ) as r:
  405. r.raise_for_status()
  406. async with aiofiles.open(file_path, "wb") as f:
  407. await f.write(await r.read())
  408. async with aiofiles.open(file_body_path, "w") as f:
  409. await f.write(json.dumps(payload))
  410. return FileResponse(file_path)
  411. except Exception as e:
  412. log.exception(e)
  413. detail = None
  414. try:
  415. if r.status != 200:
  416. res = await r.json()
  417. if "error" in res:
  418. detail = f"External: {res['error'].get('message', '')}"
  419. except Exception:
  420. detail = f"External: {e}"
  421. raise HTTPException(
  422. status_code=getattr(r, "status", 500) if r else 500,
  423. detail=detail if detail else "Open WebUI: Server Connection Error",
  424. )
  425. elif request.app.state.config.TTS_ENGINE == "transformers":
  426. payload = None
  427. try:
  428. payload = json.loads(body.decode("utf-8"))
  429. except Exception as e:
  430. log.exception(e)
  431. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  432. import torch
  433. import soundfile as sf
  434. load_speech_pipeline(request)
  435. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  436. speaker_index = 6799
  437. try:
  438. speaker_index = embeddings_dataset["filename"].index(
  439. request.app.state.config.TTS_MODEL
  440. )
  441. except Exception:
  442. pass
  443. speaker_embedding = torch.tensor(
  444. embeddings_dataset[speaker_index]["xvector"]
  445. ).unsqueeze(0)
  446. speech = request.app.state.speech_synthesiser(
  447. payload["input"],
  448. forward_params={"speaker_embeddings": speaker_embedding},
  449. )
  450. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  451. async with aiofiles.open(file_body_path, "w") as f:
  452. await f.write(json.dumps(payload))
  453. return FileResponse(file_path)
  454. def transcription_handler(request, file_path, metadata):
  455. filename = os.path.basename(file_path)
  456. file_dir = os.path.dirname(file_path)
  457. id = filename.split(".")[0]
  458. metadata = metadata or {}
  459. languages = [
  460. metadata.get("language", None) if not WHISPER_LANGUAGE else WHISPER_LANGUAGE,
  461. None, # Always fallback to None in case transcription fails
  462. ]
  463. if request.app.state.config.STT_ENGINE == "":
  464. if request.app.state.faster_whisper_model is None:
  465. request.app.state.faster_whisper_model = set_faster_whisper_model(
  466. request.app.state.config.WHISPER_MODEL
  467. )
  468. model = request.app.state.faster_whisper_model
  469. segments, info = model.transcribe(
  470. file_path,
  471. beam_size=5,
  472. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  473. language=languages[0],
  474. )
  475. log.info(
  476. "Detected language '%s' with probability %f"
  477. % (info.language, info.language_probability)
  478. )
  479. transcript = "".join([segment.text for segment in list(segments)])
  480. data = {"text": transcript.strip()}
  481. # save the transcript to a json file
  482. transcript_file = f"{file_dir}/{id}.json"
  483. with open(transcript_file, "w") as f:
  484. json.dump(data, f)
  485. log.debug(data)
  486. return data
  487. elif request.app.state.config.STT_ENGINE == "openai":
  488. r = None
  489. try:
  490. for language in languages:
  491. payload = {
  492. "model": request.app.state.config.STT_MODEL,
  493. }
  494. if language:
  495. payload["language"] = language
  496. r = requests.post(
  497. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  498. headers={
  499. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  500. },
  501. files={"file": (filename, open(file_path, "rb"))},
  502. data=payload,
  503. )
  504. if r.status_code == 200:
  505. # Successful transcription
  506. break
  507. r.raise_for_status()
  508. data = r.json()
  509. # save the transcript to a json file
  510. transcript_file = f"{file_dir}/{id}.json"
  511. with open(transcript_file, "w") as f:
  512. json.dump(data, f)
  513. return data
  514. except Exception as e:
  515. log.exception(e)
  516. detail = None
  517. if r is not None:
  518. try:
  519. res = r.json()
  520. if "error" in res:
  521. detail = f"External: {res['error'].get('message', '')}"
  522. except Exception:
  523. detail = f"External: {e}"
  524. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  525. elif request.app.state.config.STT_ENGINE == "deepgram":
  526. try:
  527. # Determine the MIME type of the file
  528. mime, _ = mimetypes.guess_type(file_path)
  529. if not mime:
  530. mime = "audio/wav" # fallback to wav if undetectable
  531. # Read the audio file
  532. with open(file_path, "rb") as f:
  533. file_data = f.read()
  534. # Build headers and parameters
  535. headers = {
  536. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  537. "Content-Type": mime,
  538. }
  539. for language in languages:
  540. params = {}
  541. if request.app.state.config.STT_MODEL:
  542. params["model"] = request.app.state.config.STT_MODEL
  543. if language:
  544. params["language"] = language
  545. # Make request to Deepgram API
  546. r = requests.post(
  547. "https://api.deepgram.com/v1/listen?smart_format=true",
  548. headers=headers,
  549. params=params,
  550. data=file_data,
  551. )
  552. if r.status_code == 200:
  553. # Successful transcription
  554. break
  555. r.raise_for_status()
  556. response_data = r.json()
  557. # Extract transcript from Deepgram response
  558. try:
  559. transcript = response_data["results"]["channels"][0]["alternatives"][
  560. 0
  561. ].get("transcript", "")
  562. except (KeyError, IndexError) as e:
  563. log.error(f"Malformed response from Deepgram: {str(e)}")
  564. raise Exception(
  565. "Failed to parse Deepgram response - unexpected response format"
  566. )
  567. data = {"text": transcript.strip()}
  568. # Save transcript
  569. transcript_file = f"{file_dir}/{id}.json"
  570. with open(transcript_file, "w") as f:
  571. json.dump(data, f)
  572. return data
  573. except Exception as e:
  574. log.exception(e)
  575. detail = None
  576. if r is not None:
  577. try:
  578. res = r.json()
  579. if "error" in res:
  580. detail = f"External: {res['error'].get('message', '')}"
  581. except Exception:
  582. detail = f"External: {e}"
  583. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  584. elif request.app.state.config.STT_ENGINE == "azure":
  585. # Check file exists and size
  586. if not os.path.exists(file_path):
  587. raise HTTPException(status_code=400, detail="Audio file not found")
  588. # Check file size (Azure has a larger limit of 200MB)
  589. file_size = os.path.getsize(file_path)
  590. if file_size > AZURE_MAX_FILE_SIZE:
  591. raise HTTPException(
  592. status_code=400,
  593. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  594. )
  595. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  596. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  597. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  598. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  599. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  600. # IF NO LOCALES, USE DEFAULTS
  601. if len(locales) < 2:
  602. locales = [
  603. "en-US",
  604. "es-ES",
  605. "es-MX",
  606. "fr-FR",
  607. "hi-IN",
  608. "it-IT",
  609. "de-DE",
  610. "en-GB",
  611. "en-IN",
  612. "ja-JP",
  613. "ko-KR",
  614. "pt-BR",
  615. "zh-CN",
  616. ]
  617. locales = ",".join(locales)
  618. if not api_key or not region:
  619. raise HTTPException(
  620. status_code=400,
  621. detail="Azure API key is required for Azure STT",
  622. )
  623. r = None
  624. try:
  625. # Prepare the request
  626. data = {
  627. "definition": json.dumps(
  628. {
  629. "locales": locales.split(","),
  630. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  631. }
  632. if locales
  633. else {}
  634. )
  635. }
  636. url = (
  637. base_url or f"https://{region}.api.cognitive.microsoft.com"
  638. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  639. # Use context manager to ensure file is properly closed
  640. with open(file_path, "rb") as audio_file:
  641. r = requests.post(
  642. url=url,
  643. files={"audio": audio_file},
  644. data=data,
  645. headers={
  646. "Ocp-Apim-Subscription-Key": api_key,
  647. },
  648. )
  649. r.raise_for_status()
  650. response = r.json()
  651. # Extract transcript from response
  652. if not response.get("combinedPhrases"):
  653. raise ValueError("No transcription found in response")
  654. # Get the full transcript from combinedPhrases
  655. transcript = response["combinedPhrases"][0].get("text", "").strip()
  656. if not transcript:
  657. raise ValueError("Empty transcript in response")
  658. data = {"text": transcript}
  659. # Save transcript to json file (consistent with other providers)
  660. transcript_file = f"{file_dir}/{id}.json"
  661. with open(transcript_file, "w") as f:
  662. json.dump(data, f)
  663. log.debug(data)
  664. return data
  665. except (KeyError, IndexError, ValueError) as e:
  666. log.exception("Error parsing Azure response")
  667. raise HTTPException(
  668. status_code=500,
  669. detail=f"Failed to parse Azure response: {str(e)}",
  670. )
  671. except requests.exceptions.RequestException as e:
  672. log.exception(e)
  673. detail = None
  674. try:
  675. if r is not None and r.status_code != 200:
  676. res = r.json()
  677. if "error" in res:
  678. detail = f"External: {res['error'].get('message', '')}"
  679. except Exception:
  680. detail = f"External: {e}"
  681. raise HTTPException(
  682. status_code=getattr(r, "status_code", 500) if r else 500,
  683. detail=detail if detail else "Open WebUI: Server Connection Error",
  684. )
  685. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  686. log.info(f"transcribe: {file_path} {metadata}")
  687. if is_audio_conversion_required(file_path):
  688. file_path = convert_audio_to_mp3(file_path)
  689. try:
  690. file_path = compress_audio(file_path)
  691. except Exception as e:
  692. log.exception(e)
  693. # Always produce a list of chunk paths (could be one entry if small)
  694. try:
  695. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  696. print(f"Chunk paths: {chunk_paths}")
  697. except Exception as e:
  698. log.exception(e)
  699. raise HTTPException(
  700. status_code=status.HTTP_400_BAD_REQUEST,
  701. detail=ERROR_MESSAGES.DEFAULT(e),
  702. )
  703. results = []
  704. try:
  705. with ThreadPoolExecutor() as executor:
  706. # Submit tasks for each chunk_path
  707. futures = [
  708. executor.submit(transcription_handler, request, chunk_path, metadata)
  709. for chunk_path in chunk_paths
  710. ]
  711. # Gather results as they complete
  712. for future in futures:
  713. try:
  714. results.append(future.result())
  715. except Exception as transcribe_exc:
  716. raise HTTPException(
  717. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  718. detail=f"Error transcribing chunk: {transcribe_exc}",
  719. )
  720. finally:
  721. # Clean up only the temporary chunks, never the original file
  722. for chunk_path in chunk_paths:
  723. if chunk_path != file_path and os.path.isfile(chunk_path):
  724. try:
  725. os.remove(chunk_path)
  726. except Exception:
  727. pass
  728. return {
  729. "text": " ".join([result["text"] for result in results]),
  730. }
  731. def compress_audio(file_path):
  732. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  733. id = os.path.splitext(os.path.basename(file_path))[
  734. 0
  735. ] # Handles names with multiple dots
  736. file_dir = os.path.dirname(file_path)
  737. audio = AudioSegment.from_file(file_path)
  738. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  739. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  740. audio.export(compressed_path, format="mp3", bitrate="32k")
  741. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  742. return compressed_path
  743. else:
  744. return file_path
  745. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  746. """
  747. Splits audio into chunks not exceeding max_bytes.
  748. Returns a list of chunk file paths. If audio fits, returns list with original path.
  749. """
  750. file_size = os.path.getsize(file_path)
  751. if file_size <= max_bytes:
  752. return [file_path] # Nothing to split
  753. audio = AudioSegment.from_file(file_path)
  754. duration_ms = len(audio)
  755. orig_size = file_size
  756. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  757. chunks = []
  758. start = 0
  759. i = 0
  760. base, _ = os.path.splitext(file_path)
  761. while start < duration_ms:
  762. end = min(start + approx_chunk_ms, duration_ms)
  763. chunk = audio[start:end]
  764. chunk_path = f"{base}_chunk_{i}.{format}"
  765. chunk.export(chunk_path, format=format, bitrate=bitrate)
  766. # Reduce chunk duration if still too large
  767. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  768. end = start + ((end - start) // 2)
  769. chunk = audio[start:end]
  770. chunk.export(chunk_path, format=format, bitrate=bitrate)
  771. if os.path.getsize(chunk_path) > max_bytes:
  772. os.remove(chunk_path)
  773. raise Exception("Audio chunk cannot be reduced below max file size.")
  774. chunks.append(chunk_path)
  775. start = end
  776. i += 1
  777. return chunks
  778. @router.post("/transcriptions")
  779. def transcription(
  780. request: Request,
  781. file: UploadFile = File(...),
  782. language: Optional[str] = Form(None),
  783. user=Depends(get_verified_user),
  784. ):
  785. log.info(f"file.content_type: {file.content_type}")
  786. stt_supported_content_types = getattr(
  787. request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
  788. )
  789. if not any(
  790. fnmatch(file.content_type, content_type)
  791. for content_type in (
  792. stt_supported_content_types
  793. if stt_supported_content_types
  794. and any(t.strip() for t in stt_supported_content_types)
  795. else ["audio/*", "video/webm"]
  796. )
  797. ):
  798. raise HTTPException(
  799. status_code=status.HTTP_400_BAD_REQUEST,
  800. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  801. )
  802. try:
  803. ext = file.filename.split(".")[-1]
  804. id = uuid.uuid4()
  805. filename = f"{id}.{ext}"
  806. contents = file.file.read()
  807. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  808. os.makedirs(file_dir, exist_ok=True)
  809. file_path = f"{file_dir}/{filename}"
  810. with open(file_path, "wb") as f:
  811. f.write(contents)
  812. try:
  813. metadata = None
  814. if language:
  815. metadata = {"language": language}
  816. result = transcribe(request, file_path, metadata)
  817. return {
  818. **result,
  819. "filename": os.path.basename(file_path),
  820. }
  821. except Exception as e:
  822. log.exception(e)
  823. raise HTTPException(
  824. status_code=status.HTTP_400_BAD_REQUEST,
  825. detail=ERROR_MESSAGES.DEFAULT(e),
  826. )
  827. except Exception as e:
  828. log.exception(e)
  829. raise HTTPException(
  830. status_code=status.HTTP_400_BAD_REQUEST,
  831. detail=ERROR_MESSAGES.DEFAULT(e),
  832. )
  833. def get_available_models(request: Request) -> list[dict]:
  834. available_models = []
  835. if request.app.state.config.TTS_ENGINE == "openai":
  836. # Use custom endpoint if not using the official OpenAI API URL
  837. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  838. "https://api.openai.com"
  839. ):
  840. try:
  841. response = requests.get(
  842. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  843. )
  844. response.raise_for_status()
  845. data = response.json()
  846. available_models = data.get("models", [])
  847. except Exception as e:
  848. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  849. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  850. else:
  851. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  852. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  853. try:
  854. response = requests.get(
  855. "https://api.elevenlabs.io/v1/models",
  856. headers={
  857. "xi-api-key": request.app.state.config.TTS_API_KEY,
  858. "Content-Type": "application/json",
  859. },
  860. timeout=5,
  861. )
  862. response.raise_for_status()
  863. models = response.json()
  864. available_models = [
  865. {"name": model["name"], "id": model["model_id"]} for model in models
  866. ]
  867. except requests.RequestException as e:
  868. log.error(f"Error fetching voices: {str(e)}")
  869. return available_models
  870. @router.get("/models")
  871. async def get_models(request: Request, user=Depends(get_verified_user)):
  872. return {"models": get_available_models(request)}
  873. def get_available_voices(request) -> dict:
  874. """Returns {voice_id: voice_name} dict"""
  875. available_voices = {}
  876. if request.app.state.config.TTS_ENGINE == "openai":
  877. # Use custom endpoint if not using the official OpenAI API URL
  878. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  879. "https://api.openai.com"
  880. ):
  881. try:
  882. response = requests.get(
  883. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  884. )
  885. response.raise_for_status()
  886. data = response.json()
  887. voices_list = data.get("voices", [])
  888. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  889. except Exception as e:
  890. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  891. available_voices = {
  892. "alloy": "alloy",
  893. "echo": "echo",
  894. "fable": "fable",
  895. "onyx": "onyx",
  896. "nova": "nova",
  897. "shimmer": "shimmer",
  898. }
  899. else:
  900. available_voices = {
  901. "alloy": "alloy",
  902. "echo": "echo",
  903. "fable": "fable",
  904. "onyx": "onyx",
  905. "nova": "nova",
  906. "shimmer": "shimmer",
  907. }
  908. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  909. try:
  910. available_voices = get_elevenlabs_voices(
  911. api_key=request.app.state.config.TTS_API_KEY
  912. )
  913. except Exception:
  914. # Avoided @lru_cache with exception
  915. pass
  916. elif request.app.state.config.TTS_ENGINE == "azure":
  917. try:
  918. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  919. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  920. url = (
  921. base_url or f"https://{region}.tts.speech.microsoft.com"
  922. ) + "/cognitiveservices/voices/list"
  923. headers = {
  924. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  925. }
  926. response = requests.get(url, headers=headers)
  927. response.raise_for_status()
  928. voices = response.json()
  929. for voice in voices:
  930. available_voices[voice["ShortName"]] = (
  931. f"{voice['DisplayName']} ({voice['ShortName']})"
  932. )
  933. except requests.RequestException as e:
  934. log.error(f"Error fetching voices: {str(e)}")
  935. return available_voices
  936. @lru_cache
  937. def get_elevenlabs_voices(api_key: str) -> dict:
  938. """
  939. Note, set the following in your .env file to use Elevenlabs:
  940. AUDIO_TTS_ENGINE=elevenlabs
  941. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  942. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  943. AUDIO_TTS_MODEL=eleven_multilingual_v2
  944. """
  945. try:
  946. # TODO: Add retries
  947. response = requests.get(
  948. "https://api.elevenlabs.io/v1/voices",
  949. headers={
  950. "xi-api-key": api_key,
  951. "Content-Type": "application/json",
  952. },
  953. )
  954. response.raise_for_status()
  955. voices_data = response.json()
  956. voices = {}
  957. for voice in voices_data.get("voices", []):
  958. voices[voice["voice_id"]] = voice["name"]
  959. except requests.RequestException as e:
  960. # Avoid @lru_cache with exception
  961. log.error(f"Error fetching voices: {str(e)}")
  962. raise RuntimeError(f"Error fetching voices: {str(e)}")
  963. return voices
  964. @router.get("/voices")
  965. async def get_voices(request: Request, user=Depends(get_verified_user)):
  966. return {
  967. "voices": [
  968. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  969. ]
  970. }