1
0

main.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. import logging
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. decode_token,
  18. get_current_user,
  19. get_verified_user,
  20. get_admin_user,
  21. )
  22. from utils.misc import calculate_sha256
  23. from config import (
  24. SRC_LOG_LEVELS,
  25. CACHE_DIR,
  26. UPLOAD_DIR,
  27. WHISPER_MODEL,
  28. WHISPER_MODEL_DIR,
  29. DEVICE_TYPE,
  30. )
  31. log = logging.getLogger(__name__)
  32. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  33. whisper_device_type = DEVICE_TYPE
  34. if whisper_device_type != "cuda":
  35. whisper_device_type = "cpu"
  36. log.info(f"whisper_device_type: {whisper_device_type}")
  37. app = FastAPI()
  38. app.add_middleware(
  39. CORSMiddleware,
  40. allow_origins=["*"],
  41. allow_credentials=True,
  42. allow_methods=["*"],
  43. allow_headers=["*"],
  44. )
  45. @app.post("/transcribe")
  46. def transcribe(
  47. file: UploadFile = File(...),
  48. user=Depends(get_current_user),
  49. ):
  50. log.info(f"file.content_type: {file.content_type}")
  51. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  52. raise HTTPException(
  53. status_code=status.HTTP_400_BAD_REQUEST,
  54. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  55. )
  56. try:
  57. filename = file.filename
  58. file_path = f"{UPLOAD_DIR}/{filename}"
  59. contents = file.file.read()
  60. with open(file_path, "wb") as f:
  61. f.write(contents)
  62. f.close()
  63. model = WhisperModel(
  64. WHISPER_MODEL,
  65. device=whisper_device_type,
  66. compute_type="int8",
  67. download_root=WHISPER_MODEL_DIR,
  68. )
  69. segments, info = model.transcribe(file_path, beam_size=5)
  70. log.info(
  71. "Detected language '%s' with probability %f"
  72. % (info.language, info.language_probability)
  73. )
  74. transcript = "".join([segment.text for segment in list(segments)])
  75. return {"text": transcript.strip()}
  76. except Exception as e:
  77. log.exception(e)
  78. raise HTTPException(
  79. status_code=status.HTTP_400_BAD_REQUEST,
  80. detail=ERROR_MESSAGES.DEFAULT(e),
  81. )