auths.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import logging
  2. from authlib.integrations.starlette_client import OAuth
  3. from authlib.oidc.core import UserInfo
  4. from fastapi import Request, UploadFile, File
  5. from fastapi import Depends, HTTPException, status
  6. from fastapi import APIRouter
  7. from pydantic import BaseModel
  8. import re
  9. import uuid
  10. import csv
  11. from starlette.responses import RedirectResponse
  12. from apps.webui.models.auths import (
  13. SigninForm,
  14. SignupForm,
  15. AddUserForm,
  16. UpdateProfileForm,
  17. UpdatePasswordForm,
  18. UserResponse,
  19. SigninResponse,
  20. Auths,
  21. ApiKey,
  22. )
  23. from apps.webui.models.users import Users
  24. from utils.utils import (
  25. get_password_hash,
  26. get_current_user,
  27. get_admin_user,
  28. create_token,
  29. create_api_key,
  30. )
  31. from utils.misc import parse_duration, validate_email_format
  32. from utils.webhook import post_webhook
  33. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  34. from config import (
  35. WEBUI_AUTH,
  36. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  37. OAUTH_PROVIDERS,
  38. ENABLE_OAUTH_SIGNUP,
  39. )
  40. router = APIRouter()
  41. ############################
  42. # GetSessionUser
  43. ############################
  44. @router.get("/", response_model=UserResponse)
  45. async def get_session_user(user=Depends(get_current_user)):
  46. return {
  47. "id": user.id,
  48. "email": user.email,
  49. "name": user.name,
  50. "role": user.role,
  51. "profile_image_url": user.profile_image_url,
  52. }
  53. ############################
  54. # Update Profile
  55. ############################
  56. @router.post("/update/profile", response_model=UserResponse)
  57. async def update_profile(
  58. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  59. ):
  60. if session_user:
  61. user = Users.update_user_by_id(
  62. session_user.id,
  63. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  64. )
  65. if user:
  66. return user
  67. else:
  68. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  69. else:
  70. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  71. ############################
  72. # Update Password
  73. ############################
  74. @router.post("/update/password", response_model=bool)
  75. async def update_password(
  76. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  77. ):
  78. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  79. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  80. if session_user:
  81. user = Auths.authenticate_user(session_user.email, form_data.password)
  82. if user:
  83. hashed = get_password_hash(form_data.new_password)
  84. return Auths.update_user_password_by_id(user.id, hashed)
  85. else:
  86. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  87. else:
  88. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  89. ############################
  90. # SignIn
  91. ############################
  92. @router.post("/signin", response_model=SigninResponse)
  93. async def signin(request: Request, form_data: SigninForm):
  94. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  95. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  96. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  97. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  98. if not Users.get_user_by_email(trusted_email.lower()):
  99. await signup(
  100. request,
  101. SignupForm(
  102. email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
  103. ),
  104. )
  105. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  106. elif WEBUI_AUTH == False:
  107. admin_email = "admin@localhost"
  108. admin_password = "admin"
  109. if Users.get_user_by_email(admin_email.lower()):
  110. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  111. else:
  112. if Users.get_num_users() != 0:
  113. raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
  114. await signup(
  115. request,
  116. SignupForm(email=admin_email, password=admin_password, name="User"),
  117. )
  118. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  119. else:
  120. user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
  121. if user:
  122. token = create_token(
  123. data={"id": user.id},
  124. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  125. )
  126. return {
  127. "token": token,
  128. "token_type": "Bearer",
  129. "id": user.id,
  130. "email": user.email,
  131. "name": user.name,
  132. "role": user.role,
  133. "profile_image_url": user.profile_image_url,
  134. }
  135. else:
  136. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  137. ############################
  138. # SignUp
  139. ############################
  140. @router.post("/signup", response_model=SigninResponse)
  141. async def signup(request: Request, form_data: SignupForm):
  142. if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
  143. raise HTTPException(
  144. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  145. )
  146. if not validate_email_format(form_data.email.lower()):
  147. raise HTTPException(
  148. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  149. )
  150. if Users.get_user_by_email(form_data.email.lower()):
  151. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  152. try:
  153. role = (
  154. "admin"
  155. if Users.get_num_users() == 0
  156. else request.app.state.config.DEFAULT_USER_ROLE
  157. )
  158. hashed = get_password_hash(form_data.password)
  159. user = Auths.insert_new_auth(
  160. form_data.email.lower(),
  161. hashed,
  162. form_data.name,
  163. form_data.profile_image_url,
  164. role,
  165. )
  166. if user:
  167. token = create_token(
  168. data={"id": user.id},
  169. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  170. )
  171. # response.set_cookie(key='token', value=token, httponly=True)
  172. if request.app.state.config.WEBHOOK_URL:
  173. post_webhook(
  174. request.app.state.config.WEBHOOK_URL,
  175. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  176. {
  177. "action": "signup",
  178. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  179. "user": user.model_dump_json(exclude_none=True),
  180. },
  181. )
  182. return {
  183. "token": token,
  184. "token_type": "Bearer",
  185. "id": user.id,
  186. "email": user.email,
  187. "name": user.name,
  188. "role": user.role,
  189. "profile_image_url": user.profile_image_url,
  190. }
  191. else:
  192. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  193. except Exception as err:
  194. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  195. ############################
  196. # AddUser
  197. ############################
  198. @router.post("/add", response_model=SigninResponse)
  199. async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
  200. if not validate_email_format(form_data.email.lower()):
  201. raise HTTPException(
  202. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  203. )
  204. if Users.get_user_by_email(form_data.email.lower()):
  205. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  206. try:
  207. print(form_data)
  208. hashed = get_password_hash(form_data.password)
  209. user = Auths.insert_new_auth(
  210. form_data.email.lower(),
  211. hashed,
  212. form_data.name,
  213. form_data.profile_image_url,
  214. form_data.role,
  215. )
  216. if user:
  217. token = create_token(data={"id": user.id})
  218. return {
  219. "token": token,
  220. "token_type": "Bearer",
  221. "id": user.id,
  222. "email": user.email,
  223. "name": user.name,
  224. "role": user.role,
  225. "profile_image_url": user.profile_image_url,
  226. }
  227. else:
  228. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  229. except Exception as err:
  230. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  231. ############################
  232. # ToggleSignUp
  233. ############################
  234. @router.get("/signup/enabled", response_model=bool)
  235. async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
  236. return request.app.state.config.ENABLE_SIGNUP
  237. @router.get("/signup/enabled/toggle", response_model=bool)
  238. async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
  239. request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
  240. return request.app.state.config.ENABLE_SIGNUP
  241. ############################
  242. # Default User Role
  243. ############################
  244. @router.get("/signup/user/role")
  245. async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
  246. return request.app.state.config.DEFAULT_USER_ROLE
  247. class UpdateRoleForm(BaseModel):
  248. role: str
  249. @router.post("/signup/user/role")
  250. async def update_default_user_role(
  251. request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
  252. ):
  253. if form_data.role in ["pending", "user", "admin"]:
  254. request.app.state.config.DEFAULT_USER_ROLE = form_data.role
  255. return request.app.state.config.DEFAULT_USER_ROLE
  256. ############################
  257. # JWT Expiration
  258. ############################
  259. @router.get("/token/expires")
  260. async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
  261. return request.app.state.config.JWT_EXPIRES_IN
  262. class UpdateJWTExpiresDurationForm(BaseModel):
  263. duration: str
  264. @router.post("/token/expires/update")
  265. async def update_token_expires_duration(
  266. request: Request,
  267. form_data: UpdateJWTExpiresDurationForm,
  268. user=Depends(get_admin_user),
  269. ):
  270. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  271. # Check if the input string matches the pattern
  272. if re.match(pattern, form_data.duration):
  273. request.app.state.config.JWT_EXPIRES_IN = form_data.duration
  274. return request.app.state.config.JWT_EXPIRES_IN
  275. else:
  276. return request.app.state.config.JWT_EXPIRES_IN
  277. ############################
  278. # API Key
  279. ############################
  280. # create api key
  281. @router.post("/api_key", response_model=ApiKey)
  282. async def create_api_key_(user=Depends(get_current_user)):
  283. api_key = create_api_key()
  284. success = Users.update_user_api_key_by_id(user.id, api_key)
  285. if success:
  286. return {
  287. "api_key": api_key,
  288. }
  289. else:
  290. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  291. # delete api key
  292. @router.delete("/api_key", response_model=bool)
  293. async def delete_api_key(user=Depends(get_current_user)):
  294. success = Users.update_user_api_key_by_id(user.id, None)
  295. return success
  296. # get api key
  297. @router.get("/api_key", response_model=ApiKey)
  298. async def get_api_key(user=Depends(get_current_user)):
  299. api_key = Users.get_user_api_key_by_id(user.id)
  300. if api_key:
  301. return {
  302. "api_key": api_key,
  303. }
  304. else:
  305. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  306. ############################
  307. # OAuth Login & Callback
  308. ############################
  309. oauth = OAuth()
  310. for provider_name, provider_config in OAUTH_PROVIDERS.items():
  311. oauth.register(
  312. name=provider_name,
  313. client_id=provider_config["client_id"],
  314. client_secret=provider_config["client_secret"],
  315. server_metadata_url=provider_config["server_metadata_url"],
  316. client_kwargs={
  317. "scope": provider_config["scope"],
  318. },
  319. )
  320. @router.get("/oauth/{provider}/login")
  321. async def oauth_login(provider: str, request: Request):
  322. if provider not in OAUTH_PROVIDERS:
  323. raise HTTPException(404)
  324. redirect_uri = request.url_for("oauth_callback", provider=provider)
  325. return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
  326. @router.get("/oauth/{provider}/callback")
  327. async def oauth_callback(provider: str, request: Request):
  328. if provider not in OAUTH_PROVIDERS:
  329. raise HTTPException(404)
  330. client = oauth.create_client(provider)
  331. token = await client.authorize_access_token(request)
  332. user_data: UserInfo = token["userinfo"]
  333. sub = user_data.get("sub")
  334. if not sub:
  335. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  336. provider_sub = f"{provider}@{sub}"
  337. # Check if the user exists
  338. user = Users.get_user_by_oauth_sub(provider_sub)
  339. if not user:
  340. # If the user does not exist, create a new user if signup is enabled
  341. if ENABLE_OAUTH_SIGNUP.value:
  342. user = Auths.insert_new_auth(
  343. email=user_data.get("email", "").lower(),
  344. password=get_password_hash(
  345. str(uuid.uuid4())
  346. ), # Random password, not used
  347. name=user_data.get("name", "User"),
  348. profile_image_url=user_data.get("picture", "/user.png"),
  349. role=request.app.state.config.DEFAULT_USER_ROLE,
  350. oauth_sub=provider_sub,
  351. )
  352. if request.app.state.config.WEBHOOK_URL:
  353. post_webhook(
  354. request.app.state.config.WEBHOOK_URL,
  355. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  356. {
  357. "action": "signup",
  358. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  359. "user": user.model_dump_json(exclude_none=True),
  360. },
  361. )
  362. else:
  363. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  364. jwt_token = create_token(
  365. data={"id": user.id},
  366. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  367. )
  368. # Redirect back to the frontend with the JWT token
  369. redirect_url = f"{request.base_url}auth#token={jwt_token}"
  370. return RedirectResponse(url=redirect_url)