audio.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from pydub import AudioSegment
  9. from pydub.silence import split_on_silence
  10. from concurrent.futures import ThreadPoolExecutor
  11. from typing import Optional
  12. from fnmatch import fnmatch
  13. import aiohttp
  14. import aiofiles
  15. import requests
  16. import mimetypes
  17. from urllib.parse import quote
  18. from fastapi import (
  19. Depends,
  20. FastAPI,
  21. File,
  22. Form,
  23. HTTPException,
  24. Request,
  25. UploadFile,
  26. status,
  27. APIRouter,
  28. )
  29. from fastapi.middleware.cors import CORSMiddleware
  30. from fastapi.responses import FileResponse
  31. from pydantic import BaseModel
  32. from open_webui.utils.auth import get_admin_user, get_verified_user
  33. from open_webui.config import (
  34. WHISPER_MODEL_AUTO_UPDATE,
  35. WHISPER_MODEL_DIR,
  36. CACHE_DIR,
  37. WHISPER_LANGUAGE,
  38. )
  39. from open_webui.constants import ERROR_MESSAGES
  40. from open_webui.env import (
  41. AIOHTTP_CLIENT_SESSION_SSL,
  42. AIOHTTP_CLIENT_TIMEOUT,
  43. ENV,
  44. SRC_LOG_LEVELS,
  45. DEVICE_TYPE,
  46. ENABLE_FORWARD_USER_INFO_HEADERS,
  47. )
  48. router = APIRouter()
  49. # Constants
  50. MAX_FILE_SIZE_MB = 20
  51. MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  52. AZURE_MAX_FILE_SIZE_MB = 200
  53. AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
  54. log = logging.getLogger(__name__)
  55. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  56. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  57. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  58. ##########################################
  59. #
  60. # Utility functions
  61. #
  62. ##########################################
  63. from pydub import AudioSegment
  64. from pydub.utils import mediainfo
  65. def is_audio_conversion_required(file_path):
  66. """
  67. Check if the given audio file needs conversion to mp3.
  68. """
  69. SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
  70. if not os.path.isfile(file_path):
  71. log.error(f"File not found: {file_path}")
  72. return False
  73. try:
  74. info = mediainfo(file_path)
  75. codec_name = info.get("codec_name", "").lower()
  76. codec_type = info.get("codec_type", "").lower()
  77. codec_tag_string = info.get("codec_tag_string", "").lower()
  78. if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
  79. # File is AAC/mp4a audio, recommend mp3 conversion
  80. return True
  81. # If the codec name is in the supported formats
  82. if codec_name in SUPPORTED_FORMATS:
  83. return False
  84. return True
  85. except Exception as e:
  86. log.error(f"Error getting audio format: {e}")
  87. return False
  88. def convert_audio_to_mp3(file_path):
  89. """Convert audio file to mp3 format."""
  90. try:
  91. output_path = os.path.splitext(file_path)[0] + ".mp3"
  92. audio = AudioSegment.from_file(file_path)
  93. audio.export(output_path, format="mp3")
  94. log.info(f"Converted {file_path} to {output_path}")
  95. return output_path
  96. except Exception as e:
  97. log.error(f"Error converting audio file: {e}")
  98. return None
  99. def set_faster_whisper_model(model: str, auto_update: bool = False):
  100. whisper_model = None
  101. if model:
  102. from faster_whisper import WhisperModel
  103. faster_whisper_kwargs = {
  104. "model_size_or_path": model,
  105. "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
  106. "compute_type": "int8",
  107. "download_root": WHISPER_MODEL_DIR,
  108. "local_files_only": not auto_update,
  109. }
  110. try:
  111. whisper_model = WhisperModel(**faster_whisper_kwargs)
  112. except Exception:
  113. log.warning(
  114. "WhisperModel initialization failed, attempting download with local_files_only=False"
  115. )
  116. faster_whisper_kwargs["local_files_only"] = False
  117. whisper_model = WhisperModel(**faster_whisper_kwargs)
  118. return whisper_model
  119. ##########################################
  120. #
  121. # Audio API
  122. #
  123. ##########################################
  124. class TTSConfigForm(BaseModel):
  125. OPENAI_API_BASE_URL: str
  126. OPENAI_API_KEY: str
  127. API_KEY: str
  128. ENGINE: str
  129. MODEL: str
  130. VOICE: str
  131. SPLIT_ON: str
  132. AZURE_SPEECH_REGION: str
  133. AZURE_SPEECH_BASE_URL: str
  134. AZURE_SPEECH_OUTPUT_FORMAT: str
  135. class STTConfigForm(BaseModel):
  136. OPENAI_API_BASE_URL: str
  137. OPENAI_API_KEY: str
  138. ENGINE: str
  139. MODEL: str
  140. SUPPORTED_CONTENT_TYPES: list[str] = []
  141. WHISPER_MODEL: str
  142. DEEPGRAM_API_KEY: str
  143. AZURE_API_KEY: str
  144. AZURE_REGION: str
  145. AZURE_LOCALES: str
  146. AZURE_BASE_URL: str
  147. AZURE_MAX_SPEAKERS: str
  148. class AudioConfigUpdateForm(BaseModel):
  149. tts: TTSConfigForm
  150. stt: STTConfigForm
  151. @router.get("/config")
  152. async def get_audio_config(request: Request, user=Depends(get_admin_user)):
  153. return {
  154. "tts": {
  155. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  156. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  157. "API_KEY": request.app.state.config.TTS_API_KEY,
  158. "ENGINE": request.app.state.config.TTS_ENGINE,
  159. "MODEL": request.app.state.config.TTS_MODEL,
  160. "VOICE": request.app.state.config.TTS_VOICE,
  161. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  162. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  163. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  164. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  165. },
  166. "stt": {
  167. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  168. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  169. "ENGINE": request.app.state.config.STT_ENGINE,
  170. "MODEL": request.app.state.config.STT_MODEL,
  171. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  172. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  173. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  174. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  175. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  176. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  177. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  178. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  179. },
  180. }
  181. @router.post("/config/update")
  182. async def update_audio_config(
  183. request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  184. ):
  185. request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  186. request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  187. request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  188. request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  189. request.app.state.config.TTS_MODEL = form_data.tts.MODEL
  190. request.app.state.config.TTS_VOICE = form_data.tts.VOICE
  191. request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  192. request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
  193. request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
  194. form_data.tts.AZURE_SPEECH_BASE_URL
  195. )
  196. request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
  197. form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
  198. )
  199. request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  200. request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  201. request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
  202. request.app.state.config.STT_MODEL = form_data.stt.MODEL
  203. request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = (
  204. form_data.stt.SUPPORTED_CONTENT_TYPES
  205. )
  206. request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
  207. request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
  208. request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
  209. request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
  210. request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
  211. request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
  212. request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
  213. form_data.stt.AZURE_MAX_SPEAKERS
  214. )
  215. if request.app.state.config.STT_ENGINE == "":
  216. request.app.state.faster_whisper_model = set_faster_whisper_model(
  217. form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
  218. )
  219. else:
  220. request.app.state.faster_whisper_model = None
  221. return {
  222. "tts": {
  223. "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
  224. "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
  225. "API_KEY": request.app.state.config.TTS_API_KEY,
  226. "ENGINE": request.app.state.config.TTS_ENGINE,
  227. "MODEL": request.app.state.config.TTS_MODEL,
  228. "VOICE": request.app.state.config.TTS_VOICE,
  229. "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
  230. "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
  231. "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
  232. "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
  233. },
  234. "stt": {
  235. "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
  236. "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
  237. "ENGINE": request.app.state.config.STT_ENGINE,
  238. "MODEL": request.app.state.config.STT_MODEL,
  239. "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES,
  240. "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
  241. "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
  242. "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
  243. "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
  244. "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
  245. "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
  246. "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
  247. },
  248. }
  249. def load_speech_pipeline(request):
  250. from transformers import pipeline
  251. from datasets import load_dataset
  252. if request.app.state.speech_synthesiser is None:
  253. request.app.state.speech_synthesiser = pipeline(
  254. "text-to-speech", "microsoft/speecht5_tts"
  255. )
  256. if request.app.state.speech_speaker_embeddings_dataset is None:
  257. request.app.state.speech_speaker_embeddings_dataset = load_dataset(
  258. "Matthijs/cmu-arctic-xvectors", split="validation"
  259. )
  260. @router.post("/speech")
  261. async def speech(request: Request, user=Depends(get_verified_user)):
  262. body = await request.body()
  263. name = hashlib.sha256(
  264. body
  265. + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
  266. + str(request.app.state.config.TTS_MODEL).encode("utf-8")
  267. ).hexdigest()
  268. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  269. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  270. # Check if the file already exists in the cache
  271. if file_path.is_file():
  272. return FileResponse(file_path)
  273. payload = None
  274. try:
  275. payload = json.loads(body.decode("utf-8"))
  276. except Exception as e:
  277. log.exception(e)
  278. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  279. r = None
  280. if request.app.state.config.TTS_ENGINE == "openai":
  281. payload["model"] = request.app.state.config.TTS_MODEL
  282. try:
  283. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  284. async with aiohttp.ClientSession(
  285. timeout=timeout, trust_env=True
  286. ) as session:
  287. r = await session.post(
  288. url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  289. json=payload,
  290. headers={
  291. "Content-Type": "application/json",
  292. "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
  293. **(
  294. {
  295. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  296. "X-OpenWebUI-User-Id": user.id,
  297. "X-OpenWebUI-User-Email": user.email,
  298. "X-OpenWebUI-User-Role": user.role,
  299. }
  300. if ENABLE_FORWARD_USER_INFO_HEADERS
  301. else {}
  302. ),
  303. },
  304. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  305. )
  306. r.raise_for_status()
  307. async with aiofiles.open(file_path, "wb") as f:
  308. await f.write(await r.read())
  309. async with aiofiles.open(file_body_path, "w") as f:
  310. await f.write(json.dumps(payload))
  311. return FileResponse(file_path)
  312. except Exception as e:
  313. log.exception(e)
  314. detail = None
  315. status_code = 500
  316. detail = f"Open WebUI: Server Connection Error"
  317. if r is not None:
  318. status_code = r.status
  319. res = await r.json()
  320. if "error" in res:
  321. detail = f"External: {res['error'].get('message', '')}"
  322. raise HTTPException(
  323. status_code=status_code,
  324. detail=detail,
  325. )
  326. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  327. voice_id = payload.get("voice", "")
  328. if voice_id not in get_available_voices(request):
  329. raise HTTPException(
  330. status_code=400,
  331. detail="Invalid voice id",
  332. )
  333. try:
  334. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  335. async with aiohttp.ClientSession(
  336. timeout=timeout, trust_env=True
  337. ) as session:
  338. async with session.post(
  339. f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
  340. json={
  341. "text": payload["input"],
  342. "model_id": request.app.state.config.TTS_MODEL,
  343. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  344. },
  345. headers={
  346. "Accept": "audio/mpeg",
  347. "Content-Type": "application/json",
  348. "xi-api-key": request.app.state.config.TTS_API_KEY,
  349. },
  350. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  351. ) as r:
  352. r.raise_for_status()
  353. async with aiofiles.open(file_path, "wb") as f:
  354. await f.write(await r.read())
  355. async with aiofiles.open(file_body_path, "w") as f:
  356. await f.write(json.dumps(payload))
  357. return FileResponse(file_path)
  358. except Exception as e:
  359. log.exception(e)
  360. detail = None
  361. try:
  362. if r.status != 200:
  363. res = await r.json()
  364. if "error" in res:
  365. detail = f"External: {res['error'].get('message', '')}"
  366. except Exception:
  367. detail = f"External: {e}"
  368. raise HTTPException(
  369. status_code=getattr(r, "status", 500) if r else 500,
  370. detail=detail if detail else "Open WebUI: Server Connection Error",
  371. )
  372. elif request.app.state.config.TTS_ENGINE == "azure":
  373. try:
  374. payload = json.loads(body.decode("utf-8"))
  375. except Exception as e:
  376. log.exception(e)
  377. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  378. region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
  379. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  380. language = request.app.state.config.TTS_VOICE
  381. locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
  382. output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
  383. try:
  384. data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
  385. <voice name="{language}">{payload["input"]}</voice>
  386. </speak>"""
  387. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  388. async with aiohttp.ClientSession(
  389. timeout=timeout, trust_env=True
  390. ) as session:
  391. async with session.post(
  392. (base_url or f"https://{region}.tts.speech.microsoft.com")
  393. + "/cognitiveservices/v1",
  394. headers={
  395. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
  396. "Content-Type": "application/ssml+xml",
  397. "X-Microsoft-OutputFormat": output_format,
  398. },
  399. data=data,
  400. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  401. ) as r:
  402. r.raise_for_status()
  403. async with aiofiles.open(file_path, "wb") as f:
  404. await f.write(await r.read())
  405. async with aiofiles.open(file_body_path, "w") as f:
  406. await f.write(json.dumps(payload))
  407. return FileResponse(file_path)
  408. except Exception as e:
  409. log.exception(e)
  410. detail = None
  411. try:
  412. if r.status != 200:
  413. res = await r.json()
  414. if "error" in res:
  415. detail = f"External: {res['error'].get('message', '')}"
  416. except Exception:
  417. detail = f"External: {e}"
  418. raise HTTPException(
  419. status_code=getattr(r, "status", 500) if r else 500,
  420. detail=detail if detail else "Open WebUI: Server Connection Error",
  421. )
  422. elif request.app.state.config.TTS_ENGINE == "transformers":
  423. payload = None
  424. try:
  425. payload = json.loads(body.decode("utf-8"))
  426. except Exception as e:
  427. log.exception(e)
  428. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  429. import torch
  430. import soundfile as sf
  431. load_speech_pipeline(request)
  432. embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
  433. speaker_index = 6799
  434. try:
  435. speaker_index = embeddings_dataset["filename"].index(
  436. request.app.state.config.TTS_MODEL
  437. )
  438. except Exception:
  439. pass
  440. speaker_embedding = torch.tensor(
  441. embeddings_dataset[speaker_index]["xvector"]
  442. ).unsqueeze(0)
  443. speech = request.app.state.speech_synthesiser(
  444. payload["input"],
  445. forward_params={"speaker_embeddings": speaker_embedding},
  446. )
  447. sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
  448. async with aiofiles.open(file_body_path, "w") as f:
  449. await f.write(json.dumps(payload))
  450. return FileResponse(file_path)
  451. def transcription_handler(request, file_path, metadata):
  452. filename = os.path.basename(file_path)
  453. file_dir = os.path.dirname(file_path)
  454. id = filename.split(".")[0]
  455. metadata = metadata or {}
  456. if request.app.state.config.STT_ENGINE == "":
  457. if request.app.state.faster_whisper_model is None:
  458. request.app.state.faster_whisper_model = set_faster_whisper_model(
  459. request.app.state.config.WHISPER_MODEL
  460. )
  461. model = request.app.state.faster_whisper_model
  462. segments, info = model.transcribe(
  463. file_path,
  464. beam_size=5,
  465. vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
  466. language=metadata.get("language") or WHISPER_LANGUAGE,
  467. )
  468. log.info(
  469. "Detected language '%s' with probability %f"
  470. % (info.language, info.language_probability)
  471. )
  472. transcript = "".join([segment.text for segment in list(segments)])
  473. data = {"text": transcript.strip()}
  474. # save the transcript to a json file
  475. transcript_file = f"{file_dir}/{id}.json"
  476. with open(transcript_file, "w") as f:
  477. json.dump(data, f)
  478. log.debug(data)
  479. return data
  480. elif request.app.state.config.STT_ENGINE == "openai":
  481. r = None
  482. try:
  483. r = requests.post(
  484. url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  485. headers={
  486. "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
  487. },
  488. files={"file": (filename, open(file_path, "rb"))},
  489. data={
  490. "model": request.app.state.config.STT_MODEL,
  491. **(
  492. {"language": metadata.get("language")}
  493. if metadata.get("language")
  494. else {}
  495. ),
  496. },
  497. )
  498. r.raise_for_status()
  499. data = r.json()
  500. # save the transcript to a json file
  501. transcript_file = f"{file_dir}/{id}.json"
  502. with open(transcript_file, "w") as f:
  503. json.dump(data, f)
  504. return data
  505. except Exception as e:
  506. log.exception(e)
  507. detail = None
  508. if r is not None:
  509. try:
  510. res = r.json()
  511. if "error" in res:
  512. detail = f"External: {res['error'].get('message', '')}"
  513. except Exception:
  514. detail = f"External: {e}"
  515. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  516. elif request.app.state.config.STT_ENGINE == "deepgram":
  517. try:
  518. # Determine the MIME type of the file
  519. mime, _ = mimetypes.guess_type(file_path)
  520. if not mime:
  521. mime = "audio/wav" # fallback to wav if undetectable
  522. # Read the audio file
  523. with open(file_path, "rb") as f:
  524. file_data = f.read()
  525. # Build headers and parameters
  526. headers = {
  527. "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
  528. "Content-Type": mime,
  529. }
  530. # Add model if specified
  531. params = {}
  532. if request.app.state.config.STT_MODEL:
  533. params["model"] = request.app.state.config.STT_MODEL
  534. # Make request to Deepgram API
  535. r = requests.post(
  536. "https://api.deepgram.com/v1/listen?smart_format=true",
  537. headers=headers,
  538. params=params,
  539. data=file_data,
  540. )
  541. r.raise_for_status()
  542. response_data = r.json()
  543. # Extract transcript from Deepgram response
  544. try:
  545. transcript = response_data["results"]["channels"][0]["alternatives"][
  546. 0
  547. ].get("transcript", "")
  548. except (KeyError, IndexError) as e:
  549. log.error(f"Malformed response from Deepgram: {str(e)}")
  550. raise Exception(
  551. "Failed to parse Deepgram response - unexpected response format"
  552. )
  553. data = {"text": transcript.strip()}
  554. # Save transcript
  555. transcript_file = f"{file_dir}/{id}.json"
  556. with open(transcript_file, "w") as f:
  557. json.dump(data, f)
  558. return data
  559. except Exception as e:
  560. log.exception(e)
  561. detail = None
  562. if r is not None:
  563. try:
  564. res = r.json()
  565. if "error" in res:
  566. detail = f"External: {res['error'].get('message', '')}"
  567. except Exception:
  568. detail = f"External: {e}"
  569. raise Exception(detail if detail else "Open WebUI: Server Connection Error")
  570. elif request.app.state.config.STT_ENGINE == "azure":
  571. # Check file exists and size
  572. if not os.path.exists(file_path):
  573. raise HTTPException(status_code=400, detail="Audio file not found")
  574. # Check file size (Azure has a larger limit of 200MB)
  575. file_size = os.path.getsize(file_path)
  576. if file_size > AZURE_MAX_FILE_SIZE:
  577. raise HTTPException(
  578. status_code=400,
  579. detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
  580. )
  581. api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
  582. region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
  583. locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
  584. base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
  585. max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
  586. # IF NO LOCALES, USE DEFAULTS
  587. if len(locales) < 2:
  588. locales = [
  589. "en-US",
  590. "es-ES",
  591. "es-MX",
  592. "fr-FR",
  593. "hi-IN",
  594. "it-IT",
  595. "de-DE",
  596. "en-GB",
  597. "en-IN",
  598. "ja-JP",
  599. "ko-KR",
  600. "pt-BR",
  601. "zh-CN",
  602. ]
  603. locales = ",".join(locales)
  604. if not api_key or not region:
  605. raise HTTPException(
  606. status_code=400,
  607. detail="Azure API key is required for Azure STT",
  608. )
  609. r = None
  610. try:
  611. # Prepare the request
  612. data = {
  613. "definition": json.dumps(
  614. {
  615. "locales": locales.split(","),
  616. "diarization": {"maxSpeakers": max_speakers, "enabled": True},
  617. }
  618. if locales
  619. else {}
  620. )
  621. }
  622. url = (
  623. base_url or f"https://{region}.api.cognitive.microsoft.com"
  624. ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
  625. # Use context manager to ensure file is properly closed
  626. with open(file_path, "rb") as audio_file:
  627. r = requests.post(
  628. url=url,
  629. files={"audio": audio_file},
  630. data=data,
  631. headers={
  632. "Ocp-Apim-Subscription-Key": api_key,
  633. },
  634. )
  635. r.raise_for_status()
  636. response = r.json()
  637. # Extract transcript from response
  638. if not response.get("combinedPhrases"):
  639. raise ValueError("No transcription found in response")
  640. # Get the full transcript from combinedPhrases
  641. transcript = response["combinedPhrases"][0].get("text", "").strip()
  642. if not transcript:
  643. raise ValueError("Empty transcript in response")
  644. data = {"text": transcript}
  645. # Save transcript to json file (consistent with other providers)
  646. transcript_file = f"{file_dir}/{id}.json"
  647. with open(transcript_file, "w") as f:
  648. json.dump(data, f)
  649. log.debug(data)
  650. return data
  651. except (KeyError, IndexError, ValueError) as e:
  652. log.exception("Error parsing Azure response")
  653. raise HTTPException(
  654. status_code=500,
  655. detail=f"Failed to parse Azure response: {str(e)}",
  656. )
  657. except requests.exceptions.RequestException as e:
  658. log.exception(e)
  659. detail = None
  660. try:
  661. if r is not None and r.status_code != 200:
  662. res = r.json()
  663. if "error" in res:
  664. detail = f"External: {res['error'].get('message', '')}"
  665. except Exception:
  666. detail = f"External: {e}"
  667. raise HTTPException(
  668. status_code=getattr(r, "status_code", 500) if r else 500,
  669. detail=detail if detail else "Open WebUI: Server Connection Error",
  670. )
  671. def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
  672. log.info(f"transcribe: {file_path} {metadata}")
  673. if is_audio_conversion_required(file_path):
  674. file_path = convert_audio_to_mp3(file_path)
  675. try:
  676. file_path = compress_audio(file_path)
  677. except Exception as e:
  678. log.exception(e)
  679. # Always produce a list of chunk paths (could be one entry if small)
  680. try:
  681. chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
  682. print(f"Chunk paths: {chunk_paths}")
  683. except Exception as e:
  684. log.exception(e)
  685. raise HTTPException(
  686. status_code=status.HTTP_400_BAD_REQUEST,
  687. detail=ERROR_MESSAGES.DEFAULT(e),
  688. )
  689. results = []
  690. try:
  691. with ThreadPoolExecutor() as executor:
  692. # Submit tasks for each chunk_path
  693. futures = [
  694. executor.submit(transcription_handler, request, chunk_path, metadata)
  695. for chunk_path in chunk_paths
  696. ]
  697. # Gather results as they complete
  698. for future in futures:
  699. try:
  700. results.append(future.result())
  701. except Exception as transcribe_exc:
  702. raise HTTPException(
  703. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  704. detail=f"Error transcribing chunk: {transcribe_exc}",
  705. )
  706. finally:
  707. # Clean up only the temporary chunks, never the original file
  708. for chunk_path in chunk_paths:
  709. if chunk_path != file_path and os.path.isfile(chunk_path):
  710. try:
  711. os.remove(chunk_path)
  712. except Exception:
  713. pass
  714. return {
  715. "text": " ".join([result["text"] for result in results]),
  716. }
  717. def compress_audio(file_path):
  718. if os.path.getsize(file_path) > MAX_FILE_SIZE:
  719. id = os.path.splitext(os.path.basename(file_path))[
  720. 0
  721. ] # Handles names with multiple dots
  722. file_dir = os.path.dirname(file_path)
  723. audio = AudioSegment.from_file(file_path)
  724. audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
  725. compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
  726. audio.export(compressed_path, format="mp3", bitrate="32k")
  727. # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
  728. return compressed_path
  729. else:
  730. return file_path
  731. def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
  732. """
  733. Splits audio into chunks not exceeding max_bytes.
  734. Returns a list of chunk file paths. If audio fits, returns list with original path.
  735. """
  736. file_size = os.path.getsize(file_path)
  737. if file_size <= max_bytes:
  738. return [file_path] # Nothing to split
  739. audio = AudioSegment.from_file(file_path)
  740. duration_ms = len(audio)
  741. orig_size = file_size
  742. approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
  743. chunks = []
  744. start = 0
  745. i = 0
  746. base, _ = os.path.splitext(file_path)
  747. while start < duration_ms:
  748. end = min(start + approx_chunk_ms, duration_ms)
  749. chunk = audio[start:end]
  750. chunk_path = f"{base}_chunk_{i}.{format}"
  751. chunk.export(chunk_path, format=format, bitrate=bitrate)
  752. # Reduce chunk duration if still too large
  753. while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
  754. end = start + ((end - start) // 2)
  755. chunk = audio[start:end]
  756. chunk.export(chunk_path, format=format, bitrate=bitrate)
  757. if os.path.getsize(chunk_path) > max_bytes:
  758. os.remove(chunk_path)
  759. raise Exception("Audio chunk cannot be reduced below max file size.")
  760. chunks.append(chunk_path)
  761. start = end
  762. i += 1
  763. return chunks
  764. @router.post("/transcriptions")
  765. def transcription(
  766. request: Request,
  767. file: UploadFile = File(...),
  768. language: Optional[str] = Form(None),
  769. user=Depends(get_verified_user),
  770. ):
  771. log.info(f"file.content_type: {file.content_type}")
  772. stt_supported_content_types = getattr(
  773. request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
  774. )
  775. if not any(
  776. fnmatch(file.content_type, content_type)
  777. for content_type in (
  778. stt_supported_content_types
  779. if stt_supported_content_types
  780. and any(t.strip() for t in stt_supported_content_types)
  781. else ["audio/*", "video/webm"]
  782. )
  783. ):
  784. raise HTTPException(
  785. status_code=status.HTTP_400_BAD_REQUEST,
  786. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  787. )
  788. try:
  789. ext = file.filename.split(".")[-1]
  790. id = uuid.uuid4()
  791. filename = f"{id}.{ext}"
  792. contents = file.file.read()
  793. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  794. os.makedirs(file_dir, exist_ok=True)
  795. file_path = f"{file_dir}/{filename}"
  796. with open(file_path, "wb") as f:
  797. f.write(contents)
  798. try:
  799. metadata = None
  800. if language:
  801. metadata = {"language": language}
  802. result = transcribe(request, file_path, metadata)
  803. return {
  804. **result,
  805. "filename": os.path.basename(file_path),
  806. }
  807. except Exception as e:
  808. log.exception(e)
  809. raise HTTPException(
  810. status_code=status.HTTP_400_BAD_REQUEST,
  811. detail=ERROR_MESSAGES.DEFAULT(e),
  812. )
  813. except Exception as e:
  814. log.exception(e)
  815. raise HTTPException(
  816. status_code=status.HTTP_400_BAD_REQUEST,
  817. detail=ERROR_MESSAGES.DEFAULT(e),
  818. )
  819. def get_available_models(request: Request) -> list[dict]:
  820. available_models = []
  821. if request.app.state.config.TTS_ENGINE == "openai":
  822. # Use custom endpoint if not using the official OpenAI API URL
  823. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  824. "https://api.openai.com"
  825. ):
  826. try:
  827. response = requests.get(
  828. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
  829. )
  830. response.raise_for_status()
  831. data = response.json()
  832. available_models = data.get("models", [])
  833. except Exception as e:
  834. log.error(f"Error fetching models from custom endpoint: {str(e)}")
  835. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  836. else:
  837. available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  838. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  839. try:
  840. response = requests.get(
  841. "https://api.elevenlabs.io/v1/models",
  842. headers={
  843. "xi-api-key": request.app.state.config.TTS_API_KEY,
  844. "Content-Type": "application/json",
  845. },
  846. timeout=5,
  847. )
  848. response.raise_for_status()
  849. models = response.json()
  850. available_models = [
  851. {"name": model["name"], "id": model["model_id"]} for model in models
  852. ]
  853. except requests.RequestException as e:
  854. log.error(f"Error fetching voices: {str(e)}")
  855. return available_models
  856. @router.get("/models")
  857. async def get_models(request: Request, user=Depends(get_verified_user)):
  858. return {"models": get_available_models(request)}
  859. def get_available_voices(request) -> dict:
  860. """Returns {voice_id: voice_name} dict"""
  861. available_voices = {}
  862. if request.app.state.config.TTS_ENGINE == "openai":
  863. # Use custom endpoint if not using the official OpenAI API URL
  864. if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
  865. "https://api.openai.com"
  866. ):
  867. try:
  868. response = requests.get(
  869. f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
  870. )
  871. response.raise_for_status()
  872. data = response.json()
  873. voices_list = data.get("voices", [])
  874. available_voices = {voice["id"]: voice["name"] for voice in voices_list}
  875. except Exception as e:
  876. log.error(f"Error fetching voices from custom endpoint: {str(e)}")
  877. available_voices = {
  878. "alloy": "alloy",
  879. "echo": "echo",
  880. "fable": "fable",
  881. "onyx": "onyx",
  882. "nova": "nova",
  883. "shimmer": "shimmer",
  884. }
  885. else:
  886. available_voices = {
  887. "alloy": "alloy",
  888. "echo": "echo",
  889. "fable": "fable",
  890. "onyx": "onyx",
  891. "nova": "nova",
  892. "shimmer": "shimmer",
  893. }
  894. elif request.app.state.config.TTS_ENGINE == "elevenlabs":
  895. try:
  896. available_voices = get_elevenlabs_voices(
  897. api_key=request.app.state.config.TTS_API_KEY
  898. )
  899. except Exception:
  900. # Avoided @lru_cache with exception
  901. pass
  902. elif request.app.state.config.TTS_ENGINE == "azure":
  903. try:
  904. region = request.app.state.config.TTS_AZURE_SPEECH_REGION
  905. base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
  906. url = (
  907. base_url or f"https://{region}.tts.speech.microsoft.com"
  908. ) + "/cognitiveservices/voices/list"
  909. headers = {
  910. "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
  911. }
  912. response = requests.get(url, headers=headers)
  913. response.raise_for_status()
  914. voices = response.json()
  915. for voice in voices:
  916. available_voices[voice["ShortName"]] = (
  917. f"{voice['DisplayName']} ({voice['ShortName']})"
  918. )
  919. except requests.RequestException as e:
  920. log.error(f"Error fetching voices: {str(e)}")
  921. return available_voices
  922. @lru_cache
  923. def get_elevenlabs_voices(api_key: str) -> dict:
  924. """
  925. Note, set the following in your .env file to use Elevenlabs:
  926. AUDIO_TTS_ENGINE=elevenlabs
  927. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  928. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  929. AUDIO_TTS_MODEL=eleven_multilingual_v2
  930. """
  931. try:
  932. # TODO: Add retries
  933. response = requests.get(
  934. "https://api.elevenlabs.io/v1/voices",
  935. headers={
  936. "xi-api-key": api_key,
  937. "Content-Type": "application/json",
  938. },
  939. )
  940. response.raise_for_status()
  941. voices_data = response.json()
  942. voices = {}
  943. for voice in voices_data.get("voices", []):
  944. voices[voice["voice_id"]] = voice["name"]
  945. except requests.RequestException as e:
  946. # Avoid @lru_cache with exception
  947. log.error(f"Error fetching voices: {str(e)}")
  948. raise RuntimeError(f"Error fetching voices: {str(e)}")
  949. return voices
  950. @router.get("/voices")
  951. async def get_voices(request: Request, user=Depends(get_verified_user)):
  952. return {
  953. "voices": [
  954. {"id": k, "name": v} for k, v in get_available_voices(request).items()
  955. ]
  956. }