audio.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from pydub import AudioSegment
  9. from pydub.silence import split_on_silence
  10. from concurrent.futures import ThreadPoolExecutor
  11. from typing import Optional
  12. from fnmatch import fnmatch
  13. import aiohttp
  14. import aiofiles
  15. import requests
  16. import mimetypes
  17. from urllib.parse import quote
  18. from fastapi import (
  19. Depends,
  20. FastAPI,
  21. File,
  22. Form,
  23. HTTPException,
  24. Request,
  25. UploadFile,
  26. status,
  27. APIRouter,
  28. )
  29. from fastapi.middleware.cors import CORSMiddleware
  30. from fastapi.responses import FileResponse
  31. from pydantic import BaseModel
  32. from open_webui.utils.auth import get_admin_user, get_verified_user
  33. from open_webui.config import (
  34. WHISPER_MODEL_AUTO_UPDATE,
  35. WHISPER_MODEL_DIR,
  36. CACHE_DIR,
  37. WHISPER_LANGUAGE,
  38. )
  39. from open_webui.constants import ERROR_MESSAGES
  40. from open_webui.env import (
  41. AIOHTTP_CLIENT_SESSION_SSL,
  42. AIOHTTP_CLIENT_TIMEOUT,
  43. ENV,
  44. SRC_LOG_LEVELS,
  45. DEVICE_TYPE,
  46. ENABLE_FORWARD_USER_INFO_HEADERS,
  47. )
  48. router = APIRouter()
  49. # Constants
  50. MAX_FILE_SIZE_MB = 20
  51. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  52. AZURE_MAX_FILE_SIZE_MB = 200
  53. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  54. log = logging.getLogger(__name__)
  55. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  56. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  57. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  58. ##########################################
  59. #
  60. # Utility functions
  61. #
  62. ##########################################
  63. from pydub import AudioSegment
  64. from pydub.utils import mediainfo
  65. def is_audio_conversion_required(file_path):
  66. """
  67. Check if the given audio file needs conversion to mp3.
  68. """
  69. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  70. if not os.path.isfile(file_path):
  71. log.error(f"File not found: {file_path}")
  72. return False
  73. try:
  74. info = mediainfo(file_path)
  75. codec_name = info.get("codec_name", "").lower()
  76. codec_type = info.get("codec_type", "").lower()
  77. codec_tag_string = info.get("codec_tag_string", "").lower()
  78. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  79. # File is AAC/mp4a audio, recommend mp3 conversion
  80. return True
  81. # If the codec name is in the supported formats
  82. if codec_name in SUPPORTED_FORMATS:
  83. return False
  84. return True
  85. except Exception as e:
  86. log.error(f"Error getting audio format: {e}")
  87. return False
  88. def convert_audio_to_mp3(file_path):
  89. """Convert audio file to mp3 format."""
  90. try:
  91. output_path = os.path.splitext(file_path)[0] + ".mp3"
  92. audio = AudioSegment.from_file(file_path)
  93. audio.export(output_path, format="mp3")
  94. log.info(f"Converted {file_path} to {output_path}")
  95. return output_path
  96. except Exception as e:
  97. log.error(f"Error converting audio file: {e}")
  98. return None
  99. def set_faster_whisper_model(model: str, auto_update: bool = False):
  100. whisper_model = None
  101. if model:
  102. from faster_whisper import WhisperModel
  103. faster_whisper_kwargs = {
  104. "model_size_or_path": model,
  105. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  106. "compute_type": "int8",
  107. "download_root": WHISPER_MODEL_DIR,
  108. "local_files_only": not auto_update,
  109. }
  110. try:
  111. whisper_model = WhisperModel(**faster_whisper_kwargs)
  112. except Exception:
  113. log.warning(
  114. "WhisperModel initialization failed, attempting download with local_files_only=False"
  115. )
  116. faster_whisper_kwargs["local_files_only"] = False
  117. whisper_model = WhisperModel(**faster_whisper_kwargs)
  118. return whisper_model
  119. ##########################################
  120. #
  121. # Audio API
  122. #
  123. ##########################################
  124. class TTSConfigForm(BaseModel):
  125. OPENAI_API_BASE_URL: str
  126. OPENAI_API_KEY: str
  127. API_KEY: str
  128. ENGINE: str
  129. MODEL: str
  130. VOICE: str
  131. SPLIT_ON: str
  132. AZURE_SPEECH_REGION: str
  133. AZURE_SPEECH_BASE_URL: str
  134. AZURE_SPEECH_OUTPUT_FORMAT: str
  135. class STTConfigForm(BaseModel):
  136. OPENAI_API_BASE_URL: str
  137. OPENAI_API_KEY: str
  138. ENGINE: str
  139. MODEL: str
  140. SUPPORTED_CONTENT_TYPES: list[str] = []
  141. WHISPER_MODEL: str
  142. DEEPGRAM_API_KEY: str
  143. AZURE_API_KEY: str
  144. AZURE_REGION: str
  145. AZURE_LOCALES: str
  146. AZURE_BASE_URL: str
  147. AZURE_MAX_SPEAKERS: str
  148. class AudioConfigUpdateForm(BaseModel):
  149. tts: TTSConfigForm
  150. stt: STTConfigForm
  151. @router.get("/config")
  152. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  153. return {
  154. "tts": {
  155. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  156. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  157. "API_KEY": request.app.state.config.TTS_API_KEY,
  158. "ENGINE": request.app.state.config.TTS_ENGINE,
  159. "MODEL": request.app.state.config.TTS_MODEL,
  160. "VOICE": request.app.state.config.TTS_VOICE,
  161. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  162. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  163. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  164. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  165. },
  166. "stt": {
  167. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  168. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  169. "ENGINE": request.app.state.config.STT_ENGINE,
  170. "MODEL": request.app.state.config.STT_MODEL,
  171. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  172. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  173. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  174. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  175. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  176. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  177. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  178. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  179. },
  180. }
  181. @router.post("/config/update")
  182. async def update_audio_config(
  183. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  184. ):
  185. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  186. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  187. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  188. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  189. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  190. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  191. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  192. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  193. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  194. form_data.tts.AZURE_SPEECH_BASE_URL
  195. )
  196. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  197. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  198. )
  199. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  200. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  201. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  202. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  203. request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = (
  204. form_data.stt.SUPPORTED_CONTENT_TYPES
  205. )
  206. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  207. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  208. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  209. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  210. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  211. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  212. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  213. form_data.stt.AZURE_MAX_SPEAKERS
  214. )
  215. if request.app.state.config.STT_ENGINE == "":
  216. request.app.state.faster_whisper_model = set_faster_whisper_model(
  217. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  218. )
  219. else:
  220. request.app.state.faster_whisper_model = None
  221. return {
  222. "tts": {
  223. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  224. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  225. "API_KEY": request.app.state.config.TTS_API_KEY,
  226. "ENGINE": request.app.state.config.TTS_ENGINE,
  227. "MODEL": request.app.state.config.TTS_MODEL,
  228. "VOICE": request.app.state.config.TTS_VOICE,
  229. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  230. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  231. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  232. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  233. },
  234. "stt": {
  235. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  236. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  237. "ENGINE": request.app.state.config.STT_ENGINE,
  238. "MODEL": request.app.state.config.STT_MODEL,
  239. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  240. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  241. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  242. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  243. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  244. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  245. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  246. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  247. },
  248. }
  249. def load_speech_pipeline(request):
  250. from transformers import pipeline
  251. from datasets import load_dataset
  252. if request.app.state.speech_synthesiser is None:
  253. request.app.state.speech_synthesiser = pipeline(
  254. "text-to-speech", "microsoft/speecht5_tts"
  255. )
  256. if request.app.state.speech_speaker_embeddings_dataset is None:
  257. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  258. "Matthijs/cmu-arctic-xvectors", split="validation"
  259. )
  260. @router.post("/speech")
  261. async def speech(request: Request, user=Depends(get_verified_user)):
  262. body = await request.body()
  263. name = hashlib.sha256(
  264. body
  265. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  266. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  267. ).hexdigest()
  268. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  269. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  270. # Check if the file already exists in the cache
  271. if file_path.is_file():
  272. return FileResponse(file_path)
  273. payload = None
  274. try:
  275. payload = json.loads(body.decode("utf-8"))
  276. except Exception as e:
  277. log.exception(e)
  278. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  279. r = None
  280. if request.app.state.config.TTS_ENGINE == "openai":
  281. payload["model"] = request.app.state.config.TTS_MODEL
  282. try:
  283. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  284. async with aiohttp.ClientSession(
  285. timeout=timeout, trust_env=True
  286. ) as session:
  287. r = await session.post(
  288. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  289. json=payload,
  290. headers={
  291. "Content-Type": "application/json",
  292. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  293. **(
  294. {
  295. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  296. "X-OpenWebUI-User-Id": user.id,
  297. "X-OpenWebUI-User-Email": user.email,
  298. "X-OpenWebUI-User-Role": user.role,
  299. }
  300. if ENABLE_FORWARD_USER_INFO_HEADERS
  301. else {}
  302. ),
  303. },
  304. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  305. )
  306. r.raise_for_status()
  307. async with aiofiles.open(file_path, "wb") as f:
  308. await f.write(await r.read())
  309. async with aiofiles.open(file_body_path, "w") as f:
  310. await f.write(json.dumps(payload))
  311. return FileResponse(file_path)
  312. except Exception as e:
  313. log.exception(e)
  314. detail = None
  315. status_code = 500
  316. detail = f"Open WebUI: Server Connection Error"
  317. if r is not None:
  318. status_code = r.status
  319. try:
  320. res = await r.json()
  321. if "error" in res:
  322. detail = f"External: {res['error']}"
  323. except Exception:
  324. detail = f"External: {e}"
  325. raise HTTPException(
  326. status_code=status_code,
  327. detail=detail,
  328. )
  329. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  330. voice_id = payload.get("voice", "")
  331. if voice_id not in get_available_voices(request):
  332. raise HTTPException(
  333. status_code=400,
  334. detail="Invalid voice id",
  335. )
  336. try:
  337. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  338. async with aiohttp.ClientSession(
  339. timeout=timeout, trust_env=True
  340. ) as session:
  341. async with session.post(
  342. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  343. json={
  344. "text": payload["input"],
  345. "model_id": request.app.state.config.TTS_MODEL,
  346. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  347. },
  348. headers={
  349. "Accept": "audio/mpeg",
  350. "Content-Type": "application/json",
  351. "xi-api-key": request.app.state.config.TTS_API_KEY,
  352. },
  353. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  354. ) as r:
  355. r.raise_for_status()
  356. async with aiofiles.open(file_path, "wb") as f:
  357. await f.write(await r.read())
  358. async with aiofiles.open(file_body_path, "w") as f:
  359. await f.write(json.dumps(payload))
  360. return FileResponse(file_path)
  361. except Exception as e:
  362. log.exception(e)
  363. detail = None
  364. try:
  365. if r.status != 200:
  366. res = await r.json()
  367. if "error" in res:
  368. detail = f"External: {res['error'].get('message', '')}"
  369. except Exception:
  370. detail = f"External: {e}"
  371. raise HTTPException(
  372. status_code=getattr(r, "status", 500) if r else 500,
  373. detail=detail if detail else "Open WebUI: Server Connection Error",
  374. )
  375. elif request.app.state.config.TTS_ENGINE == "azure":
  376. try:
  377. payload = json.loads(body.decode("utf-8"))
  378. except Exception as e:
  379. log.exception(e)
  380. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  381. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  382. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  383. language = request.app.state.config.TTS_VOICE
  384. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  385. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  386. try:
  387. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  388. <voice name="{language}">{payload["input"]}</voice>
  389. </speak>"""
  390. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  391. async with aiohttp.ClientSession(
  392. timeout=timeout, trust_env=True
  393. ) as session:
  394. async with session.post(
  395. (base_url or f"https://{region}.tts.speech.microsoft.com")
  396. + "/cognitiveservices/v1",
  397. headers={
  398. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  399. "Content-Type": "application/ssml+xml",
  400. "X-Microsoft-OutputFormat": output_format,
  401. },
  402. data=data,
  403. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  404. ) as r:
  405. r.raise_for_status()
  406. async with aiofiles.open(file_path, "wb") as f:
  407. await f.write(await r.read())
  408. async with aiofiles.open(file_body_path, "w") as f:
  409. await f.write(json.dumps(payload))
  410. return FileResponse(file_path)
  411. except Exception as e:
  412. log.exception(e)
  413. detail = None
  414. try:
  415. if r.status != 200:
  416. res = await r.json()
  417. if "error" in res:
  418. detail = f"External: {res['error'].get('message', '')}"
  419. except Exception:
  420. detail = f"External: {e}"
  421. raise HTTPException(
  422. status_code=getattr(r, "status", 500) if r else 500,
  423. detail=detail if detail else "Open WebUI: Server Connection Error",
  424. )
  425. elif request.app.state.config.TTS_ENGINE == "transformers":
  426. payload = None
  427. try:
  428. payload = json.loads(body.decode("utf-8"))
  429. except Exception as e:
  430. log.exception(e)
  431. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  432. import torch
  433. import soundfile as sf
  434. load_speech_pipeline(request)
  435. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  436. speaker_index = 6799
  437. try:
  438. speaker_index = embeddings_dataset["filename"].index(
  439. request.app.state.config.TTS_MODEL
  440. )
  441. except Exception:
  442. pass
  443. speaker_embedding = torch.tensor(
  444. embeddings_dataset[speaker_index]["xvector"]
  445. ).unsqueeze(0)
  446. speech = request.app.state.speech_synthesiser(
  447. payload["input"],
  448. forward_params={"speaker_embeddings": speaker_embedding},
  449. )
  450. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  451. async with aiofiles.open(file_body_path, "w") as f:
  452. await f.write(json.dumps(payload))
  453. return FileResponse(file_path)
  454. def transcription_handler(request, file_path, metadata):
  455. filename = os.path.basename(file_path)
  456. file_dir = os.path.dirname(file_path)
  457. id = filename.split(".")[0]
  458. metadata = metadata or {}
  459. if request.app.state.config.STT_ENGINE == "":
  460. if request.app.state.faster_whisper_model is None:
  461. request.app.state.faster_whisper_model = set_faster_whisper_model(
  462. request.app.state.config.WHISPER_MODEL
  463. )
  464. model = request.app.state.faster_whisper_model
  465. segments, info = model.transcribe(
  466. file_path,
  467. beam_size=5,
  468. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  469. language=metadata.get("language") or WHISPER_LANGUAGE,
  470. )
  471. log.info(
  472. "Detected language '%s' with probability %f"
  473. % (info.language, info.language_probability)
  474. )
  475. transcript = "".join([segment.text for segment in list(segments)])
  476. data = {"text": transcript.strip()}
  477. # save the transcript to a json file
  478. transcript_file = f"{file_dir}/{id}.json"
  479. with open(transcript_file, "w") as f:
  480. json.dump(data, f)
  481. log.debug(data)
  482. return data
  483. elif request.app.state.config.STT_ENGINE == "openai":
  484. r = None
  485. try:
  486. r = requests.post(
  487. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  488. headers={
  489. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  490. },
  491. files={"file": (filename, open(file_path, "rb"))},
  492. data={
  493. "model": request.app.state.config.STT_MODEL,
  494. **(
  495. {"language": metadata.get("language")}
  496. if metadata.get("language")
  497. else {}
  498. ),
  499. },
  500. )
  501. r.raise_for_status()
  502. data = r.json()
  503. # save the transcript to a json file
  504. transcript_file = f"{file_dir}/{id}.json"
  505. with open(transcript_file, "w") as f:
  506. json.dump(data, f)
  507. return data
  508. except Exception as e:
  509. log.exception(e)
  510. detail = None
  511. if r is not None:
  512. try:
  513. res = r.json()
  514. if "error" in res:
  515. detail = f"External: {res['error'].get('message', '')}"
  516. except Exception:
  517. detail = f"External: {e}"
  518. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  519. elif request.app.state.config.STT_ENGINE == "deepgram":
  520. try:
  521. # Determine the MIME type of the file
  522. mime, _ = mimetypes.guess_type(file_path)
  523. if not mime:
  524. mime = "audio/wav" # fallback to wav if undetectable
  525. # Read the audio file
  526. with open(file_path, "rb") as f:
  527. file_data = f.read()
  528. # Build headers and parameters
  529. headers = {
  530. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  531. "Content-Type": mime,
  532. }
  533. # Add model if specified
  534. params = {}
  535. if request.app.state.config.STT_MODEL:
  536. params["model"] = request.app.state.config.STT_MODEL
  537. # Make request to Deepgram API
  538. r = requests.post(
  539. "https://api.deepgram.com/v1/listen?smart_format=true",
  540. headers=headers,
  541. params=params,
  542. data=file_data,
  543. )
  544. r.raise_for_status()
  545. response_data = r.json()
  546. # Extract transcript from Deepgram response
  547. try:
  548. transcript = response_data["results"]["channels"][0]["alternatives"][
  549. 0
  550. ].get("transcript", "")
  551. except (KeyError, IndexError) as e:
  552. log.error(f"Malformed response from Deepgram: {str(e)}")
  553. raise Exception(
  554. "Failed to parse Deepgram response - unexpected response format"
  555. )
  556. data = {"text": transcript.strip()}
  557. # Save transcript
  558. transcript_file = f"{file_dir}/{id}.json"
  559. with open(transcript_file, "w") as f:
  560. json.dump(data, f)
  561. return data
  562. except Exception as e:
  563. log.exception(e)
  564. detail = None
  565. if r is not None:
  566. try:
  567. res = r.json()
  568. if "error" in res:
  569. detail = f"External: {res['error'].get('message', '')}"
  570. except Exception:
  571. detail = f"External: {e}"
  572. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  573. elif request.app.state.config.STT_ENGINE == "azure":
  574. # Check file exists and size
  575. if not os.path.exists(file_path):
  576. raise HTTPException(status_code=400, detail="Audio file not found")
  577. # Check file size (Azure has a larger limit of 200MB)
  578. file_size = os.path.getsize(file_path)
  579. if file_size > AZURE_MAX_FILE_SIZE:
  580. raise HTTPException(
  581. status_code=400,
  582. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  583. )
  584. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  585. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  586. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  587. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  588. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  589. # IF NO LOCALES, USE DEFAULTS
  590. if len(locales) < 2:
  591. locales = [
  592. "en-US",
  593. "es-ES",
  594. "es-MX",
  595. "fr-FR",
  596. "hi-IN",
  597. "it-IT",
  598. "de-DE",
  599. "en-GB",
  600. "en-IN",
  601. "ja-JP",
  602. "ko-KR",
  603. "pt-BR",
  604. "zh-CN",
  605. ]
  606. locales = ",".join(locales)
  607. if not api_key or not region:
  608. raise HTTPException(
  609. status_code=400,
  610. detail="Azure API key is required for Azure STT",
  611. )
  612. r = None
  613. try:
  614. # Prepare the request
  615. data = {
  616. "definition": json.dumps(
  617. {
  618. "locales": locales.split(","),
  619. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  620. }
  621. if locales
  622. else {}
  623. )
  624. }
  625. url = (
  626. base_url or f"https://{region}.api.cognitive.microsoft.com"
  627. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  628. # Use context manager to ensure file is properly closed
  629. with open(file_path, "rb") as audio_file:
  630. r = requests.post(
  631. url=url,
  632. files={"audio": audio_file},
  633. data=data,
  634. headers={
  635. "Ocp-Apim-Subscription-Key": api_key,
  636. },
  637. )
  638. r.raise_for_status()
  639. response = r.json()
  640. # Extract transcript from response
  641. if not response.get("combinedPhrases"):
  642. raise ValueError("No transcription found in response")
  643. # Get the full transcript from combinedPhrases
  644. transcript = response["combinedPhrases"][0].get("text", "").strip()
  645. if not transcript:
  646. raise ValueError("Empty transcript in response")
  647. data = {"text": transcript}
  648. # Save transcript to json file (consistent with other providers)
  649. transcript_file = f"{file_dir}/{id}.json"
  650. with open(transcript_file, "w") as f:
  651. json.dump(data, f)
  652. log.debug(data)
  653. return data
  654. except (KeyError, IndexError, ValueError) as e:
  655. log.exception("Error parsing Azure response")
  656. raise HTTPException(
  657. status_code=500,
  658. detail=f"Failed to parse Azure response: {str(e)}",
  659. )
  660. except requests.exceptions.RequestException as e:
  661. log.exception(e)
  662. detail = None
  663. try:
  664. if r is not None and r.status_code != 200:
  665. res = r.json()
  666. if "error" in res:
  667. detail = f"External: {res['error'].get('message', '')}"
  668. except Exception:
  669. detail = f"External: {e}"
  670. raise HTTPException(
  671. status_code=getattr(r, "status_code", 500) if r else 500,
  672. detail=detail if detail else "Open WebUI: Server Connection Error",
  673. )
  674. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  675. log.info(f"transcribe: {file_path} {metadata}")
  676. if is_audio_conversion_required(file_path):
  677. file_path = convert_audio_to_mp3(file_path)
  678. try:
  679. file_path = compress_audio(file_path)
  680. except Exception as e:
  681. log.exception(e)
  682. # Always produce a list of chunk paths (could be one entry if small)
  683. try:
  684. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  685. print(f"Chunk paths: {chunk_paths}")
  686. except Exception as e:
  687. log.exception(e)
  688. raise HTTPException(
  689. status_code=status.HTTP_400_BAD_REQUEST,
  690. detail=ERROR_MESSAGES.DEFAULT(e),
  691. )
  692. results = []
  693. try:
  694. with ThreadPoolExecutor() as executor:
  695. # Submit tasks for each chunk_path
  696. futures = [
  697. executor.submit(transcription_handler, request, chunk_path, metadata)
  698. for chunk_path in chunk_paths
  699. ]
  700. # Gather results as they complete
  701. for future in futures:
  702. try:
  703. results.append(future.result())
  704. except Exception as transcribe_exc:
  705. raise HTTPException(
  706. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  707. detail=f"Error transcribing chunk: {transcribe_exc}",
  708. )
  709. finally:
  710. # Clean up only the temporary chunks, never the original file
  711. for chunk_path in chunk_paths:
  712. if chunk_path != file_path and os.path.isfile(chunk_path):
  713. try:
  714. os.remove(chunk_path)
  715. except Exception:
  716. pass
  717. return {
  718. "text": " ".join([result["text"] for result in results]),
  719. }
  720. def compress_audio(file_path):
  721. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  722. id = os.path.splitext(os.path.basename(file_path))[
  723. 0
  724. ] # Handles names with multiple dots
  725. file_dir = os.path.dirname(file_path)
  726. audio = AudioSegment.from_file(file_path)
  727. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  728. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  729. audio.export(compressed_path, format="mp3", bitrate="32k")
  730. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  731. return compressed_path
  732. else:
  733. return file_path
  734. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  735. """
  736. Splits audio into chunks not exceeding max_bytes.
  737. Returns a list of chunk file paths. If audio fits, returns list with original path.
  738. """
  739. file_size = os.path.getsize(file_path)
  740. if file_size <= max_bytes:
  741. return [file_path] # Nothing to split
  742. audio = AudioSegment.from_file(file_path)
  743. duration_ms = len(audio)
  744. orig_size = file_size
  745. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  746. chunks = []
  747. start = 0
  748. i = 0
  749. base, _ = os.path.splitext(file_path)
  750. while start < duration_ms:
  751. end = min(start + approx_chunk_ms, duration_ms)
  752. chunk = audio[start:end]
  753. chunk_path = f"{base}_chunk_{i}.{format}"
  754. chunk.export(chunk_path, format=format, bitrate=bitrate)
  755. # Reduce chunk duration if still too large
  756. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  757. end = start + ((end - start) // 2)
  758. chunk = audio[start:end]
  759. chunk.export(chunk_path, format=format, bitrate=bitrate)
  760. if os.path.getsize(chunk_path) > max_bytes:
  761. os.remove(chunk_path)
  762. raise Exception("Audio chunk cannot be reduced below max file size.")
  763. chunks.append(chunk_path)
  764. start = end
  765. i += 1
  766. return chunks
  767. @router.post("/transcriptions")
  768. def transcription(
  769. request: Request,
  770. file: UploadFile = File(...),
  771. language: Optional[str] = Form(None),
  772. user=Depends(get_verified_user),
  773. ):
  774. log.info(f"file.content_type: {file.content_type}")
  775. stt_supported_content_types = getattr(
  776. request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
  777. )
  778. if not any(
  779. fnmatch(file.content_type, content_type)
  780. for content_type in (
  781. stt_supported_content_types
  782. if stt_supported_content_types
  783. and any(t.strip() for t in stt_supported_content_types)
  784. else ["audio/*", "video/webm"]
  785. )
  786. ):
  787. raise HTTPException(
  788. status_code=status.HTTP_400_BAD_REQUEST,
  789. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  790. )
  791. try:
  792. ext = file.filename.split(".")[-1]
  793. id = uuid.uuid4()
  794. filename = f"{id}.{ext}"
  795. contents = file.file.read()
  796. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  797. os.makedirs(file_dir, exist_ok=True)
  798. file_path = f"{file_dir}/{filename}"
  799. with open(file_path, "wb") as f:
  800. f.write(contents)
  801. try:
  802. metadata = None
  803. if language:
  804. metadata = {"language": language}
  805. result = transcribe(request, file_path, metadata)
  806. return {
  807. **result,
  808. "filename": os.path.basename(file_path),
  809. }
  810. except Exception as e:
  811. log.exception(e)
  812. raise HTTPException(
  813. status_code=status.HTTP_400_BAD_REQUEST,
  814. detail=ERROR_MESSAGES.DEFAULT(e),
  815. )
  816. except Exception as e:
  817. log.exception(e)
  818. raise HTTPException(
  819. status_code=status.HTTP_400_BAD_REQUEST,
  820. detail=ERROR_MESSAGES.DEFAULT(e),
  821. )
  822. def get_available_models(request: Request) -> list[dict]:
  823. available_models = []
  824. if request.app.state.config.TTS_ENGINE == "openai":
  825. # Use custom endpoint if not using the official OpenAI API URL
  826. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  827. "https://api.openai.com"
  828. ):
  829. try:
  830. response = requests.get(
  831. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  832. )
  833. response.raise_for_status()
  834. data = response.json()
  835. available_models = data.get("models", [])
  836. except Exception as e:
  837. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  838. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  839. else:
  840. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  841. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  842. try:
  843. response = requests.get(
  844. "https://api.elevenlabs.io/v1/models",
  845. headers={
  846. "xi-api-key": request.app.state.config.TTS_API_KEY,
  847. "Content-Type": "application/json",
  848. },
  849. timeout=5,
  850. )
  851. response.raise_for_status()
  852. models = response.json()
  853. available_models = [
  854. {"name": model["name"], "id": model["model_id"]} for model in models
  855. ]
  856. except requests.RequestException as e:
  857. log.error(f"Error fetching voices: {str(e)}")
  858. return available_models
  859. @router.get("/models")
  860. async def get_models(request: Request, user=Depends(get_verified_user)):
  861. return {"models": get_available_models(request)}
  862. def get_available_voices(request) -> dict:
  863. """Returns {voice_id: voice_name} dict"""
  864. available_voices = {}
  865. if request.app.state.config.TTS_ENGINE == "openai":
  866. # Use custom endpoint if not using the official OpenAI API URL
  867. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  868. "https://api.openai.com"
  869. ):
  870. try:
  871. response = requests.get(
  872. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  873. )
  874. response.raise_for_status()
  875. data = response.json()
  876. voices_list = data.get("voices", [])
  877. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  878. except Exception as e:
  879. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  880. available_voices = {
  881. "alloy": "alloy",
  882. "echo": "echo",
  883. "fable": "fable",
  884. "onyx": "onyx",
  885. "nova": "nova",
  886. "shimmer": "shimmer",
  887. }
  888. else:
  889. available_voices = {
  890. "alloy": "alloy",
  891. "echo": "echo",
  892. "fable": "fable",
  893. "onyx": "onyx",
  894. "nova": "nova",
  895. "shimmer": "shimmer",
  896. }
  897. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  898. try:
  899. available_voices = get_elevenlabs_voices(
  900. api_key=request.app.state.config.TTS_API_KEY
  901. )
  902. except Exception:
  903. # Avoided @lru_cache with exception
  904. pass
  905. elif request.app.state.config.TTS_ENGINE == "azure":
  906. try:
  907. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  908. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  909. url = (
  910. base_url or f"https://{region}.tts.speech.microsoft.com"
  911. ) + "/cognitiveservices/voices/list"
  912. headers = {
  913. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  914. }
  915. response = requests.get(url, headers=headers)
  916. response.raise_for_status()
  917. voices = response.json()
  918. for voice in voices:
  919. available_voices[voice["ShortName"]] = (
  920. f"{voice['DisplayName']} ({voice['ShortName']})"
  921. )
  922. except requests.RequestException as e:
  923. log.error(f"Error fetching voices: {str(e)}")
  924. return available_voices
  925. @lru_cache
  926. def get_elevenlabs_voices(api_key: str) -> dict:
  927. """
  928. Note, set the following in your .env file to use Elevenlabs:
  929. AUDIO_TTS_ENGINE=elevenlabs
  930. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  931. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  932. AUDIO_TTS_MODEL=eleven_multilingual_v2
  933. """
  934. try:
  935. # TODO: Add retries
  936. response = requests.get(
  937. "https://api.elevenlabs.io/v1/voices",
  938. headers={
  939. "xi-api-key": api_key,
  940. "Content-Type": "application/json",
  941. },
  942. )
  943. response.raise_for_status()
  944. voices_data = response.json()
  945. voices = {}
  946. for voice in voices_data.get("voices", []):
  947. voices[voice["voice_id"]] = voice["name"]
  948. except requests.RequestException as e:
  949. # Avoid @lru_cache with exception
  950. log.error(f"Error fetching voices: {str(e)}")
  951. raise RuntimeError(f"Error fetching voices: {str(e)}")
  952. return voices
  953. @router.get("/voices")
  954. async def get_voices(request: Request, user=Depends(get_verified_user)):
  955. return {
  956. "voices": [
  957. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  958. ]
  959. }