auths.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. import logging
  2. from fastapi import Request, UploadFile, File
  3. from fastapi import Depends, HTTPException, status
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import re
  7. import uuid
  8. import csv
  9. from apps.web.models.auths import (
  10. SigninForm,
  11. SignupForm,
  12. AddUserForm,
  13. UpdateProfileForm,
  14. UpdatePasswordForm,
  15. UserResponse,
  16. SigninResponse,
  17. Auths,
  18. ApiKey,
  19. )
  20. from apps.web.models.users import Users
  21. from utils.utils import (
  22. get_password_hash,
  23. get_current_user,
  24. get_admin_user,
  25. create_token,
  26. create_api_key,
  27. )
  28. from utils.misc import parse_duration, validate_email_format
  29. from utils.webhook import post_webhook
  30. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  31. from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  32. router = APIRouter()
  33. ############################
  34. # GetSessionUser
  35. ############################
  36. @router.get("/", response_model=UserResponse)
  37. async def get_session_user(user=Depends(get_current_user)):
  38. return {
  39. "id": user.id,
  40. "email": user.email,
  41. "name": user.name,
  42. "role": user.role,
  43. "profile_image_url": user.profile_image_url,
  44. }
  45. ############################
  46. # Update Profile
  47. ############################
  48. @router.post("/update/profile", response_model=UserResponse)
  49. async def update_profile(
  50. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  51. ):
  52. if session_user:
  53. user = Users.update_user_by_id(
  54. session_user.id,
  55. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  56. )
  57. if user:
  58. return user
  59. else:
  60. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  61. else:
  62. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  63. ############################
  64. # Update Password
  65. ############################
  66. @router.post("/update/password", response_model=bool)
  67. async def update_password(
  68. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  69. ):
  70. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  71. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  72. if session_user:
  73. user = Auths.authenticate_user(session_user.email, form_data.password)
  74. if user:
  75. hashed = get_password_hash(form_data.new_password)
  76. return Auths.update_user_password_by_id(user.id, hashed)
  77. else:
  78. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  79. else:
  80. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  81. ############################
  82. # SignIn
  83. ############################
  84. @router.post("/signin", response_model=SigninResponse)
  85. async def signin(request: Request, form_data: SigninForm):
  86. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  87. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  88. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  89. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  90. if not Users.get_user_by_email(trusted_email.lower()):
  91. await signup(
  92. request,
  93. SignupForm(
  94. email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
  95. ),
  96. )
  97. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  98. else:
  99. user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
  100. if user:
  101. token = create_token(
  102. data={"id": user.id},
  103. expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
  104. )
  105. return {
  106. "token": token,
  107. "token_type": "Bearer",
  108. "id": user.id,
  109. "email": user.email,
  110. "name": user.name,
  111. "role": user.role,
  112. "profile_image_url": user.profile_image_url,
  113. }
  114. else:
  115. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  116. ############################
  117. # SignUp
  118. ############################
  119. @router.post("/signup", response_model=SigninResponse)
  120. async def signup(request: Request, form_data: SignupForm):
  121. if not request.app.state.ENABLE_SIGNUP:
  122. raise HTTPException(
  123. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  124. )
  125. if not validate_email_format(form_data.email.lower()):
  126. raise HTTPException(
  127. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  128. )
  129. if Users.get_user_by_email(form_data.email.lower()):
  130. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  131. try:
  132. role = (
  133. "admin"
  134. if Users.get_num_users() == 0
  135. else request.app.state.DEFAULT_USER_ROLE
  136. )
  137. hashed = get_password_hash(form_data.password)
  138. user = Auths.insert_new_auth(
  139. form_data.email.lower(),
  140. hashed,
  141. form_data.name,
  142. form_data.profile_image_url,
  143. role,
  144. )
  145. if user:
  146. token = create_token(
  147. data={"id": user.id},
  148. expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
  149. )
  150. # response.set_cookie(key='token', value=token, httponly=True)
  151. if request.app.state.WEBHOOK_URL:
  152. post_webhook(
  153. request.app.state.WEBHOOK_URL,
  154. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  155. {
  156. "action": "signup",
  157. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  158. "user": user.model_dump_json(exclude_none=True),
  159. },
  160. )
  161. return {
  162. "token": token,
  163. "token_type": "Bearer",
  164. "id": user.id,
  165. "email": user.email,
  166. "name": user.name,
  167. "role": user.role,
  168. "profile_image_url": user.profile_image_url,
  169. }
  170. else:
  171. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  172. except Exception as err:
  173. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  174. ############################
  175. # AddUser
  176. ############################
  177. @router.post("/add", response_model=SigninResponse)
  178. async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
  179. if not validate_email_format(form_data.email.lower()):
  180. raise HTTPException(
  181. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  182. )
  183. if Users.get_user_by_email(form_data.email.lower()):
  184. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  185. try:
  186. print(form_data)
  187. hashed = get_password_hash(form_data.password)
  188. user = Auths.insert_new_auth(
  189. form_data.email.lower(),
  190. hashed,
  191. form_data.name,
  192. form_data.profile_image_url,
  193. form_data.role,
  194. )
  195. if user:
  196. token = create_token(data={"id": user.id})
  197. return {
  198. "token": token,
  199. "token_type": "Bearer",
  200. "id": user.id,
  201. "email": user.email,
  202. "name": user.name,
  203. "role": user.role,
  204. "profile_image_url": user.profile_image_url,
  205. }
  206. else:
  207. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  208. except Exception as err:
  209. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  210. @router.post("/add/import", response_model=SigninResponse)
  211. async def add_user_csv_import(
  212. file: UploadFile = File(...), user=Depends(get_admin_user)
  213. ):
  214. try:
  215. # Check if the file is a CSV file
  216. if file.filename.endswith(".csv"):
  217. # Read the contents of the CSV file
  218. contents = await file.read()
  219. # Decode the contents from bytes to string
  220. decoded_content = contents.decode("utf-8")
  221. # Split the CSV content into lines
  222. csv_lines = decoded_content.split("\n")
  223. # Parse CSV
  224. csv_data = []
  225. csv_reader = csv.reader(csv_lines)
  226. for row in csv_reader:
  227. csv_data.append(row)
  228. # Print the CSV data
  229. for row in csv_data:
  230. print(row)
  231. return {"message": "CSV file uploaded successfully."}
  232. else:
  233. raise "File must be a CSV."
  234. except Exception as err:
  235. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  236. # if not validate_email_format(form_data.email.lower()):
  237. # raise HTTPException(
  238. # status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  239. # )
  240. # if Users.get_user_by_email(form_data.email.lower()):
  241. # raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  242. # try:
  243. # print(form_data)
  244. # hashed = get_password_hash(form_data.password)
  245. # user = Auths.insert_new_auth(
  246. # form_data.email.lower(),
  247. # hashed,
  248. # form_data.name,
  249. # form_data.profile_image_url,
  250. # form_data.role,
  251. # )
  252. # if user:
  253. # token = create_token(data={"id": user.id})
  254. # return {
  255. # "token": token,
  256. # "token_type": "Bearer",
  257. # "id": user.id,
  258. # "email": user.email,
  259. # "name": user.name,
  260. # "role": user.role,
  261. # "profile_image_url": user.profile_image_url,
  262. # }
  263. # else:
  264. # raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  265. # except Exception as err:
  266. # raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  267. ############################
  268. # ToggleSignUp
  269. ############################
  270. @router.get("/signup/enabled", response_model=bool)
  271. async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
  272. return request.app.state.ENABLE_SIGNUP
  273. @router.get("/signup/enabled/toggle", response_model=bool)
  274. async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
  275. request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
  276. return request.app.state.ENABLE_SIGNUP
  277. ############################
  278. # Default User Role
  279. ############################
  280. @router.get("/signup/user/role")
  281. async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
  282. return request.app.state.DEFAULT_USER_ROLE
  283. class UpdateRoleForm(BaseModel):
  284. role: str
  285. @router.post("/signup/user/role")
  286. async def update_default_user_role(
  287. request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
  288. ):
  289. if form_data.role in ["pending", "user", "admin"]:
  290. request.app.state.DEFAULT_USER_ROLE = form_data.role
  291. return request.app.state.DEFAULT_USER_ROLE
  292. ############################
  293. # JWT Expiration
  294. ############################
  295. @router.get("/token/expires")
  296. async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
  297. return request.app.state.JWT_EXPIRES_IN
  298. class UpdateJWTExpiresDurationForm(BaseModel):
  299. duration: str
  300. @router.post("/token/expires/update")
  301. async def update_token_expires_duration(
  302. request: Request,
  303. form_data: UpdateJWTExpiresDurationForm,
  304. user=Depends(get_admin_user),
  305. ):
  306. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  307. # Check if the input string matches the pattern
  308. if re.match(pattern, form_data.duration):
  309. request.app.state.JWT_EXPIRES_IN = form_data.duration
  310. return request.app.state.JWT_EXPIRES_IN
  311. else:
  312. return request.app.state.JWT_EXPIRES_IN
  313. ############################
  314. # API Key
  315. ############################
  316. # create api key
  317. @router.post("/api_key", response_model=ApiKey)
  318. async def create_api_key_(user=Depends(get_current_user)):
  319. api_key = create_api_key()
  320. success = Users.update_user_api_key_by_id(user.id, api_key)
  321. if success:
  322. return {
  323. "api_key": api_key,
  324. }
  325. else:
  326. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  327. # delete api key
  328. @router.delete("/api_key", response_model=bool)
  329. async def delete_api_key(user=Depends(get_current_user)):
  330. success = Users.update_user_api_key_by_id(user.id, None)
  331. return success
  332. # get api key
  333. @router.get("/api_key", response_model=ApiKey)
  334. async def get_api_key(user=Depends(get_current_user)):
  335. api_key = Users.get_user_api_key_by_id(user.id)
  336. if api_key:
  337. return {
  338. "api_key": api_key,
  339. }
  340. else:
  341. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)