audio.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pydub import AudioSegment
  8. from pydub.silence import split_on_silence
  9. from concurrent.futures import ThreadPoolExecutor
  10. from typing import Optional
  11. from fnmatch import fnmatch
  12. import aiohttp
  13. import aiofiles
  14. import requests
  15. import mimetypes
  16. from urllib.parse import urljoin, quote
  17. from fastapi import (
  18. Depends,
  19. FastAPI,
  20. File,
  21. Form,
  22. HTTPException,
  23. Request,
  24. UploadFile,
  25. status,
  26. APIRouter,
  27. )
  28. from fastapi.middleware.cors import CORSMiddleware
  29. from fastapi.responses import FileResponse
  30. from pydantic import BaseModel
  31. from open_webui.utils.auth import get_admin_user, get_verified_user
  32. from open_webui.config import (
  33. WHISPER_MODEL_AUTO_UPDATE,
  34. WHISPER_MODEL_DIR,
  35. CACHE_DIR,
  36. WHISPER_LANGUAGE,
  37. )
  38. from open_webui.constants import ERROR_MESSAGES
  39. from open_webui.env import (
  40. AIOHTTP_CLIENT_SESSION_SSL,
  41. AIOHTTP_CLIENT_TIMEOUT,
  42. ENV,
  43. SRC_LOG_LEVELS,
  44. DEVICE_TYPE,
  45. ENABLE_FORWARD_USER_INFO_HEADERS,
  46. )
  47. router = APIRouter()
  48. # Constants
  49. MAX_FILE_SIZE_MB = 20
  50. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  51. AZURE_MAX_FILE_SIZE_MB = 200
  52. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  53. log = logging.getLogger(__name__)
  54. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  55. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  56. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  57. ##########################################
  58. #
  59. # Utility functions
  60. #
  61. ##########################################
  62. from pydub import AudioSegment
  63. from pydub.utils import mediainfo
  64. def is_audio_conversion_required(file_path):
  65. """
  66. Check if the given audio file needs conversion to mp3.
  67. """
  68. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  69. if not os.path.isfile(file_path):
  70. log.error(f"File not found: {file_path}")
  71. return False
  72. try:
  73. info = mediainfo(file_path)
  74. codec_name = info.get("codec_name", "").lower()
  75. codec_type = info.get("codec_type", "").lower()
  76. codec_tag_string = info.get("codec_tag_string", "").lower()
  77. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  78. # File is AAC/mp4a audio, recommend mp3 conversion
  79. return True
  80. # If the codec name is in the supported formats
  81. if codec_name in SUPPORTED_FORMATS:
  82. return False
  83. return True
  84. except Exception as e:
  85. log.error(f"Error getting audio format: {e}")
  86. return False
  87. def convert_audio_to_mp3(file_path):
  88. """Convert audio file to mp3 format."""
  89. try:
  90. output_path = os.path.splitext(file_path)[0] + ".mp3"
  91. audio = AudioSegment.from_file(file_path)
  92. audio.export(output_path, format="mp3")
  93. log.info(f"Converted {file_path} to {output_path}")
  94. return output_path
  95. except Exception as e:
  96. log.error(f"Error converting audio file: {e}")
  97. return None
  98. def set_faster_whisper_model(model: str, auto_update: bool = False):
  99. whisper_model = None
  100. if model:
  101. from faster_whisper import WhisperModel
  102. faster_whisper_kwargs = {
  103. "model_size_or_path": model,
  104. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  105. "compute_type": "int8",
  106. "download_root": WHISPER_MODEL_DIR,
  107. "local_files_only": not auto_update,
  108. }
  109. try:
  110. whisper_model = WhisperModel(**faster_whisper_kwargs)
  111. except Exception:
  112. log.warning(
  113. "WhisperModel initialization failed, attempting download with local_files_only=False"
  114. )
  115. faster_whisper_kwargs["local_files_only"] = False
  116. whisper_model = WhisperModel(**faster_whisper_kwargs)
  117. return whisper_model
  118. ##########################################
  119. #
  120. # Audio API
  121. #
  122. ##########################################
  123. class TTSConfigForm(BaseModel):
  124. OPENAI_API_BASE_URL: str
  125. OPENAI_API_KEY: str
  126. API_KEY: str
  127. ENGINE: str
  128. MODEL: str
  129. VOICE: str
  130. SPLIT_ON: str
  131. AZURE_SPEECH_REGION: str
  132. AZURE_SPEECH_BASE_URL: str
  133. AZURE_SPEECH_OUTPUT_FORMAT: str
  134. class STTConfigForm(BaseModel):
  135. OPENAI_API_BASE_URL: str
  136. OPENAI_API_KEY: str
  137. ENGINE: str
  138. MODEL: str
  139. SUPPORTED_CONTENT_TYPES: list[str] = []
  140. WHISPER_MODEL: str
  141. DEEPGRAM_API_KEY: str
  142. AZURE_API_KEY: str
  143. AZURE_REGION: str
  144. AZURE_LOCALES: str
  145. AZURE_BASE_URL: str
  146. AZURE_MAX_SPEAKERS: str
  147. class AudioConfigUpdateForm(BaseModel):
  148. tts: TTSConfigForm
  149. stt: STTConfigForm
  150. @router.get("/config")
  151. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  152. return {
  153. "tts": {
  154. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  155. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  156. "API_KEY": request.app.state.config.TTS_API_KEY,
  157. "ENGINE": request.app.state.config.TTS_ENGINE,
  158. "MODEL": request.app.state.config.TTS_MODEL,
  159. "VOICE": request.app.state.config.TTS_VOICE,
  160. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  161. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  162. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  163. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  164. },
  165. "stt": {
  166. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  167. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  168. "ENGINE": request.app.state.config.STT_ENGINE,
  169. "MODEL": request.app.state.config.STT_MODEL,
  170. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  171. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  172. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  173. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  174. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  175. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  176. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  177. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  178. },
  179. }
  180. @router.post("/config/update")
  181. async def update_audio_config(
  182. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  183. ):
  184. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  185. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  186. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  187. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  188. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  189. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  190. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  191. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  192. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  193. form_data.tts.AZURE_SPEECH_BASE_URL
  194. )
  195. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  196. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  197. )
  198. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  199. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  200. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  201. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  202. request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = (
  203. form_data.stt.SUPPORTED_CONTENT_TYPES
  204. )
  205. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  206. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  207. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  208. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  209. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  210. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  211. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  212. form_data.stt.AZURE_MAX_SPEAKERS
  213. )
  214. if request.app.state.config.STT_ENGINE == "":
  215. request.app.state.faster_whisper_model = set_faster_whisper_model(
  216. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  217. )
  218. else:
  219. request.app.state.faster_whisper_model = None
  220. return {
  221. "tts": {
  222. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  223. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  224. "API_KEY": request.app.state.config.TTS_API_KEY,
  225. "ENGINE": request.app.state.config.TTS_ENGINE,
  226. "MODEL": request.app.state.config.TTS_MODEL,
  227. "VOICE": request.app.state.config.TTS_VOICE,
  228. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  229. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  230. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  231. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  232. },
  233. "stt": {
  234. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  235. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  236. "ENGINE": request.app.state.config.STT_ENGINE,
  237. "MODEL": request.app.state.config.STT_MODEL,
  238. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  239. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  240. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  241. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  242. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  243. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  244. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  245. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  246. },
  247. }
  248. def load_speech_pipeline(request):
  249. from transformers import pipeline
  250. from datasets import load_dataset
  251. if request.app.state.speech_synthesiser is None:
  252. request.app.state.speech_synthesiser = pipeline(
  253. "text-to-speech", "microsoft/speecht5_tts"
  254. )
  255. if request.app.state.speech_speaker_embeddings_dataset is None:
  256. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  257. "Matthijs/cmu-arctic-xvectors", split="validation"
  258. )
  259. @router.post("/speech")
  260. async def speech(request: Request, user=Depends(get_verified_user)):
  261. body = await request.body()
  262. name = hashlib.sha256(
  263. body
  264. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  265. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  266. ).hexdigest()
  267. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  268. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  269. # Check if the file already exists in the cache
  270. if file_path.is_file():
  271. return FileResponse(file_path)
  272. payload = None
  273. try:
  274. payload = json.loads(body.decode("utf-8"))
  275. except Exception as e:
  276. log.exception(e)
  277. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  278. r = None
  279. if request.app.state.config.TTS_ENGINE == "openai":
  280. payload["model"] = request.app.state.config.TTS_MODEL
  281. try:
  282. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  283. async with aiohttp.ClientSession(
  284. timeout=timeout, trust_env=True
  285. ) as session:
  286. r = await session.post(
  287. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  288. json=payload,
  289. headers={
  290. "Content-Type": "application/json",
  291. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  292. **(
  293. {
  294. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  295. "X-OpenWebUI-User-Id": user.id,
  296. "X-OpenWebUI-User-Email": user.email,
  297. "X-OpenWebUI-User-Role": user.role,
  298. }
  299. if ENABLE_FORWARD_USER_INFO_HEADERS
  300. else {}
  301. ),
  302. },
  303. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  304. )
  305. r.raise_for_status()
  306. async with aiofiles.open(file_path, "wb") as f:
  307. await f.write(await r.read())
  308. async with aiofiles.open(file_body_path, "w") as f:
  309. await f.write(json.dumps(payload))
  310. return FileResponse(file_path)
  311. except Exception as e:
  312. log.exception(e)
  313. detail = None
  314. status_code = 500
  315. detail = f"Open WebUI: Server Connection Error"
  316. if r is not None:
  317. status_code = r.status
  318. try:
  319. res = await r.json()
  320. if "error" in res:
  321. detail = f"External: {res['error']}"
  322. except Exception:
  323. detail = f"External: {e}"
  324. raise HTTPException(
  325. status_code=status_code,
  326. detail=detail,
  327. )
  328. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  329. voice_id = payload.get("voice", "")
  330. if voice_id not in get_available_voices(request):
  331. raise HTTPException(
  332. status_code=400,
  333. detail="Invalid voice id",
  334. )
  335. try:
  336. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  337. async with aiohttp.ClientSession(
  338. timeout=timeout, trust_env=True
  339. ) as session:
  340. async with session.post(
  341. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  342. json={
  343. "text": payload["input"],
  344. "model_id": request.app.state.config.TTS_MODEL,
  345. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  346. },
  347. headers={
  348. "Accept": "audio/mpeg",
  349. "Content-Type": "application/json",
  350. "xi-api-key": request.app.state.config.TTS_API_KEY,
  351. },
  352. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  353. ) as r:
  354. r.raise_for_status()
  355. async with aiofiles.open(file_path, "wb") as f:
  356. await f.write(await r.read())
  357. async with aiofiles.open(file_body_path, "w") as f:
  358. await f.write(json.dumps(payload))
  359. return FileResponse(file_path)
  360. except Exception as e:
  361. log.exception(e)
  362. detail = None
  363. try:
  364. if r.status != 200:
  365. res = await r.json()
  366. if "error" in res:
  367. detail = f"External: {res['error'].get('message', '')}"
  368. except Exception:
  369. detail = f"External: {e}"
  370. raise HTTPException(
  371. status_code=getattr(r, "status", 500) if r else 500,
  372. detail=detail if detail else "Open WebUI: Server Connection Error",
  373. )
  374. elif request.app.state.config.TTS_ENGINE == "azure":
  375. try:
  376. payload = json.loads(body.decode("utf-8"))
  377. except Exception as e:
  378. log.exception(e)
  379. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  380. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  381. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  382. language = request.app.state.config.TTS_VOICE
  383. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  384. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  385. try:
  386. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  387. <voice name="{language}">{payload["input"]}</voice>
  388. </speak>"""
  389. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  390. async with aiohttp.ClientSession(
  391. timeout=timeout, trust_env=True
  392. ) as session:
  393. async with session.post(
  394. (base_url or f"https://{region}.tts.speech.microsoft.com")
  395. + "/cognitiveservices/v1",
  396. headers={
  397. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  398. "Content-Type": "application/ssml+xml",
  399. "X-Microsoft-OutputFormat": output_format,
  400. },
  401. data=data,
  402. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  403. ) as r:
  404. r.raise_for_status()
  405. async with aiofiles.open(file_path, "wb") as f:
  406. await f.write(await r.read())
  407. async with aiofiles.open(file_body_path, "w") as f:
  408. await f.write(json.dumps(payload))
  409. return FileResponse(file_path)
  410. except Exception as e:
  411. log.exception(e)
  412. detail = None
  413. try:
  414. if r.status != 200:
  415. res = await r.json()
  416. if "error" in res:
  417. detail = f"External: {res['error'].get('message', '')}"
  418. except Exception:
  419. detail = f"External: {e}"
  420. raise HTTPException(
  421. status_code=getattr(r, "status", 500) if r else 500,
  422. detail=detail if detail else "Open WebUI: Server Connection Error",
  423. )
  424. elif request.app.state.config.TTS_ENGINE == "transformers":
  425. payload = None
  426. try:
  427. payload = json.loads(body.decode("utf-8"))
  428. except Exception as e:
  429. log.exception(e)
  430. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  431. import torch
  432. import soundfile as sf
  433. load_speech_pipeline(request)
  434. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  435. speaker_index = 6799
  436. try:
  437. speaker_index = embeddings_dataset["filename"].index(
  438. request.app.state.config.TTS_MODEL
  439. )
  440. except Exception:
  441. pass
  442. speaker_embedding = torch.tensor(
  443. embeddings_dataset[speaker_index]["xvector"]
  444. ).unsqueeze(0)
  445. speech = request.app.state.speech_synthesiser(
  446. payload["input"],
  447. forward_params={"speaker_embeddings": speaker_embedding},
  448. )
  449. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  450. async with aiofiles.open(file_body_path, "w") as f:
  451. await f.write(json.dumps(payload))
  452. return FileResponse(file_path)
  453. def transcription_handler(request, file_path, metadata):
  454. filename = os.path.basename(file_path)
  455. file_dir = os.path.dirname(file_path)
  456. id = filename.split(".")[0]
  457. metadata = metadata or {}
  458. languages = [
  459. metadata.get("language", None) if WHISPER_LANGUAGE == "" else WHISPER_LANGUAGE,
  460. None, # Always fallback to None in case transcription fails
  461. ]
  462. if request.app.state.config.STT_ENGINE == "":
  463. if request.app.state.faster_whisper_model is None:
  464. request.app.state.faster_whisper_model = set_faster_whisper_model(
  465. request.app.state.config.WHISPER_MODEL
  466. )
  467. model = request.app.state.faster_whisper_model
  468. segments, info = model.transcribe(
  469. file_path,
  470. beam_size=5,
  471. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  472. language=languages[0],
  473. )
  474. log.info(
  475. "Detected language '%s' with probability %f"
  476. % (info.language, info.language_probability)
  477. )
  478. transcript = "".join([segment.text for segment in list(segments)])
  479. data = {"text": transcript.strip()}
  480. # save the transcript to a json file
  481. transcript_file = f"{file_dir}/{id}.json"
  482. with open(transcript_file, "w") as f:
  483. json.dump(data, f)
  484. log.debug(data)
  485. return data
  486. elif request.app.state.config.STT_ENGINE == "openai":
  487. r = None
  488. try:
  489. for language in languages:
  490. payload = {
  491. "model": request.app.state.config.STT_MODEL,
  492. }
  493. if language:
  494. payload["language"] = language
  495. r = requests.post(
  496. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  497. headers={
  498. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  499. },
  500. files={"file": (filename, open(file_path, "rb"))},
  501. data=payload,
  502. )
  503. if r.status_code == 200:
  504. # Successful transcription
  505. break
  506. r.raise_for_status()
  507. data = r.json()
  508. # save the transcript to a json file
  509. transcript_file = f"{file_dir}/{id}.json"
  510. with open(transcript_file, "w") as f:
  511. json.dump(data, f)
  512. return data
  513. except Exception as e:
  514. log.exception(e)
  515. detail = None
  516. if r is not None:
  517. try:
  518. res = r.json()
  519. if "error" in res:
  520. detail = f"External: {res['error'].get('message', '')}"
  521. except Exception:
  522. detail = f"External: {e}"
  523. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  524. elif request.app.state.config.STT_ENGINE == "deepgram":
  525. try:
  526. # Determine the MIME type of the file
  527. mime, _ = mimetypes.guess_type(file_path)
  528. if not mime:
  529. mime = "audio/wav" # fallback to wav if undetectable
  530. # Read the audio file
  531. with open(file_path, "rb") as f:
  532. file_data = f.read()
  533. # Build headers and parameters
  534. headers = {
  535. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  536. "Content-Type": mime,
  537. }
  538. for language in languages:
  539. params = {}
  540. if request.app.state.config.STT_MODEL:
  541. params["model"] = request.app.state.config.STT_MODEL
  542. if language:
  543. params["language"] = language
  544. # Make request to Deepgram API
  545. r = requests.post(
  546. "https://api.deepgram.com/v1/listen?smart_format=true",
  547. headers=headers,
  548. params=params,
  549. data=file_data,
  550. )
  551. if r.status_code == 200:
  552. # Successful transcription
  553. break
  554. r.raise_for_status()
  555. response_data = r.json()
  556. # Extract transcript from Deepgram response
  557. try:
  558. transcript = response_data["results"]["channels"][0]["alternatives"][
  559. 0
  560. ].get("transcript", "")
  561. except (KeyError, IndexError) as e:
  562. log.error(f"Malformed response from Deepgram: {str(e)}")
  563. raise Exception(
  564. "Failed to parse Deepgram response - unexpected response format"
  565. )
  566. data = {"text": transcript.strip()}
  567. # Save transcript
  568. transcript_file = f"{file_dir}/{id}.json"
  569. with open(transcript_file, "w") as f:
  570. json.dump(data, f)
  571. return data
  572. except Exception as e:
  573. log.exception(e)
  574. detail = None
  575. if r is not None:
  576. try:
  577. res = r.json()
  578. if "error" in res:
  579. detail = f"External: {res['error'].get('message', '')}"
  580. except Exception:
  581. detail = f"External: {e}"
  582. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  583. elif request.app.state.config.STT_ENGINE == "azure":
  584. # Check file exists and size
  585. if not os.path.exists(file_path):
  586. raise HTTPException(status_code=400, detail="Audio file not found")
  587. # Check file size (Azure has a larger limit of 200MB)
  588. file_size = os.path.getsize(file_path)
  589. if file_size > AZURE_MAX_FILE_SIZE:
  590. raise HTTPException(
  591. status_code=400,
  592. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  593. )
  594. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  595. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  596. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  597. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  598. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  599. # IF NO LOCALES, USE DEFAULTS
  600. if len(locales) < 2:
  601. locales = [
  602. "en-US",
  603. "es-ES",
  604. "es-MX",
  605. "fr-FR",
  606. "hi-IN",
  607. "it-IT",
  608. "de-DE",
  609. "en-GB",
  610. "en-IN",
  611. "ja-JP",
  612. "ko-KR",
  613. "pt-BR",
  614. "zh-CN",
  615. ]
  616. locales = ",".join(locales)
  617. if not api_key or not region:
  618. raise HTTPException(
  619. status_code=400,
  620. detail="Azure API key is required for Azure STT",
  621. )
  622. r = None
  623. try:
  624. # Prepare the request
  625. data = {
  626. "definition": json.dumps(
  627. {
  628. "locales": locales.split(","),
  629. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  630. }
  631. if locales
  632. else {}
  633. )
  634. }
  635. url = (
  636. base_url or f"https://{region}.api.cognitive.microsoft.com"
  637. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  638. # Use context manager to ensure file is properly closed
  639. with open(file_path, "rb") as audio_file:
  640. r = requests.post(
  641. url=url,
  642. files={"audio": audio_file},
  643. data=data,
  644. headers={
  645. "Ocp-Apim-Subscription-Key": api_key,
  646. },
  647. )
  648. r.raise_for_status()
  649. response = r.json()
  650. # Extract transcript from response
  651. if not response.get("combinedPhrases"):
  652. raise ValueError("No transcription found in response")
  653. # Get the full transcript from combinedPhrases
  654. transcript = response["combinedPhrases"][0].get("text", "").strip()
  655. if not transcript:
  656. raise ValueError("Empty transcript in response")
  657. data = {"text": transcript}
  658. # Save transcript to json file (consistent with other providers)
  659. transcript_file = f"{file_dir}/{id}.json"
  660. with open(transcript_file, "w") as f:
  661. json.dump(data, f)
  662. log.debug(data)
  663. return data
  664. except (KeyError, IndexError, ValueError) as e:
  665. log.exception("Error parsing Azure response")
  666. raise HTTPException(
  667. status_code=500,
  668. detail=f"Failed to parse Azure response: {str(e)}",
  669. )
  670. except requests.exceptions.RequestException as e:
  671. log.exception(e)
  672. detail = None
  673. try:
  674. if r is not None and r.status_code != 200:
  675. res = r.json()
  676. if "error" in res:
  677. detail = f"External: {res['error'].get('message', '')}"
  678. except Exception:
  679. detail = f"External: {e}"
  680. raise HTTPException(
  681. status_code=getattr(r, "status_code", 500) if r else 500,
  682. detail=detail if detail else "Open WebUI: Server Connection Error",
  683. )
  684. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  685. log.info(f"transcribe: {file_path} {metadata}")
  686. if is_audio_conversion_required(file_path):
  687. file_path = convert_audio_to_mp3(file_path)
  688. try:
  689. file_path = compress_audio(file_path)
  690. except Exception as e:
  691. log.exception(e)
  692. # Always produce a list of chunk paths (could be one entry if small)
  693. try:
  694. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  695. print(f"Chunk paths: {chunk_paths}")
  696. except Exception as e:
  697. log.exception(e)
  698. raise HTTPException(
  699. status_code=status.HTTP_400_BAD_REQUEST,
  700. detail=ERROR_MESSAGES.DEFAULT(e),
  701. )
  702. results = []
  703. try:
  704. with ThreadPoolExecutor() as executor:
  705. # Submit tasks for each chunk_path
  706. futures = [
  707. executor.submit(transcription_handler, request, chunk_path, metadata)
  708. for chunk_path in chunk_paths
  709. ]
  710. # Gather results as they complete
  711. for future in futures:
  712. try:
  713. results.append(future.result())
  714. except Exception as transcribe_exc:
  715. raise HTTPException(
  716. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  717. detail=f"Error transcribing chunk: {transcribe_exc}",
  718. )
  719. finally:
  720. # Clean up only the temporary chunks, never the original file
  721. for chunk_path in chunk_paths:
  722. if chunk_path != file_path and os.path.isfile(chunk_path):
  723. try:
  724. os.remove(chunk_path)
  725. except Exception:
  726. pass
  727. return {
  728. "text": " ".join([result["text"] for result in results]),
  729. }
  730. def compress_audio(file_path):
  731. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  732. id = os.path.splitext(os.path.basename(file_path))[
  733. 0
  734. ] # Handles names with multiple dots
  735. file_dir = os.path.dirname(file_path)
  736. audio = AudioSegment.from_file(file_path)
  737. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  738. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  739. audio.export(compressed_path, format="mp3", bitrate="32k")
  740. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  741. return compressed_path
  742. else:
  743. return file_path
  744. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  745. """
  746. Splits audio into chunks not exceeding max_bytes.
  747. Returns a list of chunk file paths. If audio fits, returns list with original path.
  748. """
  749. file_size = os.path.getsize(file_path)
  750. if file_size <= max_bytes:
  751. return [file_path] # Nothing to split
  752. audio = AudioSegment.from_file(file_path)
  753. duration_ms = len(audio)
  754. orig_size = file_size
  755. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  756. chunks = []
  757. start = 0
  758. i = 0
  759. base, _ = os.path.splitext(file_path)
  760. while start < duration_ms:
  761. end = min(start + approx_chunk_ms, duration_ms)
  762. chunk = audio[start:end]
  763. chunk_path = f"{base}_chunk_{i}.{format}"
  764. chunk.export(chunk_path, format=format, bitrate=bitrate)
  765. # Reduce chunk duration if still too large
  766. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  767. end = start + ((end - start) // 2)
  768. chunk = audio[start:end]
  769. chunk.export(chunk_path, format=format, bitrate=bitrate)
  770. if os.path.getsize(chunk_path) > max_bytes:
  771. os.remove(chunk_path)
  772. raise Exception("Audio chunk cannot be reduced below max file size.")
  773. chunks.append(chunk_path)
  774. start = end
  775. i += 1
  776. return chunks
  777. @router.post("/transcriptions")
  778. def transcription(
  779. request: Request,
  780. file: UploadFile = File(...),
  781. language: Optional[str] = Form(None),
  782. user=Depends(get_verified_user),
  783. ):
  784. log.info(f"file.content_type: {file.content_type}")
  785. stt_supported_content_types = getattr(
  786. request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
  787. )
  788. if not any(
  789. fnmatch(file.content_type, content_type)
  790. for content_type in (
  791. stt_supported_content_types
  792. if stt_supported_content_types
  793. and any(t.strip() for t in stt_supported_content_types)
  794. else ["audio/*", "video/webm"]
  795. )
  796. ):
  797. raise HTTPException(
  798. status_code=status.HTTP_400_BAD_REQUEST,
  799. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  800. )
  801. try:
  802. ext = file.filename.split(".")[-1]
  803. id = uuid.uuid4()
  804. filename = f"{id}.{ext}"
  805. contents = file.file.read()
  806. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  807. os.makedirs(file_dir, exist_ok=True)
  808. file_path = f"{file_dir}/{filename}"
  809. with open(file_path, "wb") as f:
  810. f.write(contents)
  811. try:
  812. metadata = None
  813. if language:
  814. metadata = {"language": language}
  815. result = transcribe(request, file_path, metadata)
  816. return {
  817. **result,
  818. "filename": os.path.basename(file_path),
  819. }
  820. except Exception as e:
  821. log.exception(e)
  822. raise HTTPException(
  823. status_code=status.HTTP_400_BAD_REQUEST,
  824. detail=ERROR_MESSAGES.DEFAULT(e),
  825. )
  826. except Exception as e:
  827. log.exception(e)
  828. raise HTTPException(
  829. status_code=status.HTTP_400_BAD_REQUEST,
  830. detail=ERROR_MESSAGES.DEFAULT(e),
  831. )
  832. def get_available_models(request: Request) -> list[dict]:
  833. available_models = []
  834. if request.app.state.config.TTS_ENGINE == "openai":
  835. # Use custom endpoint if not using the official OpenAI API URL
  836. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  837. "https://api.openai.com"
  838. ):
  839. try:
  840. response = requests.get(
  841. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  842. )
  843. response.raise_for_status()
  844. data = response.json()
  845. available_models = data.get("models", [])
  846. except Exception as e:
  847. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  848. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  849. else:
  850. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  851. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  852. try:
  853. response = requests.get(
  854. "https://api.elevenlabs.io/v1/models",
  855. headers={
  856. "xi-api-key": request.app.state.config.TTS_API_KEY,
  857. "Content-Type": "application/json",
  858. },
  859. timeout=5,
  860. )
  861. response.raise_for_status()
  862. models = response.json()
  863. available_models = [
  864. {"name": model["name"], "id": model["model_id"]} for model in models
  865. ]
  866. except requests.RequestException as e:
  867. log.error(f"Error fetching voices: {str(e)}")
  868. return available_models
  869. @router.get("/models")
  870. async def get_models(request: Request, user=Depends(get_verified_user)):
  871. return {"models": get_available_models(request)}
  872. def get_available_voices(request) -> dict:
  873. """Returns {voice_id: voice_name} dict"""
  874. available_voices = {}
  875. if request.app.state.config.TTS_ENGINE == "openai":
  876. # Use custom endpoint if not using the official OpenAI API URL
  877. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  878. "https://api.openai.com"
  879. ):
  880. try:
  881. response = requests.get(
  882. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  883. )
  884. response.raise_for_status()
  885. data = response.json()
  886. voices_list = data.get("voices", [])
  887. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  888. except Exception as e:
  889. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  890. available_voices = {
  891. "alloy": "alloy",
  892. "echo": "echo",
  893. "fable": "fable",
  894. "onyx": "onyx",
  895. "nova": "nova",
  896. "shimmer": "shimmer",
  897. }
  898. else:
  899. available_voices = {
  900. "alloy": "alloy",
  901. "echo": "echo",
  902. "fable": "fable",
  903. "onyx": "onyx",
  904. "nova": "nova",
  905. "shimmer": "shimmer",
  906. }
  907. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  908. try:
  909. available_voices = get_elevenlabs_voices(
  910. api_key=request.app.state.config.TTS_API_KEY
  911. )
  912. except Exception:
  913. # Avoided @lru_cache with exception
  914. pass
  915. elif request.app.state.config.TTS_ENGINE == "azure":
  916. try:
  917. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  918. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  919. url = (
  920. base_url or f"https://{region}.tts.speech.microsoft.com"
  921. ) + "/cognitiveservices/voices/list"
  922. headers = {
  923. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  924. }
  925. response = requests.get(url, headers=headers)
  926. response.raise_for_status()
  927. voices = response.json()
  928. for voice in voices:
  929. available_voices[voice["ShortName"]] = (
  930. f"{voice['DisplayName']} ({voice['ShortName']})"
  931. )
  932. except requests.RequestException as e:
  933. log.error(f"Error fetching voices: {str(e)}")
  934. return available_voices
  935. @lru_cache
  936. def get_elevenlabs_voices(api_key: str) -> dict:
  937. """
  938. Note, set the following in your .env file to use Elevenlabs:
  939. AUDIO_TTS_ENGINE=elevenlabs
  940. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  941. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  942. AUDIO_TTS_MODEL=eleven_multilingual_v2
  943. """
  944. try:
  945. # TODO: Add retries
  946. response = requests.get(
  947. "https://api.elevenlabs.io/v1/voices",
  948. headers={
  949. "xi-api-key": api_key,
  950. "Content-Type": "application/json",
  951. },
  952. )
  953. response.raise_for_status()
  954. voices_data = response.json()
  955. voices = {}
  956. for voice in voices_data.get("voices", []):
  957. voices[voice["voice_id"]] = voice["name"]
  958. except requests.RequestException as e:
  959. # Avoid @lru_cache with exception
  960. log.error(f"Error fetching voices: {str(e)}")
  961. raise RuntimeError(f"Error fetching voices: {str(e)}")
  962. return voices
  963. @router.get("/voices")
  964. async def get_voices(request: Request, user=Depends(get_verified_user)):
  965. return {
  966. "voices": [
  967. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  968. ]
  969. }