audio.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from pydub import AudioSegment
  9. from pydub.silence import split_on_silence
  10. from concurrent.futures import ThreadPoolExecutor
  11. from typing import Optional
  12. import aiohttp
  13. import aiofiles
  14. import requests
  15. import mimetypes
  16. from fastapi import (
  17. Depends,
  18. FastAPI,
  19. File,
  20. Form,
  21. HTTPException,
  22. Request,
  23. UploadFile,
  24. status,
  25. APIRouter,
  26. )
  27. from fastapi.middleware.cors import CORSMiddleware
  28. from fastapi.responses import FileResponse
  29. from pydantic import BaseModel
  30. from open_webui.utils.auth import get_admin_user, get_verified_user
  31. from open_webui.config import (
  32. WHISPER_MODEL_AUTO_UPDATE,
  33. WHISPER_MODEL_DIR,
  34. CACHE_DIR,
  35. WHISPER_LANGUAGE,
  36. )
  37. from open_webui.constants import ERROR_MESSAGES
  38. from open_webui.env import (
  39. AIOHTTP_CLIENT_SESSION_SSL,
  40. AIOHTTP_CLIENT_TIMEOUT,
  41. ENV,
  42. SRC_LOG_LEVELS,
  43. DEVICE_TYPE,
  44. ENABLE_FORWARD_USER_INFO_HEADERS,
  45. )
  46. router = APIRouter()
  47. # Constants
  48. MAX_FILE_SIZE_MB = 20
  49. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  50. AZURE_MAX_FILE_SIZE_MB = 200
  51. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  52. log = logging.getLogger(__name__)
  53. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  54. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  55. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  56. ##########################################
  57. #
  58. # Utility functions
  59. #
  60. ##########################################
  61. from pydub import AudioSegment
  62. from pydub.utils import mediainfo
  63. def is_audio_conversion_required(file_path):
  64. """
  65. Check if the given audio file needs conversion to mp3.
  66. """
  67. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  68. if not os.path.isfile(file_path):
  69. log.error(f"File not found: {file_path}")
  70. return False
  71. try:
  72. info = mediainfo(file_path)
  73. codec_name = info.get("codec_name", "").lower()
  74. codec_type = info.get("codec_type", "").lower()
  75. codec_tag_string = info.get("codec_tag_string", "").lower()
  76. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  77. # File is AAC/mp4a audio, recommend mp3 conversion
  78. return True
  79. # If the codec name is in the supported formats
  80. if codec_name in SUPPORTED_FORMATS:
  81. return False
  82. return True
  83. except Exception as e:
  84. log.error(f"Error getting audio format: {e}")
  85. return False
  86. def convert_audio_to_mp3(file_path):
  87. """Convert audio file to mp3 format."""
  88. try:
  89. output_path = os.path.splitext(file_path)[0] + ".mp3"
  90. audio = AudioSegment.from_file(file_path)
  91. audio.export(output_path, format="mp3")
  92. log.info(f"Converted {file_path} to {output_path}")
  93. return output_path
  94. except Exception as e:
  95. log.error(f"Error converting audio file: {e}")
  96. return None
  97. def set_faster_whisper_model(model: str, auto_update: bool = False):
  98. whisper_model = None
  99. if model:
  100. from faster_whisper import WhisperModel
  101. faster_whisper_kwargs = {
  102. "model_size_or_path": model,
  103. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  104. "compute_type": "int8",
  105. "download_root": WHISPER_MODEL_DIR,
  106. "local_files_only": not auto_update,
  107. }
  108. try:
  109. whisper_model = WhisperModel(**faster_whisper_kwargs)
  110. except Exception:
  111. log.warning(
  112. "WhisperModel initialization failed, attempting download with local_files_only=False"
  113. )
  114. faster_whisper_kwargs["local_files_only"] = False
  115. whisper_model = WhisperModel(**faster_whisper_kwargs)
  116. return whisper_model
  117. ##########################################
  118. #
  119. # Audio API
  120. #
  121. ##########################################
  122. class TTSConfigForm(BaseModel):
  123. OPENAI_API_BASE_URL: str
  124. OPENAI_API_KEY: str
  125. API_KEY: str
  126. ENGINE: str
  127. MODEL: str
  128. VOICE: str
  129. SPLIT_ON: str
  130. AZURE_SPEECH_REGION: str
  131. AZURE_SPEECH_BASE_URL: str
  132. AZURE_SPEECH_OUTPUT_FORMAT: str
  133. class STTConfigForm(BaseModel):
  134. OPENAI_API_BASE_URL: str
  135. OPENAI_API_KEY: str
  136. ENGINE: str
  137. MODEL: str
  138. WHISPER_MODEL: str
  139. DEEPGRAM_API_KEY: str
  140. AZURE_API_KEY: str
  141. AZURE_REGION: str
  142. AZURE_LOCALES: str
  143. AZURE_BASE_URL: str
  144. AZURE_MAX_SPEAKERS: str
  145. class AudioConfigUpdateForm(BaseModel):
  146. tts: TTSConfigForm
  147. stt: STTConfigForm
  148. @router.get("/config")
  149. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  150. return {
  151. "tts": {
  152. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  153. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  154. "API_KEY": request.app.state.config.TTS_API_KEY,
  155. "ENGINE": request.app.state.config.TTS_ENGINE,
  156. "MODEL": request.app.state.config.TTS_MODEL,
  157. "VOICE": request.app.state.config.TTS_VOICE,
  158. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  159. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  160. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  161. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  162. },
  163. "stt": {
  164. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  165. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  166. "ENGINE": request.app.state.config.STT_ENGINE,
  167. "MODEL": request.app.state.config.STT_MODEL,
  168. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  169. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  170. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  171. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  172. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  173. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  174. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  175. },
  176. }
  177. @router.post("/config/update")
  178. async def update_audio_config(
  179. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  180. ):
  181. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  182. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  183. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  184. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  185. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  186. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  187. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  188. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  189. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  190. form_data.tts.AZURE_SPEECH_BASE_URL
  191. )
  192. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  193. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  194. )
  195. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  196. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  197. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  198. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  199. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  200. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  201. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  202. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  203. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  204. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  205. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  206. form_data.stt.AZURE_MAX_SPEAKERS
  207. )
  208. if request.app.state.config.STT_ENGINE == "":
  209. request.app.state.faster_whisper_model = set_faster_whisper_model(
  210. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  211. )
  212. return {
  213. "tts": {
  214. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  215. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  216. "API_KEY": request.app.state.config.TTS_API_KEY,
  217. "ENGINE": request.app.state.config.TTS_ENGINE,
  218. "MODEL": request.app.state.config.TTS_MODEL,
  219. "VOICE": request.app.state.config.TTS_VOICE,
  220. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  221. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  222. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  223. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  224. },
  225. "stt": {
  226. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  227. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  228. "ENGINE": request.app.state.config.STT_ENGINE,
  229. "MODEL": request.app.state.config.STT_MODEL,
  230. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  231. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  232. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  233. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  234. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  235. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  236. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  237. },
  238. }
  239. def load_speech_pipeline(request):
  240. from transformers import pipeline
  241. from datasets import load_dataset
  242. if request.app.state.speech_synthesiser is None:
  243. request.app.state.speech_synthesiser = pipeline(
  244. "text-to-speech", "microsoft/speecht5_tts"
  245. )
  246. if request.app.state.speech_speaker_embeddings_dataset is None:
  247. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  248. "Matthijs/cmu-arctic-xvectors", split="validation"
  249. )
  250. @router.post("/speech")
  251. async def speech(request: Request, user=Depends(get_verified_user)):
  252. body = await request.body()
  253. name = hashlib.sha256(
  254. body
  255. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  256. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  257. ).hexdigest()
  258. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  259. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  260. # Check if the file already exists in the cache
  261. if file_path.is_file():
  262. return FileResponse(file_path)
  263. payload = None
  264. try:
  265. payload = json.loads(body.decode("utf-8"))
  266. except Exception as e:
  267. log.exception(e)
  268. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  269. if request.app.state.config.TTS_ENGINE == "openai":
  270. payload["model"] = request.app.state.config.TTS_MODEL
  271. try:
  272. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  273. async with aiohttp.ClientSession(
  274. timeout=timeout, trust_env=True
  275. ) as session:
  276. async with session.post(
  277. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  278. json=payload,
  279. headers={
  280. "Content-Type": "application/json",
  281. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  282. **(
  283. {
  284. "X-OpenWebUI-User-Name": user.name,
  285. "X-OpenWebUI-User-Id": user.id,
  286. "X-OpenWebUI-User-Email": user.email,
  287. "X-OpenWebUI-User-Role": user.role,
  288. }
  289. if ENABLE_FORWARD_USER_INFO_HEADERS
  290. else {}
  291. ),
  292. },
  293. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  294. ) as r:
  295. r.raise_for_status()
  296. async with aiofiles.open(file_path, "wb") as f:
  297. await f.write(await r.read())
  298. async with aiofiles.open(file_body_path, "w") as f:
  299. await f.write(json.dumps(payload))
  300. return FileResponse(file_path)
  301. except Exception as e:
  302. log.exception(e)
  303. detail = None
  304. try:
  305. if r.status != 200:
  306. res = await r.json()
  307. if "error" in res:
  308. detail = f"External: {res['error'].get('message', '')}"
  309. except Exception:
  310. detail = f"External: {e}"
  311. raise HTTPException(
  312. status_code=getattr(r, "status", 500) if r else 500,
  313. detail=detail if detail else "Open WebUI: Server Connection Error",
  314. )
  315. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  316. voice_id = payload.get("voice", "")
  317. if voice_id not in get_available_voices(request):
  318. raise HTTPException(
  319. status_code=400,
  320. detail="Invalid voice id",
  321. )
  322. try:
  323. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  324. async with aiohttp.ClientSession(
  325. timeout=timeout, trust_env=True
  326. ) as session:
  327. async with session.post(
  328. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  329. json={
  330. "text": payload["input"],
  331. "model_id": request.app.state.config.TTS_MODEL,
  332. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  333. },
  334. headers={
  335. "Accept": "audio/mpeg",
  336. "Content-Type": "application/json",
  337. "xi-api-key": request.app.state.config.TTS_API_KEY,
  338. },
  339. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  340. ) as r:
  341. r.raise_for_status()
  342. async with aiofiles.open(file_path, "wb") as f:
  343. await f.write(await r.read())
  344. async with aiofiles.open(file_body_path, "w") as f:
  345. await f.write(json.dumps(payload))
  346. return FileResponse(file_path)
  347. except Exception as e:
  348. log.exception(e)
  349. detail = None
  350. try:
  351. if r.status != 200:
  352. res = await r.json()
  353. if "error" in res:
  354. detail = f"External: {res['error'].get('message', '')}"
  355. except Exception:
  356. detail = f"External: {e}"
  357. raise HTTPException(
  358. status_code=getattr(r, "status", 500) if r else 500,
  359. detail=detail if detail else "Open WebUI: Server Connection Error",
  360. )
  361. elif request.app.state.config.TTS_ENGINE == "azure":
  362. try:
  363. payload = json.loads(body.decode("utf-8"))
  364. except Exception as e:
  365. log.exception(e)
  366. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  367. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  368. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  369. language = request.app.state.config.TTS_VOICE
  370. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  371. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  372. try:
  373. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  374. <voice name="{language}">{payload["input"]}</voice>
  375. </speak>"""
  376. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  377. async with aiohttp.ClientSession(
  378. timeout=timeout, trust_env=True
  379. ) as session:
  380. async with session.post(
  381. (base_url or f"https://{region}.tts.speech.microsoft.com")
  382. + "/cognitiveservices/v1",
  383. headers={
  384. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  385. "Content-Type": "application/ssml+xml",
  386. "X-Microsoft-OutputFormat": output_format,
  387. },
  388. data=data,
  389. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  390. ) as r:
  391. r.raise_for_status()
  392. async with aiofiles.open(file_path, "wb") as f:
  393. await f.write(await r.read())
  394. async with aiofiles.open(file_body_path, "w") as f:
  395. await f.write(json.dumps(payload))
  396. return FileResponse(file_path)
  397. except Exception as e:
  398. log.exception(e)
  399. detail = None
  400. try:
  401. if r.status != 200:
  402. res = await r.json()
  403. if "error" in res:
  404. detail = f"External: {res['error'].get('message', '')}"
  405. except Exception:
  406. detail = f"External: {e}"
  407. raise HTTPException(
  408. status_code=getattr(r, "status", 500) if r else 500,
  409. detail=detail if detail else "Open WebUI: Server Connection Error",
  410. )
  411. elif request.app.state.config.TTS_ENGINE == "transformers":
  412. payload = None
  413. try:
  414. payload = json.loads(body.decode("utf-8"))
  415. except Exception as e:
  416. log.exception(e)
  417. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  418. import torch
  419. import soundfile as sf
  420. load_speech_pipeline(request)
  421. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  422. speaker_index = 6799
  423. try:
  424. speaker_index = embeddings_dataset["filename"].index(
  425. request.app.state.config.TTS_MODEL
  426. )
  427. except Exception:
  428. pass
  429. speaker_embedding = torch.tensor(
  430. embeddings_dataset[speaker_index]["xvector"]
  431. ).unsqueeze(0)
  432. speech = request.app.state.speech_synthesiser(
  433. payload["input"],
  434. forward_params={"speaker_embeddings": speaker_embedding},
  435. )
  436. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  437. async with aiofiles.open(file_body_path, "w") as f:
  438. await f.write(json.dumps(payload))
  439. return FileResponse(file_path)
  440. def transcription_handler(request, file_path, metadata):
  441. filename = os.path.basename(file_path)
  442. file_dir = os.path.dirname(file_path)
  443. id = filename.split(".")[0]
  444. metadata = metadata or {}
  445. if request.app.state.config.STT_ENGINE == "":
  446. if request.app.state.faster_whisper_model is None:
  447. request.app.state.faster_whisper_model = set_faster_whisper_model(
  448. request.app.state.config.WHISPER_MODEL
  449. )
  450. model = request.app.state.faster_whisper_model
  451. segments, info = model.transcribe(
  452. file_path,
  453. beam_size=5,
  454. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  455. language=metadata.get("language") or WHISPER_LANGUAGE,
  456. )
  457. log.info(
  458. "Detected language '%s' with probability %f"
  459. % (info.language, info.language_probability)
  460. )
  461. transcript = "".join([segment.text for segment in list(segments)])
  462. data = {"text": transcript.strip()}
  463. # save the transcript to a json file
  464. transcript_file = f"{file_dir}/{id}.json"
  465. with open(transcript_file, "w") as f:
  466. json.dump(data, f)
  467. log.debug(data)
  468. return data
  469. elif request.app.state.config.STT_ENGINE == "openai":
  470. r = None
  471. try:
  472. r = requests.post(
  473. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  474. headers={
  475. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  476. },
  477. files={"file": (filename, open(file_path, "rb"))},
  478. data={
  479. "model": request.app.state.config.STT_MODEL,
  480. **(
  481. {"language": metadata.get("language")}
  482. if metadata.get("language")
  483. else {}
  484. ),
  485. },
  486. )
  487. r.raise_for_status()
  488. data = r.json()
  489. # save the transcript to a json file
  490. transcript_file = f"{file_dir}/{id}.json"
  491. with open(transcript_file, "w") as f:
  492. json.dump(data, f)
  493. return data
  494. except Exception as e:
  495. log.exception(e)
  496. detail = None
  497. if r is not None:
  498. try:
  499. res = r.json()
  500. if "error" in res:
  501. detail = f"External: {res['error'].get('message', '')}"
  502. except Exception:
  503. detail = f"External: {e}"
  504. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  505. elif request.app.state.config.STT_ENGINE == "deepgram":
  506. try:
  507. # Determine the MIME type of the file
  508. mime, _ = mimetypes.guess_type(file_path)
  509. if not mime:
  510. mime = "audio/wav" # fallback to wav if undetectable
  511. # Read the audio file
  512. with open(file_path, "rb") as f:
  513. file_data = f.read()
  514. # Build headers and parameters
  515. headers = {
  516. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  517. "Content-Type": mime,
  518. }
  519. # Add model if specified
  520. params = {}
  521. if request.app.state.config.STT_MODEL:
  522. params["model"] = request.app.state.config.STT_MODEL
  523. # Make request to Deepgram API
  524. r = requests.post(
  525. "https://api.deepgram.com/v1/listen",
  526. headers=headers,
  527. params=params,
  528. data=file_data,
  529. )
  530. r.raise_for_status()
  531. response_data = r.json()
  532. # Extract transcript from Deepgram response
  533. try:
  534. transcript = response_data["results"]["channels"][0]["alternatives"][
  535. 0
  536. ].get("transcript", "")
  537. except (KeyError, IndexError) as e:
  538. log.error(f"Malformed response from Deepgram: {str(e)}")
  539. raise Exception(
  540. "Failed to parse Deepgram response - unexpected response format"
  541. )
  542. data = {"text": transcript.strip()}
  543. # Save transcript
  544. transcript_file = f"{file_dir}/{id}.json"
  545. with open(transcript_file, "w") as f:
  546. json.dump(data, f)
  547. return data
  548. except Exception as e:
  549. log.exception(e)
  550. detail = None
  551. if r is not None:
  552. try:
  553. res = r.json()
  554. if "error" in res:
  555. detail = f"External: {res['error'].get('message', '')}"
  556. except Exception:
  557. detail = f"External: {e}"
  558. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  559. elif request.app.state.config.STT_ENGINE == "azure":
  560. # Check file exists and size
  561. if not os.path.exists(file_path):
  562. raise HTTPException(status_code=400, detail="Audio file not found")
  563. # Check file size (Azure has a larger limit of 200MB)
  564. file_size = os.path.getsize(file_path)
  565. if file_size > AZURE_MAX_FILE_SIZE:
  566. raise HTTPException(
  567. status_code=400,
  568. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  569. )
  570. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  571. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  572. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  573. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  574. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  575. # IF NO LOCALES, USE DEFAULTS
  576. if len(locales) < 2:
  577. locales = [
  578. "en-US",
  579. "es-ES",
  580. "es-MX",
  581. "fr-FR",
  582. "hi-IN",
  583. "it-IT",
  584. "de-DE",
  585. "en-GB",
  586. "en-IN",
  587. "ja-JP",
  588. "ko-KR",
  589. "pt-BR",
  590. "zh-CN",
  591. ]
  592. locales = ",".join(locales)
  593. if not api_key or not region:
  594. raise HTTPException(
  595. status_code=400,
  596. detail="Azure API key is required for Azure STT",
  597. )
  598. r = None
  599. try:
  600. # Prepare the request
  601. data = {
  602. "definition": json.dumps(
  603. {
  604. "locales": locales.split(","),
  605. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  606. }
  607. if locales
  608. else {}
  609. )
  610. }
  611. url = (
  612. base_url or f"https://{region}.api.cognitive.microsoft.com"
  613. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  614. # Use context manager to ensure file is properly closed
  615. with open(file_path, "rb") as audio_file:
  616. r = requests.post(
  617. url=url,
  618. files={"audio": audio_file},
  619. data=data,
  620. headers={
  621. "Ocp-Apim-Subscription-Key": api_key,
  622. },
  623. )
  624. r.raise_for_status()
  625. response = r.json()
  626. # Extract transcript from response
  627. if not response.get("combinedPhrases"):
  628. raise ValueError("No transcription found in response")
  629. # Get the full transcript from combinedPhrases
  630. transcript = response["combinedPhrases"][0].get("text", "").strip()
  631. if not transcript:
  632. raise ValueError("Empty transcript in response")
  633. data = {"text": transcript}
  634. # Save transcript to json file (consistent with other providers)
  635. transcript_file = f"{file_dir}/{id}.json"
  636. with open(transcript_file, "w") as f:
  637. json.dump(data, f)
  638. log.debug(data)
  639. return data
  640. except (KeyError, IndexError, ValueError) as e:
  641. log.exception("Error parsing Azure response")
  642. raise HTTPException(
  643. status_code=500,
  644. detail=f"Failed to parse Azure response: {str(e)}",
  645. )
  646. except requests.exceptions.RequestException as e:
  647. log.exception(e)
  648. detail = None
  649. try:
  650. if r is not None and r.status_code != 200:
  651. res = r.json()
  652. if "error" in res:
  653. detail = f"External: {res['error'].get('message', '')}"
  654. except Exception:
  655. detail = f"External: {e}"
  656. raise HTTPException(
  657. status_code=getattr(r, "status_code", 500) if r else 500,
  658. detail=detail if detail else "Open WebUI: Server Connection Error",
  659. )
  660. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  661. log.info(f"transcribe: {file_path} {metadata}")
  662. if is_audio_conversion_required(file_path):
  663. file_path = convert_audio_to_mp3(file_path)
  664. try:
  665. file_path = compress_audio(file_path)
  666. except Exception as e:
  667. log.exception(e)
  668. # Always produce a list of chunk paths (could be one entry if small)
  669. try:
  670. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  671. print(f"Chunk paths: {chunk_paths}")
  672. except Exception as e:
  673. log.exception(e)
  674. raise HTTPException(
  675. status_code=status.HTTP_400_BAD_REQUEST,
  676. detail=ERROR_MESSAGES.DEFAULT(e),
  677. )
  678. results = []
  679. try:
  680. with ThreadPoolExecutor() as executor:
  681. # Submit tasks for each chunk_path
  682. futures = [
  683. executor.submit(transcription_handler, request, chunk_path, metadata)
  684. for chunk_path in chunk_paths
  685. ]
  686. # Gather results as they complete
  687. for future in futures:
  688. try:
  689. results.append(future.result())
  690. except Exception as transcribe_exc:
  691. raise HTTPException(
  692. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  693. detail=f"Error transcribing chunk: {transcribe_exc}",
  694. )
  695. finally:
  696. # Clean up only the temporary chunks, never the original file
  697. for chunk_path in chunk_paths:
  698. if chunk_path != file_path and os.path.isfile(chunk_path):
  699. try:
  700. os.remove(chunk_path)
  701. except Exception:
  702. pass
  703. return {
  704. "text": " ".join([result["text"] for result in results]),
  705. }
  706. def compress_audio(file_path):
  707. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  708. id = os.path.splitext(os.path.basename(file_path))[
  709. 0
  710. ] # Handles names with multiple dots
  711. file_dir = os.path.dirname(file_path)
  712. audio = AudioSegment.from_file(file_path)
  713. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  714. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  715. audio.export(compressed_path, format="mp3", bitrate="32k")
  716. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  717. return compressed_path
  718. else:
  719. return file_path
  720. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  721. """
  722. Splits audio into chunks not exceeding max_bytes.
  723. Returns a list of chunk file paths. If audio fits, returns list with original path.
  724. """
  725. file_size = os.path.getsize(file_path)
  726. if file_size <= max_bytes:
  727. return [file_path] # Nothing to split
  728. audio = AudioSegment.from_file(file_path)
  729. duration_ms = len(audio)
  730. orig_size = file_size
  731. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  732. chunks = []
  733. start = 0
  734. i = 0
  735. base, _ = os.path.splitext(file_path)
  736. while start < duration_ms:
  737. end = min(start + approx_chunk_ms, duration_ms)
  738. chunk = audio[start:end]
  739. chunk_path = f"{base}_chunk_{i}.{format}"
  740. chunk.export(chunk_path, format=format, bitrate=bitrate)
  741. # Reduce chunk duration if still too large
  742. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  743. end = start + ((end - start) // 2)
  744. chunk = audio[start:end]
  745. chunk.export(chunk_path, format=format, bitrate=bitrate)
  746. if os.path.getsize(chunk_path) > max_bytes:
  747. os.remove(chunk_path)
  748. raise Exception("Audio chunk cannot be reduced below max file size.")
  749. chunks.append(chunk_path)
  750. start = end
  751. i += 1
  752. return chunks
  753. @router.post("/transcriptions")
  754. def transcription(
  755. request: Request,
  756. file: UploadFile = File(...),
  757. language: Optional[str] = Form(None),
  758. user=Depends(get_verified_user),
  759. ):
  760. log.info(f"file.content_type: {file.content_type}")
  761. SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types!
  762. if not (
  763. file.content_type.startswith("audio/")
  764. or file.content_type in SUPPORTED_CONTENT_TYPES
  765. ):
  766. raise HTTPException(
  767. status_code=status.HTTP_400_BAD_REQUEST,
  768. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  769. )
  770. try:
  771. ext = file.filename.split(".")[-1]
  772. id = uuid.uuid4()
  773. filename = f"{id}.{ext}"
  774. contents = file.file.read()
  775. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  776. os.makedirs(file_dir, exist_ok=True)
  777. file_path = f"{file_dir}/{filename}"
  778. with open(file_path, "wb") as f:
  779. f.write(contents)
  780. try:
  781. metadata = None
  782. if language:
  783. metadata = {"language": language}
  784. result = transcribe(request, file_path, metadata)
  785. return {
  786. **result,
  787. "filename": os.path.basename(file_path),
  788. }
  789. except Exception as e:
  790. log.exception(e)
  791. raise HTTPException(
  792. status_code=status.HTTP_400_BAD_REQUEST,
  793. detail=ERROR_MESSAGES.DEFAULT(e),
  794. )
  795. except Exception as e:
  796. log.exception(e)
  797. raise HTTPException(
  798. status_code=status.HTTP_400_BAD_REQUEST,
  799. detail=ERROR_MESSAGES.DEFAULT(e),
  800. )
  801. def get_available_models(request: Request) -> list[dict]:
  802. available_models = []
  803. if request.app.state.config.TTS_ENGINE == "openai":
  804. # Use custom endpoint if not using the official OpenAI API URL
  805. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  806. "https://api.openai.com"
  807. ):
  808. try:
  809. response = requests.get(
  810. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  811. )
  812. response.raise_for_status()
  813. data = response.json()
  814. available_models = data.get("models", [])
  815. except Exception as e:
  816. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  817. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  818. else:
  819. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  820. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  821. try:
  822. response = requests.get(
  823. "https://api.elevenlabs.io/v1/models",
  824. headers={
  825. "xi-api-key": request.app.state.config.TTS_API_KEY,
  826. "Content-Type": "application/json",
  827. },
  828. timeout=5,
  829. )
  830. response.raise_for_status()
  831. models = response.json()
  832. available_models = [
  833. {"name": model["name"], "id": model["model_id"]} for model in models
  834. ]
  835. except requests.RequestException as e:
  836. log.error(f"Error fetching voices: {str(e)}")
  837. return available_models
  838. @router.get("/models")
  839. async def get_models(request: Request, user=Depends(get_verified_user)):
  840. return {"models": get_available_models(request)}
  841. def get_available_voices(request) -> dict:
  842. """Returns {voice_id: voice_name} dict"""
  843. available_voices = {}
  844. if request.app.state.config.TTS_ENGINE == "openai":
  845. # Use custom endpoint if not using the official OpenAI API URL
  846. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  847. "https://api.openai.com"
  848. ):
  849. try:
  850. response = requests.get(
  851. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  852. )
  853. response.raise_for_status()
  854. data = response.json()
  855. voices_list = data.get("voices", [])
  856. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  857. except Exception as e:
  858. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  859. available_voices = {
  860. "alloy": "alloy",
  861. "echo": "echo",
  862. "fable": "fable",
  863. "onyx": "onyx",
  864. "nova": "nova",
  865. "shimmer": "shimmer",
  866. }
  867. else:
  868. available_voices = {
  869. "alloy": "alloy",
  870. "echo": "echo",
  871. "fable": "fable",
  872. "onyx": "onyx",
  873. "nova": "nova",
  874. "shimmer": "shimmer",
  875. }
  876. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  877. try:
  878. available_voices = get_elevenlabs_voices(
  879. api_key=request.app.state.config.TTS_API_KEY
  880. )
  881. except Exception:
  882. # Avoided @lru_cache with exception
  883. pass
  884. elif request.app.state.config.TTS_ENGINE == "azure":
  885. try:
  886. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  887. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  888. url = (
  889. base_url or f"https://{region}.tts.speech.microsoft.com"
  890. ) + "/cognitiveservices/voices/list"
  891. headers = {
  892. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  893. }
  894. response = requests.get(url, headers=headers)
  895. response.raise_for_status()
  896. voices = response.json()
  897. for voice in voices:
  898. available_voices[voice["ShortName"]] = (
  899. f"{voice['DisplayName']} ({voice['ShortName']})"
  900. )
  901. except requests.RequestException as e:
  902. log.error(f"Error fetching voices: {str(e)}")
  903. return available_voices
  904. @lru_cache
  905. def get_elevenlabs_voices(api_key: str) -> dict:
  906. """
  907. Note, set the following in your .env file to use Elevenlabs:
  908. AUDIO_TTS_ENGINE=elevenlabs
  909. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  910. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  911. AUDIO_TTS_MODEL=eleven_multilingual_v2
  912. """
  913. try:
  914. # TODO: Add retries
  915. response = requests.get(
  916. "https://api.elevenlabs.io/v1/voices",
  917. headers={
  918. "xi-api-key": api_key,
  919. "Content-Type": "application/json",
  920. },
  921. )
  922. response.raise_for_status()
  923. voices_data = response.json()
  924. voices = {}
  925. for voice in voices_data.get("voices", []):
  926. voices[voice["voice_id"]] = voice["name"]
  927. except requests.RequestException as e:
  928. # Avoid @lru_cache with exception
  929. log.error(f"Error fetching voices: {str(e)}")
  930. raise RuntimeError(f"Error fetching voices: {str(e)}")
  931. return voices
  932. @router.get("/voices")
  933. async def get_voices(request: Request, user=Depends(get_verified_user)):
  934. return {
  935. "voices": [
  936. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  937. ]
  938. }