1
0

oauth.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855
  1. import base64
  2. import logging
  3. import mimetypes
  4. import sys
  5. import uuid
  6. import json
  7. from datetime import datetime, timedelta
  8. import re
  9. import fnmatch
  10. import time
  11. import aiohttp
  12. from authlib.integrations.starlette_client import OAuth
  13. from authlib.oidc.core import UserInfo
  14. from fastapi import (
  15. HTTPException,
  16. status,
  17. )
  18. from starlette.responses import RedirectResponse
  19. from open_webui.models.auths import Auths
  20. from open_webui.models.oauth_sessions import OAuthSessions
  21. from open_webui.models.users import Users
  22. from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
  23. from open_webui.config import (
  24. DEFAULT_USER_ROLE,
  25. ENABLE_OAUTH_SIGNUP,
  26. OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
  27. OAUTH_PROVIDERS,
  28. ENABLE_OAUTH_ROLE_MANAGEMENT,
  29. ENABLE_OAUTH_GROUP_MANAGEMENT,
  30. ENABLE_OAUTH_GROUP_CREATION,
  31. OAUTH_BLOCKED_GROUPS,
  32. OAUTH_ROLES_CLAIM,
  33. OAUTH_SUB_CLAIM,
  34. OAUTH_GROUPS_CLAIM,
  35. OAUTH_EMAIL_CLAIM,
  36. OAUTH_PICTURE_CLAIM,
  37. OAUTH_USERNAME_CLAIM,
  38. OAUTH_ALLOWED_ROLES,
  39. OAUTH_ADMIN_ROLES,
  40. OAUTH_ALLOWED_DOMAINS,
  41. OAUTH_UPDATE_PICTURE_ON_LOGIN,
  42. WEBHOOK_URL,
  43. JWT_EXPIRES_IN,
  44. AppConfig,
  45. )
  46. from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  47. from open_webui.env import (
  48. AIOHTTP_CLIENT_SESSION_SSL,
  49. WEBUI_NAME,
  50. WEBUI_AUTH_COOKIE_SAME_SITE,
  51. WEBUI_AUTH_COOKIE_SECURE,
  52. ENABLE_OAUTH_ID_TOKEN_COOKIE,
  53. )
  54. from open_webui.utils.misc import parse_duration
  55. from open_webui.utils.auth import get_password_hash, create_token
  56. from open_webui.utils.webhook import post_webhook
  57. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  58. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  59. log = logging.getLogger(__name__)
  60. log.setLevel(SRC_LOG_LEVELS["OAUTH"])
  61. auth_manager_config = AppConfig()
  62. auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  63. auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
  64. auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
  65. auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
  66. auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
  67. auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
  68. auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
  69. auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
  70. auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM
  71. auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
  72. auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
  73. auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  74. auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  75. auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
  76. auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
  77. auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS
  78. auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
  79. auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  80. auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
  81. def is_in_blocked_groups(group_name: str, groups: list) -> bool:
  82. """
  83. Check if a group name matches any blocked pattern.
  84. Supports exact matches, shell-style wildcards (*, ?), and regex patterns.
  85. Args:
  86. group_name: The group name to check
  87. groups: List of patterns to match against
  88. Returns:
  89. True if the group is blocked, False otherwise
  90. """
  91. if not groups:
  92. return False
  93. for group_pattern in groups:
  94. if not group_pattern: # Skip empty patterns
  95. continue
  96. # Exact match
  97. if group_name == group_pattern:
  98. return True
  99. # Try as regex pattern first if it contains regex-specific characters
  100. if any(
  101. char in group_pattern
  102. for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"]
  103. ):
  104. try:
  105. # Use the original pattern as-is for regex matching
  106. if re.search(group_pattern, group_name):
  107. return True
  108. except re.error:
  109. # If regex is invalid, fall through to wildcard check
  110. pass
  111. # Shell-style wildcard match (supports * and ?)
  112. if "*" in group_pattern or "?" in group_pattern:
  113. if fnmatch.fnmatch(group_name, group_pattern):
  114. return True
  115. return False
  116. class OAuthManager:
  117. def __init__(self, app):
  118. self.oauth = OAuth()
  119. self.app = app
  120. self._clients = {}
  121. for _, provider_config in OAUTH_PROVIDERS.items():
  122. provider_config["register"](self.oauth)
  123. def get_client(self, provider_name):
  124. if provider_name not in self._clients:
  125. self._clients[provider_name] = self.oauth.create_client(provider_name)
  126. return self._clients[provider_name]
  127. def get_server_metadata_url(self, provider_name):
  128. if provider_name in self._clients:
  129. client = self._clients[provider_name]
  130. return (
  131. client.server_metadata_url
  132. if hasattr(client, "server_metadata_url")
  133. else None
  134. )
  135. return None
  136. def get_oauth_token(
  137. self, user_id: str, session_id: str, force_refresh: bool = False
  138. ):
  139. """
  140. Get a valid OAuth token for the user, automatically refreshing if needed.
  141. Args:
  142. user_id: The user ID
  143. provider: Optional provider name. If None, gets the most recent session.
  144. force_refresh: Force token refresh even if current token appears valid
  145. Returns:
  146. dict: OAuth token data with access_token, or None if no valid token available
  147. """
  148. try:
  149. # Get the OAuth session
  150. session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
  151. if not session:
  152. log.warning(
  153. f"No OAuth session found for user {user_id}, session {session_id}"
  154. )
  155. return None
  156. if force_refresh or datetime.now() + timedelta(
  157. minutes=5
  158. ) >= datetime.fromtimestamp(session.expires_at):
  159. log.debug(
  160. f"Token refresh needed for user {user_id}, provider {session.provider}"
  161. )
  162. refreshed_token = self._refresh_token(session)
  163. if refreshed_token:
  164. return refreshed_token
  165. else:
  166. log.warning(
  167. f"Token refresh failed for user {user_id}, provider {session.provider}"
  168. )
  169. return None
  170. return session.token
  171. except Exception as e:
  172. log.error(f"Error getting OAuth token for user {user_id}: {e}")
  173. return None
  174. async def _refresh_token(self, session) -> dict:
  175. """
  176. Refresh an OAuth token if needed, with concurrency protection.
  177. Args:
  178. session: The OAuth session object
  179. Returns:
  180. dict: Refreshed token data, or None if refresh failed
  181. """
  182. try:
  183. # Perform the actual refresh
  184. refreshed_token = await self._perform_token_refresh(session)
  185. if refreshed_token:
  186. # Update the session with new token data
  187. session = OAuthSessions.update_session_by_id(
  188. session.id, refreshed_token
  189. )
  190. log.info(f"Successfully refreshed token for session {session.id}")
  191. return session.token
  192. else:
  193. log.error(f"Failed to refresh token for session {session.id}")
  194. return None
  195. except Exception as e:
  196. log.error(f"Error refreshing token for session {session.id}: {e}")
  197. return None
  198. async def _perform_token_refresh(self, session) -> dict:
  199. """
  200. Perform the actual OAuth token refresh.
  201. Args:
  202. session: The OAuth session object
  203. Returns:
  204. dict: New token data, or None if refresh failed
  205. """
  206. provider = session.provider
  207. token_data = session.token
  208. if not token_data.get("refresh_token"):
  209. log.warning(f"No refresh token available for session {session.id}")
  210. return None
  211. try:
  212. client = self.get_client(provider)
  213. if not client:
  214. log.error(f"No OAuth client found for provider {provider}")
  215. return None
  216. token_endpoint = None
  217. async with aiohttp.ClientSession(trust_env=True) as session_http:
  218. async with session_http.get(client.gserver_metadata_url) as r:
  219. if r.status == 200:
  220. openid_data = await r.json()
  221. token_endpoint = openid_data.get("token_endpoint")
  222. else:
  223. log.error(
  224. f"Failed to fetch OpenID configuration for provider {provider}"
  225. )
  226. if not token_endpoint:
  227. log.error(f"No token endpoint found for provider {provider}")
  228. return None
  229. # Prepare refresh request
  230. refresh_data = {
  231. "grant_type": "refresh_token",
  232. "refresh_token": token_data["refresh_token"],
  233. "client_id": client.client_id,
  234. }
  235. # Add client_secret if available (some providers require it)
  236. if hasattr(client, "client_secret") and client.client_secret:
  237. refresh_data["client_secret"] = client.client_secret
  238. # Make refresh request
  239. async with aiohttp.ClientSession(trust_env=True) as session_http:
  240. async with session_http.post(
  241. token_endpoint,
  242. data=refresh_data,
  243. headers={"Content-Type": "application/x-www-form-urlencoded"},
  244. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  245. ) as r:
  246. if r.status == 200:
  247. new_token_data = await r.json()
  248. # Merge with existing token data (preserve refresh_token if not provided)
  249. if "refresh_token" not in new_token_data:
  250. new_token_data["refresh_token"] = token_data[
  251. "refresh_token"
  252. ]
  253. # Add timestamp for tracking
  254. new_token_data["issued_at"] = datetime.now().timestamp()
  255. # Calculate expires_at if we have expires_in
  256. if (
  257. "expires_in" in new_token_data
  258. and "expires_at" not in new_token_data
  259. ):
  260. new_token_data["expires_at"] = (
  261. datetime.now().timestamp()
  262. + new_token_data["expires_in"]
  263. )
  264. log.debug(f"Token refresh successful for provider {provider}")
  265. return new_token_data
  266. else:
  267. error_text = await r.text()
  268. log.error(
  269. f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
  270. )
  271. return None
  272. except Exception as e:
  273. log.error(f"Exception during token refresh for provider {provider}: {e}")
  274. return None
  275. def get_user_role(self, user, user_data):
  276. user_count = Users.get_num_users()
  277. if user and user_count == 1:
  278. # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
  279. log.debug("Assigning the only user the admin role")
  280. return "admin"
  281. if not user and user_count == 0:
  282. # If there are no users, assign the role "admin", as the first user will be an admin
  283. log.debug("Assigning the first user the admin role")
  284. return "admin"
  285. if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
  286. log.debug("Running OAUTH Role management")
  287. oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
  288. oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
  289. oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
  290. oauth_roles = []
  291. # Default/fallback role if no matching roles are found
  292. role = auth_manager_config.DEFAULT_USER_ROLE
  293. # Next block extracts the roles from the user data, accepting nested claims of any depth
  294. if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
  295. claim_data = user_data
  296. nested_claims = oauth_claim.split(".")
  297. for nested_claim in nested_claims:
  298. claim_data = claim_data.get(nested_claim, {})
  299. oauth_roles = []
  300. if isinstance(claim_data, list):
  301. oauth_roles = claim_data
  302. if isinstance(claim_data, str) or isinstance(claim_data, int):
  303. oauth_roles = [str(claim_data)]
  304. log.debug(f"Oauth Roles claim: {oauth_claim}")
  305. log.debug(f"User roles from oauth: {oauth_roles}")
  306. log.debug(f"Accepted user roles: {oauth_allowed_roles}")
  307. log.debug(f"Accepted admin roles: {oauth_admin_roles}")
  308. # If any roles are found, check if they match the allowed or admin roles
  309. if oauth_roles:
  310. # If role management is enabled, and matching roles are provided, use the roles
  311. for allowed_role in oauth_allowed_roles:
  312. # If the user has any of the allowed roles, assign the role "user"
  313. if allowed_role in oauth_roles:
  314. log.debug("Assigned user the user role")
  315. role = "user"
  316. break
  317. for admin_role in oauth_admin_roles:
  318. # If the user has any of the admin roles, assign the role "admin"
  319. if admin_role in oauth_roles:
  320. log.debug("Assigned user the admin role")
  321. role = "admin"
  322. break
  323. else:
  324. if not user:
  325. # If role management is disabled, use the default role for new users
  326. role = auth_manager_config.DEFAULT_USER_ROLE
  327. else:
  328. # If role management is disabled, use the existing role for existing users
  329. role = user.role
  330. return role
  331. def update_user_groups(self, user, user_data, default_permissions):
  332. log.debug("Running OAUTH Group management")
  333. oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
  334. try:
  335. blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS)
  336. except Exception as e:
  337. log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}")
  338. blocked_groups = []
  339. user_oauth_groups = []
  340. # Nested claim search for groups claim
  341. if oauth_claim:
  342. claim_data = user_data
  343. nested_claims = oauth_claim.split(".")
  344. for nested_claim in nested_claims:
  345. claim_data = claim_data.get(nested_claim, {})
  346. if isinstance(claim_data, list):
  347. user_oauth_groups = claim_data
  348. elif isinstance(claim_data, str):
  349. user_oauth_groups = [claim_data]
  350. else:
  351. user_oauth_groups = []
  352. user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
  353. all_available_groups: list[GroupModel] = Groups.get_groups()
  354. # Create groups if they don't exist and creation is enabled
  355. if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
  356. log.debug("Checking for missing groups to create...")
  357. all_group_names = {g.name for g in all_available_groups}
  358. groups_created = False
  359. # Determine creator ID: Prefer admin, fallback to current user if no admin exists
  360. admin_user = Users.get_super_admin_user()
  361. creator_id = admin_user.id if admin_user else user.id
  362. log.debug(f"Using creator ID {creator_id} for potential group creation.")
  363. for group_name in user_oauth_groups:
  364. if group_name not in all_group_names:
  365. log.info(
  366. f"Group '{group_name}' not found via OAuth claim. Creating group..."
  367. )
  368. try:
  369. new_group_form = GroupForm(
  370. name=group_name,
  371. description=f"Group '{group_name}' created automatically via OAuth.",
  372. permissions=default_permissions, # Use default permissions from function args
  373. user_ids=[], # Start with no users, user will be added later by subsequent logic
  374. )
  375. # Use determined creator ID (admin or fallback to current user)
  376. created_group = Groups.insert_new_group(
  377. creator_id, new_group_form
  378. )
  379. if created_group:
  380. log.info(
  381. f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
  382. )
  383. groups_created = True
  384. # Add to local set to prevent duplicate creation attempts in this run
  385. all_group_names.add(group_name)
  386. else:
  387. log.error(
  388. f"Failed to create group '{group_name}' via OAuth."
  389. )
  390. except Exception as e:
  391. log.error(f"Error creating group '{group_name}' via OAuth: {e}")
  392. # Refresh the list of all available groups if any were created
  393. if groups_created:
  394. all_available_groups = Groups.get_groups()
  395. log.debug("Refreshed list of all available groups after creation.")
  396. log.debug(f"Oauth Groups claim: {oauth_claim}")
  397. log.debug(f"User oauth groups: {user_oauth_groups}")
  398. log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
  399. log.debug(
  400. f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
  401. )
  402. # Remove groups that user is no longer a part of
  403. for group_model in user_current_groups:
  404. if (
  405. user_oauth_groups
  406. and group_model.name not in user_oauth_groups
  407. and not is_in_blocked_groups(group_model.name, blocked_groups)
  408. ):
  409. # Remove group from user
  410. log.debug(
  411. f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
  412. )
  413. user_ids = group_model.user_ids
  414. user_ids = [i for i in user_ids if i != user.id]
  415. # In case a group is created, but perms are never assigned to the group by hitting "save"
  416. group_permissions = group_model.permissions
  417. if not group_permissions:
  418. group_permissions = default_permissions
  419. update_form = GroupUpdateForm(
  420. name=group_model.name,
  421. description=group_model.description,
  422. permissions=group_permissions,
  423. user_ids=user_ids,
  424. )
  425. Groups.update_group_by_id(
  426. id=group_model.id, form_data=update_form, overwrite=False
  427. )
  428. # Add user to new groups
  429. for group_model in all_available_groups:
  430. if (
  431. user_oauth_groups
  432. and group_model.name in user_oauth_groups
  433. and not any(gm.name == group_model.name for gm in user_current_groups)
  434. and not is_in_blocked_groups(group_model.name, blocked_groups)
  435. ):
  436. # Add user to group
  437. log.debug(
  438. f"Adding user to group {group_model.name} as it was found in their oauth groups"
  439. )
  440. user_ids = group_model.user_ids
  441. user_ids.append(user.id)
  442. # In case a group is created, but perms are never assigned to the group by hitting "save"
  443. group_permissions = group_model.permissions
  444. if not group_permissions:
  445. group_permissions = default_permissions
  446. update_form = GroupUpdateForm(
  447. name=group_model.name,
  448. description=group_model.description,
  449. permissions=group_permissions,
  450. user_ids=user_ids,
  451. )
  452. Groups.update_group_by_id(
  453. id=group_model.id, form_data=update_form, overwrite=False
  454. )
  455. async def _process_picture_url(
  456. self, picture_url: str, access_token: str = None
  457. ) -> str:
  458. """Process a picture URL and return a base64 encoded data URL.
  459. Args:
  460. picture_url: The URL of the picture to process
  461. access_token: Optional OAuth access token for authenticated requests
  462. Returns:
  463. A data URL containing the base64 encoded picture, or "/user.png" if processing fails
  464. """
  465. if not picture_url:
  466. return "/user.png"
  467. try:
  468. get_kwargs = {}
  469. if access_token:
  470. get_kwargs["headers"] = {
  471. "Authorization": f"Bearer {access_token}",
  472. }
  473. async with aiohttp.ClientSession(trust_env=True) as session:
  474. async with session.get(
  475. picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
  476. ) as resp:
  477. if resp.ok:
  478. picture = await resp.read()
  479. base64_encoded_picture = base64.b64encode(picture).decode(
  480. "utf-8"
  481. )
  482. guessed_mime_type = mimetypes.guess_type(picture_url)[0]
  483. if guessed_mime_type is None:
  484. guessed_mime_type = "image/jpeg"
  485. return (
  486. f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
  487. )
  488. else:
  489. log.warning(
  490. f"Failed to fetch profile picture from {picture_url}"
  491. )
  492. return "/user.png"
  493. except Exception as e:
  494. log.error(f"Error processing profile picture '{picture_url}': {e}")
  495. return "/user.png"
  496. async def handle_login(self, request, provider):
  497. if provider not in OAUTH_PROVIDERS:
  498. raise HTTPException(404)
  499. # If the provider has a custom redirect URL, use that, otherwise automatically generate one
  500. redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
  501. "oauth_callback", provider=provider
  502. )
  503. client = self.get_client(provider)
  504. if client is None:
  505. raise HTTPException(404)
  506. return await client.authorize_redirect(request, redirect_uri)
  507. async def handle_callback(self, request, provider, response):
  508. if provider not in OAUTH_PROVIDERS:
  509. raise HTTPException(404)
  510. error_message = None
  511. try:
  512. client = self.get_client(provider)
  513. try:
  514. token = await client.authorize_access_token(request)
  515. except Exception as e:
  516. log.warning(f"OAuth callback error: {e}")
  517. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  518. # Try to get userinfo from the token first, some providers include it there
  519. user_data: UserInfo = token.get("userinfo")
  520. if (
  521. (not user_data)
  522. or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
  523. or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
  524. ):
  525. user_data: UserInfo = await client.userinfo(token=token)
  526. if provider == "feishu" and isinstance(user_data, dict) and "data" in user_data:
  527. user_data = user_data["data"]
  528. if not user_data:
  529. log.warning(f"OAuth callback failed, user data is missing: {token}")
  530. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  531. # Extract the "sub" claim, using custom claim if configured
  532. if auth_manager_config.OAUTH_SUB_CLAIM:
  533. sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
  534. else:
  535. # Fallback to the default sub claim if not configured
  536. sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
  537. if not sub:
  538. log.warning(f"OAuth callback failed, sub is missing: {user_data}")
  539. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  540. provider_sub = f"{provider}@{sub}"
  541. # Email extraction
  542. email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
  543. email = user_data.get(email_claim, "")
  544. # We currently mandate that email addresses are provided
  545. if not email:
  546. # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
  547. if provider == "github":
  548. try:
  549. access_token = token.get("access_token")
  550. headers = {"Authorization": f"Bearer {access_token}"}
  551. async with aiohttp.ClientSession(trust_env=True) as session:
  552. async with session.get(
  553. "https://api.github.com/user/emails",
  554. headers=headers,
  555. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  556. ) as resp:
  557. if resp.ok:
  558. emails = await resp.json()
  559. # use the primary email as the user's email
  560. primary_email = next(
  561. (
  562. e["email"]
  563. for e in emails
  564. if e.get("primary")
  565. ),
  566. None,
  567. )
  568. if primary_email:
  569. email = primary_email
  570. else:
  571. log.warning(
  572. "No primary email found in GitHub response"
  573. )
  574. raise HTTPException(
  575. 400, detail=ERROR_MESSAGES.INVALID_CRED
  576. )
  577. else:
  578. log.warning("Failed to fetch GitHub email")
  579. raise HTTPException(
  580. 400, detail=ERROR_MESSAGES.INVALID_CRED
  581. )
  582. except Exception as e:
  583. log.warning(f"Error fetching GitHub email: {e}")
  584. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  585. else:
  586. log.warning(f"OAuth callback failed, email is missing: {user_data}")
  587. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  588. email = email.lower()
  589. # If allowed domains are configured, check if the email domain is in the list
  590. if (
  591. "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
  592. and email.split("@")[-1]
  593. not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
  594. ):
  595. log.warning(
  596. f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
  597. )
  598. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  599. # Check if the user exists
  600. user = Users.get_user_by_oauth_sub(provider_sub)
  601. if not user:
  602. # If the user does not exist, check if merging is enabled
  603. if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
  604. # Check if the user exists by email
  605. user = Users.get_user_by_email(email)
  606. if user:
  607. # Update the user with the new oauth sub
  608. Users.update_user_oauth_sub_by_id(user.id, provider_sub)
  609. if user:
  610. determined_role = self.get_user_role(user, user_data)
  611. if user.role != determined_role:
  612. Users.update_user_role_by_id(user.id, determined_role)
  613. # Update profile picture if enabled and different from current
  614. if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
  615. picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
  616. if picture_claim:
  617. new_picture_url = user_data.get(
  618. picture_claim,
  619. OAUTH_PROVIDERS[provider].get("picture_url", ""),
  620. )
  621. processed_picture_url = await self._process_picture_url(
  622. new_picture_url, token.get("access_token")
  623. )
  624. if processed_picture_url != user.profile_image_url:
  625. Users.update_user_profile_image_url_by_id(
  626. user.id, processed_picture_url
  627. )
  628. log.debug(f"Updated profile picture for user {user.email}")
  629. else:
  630. # If the user does not exist, check if signups are enabled
  631. if auth_manager_config.ENABLE_OAUTH_SIGNUP:
  632. # Check if an existing user with the same email already exists
  633. existing_user = Users.get_user_by_email(email)
  634. if existing_user:
  635. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  636. picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
  637. if picture_claim:
  638. picture_url = user_data.get(
  639. picture_claim,
  640. OAUTH_PROVIDERS[provider].get("picture_url", ""),
  641. )
  642. picture_url = await self._process_picture_url(
  643. picture_url, token.get("access_token")
  644. )
  645. else:
  646. picture_url = "/user.png"
  647. username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
  648. name = user_data.get(username_claim)
  649. if not name:
  650. log.warning("Username claim is missing, using email as name")
  651. name = email
  652. user = Auths.insert_new_auth(
  653. email=email,
  654. password=get_password_hash(
  655. str(uuid.uuid4())
  656. ), # Random password, not used
  657. name=name,
  658. profile_image_url=picture_url,
  659. role=self.get_user_role(None, user_data),
  660. oauth_sub=provider_sub,
  661. )
  662. if auth_manager_config.WEBHOOK_URL:
  663. await post_webhook(
  664. WEBUI_NAME,
  665. auth_manager_config.WEBHOOK_URL,
  666. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  667. {
  668. "action": "signup",
  669. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  670. "user": user.model_dump_json(exclude_none=True),
  671. },
  672. )
  673. else:
  674. raise HTTPException(
  675. status.HTTP_403_FORBIDDEN,
  676. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  677. )
  678. jwt_token = create_token(
  679. data={"id": user.id},
  680. expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
  681. )
  682. if (
  683. auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
  684. and user.role != "admin"
  685. ):
  686. self.update_user_groups(
  687. user=user,
  688. user_data=user_data,
  689. default_permissions=request.app.state.config.USER_PERMISSIONS,
  690. )
  691. except Exception as e:
  692. log.error(f"Error during OAuth process: {e}")
  693. error_message = (
  694. e.detail
  695. if isinstance(e, HTTPException) and e.detail
  696. else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
  697. )
  698. redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
  699. if redirect_base_url.endswith("/"):
  700. redirect_base_url = redirect_base_url[:-1]
  701. redirect_url = f"{redirect_base_url}/auth"
  702. if error_message:
  703. redirect_url = f"{redirect_url}?error={error_message}"
  704. return RedirectResponse(url=redirect_url, headers=response.headers)
  705. response = RedirectResponse(url=redirect_url, headers=response.headers)
  706. # Set the cookie token
  707. # Redirect back to the frontend with the JWT token
  708. response.set_cookie(
  709. key="token",
  710. value=jwt_token,
  711. httponly=False, # Required for frontend access
  712. samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
  713. secure=WEBUI_AUTH_COOKIE_SECURE,
  714. )
  715. # Legacy cookies for compatibility with older frontend versions
  716. if ENABLE_OAUTH_ID_TOKEN_COOKIE:
  717. response.set_cookie(
  718. key="oauth_id_token",
  719. value=token.get("id_token"),
  720. httponly=True,
  721. samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
  722. secure=WEBUI_AUTH_COOKIE_SECURE,
  723. )
  724. try:
  725. # Add timestamp for tracking
  726. token["issued_at"] = datetime.now().timestamp()
  727. # Calculate expires_at if we have expires_in
  728. if "expires_in" in token and "expires_at" not in token:
  729. token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
  730. # Clean up any existing sessions for this user/provider first
  731. sessions = OAuthSessions.get_sessions_by_user_id(user.id)
  732. for session in sessions:
  733. if session.provider == provider:
  734. OAuthSessions.delete_session_by_id(session.id)
  735. session = OAuthSessions.create_session(
  736. user_id=user.id,
  737. provider=provider,
  738. token=token,
  739. )
  740. response.set_cookie(
  741. key="oauth_session_id",
  742. value=session.id,
  743. httponly=True,
  744. samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
  745. secure=WEBUI_AUTH_COOKIE_SECURE,
  746. )
  747. log.info(
  748. f"Stored OAuth session server-side for user {user.id}, provider {provider}"
  749. )
  750. except Exception as e:
  751. log.error(f"Failed to store OAuth session server-side: {e}")
  752. return response