main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. import os
  2. import logging
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from pydantic import BaseModel
  16. import uuid
  17. import requests
  18. import hashlib
  19. from pathlib import Path
  20. import json
  21. from constants import ERROR_MESSAGES
  22. from utils.utils import (
  23. decode_token,
  24. get_current_user,
  25. get_verified_user,
  26. get_admin_user,
  27. )
  28. from utils.misc import calculate_sha256
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. UPLOAD_DIR,
  33. WHISPER_MODEL,
  34. WHISPER_MODEL_DIR,
  35. WHISPER_MODEL_AUTO_UPDATE,
  36. DEVICE_TYPE,
  37. AUDIO_STT_OPENAI_API_BASE_URL,
  38. AUDIO_STT_OPENAI_API_KEY,
  39. AUDIO_TTS_OPENAI_API_BASE_URL,
  40. AUDIO_TTS_OPENAI_API_KEY,
  41. AUDIO_TTS_API_KEY,
  42. AUDIO_STT_ENGINE,
  43. AUDIO_STT_MODEL,
  44. AUDIO_TTS_ENGINE,
  45. AUDIO_TTS_MODEL,
  46. AUDIO_TTS_VOICE,
  47. AppConfig,
  48. )
  49. log = logging.getLogger(__name__)
  50. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  51. app = FastAPI()
  52. app.add_middleware(
  53. CORSMiddleware,
  54. allow_origins=["*"],
  55. allow_credentials=True,
  56. allow_methods=["*"],
  57. allow_headers=["*"],
  58. )
  59. app.state.config = AppConfig()
  60. app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
  61. app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
  62. app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
  63. app.state.config.STT_MODEL = AUDIO_STT_MODEL
  64. app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
  65. app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
  66. app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
  67. app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
  68. app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
  69. app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
  70. # setting device type for whisper model
  71. whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
  72. log.info(f"whisper_device_type: {whisper_device_type}")
  73. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  74. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  75. class TTSConfigForm(BaseModel):
  76. OPENAI_API_BASE_URL: str
  77. OPENAI_API_KEY: str
  78. API_KEY: str
  79. ENGINE: str
  80. MODEL: str
  81. VOICE: str
  82. class STTConfigForm(BaseModel):
  83. OPENAI_API_BASE_URL: str
  84. OPENAI_API_KEY: str
  85. ENGINE: str
  86. MODEL: str
  87. class AudioConfigUpdateForm(BaseModel):
  88. tts: TTSConfigForm
  89. stt: STTConfigForm
  90. from pydub import AudioSegment
  91. from pydub.utils import mediainfo
  92. def is_mp4_audio(file_path):
  93. """Check if the given file is an MP4 audio file."""
  94. if not os.path.isfile(file_path):
  95. print(f"File not found: {file_path}")
  96. return False
  97. info = mediainfo(file_path)
  98. if (
  99. info.get("codec_name") == "aac"
  100. and info.get("codec_type") == "audio"
  101. and info.get("codec_tag_string") == "mp4a"
  102. ):
  103. return True
  104. return False
  105. def convert_mp4_to_wav(file_path, output_path):
  106. """Convert MP4 audio file to WAV format."""
  107. audio = AudioSegment.from_file(file_path, format="mp4")
  108. audio.export(output_path, format="wav")
  109. print(f"Converted {file_path} to {output_path}")
  110. @app.get("/config")
  111. async def get_audio_config(user=Depends(get_admin_user)):
  112. return {
  113. "tts": {
  114. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  115. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  116. "API_KEY": app.state.config.TTS_API_KEY,
  117. "ENGINE": app.state.config.TTS_ENGINE,
  118. "MODEL": app.state.config.TTS_MODEL,
  119. "VOICE": app.state.config.TTS_VOICE,
  120. },
  121. "stt": {
  122. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  123. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  124. "ENGINE": app.state.config.STT_ENGINE,
  125. "MODEL": app.state.config.STT_MODEL,
  126. },
  127. }
  128. @app.post("/config/update")
  129. async def update_audio_config(
  130. form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  131. ):
  132. app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  133. app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  134. app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  135. app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  136. app.state.config.TTS_MODEL = form_data.tts.MODEL
  137. app.state.config.TTS_VOICE = form_data.tts.VOICE
  138. app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  139. app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  140. app.state.config.STT_ENGINE = form_data.stt.ENGINE
  141. app.state.config.STT_MODEL = form_data.stt.MODEL
  142. return {
  143. "tts": {
  144. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  145. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  146. "API_KEY": app.state.config.TTS_API_KEY,
  147. "ENGINE": app.state.config.TTS_ENGINE,
  148. "MODEL": app.state.config.TTS_MODEL,
  149. "VOICE": app.state.config.TTS_VOICE,
  150. },
  151. "stt": {
  152. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  153. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  154. "ENGINE": app.state.config.STT_ENGINE,
  155. "MODEL": app.state.config.STT_MODEL,
  156. },
  157. }
  158. @app.post("/speech")
  159. async def speech(request: Request, user=Depends(get_verified_user)):
  160. body = await request.body()
  161. name = hashlib.sha256(body).hexdigest()
  162. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  163. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  164. # Check if the file already exists in the cache
  165. if file_path.is_file():
  166. return FileResponse(file_path)
  167. if app.state.config.TTS_ENGINE == "openai":
  168. headers = {}
  169. headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
  170. headers["Content-Type"] = "application/json"
  171. try:
  172. body = body.decode("utf-8")
  173. body = json.loads(body)
  174. body["model"] = app.state.config.TTS_MODEL
  175. body = json.dumps(body).encode("utf-8")
  176. except Exception as e:
  177. pass
  178. r = None
  179. try:
  180. r = requests.post(
  181. url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  182. data=body,
  183. headers=headers,
  184. stream=True,
  185. )
  186. r.raise_for_status()
  187. # Save the streaming content to a file
  188. with open(file_path, "wb") as f:
  189. for chunk in r.iter_content(chunk_size=8192):
  190. f.write(chunk)
  191. with open(file_body_path, "w") as f:
  192. json.dump(json.loads(body.decode("utf-8")), f)
  193. # Return the saved file
  194. return FileResponse(file_path)
  195. except Exception as e:
  196. log.exception(e)
  197. error_detail = "Open WebUI: Server Connection Error"
  198. if r is not None:
  199. try:
  200. res = r.json()
  201. if "error" in res:
  202. error_detail = f"External: {res['error']['message']}"
  203. except:
  204. error_detail = f"External: {e}"
  205. raise HTTPException(
  206. status_code=r.status_code if r != None else 500,
  207. detail=error_detail,
  208. )
  209. elif app.state.config.TTS_ENGINE == "elevenlabs":
  210. payload = None
  211. try:
  212. payload = json.loads(body.decode("utf-8"))
  213. except Exception as e:
  214. log.exception(e)
  215. pass
  216. url = f"https://api.elevenlabs.io/v1/text-to-speech/{payload['voice']}"
  217. headers = {
  218. "Accept": "audio/mpeg",
  219. "Content-Type": "application/json",
  220. "xi-api-key": app.state.config.TTS_API_KEY,
  221. }
  222. data = {
  223. "text": payload["input"],
  224. "model_id": app.state.config.TTS_MODEL,
  225. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  226. }
  227. try:
  228. r = requests.post(url, json=data, headers=headers)
  229. r.raise_for_status()
  230. # Save the streaming content to a file
  231. with open(file_path, "wb") as f:
  232. for chunk in r.iter_content(chunk_size=8192):
  233. f.write(chunk)
  234. with open(file_body_path, "w") as f:
  235. json.dump(json.loads(body.decode("utf-8")), f)
  236. # Return the saved file
  237. return FileResponse(file_path)
  238. except Exception as e:
  239. log.exception(e)
  240. error_detail = "Open WebUI: Server Connection Error"
  241. if r is not None:
  242. try:
  243. res = r.json()
  244. if "error" in res:
  245. error_detail = f"External: {res['error']['message']}"
  246. except:
  247. error_detail = f"External: {e}"
  248. raise HTTPException(
  249. status_code=r.status_code if r != None else 500,
  250. detail=error_detail,
  251. )
  252. @app.post("/transcriptions")
  253. def transcribe(
  254. file: UploadFile = File(...),
  255. user=Depends(get_current_user),
  256. ):
  257. log.info(f"file.content_type: {file.content_type}")
  258. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  259. raise HTTPException(
  260. status_code=status.HTTP_400_BAD_REQUEST,
  261. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  262. )
  263. try:
  264. ext = file.filename.split(".")[-1]
  265. id = uuid.uuid4()
  266. filename = f"{id}.{ext}"
  267. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  268. os.makedirs(file_dir, exist_ok=True)
  269. file_path = f"{file_dir}/{filename}"
  270. print(filename)
  271. contents = file.file.read()
  272. with open(file_path, "wb") as f:
  273. f.write(contents)
  274. f.close()
  275. if app.state.config.STT_ENGINE == "":
  276. from faster_whisper import WhisperModel
  277. whisper_kwargs = {
  278. "model_size_or_path": WHISPER_MODEL,
  279. "device": whisper_device_type,
  280. "compute_type": "int8",
  281. "download_root": WHISPER_MODEL_DIR,
  282. "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
  283. }
  284. log.debug(f"whisper_kwargs: {whisper_kwargs}")
  285. try:
  286. model = WhisperModel(**whisper_kwargs)
  287. except:
  288. log.warning(
  289. "WhisperModel initialization failed, attempting download with local_files_only=False"
  290. )
  291. whisper_kwargs["local_files_only"] = False
  292. model = WhisperModel(**whisper_kwargs)
  293. segments, info = model.transcribe(file_path, beam_size=5)
  294. log.info(
  295. "Detected language '%s' with probability %f"
  296. % (info.language, info.language_probability)
  297. )
  298. transcript = "".join([segment.text for segment in list(segments)])
  299. data = {"text": transcript.strip()}
  300. # save the transcript to a json file
  301. transcript_file = f"{file_dir}/{id}.json"
  302. with open(transcript_file, "w") as f:
  303. json.dump(data, f)
  304. print(data)
  305. return data
  306. elif app.state.config.STT_ENGINE == "openai":
  307. if is_mp4_audio(file_path):
  308. print("is_mp4_audio")
  309. os.rename(file_path, file_path.replace(".wav", ".mp4"))
  310. # Convert MP4 audio file to WAV format
  311. convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
  312. headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
  313. files = {"file": (filename, open(file_path, "rb"))}
  314. data = {"model": app.state.config.STT_MODEL}
  315. print(files, data)
  316. r = None
  317. try:
  318. r = requests.post(
  319. url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  320. headers=headers,
  321. files=files,
  322. data=data,
  323. )
  324. r.raise_for_status()
  325. data = r.json()
  326. # save the transcript to a json file
  327. transcript_file = f"{file_dir}/{id}.json"
  328. with open(transcript_file, "w") as f:
  329. json.dump(data, f)
  330. print(data)
  331. return data
  332. except Exception as e:
  333. log.exception(e)
  334. error_detail = "Open WebUI: Server Connection Error"
  335. if r is not None:
  336. try:
  337. res = r.json()
  338. if "error" in res:
  339. error_detail = f"External: {res['error']['message']}"
  340. except:
  341. error_detail = f"External: {e}"
  342. raise HTTPException(
  343. status_code=r.status_code if r != None else 500,
  344. detail=error_detail,
  345. )
  346. except Exception as e:
  347. log.exception(e)
  348. raise HTTPException(
  349. status_code=status.HTTP_400_BAD_REQUEST,
  350. detail=ERROR_MESSAGES.DEFAULT(e),
  351. )