audio.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from pydub import AudioSegment
  9. from pydub.silence import split_on_silence
  10. import aiohttp
  11. import aiofiles
  12. import requests
  13. import mimetypes
  14. from fastapi import (
  15. Depends,
  16. FastAPI,
  17. File,
  18. HTTPException,
  19. Request,
  20. UploadFile,
  21. status,
  22. APIRouter,
  23. )
  24. from fastapi.middleware.cors import CORSMiddleware
  25. from fastapi.responses import FileResponse
  26. from pydantic import BaseModel
  27. from open_webui.utils.auth import get_admin_user, get_verified_user
  28. from open_webui.config import (
  29. WHISPER_MODEL_AUTO_UPDATE,
  30. WHISPER_MODEL_DIR,
  31. CACHE_DIR,
  32. WHISPER_LANGUAGE,
  33. )
  34. from open_webui.constants import ERROR_MESSAGES
  35. from open_webui.env import (
  36. AIOHTTP_CLIENT_TIMEOUT,
  37. ENV,
  38. SRC_LOG_LEVELS,
  39. DEVICE_TYPE,
  40. ENABLE_FORWARD_USER_INFO_HEADERS,
  41. )
  42. router = APIRouter()
  43. # Constants
  44. MAX_FILE_SIZE_MB = 25
  45. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  46. AZURE_MAX_FILE_SIZE_MB = 200
  47. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  48. log = logging.getLogger(__name__)
  49. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  50. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  51. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  52. ##########################################
  53. #
  54. # Utility functions
  55. #
  56. ##########################################
  57. from pydub import AudioSegment
  58. from pydub.utils import mediainfo
  59. def get_audio_format(file_path):
  60. """Check if the given file needs to be converted to a different format."""
  61. if not os.path.isfile(file_path):
  62. log.error(f"File not found: {file_path}")
  63. return False
  64. info = mediainfo(file_path)
  65. if (
  66. info.get("codec_name") == "aac"
  67. and info.get("codec_type") == "audio"
  68. and info.get("codec_tag_string") == "mp4a"
  69. ):
  70. return "mp4"
  71. elif info.get("format_name") == "ogg":
  72. return "ogg"
  73. elif info.get("format_name") == "matroska,webm":
  74. return "webm"
  75. return None
  76. def convert_audio_to_wav(file_path, output_path, conversion_type):
  77. """Convert MP4/OGG audio file to WAV format."""
  78. audio = AudioSegment.from_file(file_path, format=conversion_type)
  79. audio.export(output_path, format="wav")
  80. log.info(f"Converted {file_path} to {output_path}")
  81. def set_faster_whisper_model(model: str, auto_update: bool = False):
  82. whisper_model = None
  83. if model:
  84. from faster_whisper import WhisperModel
  85. faster_whisper_kwargs = {
  86. "model_size_or_path": model,
  87. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  88. "compute_type": "int8",
  89. "download_root": WHISPER_MODEL_DIR,
  90. "local_files_only": not auto_update,
  91. }
  92. try:
  93. whisper_model = WhisperModel(**faster_whisper_kwargs)
  94. except Exception:
  95. log.warning(
  96. "WhisperModel initialization failed, attempting download with local_files_only=False"
  97. )
  98. faster_whisper_kwargs["local_files_only"] = False
  99. whisper_model = WhisperModel(**faster_whisper_kwargs)
  100. return whisper_model
  101. ##########################################
  102. #
  103. # Audio API
  104. #
  105. ##########################################
  106. class TTSConfigForm(BaseModel):
  107. OPENAI_API_BASE_URL: str
  108. OPENAI_API_KEY: str
  109. API_KEY: str
  110. ENGINE: str
  111. MODEL: str
  112. VOICE: str
  113. SPLIT_ON: str
  114. AZURE_SPEECH_REGION: str
  115. AZURE_SPEECH_OUTPUT_FORMAT: str
  116. class STTConfigForm(BaseModel):
  117. OPENAI_API_BASE_URL: str
  118. OPENAI_API_KEY: str
  119. ENGINE: str
  120. MODEL: str
  121. WHISPER_MODEL: str
  122. DEEPGRAM_API_KEY: str
  123. AZURE_API_KEY: str
  124. AZURE_REGION: str
  125. AZURE_LOCALES: str
  126. AZURE_BASE_URL: str
  127. AZURE_MAX_SPEAKERS: str
  128. class AudioConfigUpdateForm(BaseModel):
  129. tts: TTSConfigForm
  130. stt: STTConfigForm
  131. @router.get("/config")
  132. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  133. return {
  134. "tts": {
  135. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  136. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  137. "API_KEY": request.app.state.config.TTS_API_KEY,
  138. "ENGINE": request.app.state.config.TTS_ENGINE,
  139. "MODEL": request.app.state.config.TTS_MODEL,
  140. "VOICE": request.app.state.config.TTS_VOICE,
  141. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  142. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  143. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  144. },
  145. "stt": {
  146. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  147. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  148. "ENGINE": request.app.state.config.STT_ENGINE,
  149. "MODEL": request.app.state.config.STT_MODEL,
  150. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  151. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  152. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  153. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  154. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  155. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  156. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  157. },
  158. }
  159. @router.post("/config/update")
  160. async def update_audio_config(
  161. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  162. ):
  163. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  164. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  165. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  166. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  167. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  168. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  169. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  170. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  171. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  172. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  173. )
  174. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  175. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  176. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  177. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  178. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  179. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  180. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  181. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  182. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  183. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  184. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  185. form_data.stt.AZURE_MAX_SPEAKERS
  186. )
  187. if request.app.state.config.STT_ENGINE == "":
  188. request.app.state.faster_whisper_model = set_faster_whisper_model(
  189. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  190. )
  191. return {
  192. "tts": {
  193. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  194. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  195. "API_KEY": request.app.state.config.TTS_API_KEY,
  196. "ENGINE": request.app.state.config.TTS_ENGINE,
  197. "MODEL": request.app.state.config.TTS_MODEL,
  198. "VOICE": request.app.state.config.TTS_VOICE,
  199. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  200. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  201. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  202. },
  203. "stt": {
  204. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  205. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  206. "ENGINE": request.app.state.config.STT_ENGINE,
  207. "MODEL": request.app.state.config.STT_MODEL,
  208. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  209. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  210. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  211. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  212. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  213. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  214. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  215. },
  216. }
  217. def load_speech_pipeline(request):
  218. from transformers import pipeline
  219. from datasets import load_dataset
  220. if request.app.state.speech_synthesiser is None:
  221. request.app.state.speech_synthesiser = pipeline(
  222. "text-to-speech", "microsoft/speecht5_tts"
  223. )
  224. if request.app.state.speech_speaker_embeddings_dataset is None:
  225. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  226. "Matthijs/cmu-arctic-xvectors", split="validation"
  227. )
  228. @router.post("/speech")
  229. async def speech(request: Request, user=Depends(get_verified_user)):
  230. body = await request.body()
  231. name = hashlib.sha256(
  232. body
  233. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  234. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  235. ).hexdigest()
  236. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  237. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  238. # Check if the file already exists in the cache
  239. if file_path.is_file():
  240. return FileResponse(file_path)
  241. payload = None
  242. try:
  243. payload = json.loads(body.decode("utf-8"))
  244. except Exception as e:
  245. log.exception(e)
  246. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  247. if request.app.state.config.TTS_ENGINE == "openai":
  248. payload["model"] = request.app.state.config.TTS_MODEL
  249. try:
  250. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  251. async with aiohttp.ClientSession(
  252. timeout=timeout, trust_env=True
  253. ) as session:
  254. async with session.post(
  255. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  256. json=payload,
  257. headers={
  258. "Content-Type": "application/json",
  259. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  260. **(
  261. {
  262. "X-OpenWebUI-User-Name": user.name,
  263. "X-OpenWebUI-User-Id": user.id,
  264. "X-OpenWebUI-User-Email": user.email,
  265. "X-OpenWebUI-User-Role": user.role,
  266. }
  267. if ENABLE_FORWARD_USER_INFO_HEADERS
  268. else {}
  269. ),
  270. },
  271. ) as r:
  272. r.raise_for_status()
  273. async with aiofiles.open(file_path, "wb") as f:
  274. await f.write(await r.read())
  275. async with aiofiles.open(file_body_path, "w") as f:
  276. await f.write(json.dumps(payload))
  277. return FileResponse(file_path)
  278. except Exception as e:
  279. log.exception(e)
  280. detail = None
  281. try:
  282. if r.status != 200:
  283. res = await r.json()
  284. if "error" in res:
  285. detail = f"External: {res['error'].get('message', '')}"
  286. except Exception:
  287. detail = f"External: {e}"
  288. raise HTTPException(
  289. status_code=getattr(r, "status", 500) if r else 500,
  290. detail=detail if detail else "Open WebUI: Server Connection Error",
  291. )
  292. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  293. voice_id = payload.get("voice", "")
  294. if voice_id not in get_available_voices(request):
  295. raise HTTPException(
  296. status_code=400,
  297. detail="Invalid voice id",
  298. )
  299. try:
  300. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  301. async with aiohttp.ClientSession(
  302. timeout=timeout, trust_env=True
  303. ) as session:
  304. async with session.post(
  305. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  306. json={
  307. "text": payload["input"],
  308. "model_id": request.app.state.config.TTS_MODEL,
  309. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  310. },
  311. headers={
  312. "Accept": "audio/mpeg",
  313. "Content-Type": "application/json",
  314. "xi-api-key": request.app.state.config.TTS_API_KEY,
  315. },
  316. ) as r:
  317. r.raise_for_status()
  318. async with aiofiles.open(file_path, "wb") as f:
  319. await f.write(await r.read())
  320. async with aiofiles.open(file_body_path, "w") as f:
  321. await f.write(json.dumps(payload))
  322. return FileResponse(file_path)
  323. except Exception as e:
  324. log.exception(e)
  325. detail = None
  326. try:
  327. if r.status != 200:
  328. res = await r.json()
  329. if "error" in res:
  330. detail = f"External: {res['error'].get('message', '')}"
  331. except Exception:
  332. detail = f"External: {e}"
  333. raise HTTPException(
  334. status_code=getattr(r, "status", 500) if r else 500,
  335. detail=detail if detail else "Open WebUI: Server Connection Error",
  336. )
  337. elif request.app.state.config.TTS_ENGINE == "azure":
  338. try:
  339. payload = json.loads(body.decode("utf-8"))
  340. except Exception as e:
  341. log.exception(e)
  342. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  343. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  344. language = request.app.state.config.TTS_VOICE
  345. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  346. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  347. try:
  348. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  349. <voice name="{language}">{payload["input"]}</voice>
  350. </speak>"""
  351. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  352. async with aiohttp.ClientSession(
  353. timeout=timeout, trust_env=True
  354. ) as session:
  355. async with session.post(
  356. f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
  357. headers={
  358. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  359. "Content-Type": "application/ssml+xml",
  360. "X-Microsoft-OutputFormat": output_format,
  361. },
  362. data=data,
  363. ) as r:
  364. r.raise_for_status()
  365. async with aiofiles.open(file_path, "wb") as f:
  366. await f.write(await r.read())
  367. async with aiofiles.open(file_body_path, "w") as f:
  368. await f.write(json.dumps(payload))
  369. return FileResponse(file_path)
  370. except Exception as e:
  371. log.exception(e)
  372. detail = None
  373. try:
  374. if r.status != 200:
  375. res = await r.json()
  376. if "error" in res:
  377. detail = f"External: {res['error'].get('message', '')}"
  378. except Exception:
  379. detail = f"External: {e}"
  380. raise HTTPException(
  381. status_code=getattr(r, "status", 500) if r else 500,
  382. detail=detail if detail else "Open WebUI: Server Connection Error",
  383. )
  384. elif request.app.state.config.TTS_ENGINE == "transformers":
  385. payload = None
  386. try:
  387. payload = json.loads(body.decode("utf-8"))
  388. except Exception as e:
  389. log.exception(e)
  390. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  391. import torch
  392. import soundfile as sf
  393. load_speech_pipeline(request)
  394. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  395. speaker_index = 6799
  396. try:
  397. speaker_index = embeddings_dataset["filename"].index(
  398. request.app.state.config.TTS_MODEL
  399. )
  400. except Exception:
  401. pass
  402. speaker_embedding = torch.tensor(
  403. embeddings_dataset[speaker_index]["xvector"]
  404. ).unsqueeze(0)
  405. speech = request.app.state.speech_synthesiser(
  406. payload["input"],
  407. forward_params={"speaker_embeddings": speaker_embedding},
  408. )
  409. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  410. async with aiofiles.open(file_body_path, "w") as f:
  411. await f.write(json.dumps(payload))
  412. return FileResponse(file_path)
  413. def transcribe(request: Request, file_path):
  414. log.info(f"transcribe: {file_path}")
  415. filename = os.path.basename(file_path)
  416. file_dir = os.path.dirname(file_path)
  417. id = filename.split(".")[0]
  418. if request.app.state.config.STT_ENGINE == "":
  419. if request.app.state.faster_whisper_model is None:
  420. request.app.state.faster_whisper_model = set_faster_whisper_model(
  421. request.app.state.config.WHISPER_MODEL
  422. )
  423. model = request.app.state.faster_whisper_model
  424. segments, info = model.transcribe(
  425. file_path,
  426. beam_size=5,
  427. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  428. language=WHISPER_LANGUAGE,
  429. )
  430. log.info(
  431. "Detected language '%s' with probability %f"
  432. % (info.language, info.language_probability)
  433. )
  434. transcript = "".join([segment.text for segment in list(segments)])
  435. data = {"text": transcript.strip()}
  436. # save the transcript to a json file
  437. transcript_file = f"{file_dir}/{id}.json"
  438. with open(transcript_file, "w") as f:
  439. json.dump(data, f)
  440. log.debug(data)
  441. return data
  442. elif request.app.state.config.STT_ENGINE == "openai":
  443. audio_format = get_audio_format(file_path)
  444. if audio_format:
  445. os.rename(file_path, file_path.replace(".wav", f".{audio_format}"))
  446. # Convert unsupported audio file to WAV format
  447. convert_audio_to_wav(
  448. file_path.replace(".wav", f".{audio_format}"),
  449. file_path,
  450. audio_format,
  451. )
  452. r = None
  453. try:
  454. r = requests.post(
  455. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  456. headers={
  457. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  458. },
  459. files={"file": (filename, open(file_path, "rb"))},
  460. data={"model": request.app.state.config.STT_MODEL},
  461. )
  462. r.raise_for_status()
  463. data = r.json()
  464. # save the transcript to a json file
  465. transcript_file = f"{file_dir}/{id}.json"
  466. with open(transcript_file, "w") as f:
  467. json.dump(data, f)
  468. return data
  469. except Exception as e:
  470. log.exception(e)
  471. detail = None
  472. if r is not None:
  473. try:
  474. res = r.json()
  475. if "error" in res:
  476. detail = f"External: {res['error'].get('message', '')}"
  477. except Exception:
  478. detail = f"External: {e}"
  479. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  480. elif request.app.state.config.STT_ENGINE == "deepgram":
  481. try:
  482. # Determine the MIME type of the file
  483. mime, _ = mimetypes.guess_type(file_path)
  484. if not mime:
  485. mime = "audio/wav" # fallback to wav if undetectable
  486. # Read the audio file
  487. with open(file_path, "rb") as f:
  488. file_data = f.read()
  489. # Build headers and parameters
  490. headers = {
  491. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  492. "Content-Type": mime,
  493. }
  494. # Add model if specified
  495. params = {}
  496. if request.app.state.config.STT_MODEL:
  497. params["model"] = request.app.state.config.STT_MODEL
  498. # Make request to Deepgram API
  499. r = requests.post(
  500. "https://api.deepgram.com/v1/listen",
  501. headers=headers,
  502. params=params,
  503. data=file_data,
  504. )
  505. r.raise_for_status()
  506. response_data = r.json()
  507. # Extract transcript from Deepgram response
  508. try:
  509. transcript = response_data["results"]["channels"][0]["alternatives"][
  510. 0
  511. ].get("transcript", "")
  512. except (KeyError, IndexError) as e:
  513. log.error(f"Malformed response from Deepgram: {str(e)}")
  514. raise Exception(
  515. "Failed to parse Deepgram response - unexpected response format"
  516. )
  517. data = {"text": transcript.strip()}
  518. # Save transcript
  519. transcript_file = f"{file_dir}/{id}.json"
  520. with open(transcript_file, "w") as f:
  521. json.dump(data, f)
  522. return data
  523. except Exception as e:
  524. log.exception(e)
  525. detail = None
  526. if r is not None:
  527. try:
  528. res = r.json()
  529. if "error" in res:
  530. detail = f"External: {res['error'].get('message', '')}"
  531. except Exception:
  532. detail = f"External: {e}"
  533. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  534. elif request.app.state.config.STT_ENGINE == "azure":
  535. # Check file exists and size
  536. if not os.path.exists(file_path):
  537. raise HTTPException(status_code=400, detail="Audio file not found")
  538. # Check file size (Azure has a larger limit of 200MB)
  539. file_size = os.path.getsize(file_path)
  540. if file_size > AZURE_MAX_FILE_SIZE:
  541. raise HTTPException(
  542. status_code=400,
  543. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  544. )
  545. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  546. region = request.app.state.config.AUDIO_STT_AZURE_REGION
  547. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  548. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  549. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS
  550. # IF NO LOCALES, USE DEFAULTS
  551. if len(locales) < 2:
  552. locales = [
  553. "en-US",
  554. "es-ES",
  555. "es-MX",
  556. "fr-FR",
  557. "hi-IN",
  558. "it-IT",
  559. "de-DE",
  560. "en-GB",
  561. "en-IN",
  562. "ja-JP",
  563. "ko-KR",
  564. "pt-BR",
  565. "zh-CN",
  566. ]
  567. locales = ",".join(locales)
  568. if not api_key or not region:
  569. raise HTTPException(
  570. status_code=400,
  571. detail="Azure API key is required for Azure STT",
  572. )
  573. if not base_url and not region:
  574. raise HTTPException(
  575. status_code=400,
  576. detail="Azure region or base url is required for Azure STT",
  577. )
  578. r = None
  579. try:
  580. # Prepare the request
  581. data = {
  582. "definition": json.dumps(
  583. {
  584. "locales": locales.split(","),
  585. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  586. }
  587. if locales
  588. else {}
  589. )
  590. }
  591. url = (
  592. base_url
  593. or f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  594. )
  595. # Use context manager to ensure file is properly closed
  596. with open(file_path, "rb") as audio_file:
  597. r = requests.post(
  598. url=url,
  599. files={"audio": audio_file},
  600. data=data,
  601. headers={
  602. "Ocp-Apim-Subscription-Key": api_key,
  603. },
  604. )
  605. r.raise_for_status()
  606. response = r.json()
  607. # Extract transcript from response
  608. if not response.get("combinedPhrases"):
  609. raise ValueError("No transcription found in response")
  610. # Get the full transcript from combinedPhrases
  611. transcript = response["combinedPhrases"][0].get("text", "").strip()
  612. if not transcript:
  613. raise ValueError("Empty transcript in response")
  614. data = {"text": transcript}
  615. # Save transcript to json file (consistent with other providers)
  616. transcript_file = f"{file_dir}/{id}.json"
  617. with open(transcript_file, "w") as f:
  618. json.dump(data, f)
  619. log.debug(data)
  620. return data
  621. except (KeyError, IndexError, ValueError) as e:
  622. log.exception("Error parsing Azure response")
  623. raise HTTPException(
  624. status_code=500,
  625. detail=f"Failed to parse Azure response: {str(e)}",
  626. )
  627. except requests.exceptions.RequestException as e:
  628. log.exception(e)
  629. detail = None
  630. try:
  631. if r is not None and r.status_code != 200:
  632. res = r.json()
  633. if "error" in res:
  634. detail = f"External: {res['error'].get('message', '')}"
  635. except Exception:
  636. detail = f"External: {e}"
  637. raise HTTPException(
  638. status_code=getattr(r, "status_code", 500) if r else 500,
  639. detail=detail if detail else "Open WebUI: Server Connection Error",
  640. )
  641. def compress_audio(file_path):
  642. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  643. file_dir = os.path.dirname(file_path)
  644. audio = AudioSegment.from_file(file_path)
  645. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  646. compressed_path = f"{file_dir}/{id}_compressed.opus"
  647. audio.export(compressed_path, format="opus", bitrate="32k")
  648. log.debug(f"Compressed audio to {compressed_path}")
  649. if (
  650. os.path.getsize(compressed_path) > MAX_FILE_SIZE
  651. ): # Still larger than MAX_FILE_SIZE after compression
  652. raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
  653. return compressed_path
  654. else:
  655. return file_path
  656. @router.post("/transcriptions")
  657. def transcription(
  658. request: Request,
  659. file: UploadFile = File(...),
  660. user=Depends(get_verified_user),
  661. ):
  662. log.info(f"file.content_type: {file.content_type}")
  663. supported_filetypes = ("audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a")
  664. if not file.content_type.startswith(supported_filetypes):
  665. raise HTTPException(
  666. status_code=status.HTTP_400_BAD_REQUEST,
  667. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  668. )
  669. try:
  670. ext = file.filename.split(".")[-1]
  671. id = uuid.uuid4()
  672. filename = f"{id}.{ext}"
  673. contents = file.file.read()
  674. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  675. os.makedirs(file_dir, exist_ok=True)
  676. file_path = f"{file_dir}/{filename}"
  677. with open(file_path, "wb") as f:
  678. f.write(contents)
  679. try:
  680. try:
  681. file_path = compress_audio(file_path)
  682. except Exception as e:
  683. log.exception(e)
  684. raise HTTPException(
  685. status_code=status.HTTP_400_BAD_REQUEST,
  686. detail=ERROR_MESSAGES.DEFAULT(e),
  687. )
  688. data = transcribe(request, file_path)
  689. file_path = file_path.split("/")[-1]
  690. return {**data, "filename": file_path}
  691. except Exception as e:
  692. log.exception(e)
  693. raise HTTPException(
  694. status_code=status.HTTP_400_BAD_REQUEST,
  695. detail=ERROR_MESSAGES.DEFAULT(e),
  696. )
  697. except Exception as e:
  698. log.exception(e)
  699. raise HTTPException(
  700. status_code=status.HTTP_400_BAD_REQUEST,
  701. detail=ERROR_MESSAGES.DEFAULT(e),
  702. )
  703. def get_available_models(request: Request) -> list[dict]:
  704. available_models = []
  705. if request.app.state.config.TTS_ENGINE == "openai":
  706. # Use custom endpoint if not using the official OpenAI API URL
  707. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  708. "https://api.openai.com"
  709. ):
  710. try:
  711. response = requests.get(
  712. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  713. )
  714. response.raise_for_status()
  715. data = response.json()
  716. available_models = data.get("models", [])
  717. except Exception as e:
  718. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  719. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  720. else:
  721. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  722. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  723. try:
  724. response = requests.get(
  725. "https://api.elevenlabs.io/v1/models",
  726. headers={
  727. "xi-api-key": request.app.state.config.TTS_API_KEY,
  728. "Content-Type": "application/json",
  729. },
  730. timeout=5,
  731. )
  732. response.raise_for_status()
  733. models = response.json()
  734. available_models = [
  735. {"name": model["name"], "id": model["model_id"]} for model in models
  736. ]
  737. except requests.RequestException as e:
  738. log.error(f"Error fetching voices: {str(e)}")
  739. return available_models
  740. @router.get("/models")
  741. async def get_models(request: Request, user=Depends(get_verified_user)):
  742. return {"models": get_available_models(request)}
  743. def get_available_voices(request) -> dict:
  744. """Returns {voice_id: voice_name} dict"""
  745. available_voices = {}
  746. if request.app.state.config.TTS_ENGINE == "openai":
  747. # Use custom endpoint if not using the official OpenAI API URL
  748. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  749. "https://api.openai.com"
  750. ):
  751. try:
  752. response = requests.get(
  753. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  754. )
  755. response.raise_for_status()
  756. data = response.json()
  757. voices_list = data.get("voices", [])
  758. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  759. except Exception as e:
  760. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  761. available_voices = {
  762. "alloy": "alloy",
  763. "echo": "echo",
  764. "fable": "fable",
  765. "onyx": "onyx",
  766. "nova": "nova",
  767. "shimmer": "shimmer",
  768. }
  769. else:
  770. available_voices = {
  771. "alloy": "alloy",
  772. "echo": "echo",
  773. "fable": "fable",
  774. "onyx": "onyx",
  775. "nova": "nova",
  776. "shimmer": "shimmer",
  777. }
  778. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  779. try:
  780. available_voices = get_elevenlabs_voices(
  781. api_key=request.app.state.config.TTS_API_KEY
  782. )
  783. except Exception:
  784. # Avoided @lru_cache with exception
  785. pass
  786. elif request.app.state.config.TTS_ENGINE == "azure":
  787. try:
  788. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  789. url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
  790. headers = {
  791. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  792. }
  793. response = requests.get(url, headers=headers)
  794. response.raise_for_status()
  795. voices = response.json()
  796. for voice in voices:
  797. available_voices[voice["ShortName"]] = (
  798. f"{voice['DisplayName']} ({voice['ShortName']})"
  799. )
  800. except requests.RequestException as e:
  801. log.error(f"Error fetching voices: {str(e)}")
  802. return available_voices
  803. @lru_cache
  804. def get_elevenlabs_voices(api_key: str) -> dict:
  805. """
  806. Note, set the following in your .env file to use Elevenlabs:
  807. AUDIO_TTS_ENGINE=elevenlabs
  808. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  809. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  810. AUDIO_TTS_MODEL=eleven_multilingual_v2
  811. """
  812. try:
  813. # TODO: Add retries
  814. response = requests.get(
  815. "https://api.elevenlabs.io/v1/voices",
  816. headers={
  817. "xi-api-key": api_key,
  818. "Content-Type": "application/json",
  819. },
  820. )
  821. response.raise_for_status()
  822. voices_data = response.json()
  823. voices = {}
  824. for voice in voices_data.get("voices", []):
  825. voices[voice["voice_id"]] = voice["name"]
  826. except requests.RequestException as e:
  827. # Avoid @lru_cache with exception
  828. log.error(f"Error fetching voices: {str(e)}")
  829. raise RuntimeError(f"Error fetching voices: {str(e)}")
  830. return voices
  831. @router.get("/voices")
  832. async def get_voices(request: Request, user=Depends(get_verified_user)):
  833. return {
  834. "voices": [
  835. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  836. ]
  837. }