1
0

audio.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. import html
  7. from functools import lru_cache
  8. from pydub import AudioSegment
  9. from pydub.silence import split_on_silence
  10. from concurrent.futures import ThreadPoolExecutor
  11. from typing import Optional
  12. from fnmatch import fnmatch
  13. import aiohttp
  14. import aiofiles
  15. import requests
  16. import mimetypes
  17. from urllib.parse import urljoin, quote
  18. from fastapi import (
  19. Depends,
  20. FastAPI,
  21. File,
  22. Form,
  23. HTTPException,
  24. Request,
  25. UploadFile,
  26. status,
  27. APIRouter,
  28. )
  29. from fastapi.middleware.cors import CORSMiddleware
  30. from fastapi.responses import FileResponse
  31. from pydantic import BaseModel
  32. from open_webui.utils.auth import get_admin_user, get_verified_user
  33. from open_webui.config import (
  34. WHISPER_MODEL_AUTO_UPDATE,
  35. WHISPER_MODEL_DIR,
  36. CACHE_DIR,
  37. WHISPER_LANGUAGE,
  38. )
  39. from open_webui.constants import ERROR_MESSAGES
  40. from open_webui.env import (
  41. AIOHTTP_CLIENT_SESSION_SSL,
  42. AIOHTTP_CLIENT_TIMEOUT,
  43. ENV,
  44. SRC_LOG_LEVELS,
  45. DEVICE_TYPE,
  46. ENABLE_FORWARD_USER_INFO_HEADERS,
  47. )
  48. router = APIRouter()
  49. # Constants
  50. MAX_FILE_SIZE_MB = 20
  51. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  52. AZURE_MAX_FILE_SIZE_MB = 200
  53. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  54. log = logging.getLogger(__name__)
  55. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  56. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  57. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  58. ##########################################
  59. #
  60. # Utility functions
  61. #
  62. ##########################################
  63. from pydub import AudioSegment
  64. from pydub.utils import mediainfo
  65. def is_audio_conversion_required(file_path):
  66. """
  67. Check if the given audio file needs conversion to mp3.
  68. """
  69. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  70. if not os.path.isfile(file_path):
  71. log.error(f"File not found: {file_path}")
  72. return False
  73. try:
  74. info = mediainfo(file_path)
  75. codec_name = info.get("codec_name", "").lower()
  76. codec_type = info.get("codec_type", "").lower()
  77. codec_tag_string = info.get("codec_tag_string", "").lower()
  78. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  79. # File is AAC/mp4a audio, recommend mp3 conversion
  80. return True
  81. # If the codec name is in the supported formats
  82. if codec_name in SUPPORTED_FORMATS:
  83. return False
  84. return True
  85. except Exception as e:
  86. log.error(f"Error getting audio format: {e}")
  87. return False
  88. def convert_audio_to_mp3(file_path):
  89. """Convert audio file to mp3 format."""
  90. try:
  91. output_path = os.path.splitext(file_path)[0] + ".mp3"
  92. audio = AudioSegment.from_file(file_path)
  93. audio.export(output_path, format="mp3")
  94. log.info(f"Converted {file_path} to {output_path}")
  95. return output_path
  96. except Exception as e:
  97. log.error(f"Error converting audio file: {e}")
  98. return None
  99. def set_faster_whisper_model(model: str, auto_update: bool = False):
  100. whisper_model = None
  101. if model:
  102. from faster_whisper import WhisperModel
  103. faster_whisper_kwargs = {
  104. "model_size_or_path": model,
  105. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  106. "compute_type": "int8",
  107. "download_root": WHISPER_MODEL_DIR,
  108. "local_files_only": not auto_update,
  109. }
  110. try:
  111. whisper_model = WhisperModel(**faster_whisper_kwargs)
  112. except Exception:
  113. log.warning(
  114. "WhisperModel initialization failed, attempting download with local_files_only=False"
  115. )
  116. faster_whisper_kwargs["local_files_only"] = False
  117. whisper_model = WhisperModel(**faster_whisper_kwargs)
  118. return whisper_model
  119. ##########################################
  120. #
  121. # Audio API
  122. #
  123. ##########################################
  124. class TTSConfigForm(BaseModel):
  125. OPENAI_API_BASE_URL: str
  126. OPENAI_API_KEY: str
  127. OPENAI_PARAMS: Optional[dict] = None
  128. API_KEY: str
  129. ENGINE: str
  130. MODEL: str
  131. VOICE: str
  132. SPLIT_ON: str
  133. AZURE_SPEECH_REGION: str
  134. AZURE_SPEECH_BASE_URL: str
  135. AZURE_SPEECH_OUTPUT_FORMAT: str
  136. class STTConfigForm(BaseModel):
  137. OPENAI_API_BASE_URL: str
  138. OPENAI_API_KEY: str
  139. ENGINE: str
  140. MODEL: str
  141. SUPPORTED_CONTENT_TYPES: list[str] = []
  142. WHISPER_MODEL: str
  143. DEEPGRAM_API_KEY: str
  144. AZURE_API_KEY: str
  145. AZURE_REGION: str
  146. AZURE_LOCALES: str
  147. AZURE_BASE_URL: str
  148. AZURE_MAX_SPEAKERS: str
  149. class AudioConfigUpdateForm(BaseModel):
  150. tts: TTSConfigForm
  151. stt: STTConfigForm
  152. @router.get("/config")
  153. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  154. return {
  155. "tts": {
  156. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  157. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  158. "OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS,
  159. "API_KEY": request.app.state.config.TTS_API_KEY,
  160. "ENGINE": request.app.state.config.TTS_ENGINE,
  161. "MODEL": request.app.state.config.TTS_MODEL,
  162. "VOICE": request.app.state.config.TTS_VOICE,
  163. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  164. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  165. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  166. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  167. },
  168. "stt": {
  169. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  170. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  171. "ENGINE": request.app.state.config.STT_ENGINE,
  172. "MODEL": request.app.state.config.STT_MODEL,
  173. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  174. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  175. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  176. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  177. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  178. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  179. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  180. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  181. },
  182. }
  183. @router.post("/config/update")
  184. async def update_audio_config(
  185. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  186. ):
  187. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  188. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  189. request.app.state.config.TTS_OPENAI_PARAMS = form_data.tts.OPENAI_PARAMS
  190. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  191. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  192. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  193. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  194. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  195. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  196. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  197. form_data.tts.AZURE_SPEECH_BASE_URL
  198. )
  199. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  200. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  201. )
  202. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  203. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  204. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  205. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  206. request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = (
  207. form_data.stt.SUPPORTED_CONTENT_TYPES
  208. )
  209. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  210. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  211. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  212. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  213. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  214. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  215. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  216. form_data.stt.AZURE_MAX_SPEAKERS
  217. )
  218. if request.app.state.config.STT_ENGINE == "":
  219. request.app.state.faster_whisper_model = set_faster_whisper_model(
  220. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  221. )
  222. else:
  223. request.app.state.faster_whisper_model = None
  224. return {
  225. "tts": {
  226. "ENGINE": request.app.state.config.TTS_ENGINE,
  227. "MODEL": request.app.state.config.TTS_MODEL,
  228. "VOICE": request.app.state.config.TTS_VOICE,
  229. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  230. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  231. "OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS,
  232. "API_KEY": request.app.state.config.TTS_API_KEY,
  233. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  234. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  235. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  236. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  237. },
  238. "stt": {
  239. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  240. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  241. "ENGINE": request.app.state.config.STT_ENGINE,
  242. "MODEL": request.app.state.config.STT_MODEL,
  243. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  244. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  245. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  246. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  247. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  248. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  249. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  250. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  251. },
  252. }
  253. def load_speech_pipeline(request):
  254. from transformers import pipeline
  255. from datasets import load_dataset
  256. if request.app.state.speech_synthesiser is None:
  257. request.app.state.speech_synthesiser = pipeline(
  258. "text-to-speech", "microsoft/speecht5_tts"
  259. )
  260. if request.app.state.speech_speaker_embeddings_dataset is None:
  261. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  262. "Matthijs/cmu-arctic-xvectors", split="validation"
  263. )
  264. @router.post("/speech")
  265. async def speech(request: Request, user=Depends(get_verified_user)):
  266. body = await request.body()
  267. name = hashlib.sha256(
  268. body
  269. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  270. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  271. ).hexdigest()
  272. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  273. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  274. # Check if the file already exists in the cache
  275. if file_path.is_file():
  276. return FileResponse(file_path)
  277. payload = None
  278. try:
  279. payload = json.loads(body.decode("utf-8"))
  280. except Exception as e:
  281. log.exception(e)
  282. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  283. r = None
  284. if request.app.state.config.TTS_ENGINE == "openai":
  285. payload["model"] = request.app.state.config.TTS_MODEL
  286. try:
  287. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  288. async with aiohttp.ClientSession(
  289. timeout=timeout, trust_env=True
  290. ) as session:
  291. payload = {
  292. **payload,
  293. **(request.app.state.config.TTS_OPENAI_PARAMS or {}),
  294. }
  295. r = await session.post(
  296. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  297. json=payload,
  298. headers={
  299. "Content-Type": "application/json",
  300. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  301. **(
  302. {
  303. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  304. "X-OpenWebUI-User-Id": user.id,
  305. "X-OpenWebUI-User-Email": user.email,
  306. "X-OpenWebUI-User-Role": user.role,
  307. }
  308. if ENABLE_FORWARD_USER_INFO_HEADERS
  309. else {}
  310. ),
  311. },
  312. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  313. )
  314. r.raise_for_status()
  315. async with aiofiles.open(file_path, "wb") as f:
  316. await f.write(await r.read())
  317. async with aiofiles.open(file_body_path, "w") as f:
  318. await f.write(json.dumps(payload))
  319. return FileResponse(file_path)
  320. except Exception as e:
  321. log.exception(e)
  322. detail = None
  323. status_code = 500
  324. detail = f"Open WebUI: Server Connection Error"
  325. if r is not None:
  326. status_code = r.status
  327. try:
  328. res = await r.json()
  329. if "error" in res:
  330. detail = f"External: {res['error']}"
  331. except Exception:
  332. detail = f"External: {e}"
  333. raise HTTPException(
  334. status_code=status_code,
  335. detail=detail,
  336. )
  337. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  338. voice_id = payload.get("voice", "")
  339. if voice_id not in get_available_voices(request):
  340. raise HTTPException(
  341. status_code=400,
  342. detail="Invalid voice id",
  343. )
  344. try:
  345. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  346. async with aiohttp.ClientSession(
  347. timeout=timeout, trust_env=True
  348. ) as session:
  349. async with session.post(
  350. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  351. json={
  352. "text": payload["input"],
  353. "model_id": request.app.state.config.TTS_MODEL,
  354. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  355. },
  356. headers={
  357. "Accept": "audio/mpeg",
  358. "Content-Type": "application/json",
  359. "xi-api-key": request.app.state.config.TTS_API_KEY,
  360. },
  361. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  362. ) as r:
  363. r.raise_for_status()
  364. async with aiofiles.open(file_path, "wb") as f:
  365. await f.write(await r.read())
  366. async with aiofiles.open(file_body_path, "w") as f:
  367. await f.write(json.dumps(payload))
  368. return FileResponse(file_path)
  369. except Exception as e:
  370. log.exception(e)
  371. detail = None
  372. try:
  373. if r.status != 200:
  374. res = await r.json()
  375. if "error" in res:
  376. detail = f"External: {res['error'].get('message', '')}"
  377. except Exception:
  378. detail = f"External: {e}"
  379. raise HTTPException(
  380. status_code=getattr(r, "status", 500) if r else 500,
  381. detail=detail if detail else "Open WebUI: Server Connection Error",
  382. )
  383. elif request.app.state.config.TTS_ENGINE == "azure":
  384. try:
  385. payload = json.loads(body.decode("utf-8"))
  386. except Exception as e:
  387. log.exception(e)
  388. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  389. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  390. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  391. language = request.app.state.config.TTS_VOICE
  392. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  393. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  394. try:
  395. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  396. <voice name="{language}">{html.escape(payload["input"])}</voice>
  397. </speak>"""
  398. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  399. async with aiohttp.ClientSession(
  400. timeout=timeout, trust_env=True
  401. ) as session:
  402. async with session.post(
  403. (base_url or f"https://{region}.tts.speech.microsoft.com")
  404. + "/cognitiveservices/v1",
  405. headers={
  406. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  407. "Content-Type": "application/ssml+xml",
  408. "X-Microsoft-OutputFormat": output_format,
  409. },
  410. data=data,
  411. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  412. ) as r:
  413. r.raise_for_status()
  414. async with aiofiles.open(file_path, "wb") as f:
  415. await f.write(await r.read())
  416. async with aiofiles.open(file_body_path, "w") as f:
  417. await f.write(json.dumps(payload))
  418. return FileResponse(file_path)
  419. except Exception as e:
  420. log.exception(e)
  421. detail = None
  422. try:
  423. if r.status != 200:
  424. res = await r.json()
  425. if "error" in res:
  426. detail = f"External: {res['error'].get('message', '')}"
  427. except Exception:
  428. detail = f"External: {e}"
  429. raise HTTPException(
  430. status_code=getattr(r, "status", 500) if r else 500,
  431. detail=detail if detail else "Open WebUI: Server Connection Error",
  432. )
  433. elif request.app.state.config.TTS_ENGINE == "transformers":
  434. payload = None
  435. try:
  436. payload = json.loads(body.decode("utf-8"))
  437. except Exception as e:
  438. log.exception(e)
  439. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  440. import torch
  441. import soundfile as sf
  442. load_speech_pipeline(request)
  443. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  444. speaker_index = 6799
  445. try:
  446. speaker_index = embeddings_dataset["filename"].index(
  447. request.app.state.config.TTS_MODEL
  448. )
  449. except Exception:
  450. pass
  451. speaker_embedding = torch.tensor(
  452. embeddings_dataset[speaker_index]["xvector"]
  453. ).unsqueeze(0)
  454. speech = request.app.state.speech_synthesiser(
  455. payload["input"],
  456. forward_params={"speaker_embeddings": speaker_embedding},
  457. )
  458. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  459. async with aiofiles.open(file_body_path, "w") as f:
  460. await f.write(json.dumps(payload))
  461. return FileResponse(file_path)
  462. def transcription_handler(request, file_path, metadata):
  463. filename = os.path.basename(file_path)
  464. file_dir = os.path.dirname(file_path)
  465. id = filename.split(".")[0]
  466. metadata = metadata or {}
  467. languages = [
  468. metadata.get("language", None) if not WHISPER_LANGUAGE else WHISPER_LANGUAGE,
  469. None, # Always fallback to None in case transcription fails
  470. ]
  471. if request.app.state.config.STT_ENGINE == "":
  472. if request.app.state.faster_whisper_model is None:
  473. request.app.state.faster_whisper_model = set_faster_whisper_model(
  474. request.app.state.config.WHISPER_MODEL
  475. )
  476. model = request.app.state.faster_whisper_model
  477. segments, info = model.transcribe(
  478. file_path,
  479. beam_size=5,
  480. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  481. language=languages[0],
  482. )
  483. log.info(
  484. "Detected language '%s' with probability %f"
  485. % (info.language, info.language_probability)
  486. )
  487. transcript = "".join([segment.text for segment in list(segments)])
  488. data = {"text": transcript.strip()}
  489. # save the transcript to a json file
  490. transcript_file = f"{file_dir}/{id}.json"
  491. with open(transcript_file, "w") as f:
  492. json.dump(data, f)
  493. log.debug(data)
  494. return data
  495. elif request.app.state.config.STT_ENGINE == "openai":
  496. r = None
  497. try:
  498. for language in languages:
  499. payload = {
  500. "model": request.app.state.config.STT_MODEL,
  501. }
  502. if language:
  503. payload["language"] = language
  504. r = requests.post(
  505. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  506. headers={
  507. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  508. },
  509. files={"file": (filename, open(file_path, "rb"))},
  510. data=payload,
  511. )
  512. if r.status_code == 200:
  513. # Successful transcription
  514. break
  515. r.raise_for_status()
  516. data = r.json()
  517. # save the transcript to a json file
  518. transcript_file = f"{file_dir}/{id}.json"
  519. with open(transcript_file, "w") as f:
  520. json.dump(data, f)
  521. return data
  522. except Exception as e:
  523. log.exception(e)
  524. detail = None
  525. if r is not None:
  526. try:
  527. res = r.json()
  528. if "error" in res:
  529. detail = f"External: {res['error'].get('message', '')}"
  530. except Exception:
  531. detail = f"External: {e}"
  532. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  533. elif request.app.state.config.STT_ENGINE == "deepgram":
  534. try:
  535. # Determine the MIME type of the file
  536. mime, _ = mimetypes.guess_type(file_path)
  537. if not mime:
  538. mime = "audio/wav" # fallback to wav if undetectable
  539. # Read the audio file
  540. with open(file_path, "rb") as f:
  541. file_data = f.read()
  542. # Build headers and parameters
  543. headers = {
  544. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  545. "Content-Type": mime,
  546. }
  547. for language in languages:
  548. params = {}
  549. if request.app.state.config.STT_MODEL:
  550. params["model"] = request.app.state.config.STT_MODEL
  551. if language:
  552. params["language"] = language
  553. # Make request to Deepgram API
  554. r = requests.post(
  555. "https://api.deepgram.com/v1/listen?smart_format=true",
  556. headers=headers,
  557. params=params,
  558. data=file_data,
  559. )
  560. if r.status_code == 200:
  561. # Successful transcription
  562. break
  563. r.raise_for_status()
  564. response_data = r.json()
  565. # Extract transcript from Deepgram response
  566. try:
  567. transcript = response_data["results"]["channels"][0]["alternatives"][
  568. 0
  569. ].get("transcript", "")
  570. except (KeyError, IndexError) as e:
  571. log.error(f"Malformed response from Deepgram: {str(e)}")
  572. raise Exception(
  573. "Failed to parse Deepgram response - unexpected response format"
  574. )
  575. data = {"text": transcript.strip()}
  576. # Save transcript
  577. transcript_file = f"{file_dir}/{id}.json"
  578. with open(transcript_file, "w") as f:
  579. json.dump(data, f)
  580. return data
  581. except Exception as e:
  582. log.exception(e)
  583. detail = None
  584. if r is not None:
  585. try:
  586. res = r.json()
  587. if "error" in res:
  588. detail = f"External: {res['error'].get('message', '')}"
  589. except Exception:
  590. detail = f"External: {e}"
  591. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  592. elif request.app.state.config.STT_ENGINE == "azure":
  593. # Check file exists and size
  594. if not os.path.exists(file_path):
  595. raise HTTPException(status_code=400, detail="Audio file not found")
  596. # Check file size (Azure has a larger limit of 200MB)
  597. file_size = os.path.getsize(file_path)
  598. if file_size > AZURE_MAX_FILE_SIZE:
  599. raise HTTPException(
  600. status_code=400,
  601. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  602. )
  603. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  604. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  605. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  606. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  607. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  608. # IF NO LOCALES, USE DEFAULTS
  609. if len(locales) < 2:
  610. locales = [
  611. "en-US",
  612. "es-ES",
  613. "es-MX",
  614. "fr-FR",
  615. "hi-IN",
  616. "it-IT",
  617. "de-DE",
  618. "en-GB",
  619. "en-IN",
  620. "ja-JP",
  621. "ko-KR",
  622. "pt-BR",
  623. "zh-CN",
  624. ]
  625. locales = ",".join(locales)
  626. if not api_key or not region:
  627. raise HTTPException(
  628. status_code=400,
  629. detail="Azure API key is required for Azure STT",
  630. )
  631. r = None
  632. try:
  633. # Prepare the request
  634. data = {
  635. "definition": json.dumps(
  636. {
  637. "locales": locales.split(","),
  638. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  639. }
  640. if locales
  641. else {}
  642. )
  643. }
  644. url = (
  645. base_url or f"https://{region}.api.cognitive.microsoft.com"
  646. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  647. # Use context manager to ensure file is properly closed
  648. with open(file_path, "rb") as audio_file:
  649. r = requests.post(
  650. url=url,
  651. files={"audio": audio_file},
  652. data=data,
  653. headers={
  654. "Ocp-Apim-Subscription-Key": api_key,
  655. },
  656. )
  657. r.raise_for_status()
  658. response = r.json()
  659. # Extract transcript from response
  660. if not response.get("combinedPhrases"):
  661. raise ValueError("No transcription found in response")
  662. # Get the full transcript from combinedPhrases
  663. transcript = response["combinedPhrases"][0].get("text", "").strip()
  664. if not transcript:
  665. raise ValueError("Empty transcript in response")
  666. data = {"text": transcript}
  667. # Save transcript to json file (consistent with other providers)
  668. transcript_file = f"{file_dir}/{id}.json"
  669. with open(transcript_file, "w") as f:
  670. json.dump(data, f)
  671. log.debug(data)
  672. return data
  673. except (KeyError, IndexError, ValueError) as e:
  674. log.exception("Error parsing Azure response")
  675. raise HTTPException(
  676. status_code=500,
  677. detail=f"Failed to parse Azure response: {str(e)}",
  678. )
  679. except requests.exceptions.RequestException as e:
  680. log.exception(e)
  681. detail = None
  682. try:
  683. if r is not None and r.status_code != 200:
  684. res = r.json()
  685. if "error" in res:
  686. detail = f"External: {res['error'].get('message', '')}"
  687. except Exception:
  688. detail = f"External: {e}"
  689. raise HTTPException(
  690. status_code=getattr(r, "status_code", 500) if r else 500,
  691. detail=detail if detail else "Open WebUI: Server Connection Error",
  692. )
  693. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  694. log.info(f"transcribe: {file_path} {metadata}")
  695. if is_audio_conversion_required(file_path):
  696. file_path = convert_audio_to_mp3(file_path)
  697. try:
  698. file_path = compress_audio(file_path)
  699. except Exception as e:
  700. log.exception(e)
  701. # Always produce a list of chunk paths (could be one entry if small)
  702. try:
  703. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  704. print(f"Chunk paths: {chunk_paths}")
  705. except Exception as e:
  706. log.exception(e)
  707. raise HTTPException(
  708. status_code=status.HTTP_400_BAD_REQUEST,
  709. detail=ERROR_MESSAGES.DEFAULT(e),
  710. )
  711. results = []
  712. try:
  713. with ThreadPoolExecutor() as executor:
  714. # Submit tasks for each chunk_path
  715. futures = [
  716. executor.submit(transcription_handler, request, chunk_path, metadata)
  717. for chunk_path in chunk_paths
  718. ]
  719. # Gather results as they complete
  720. for future in futures:
  721. try:
  722. results.append(future.result())
  723. except Exception as transcribe_exc:
  724. raise HTTPException(
  725. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  726. detail=f"Error transcribing chunk: {transcribe_exc}",
  727. )
  728. finally:
  729. # Clean up only the temporary chunks, never the original file
  730. for chunk_path in chunk_paths:
  731. if chunk_path != file_path and os.path.isfile(chunk_path):
  732. try:
  733. os.remove(chunk_path)
  734. except Exception:
  735. pass
  736. return {
  737. "text": " ".join([result["text"] for result in results]),
  738. }
  739. def compress_audio(file_path):
  740. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  741. id = os.path.splitext(os.path.basename(file_path))[
  742. 0
  743. ] # Handles names with multiple dots
  744. file_dir = os.path.dirname(file_path)
  745. audio = AudioSegment.from_file(file_path)
  746. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  747. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  748. audio.export(compressed_path, format="mp3", bitrate="32k")
  749. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  750. return compressed_path
  751. else:
  752. return file_path
  753. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  754. """
  755. Splits audio into chunks not exceeding max_bytes.
  756. Returns a list of chunk file paths. If audio fits, returns list with original path.
  757. """
  758. file_size = os.path.getsize(file_path)
  759. if file_size <= max_bytes:
  760. return [file_path] # Nothing to split
  761. audio = AudioSegment.from_file(file_path)
  762. duration_ms = len(audio)
  763. orig_size = file_size
  764. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  765. chunks = []
  766. start = 0
  767. i = 0
  768. base, _ = os.path.splitext(file_path)
  769. while start < duration_ms:
  770. end = min(start + approx_chunk_ms, duration_ms)
  771. chunk = audio[start:end]
  772. chunk_path = f"{base}_chunk_{i}.{format}"
  773. chunk.export(chunk_path, format=format, bitrate=bitrate)
  774. # Reduce chunk duration if still too large
  775. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  776. end = start + ((end - start) // 2)
  777. chunk = audio[start:end]
  778. chunk.export(chunk_path, format=format, bitrate=bitrate)
  779. if os.path.getsize(chunk_path) > max_bytes:
  780. os.remove(chunk_path)
  781. raise Exception("Audio chunk cannot be reduced below max file size.")
  782. chunks.append(chunk_path)
  783. start = end
  784. i += 1
  785. return chunks
  786. @router.post("/transcriptions")
  787. def transcription(
  788. request: Request,
  789. file: UploadFile = File(...),
  790. language: Optional[str] = Form(None),
  791. user=Depends(get_verified_user),
  792. ):
  793. log.info(f"file.content_type: {file.content_type}")
  794. stt_supported_content_types = getattr(
  795. request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
  796. )
  797. if not any(
  798. fnmatch(file.content_type, content_type)
  799. for content_type in (
  800. stt_supported_content_types
  801. if stt_supported_content_types
  802. and any(t.strip() for t in stt_supported_content_types)
  803. else ["audio/*", "video/webm"]
  804. )
  805. ):
  806. raise HTTPException(
  807. status_code=status.HTTP_400_BAD_REQUEST,
  808. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  809. )
  810. try:
  811. ext = file.filename.split(".")[-1]
  812. id = uuid.uuid4()
  813. filename = f"{id}.{ext}"
  814. contents = file.file.read()
  815. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  816. os.makedirs(file_dir, exist_ok=True)
  817. file_path = f"{file_dir}/{filename}"
  818. with open(file_path, "wb") as f:
  819. f.write(contents)
  820. try:
  821. metadata = None
  822. if language:
  823. metadata = {"language": language}
  824. result = transcribe(request, file_path, metadata)
  825. return {
  826. **result,
  827. "filename": os.path.basename(file_path),
  828. }
  829. except Exception as e:
  830. log.exception(e)
  831. raise HTTPException(
  832. status_code=status.HTTP_400_BAD_REQUEST,
  833. detail=ERROR_MESSAGES.DEFAULT(e),
  834. )
  835. except Exception as e:
  836. log.exception(e)
  837. raise HTTPException(
  838. status_code=status.HTTP_400_BAD_REQUEST,
  839. detail=ERROR_MESSAGES.DEFAULT(e),
  840. )
  841. def get_available_models(request: Request) -> list[dict]:
  842. available_models = []
  843. if request.app.state.config.TTS_ENGINE == "openai":
  844. # Use custom endpoint if not using the official OpenAI API URL
  845. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  846. "https://api.openai.com"
  847. ):
  848. try:
  849. response = requests.get(
  850. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  851. )
  852. response.raise_for_status()
  853. data = response.json()
  854. available_models = data.get("models", [])
  855. except Exception as e:
  856. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  857. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  858. else:
  859. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  860. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  861. try:
  862. response = requests.get(
  863. "https://api.elevenlabs.io/v1/models",
  864. headers={
  865. "xi-api-key": request.app.state.config.TTS_API_KEY,
  866. "Content-Type": "application/json",
  867. },
  868. timeout=5,
  869. )
  870. response.raise_for_status()
  871. models = response.json()
  872. available_models = [
  873. {"name": model["name"], "id": model["model_id"]} for model in models
  874. ]
  875. except requests.RequestException as e:
  876. log.error(f"Error fetching voices: {str(e)}")
  877. return available_models
  878. @router.get("/models")
  879. async def get_models(request: Request, user=Depends(get_verified_user)):
  880. return {"models": get_available_models(request)}
  881. def get_available_voices(request) -> dict:
  882. """Returns {voice_id: voice_name} dict"""
  883. available_voices = {}
  884. if request.app.state.config.TTS_ENGINE == "openai":
  885. # Use custom endpoint if not using the official OpenAI API URL
  886. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  887. "https://api.openai.com"
  888. ):
  889. try:
  890. response = requests.get(
  891. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  892. )
  893. response.raise_for_status()
  894. data = response.json()
  895. voices_list = data.get("voices", [])
  896. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  897. except Exception as e:
  898. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  899. available_voices = {
  900. "alloy": "alloy",
  901. "echo": "echo",
  902. "fable": "fable",
  903. "onyx": "onyx",
  904. "nova": "nova",
  905. "shimmer": "shimmer",
  906. }
  907. else:
  908. available_voices = {
  909. "alloy": "alloy",
  910. "echo": "echo",
  911. "fable": "fable",
  912. "onyx": "onyx",
  913. "nova": "nova",
  914. "shimmer": "shimmer",
  915. }
  916. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  917. try:
  918. available_voices = get_elevenlabs_voices(
  919. api_key=request.app.state.config.TTS_API_KEY
  920. )
  921. except Exception:
  922. # Avoided @lru_cache with exception
  923. pass
  924. elif request.app.state.config.TTS_ENGINE == "azure":
  925. try:
  926. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  927. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  928. url = (
  929. base_url or f"https://{region}.tts.speech.microsoft.com"
  930. ) + "/cognitiveservices/voices/list"
  931. headers = {
  932. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  933. }
  934. response = requests.get(url, headers=headers)
  935. response.raise_for_status()
  936. voices = response.json()
  937. for voice in voices:
  938. available_voices[voice["ShortName"]] = (
  939. f"{voice['DisplayName']} ({voice['ShortName']})"
  940. )
  941. except requests.RequestException as e:
  942. log.error(f"Error fetching voices: {str(e)}")
  943. return available_voices
  944. @lru_cache
  945. def get_elevenlabs_voices(api_key: str) -> dict:
  946. """
  947. Note, set the following in your .env file to use Elevenlabs:
  948. AUDIO_TTS_ENGINE=elevenlabs
  949. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  950. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  951. AUDIO_TTS_MODEL=eleven_multilingual_v2
  952. """
  953. try:
  954. # TODO: Add retries
  955. response = requests.get(
  956. "https://api.elevenlabs.io/v1/voices",
  957. headers={
  958. "xi-api-key": api_key,
  959. "Content-Type": "application/json",
  960. },
  961. )
  962. response.raise_for_status()
  963. voices_data = response.json()
  964. voices = {}
  965. for voice in voices_data.get("voices", []):
  966. voices[voice["voice_id"]] = voice["name"]
  967. except requests.RequestException as e:
  968. # Avoid @lru_cache with exception
  969. log.error(f"Error fetching voices: {str(e)}")
  970. raise RuntimeError(f"Error fetching voices: {str(e)}")
  971. return voices
  972. @router.get("/voices")
  973. async def get_voices(request: Request, user=Depends(get_verified_user)):
  974. return {
  975. "voices": [
  976. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  977. ]
  978. }