|
@@ -401,185 +401,207 @@ class OAuthManager:
|
|
|
async def handle_callback(self, request, provider, response):
|
|
|
if provider not in OAUTH_PROVIDERS:
|
|
|
raise HTTPException(404)
|
|
|
- client = self.get_client(provider)
|
|
|
+
|
|
|
+ error_message = None
|
|
|
try:
|
|
|
- token = await client.authorize_access_token(request)
|
|
|
- except Exception as e:
|
|
|
- log.warning(f"OAuth callback error: {e}")
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
- user_data: UserInfo = token.get("userinfo")
|
|
|
- if (
|
|
|
- (not user_data)
|
|
|
- or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
|
|
|
- or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
|
|
|
- ):
|
|
|
- user_data: UserInfo = await client.userinfo(token=token)
|
|
|
- if not user_data:
|
|
|
- log.warning(f"OAuth callback failed, user data is missing: {token}")
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+ client = self.get_client(provider)
|
|
|
+ try:
|
|
|
+ token = await client.authorize_access_token(request)
|
|
|
+ except Exception as e:
|
|
|
+ log.warning(f"OAuth callback error: {e}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+ user_data: UserInfo = token.get("userinfo")
|
|
|
+ if (
|
|
|
+ (not user_data)
|
|
|
+ or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
|
|
|
+ or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
|
|
|
+ ):
|
|
|
+ user_data: UserInfo = await client.userinfo(token=token)
|
|
|
+ if not user_data:
|
|
|
+ log.warning(f"OAuth callback failed, user data is missing: {token}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
|
|
- if auth_manager_config.OAUTH_SUB_CLAIM:
|
|
|
- sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
|
|
- else:
|
|
|
- # Fallback to the default sub claim if not configured
|
|
|
- sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
|
|
-
|
|
|
- if not sub:
|
|
|
- log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
-
|
|
|
- provider_sub = f"{provider}@{sub}"
|
|
|
-
|
|
|
- email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
|
|
- email = user_data.get(email_claim, "")
|
|
|
- # We currently mandate that email addresses are provided
|
|
|
- if not email:
|
|
|
- # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
|
|
|
- if provider == "github":
|
|
|
- try:
|
|
|
- access_token = token.get("access_token")
|
|
|
- headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
- async with aiohttp.ClientSession(trust_env=True) as session:
|
|
|
- async with session.get(
|
|
|
- "https://api.github.com/user/emails",
|
|
|
- headers=headers,
|
|
|
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
- ) as resp:
|
|
|
- if resp.ok:
|
|
|
- emails = await resp.json()
|
|
|
- # use the primary email as the user's email
|
|
|
- primary_email = next(
|
|
|
- (e["email"] for e in emails if e.get("primary")),
|
|
|
- None,
|
|
|
- )
|
|
|
- if primary_email:
|
|
|
- email = primary_email
|
|
|
- else:
|
|
|
- log.warning(
|
|
|
- "No primary email found in GitHub response"
|
|
|
+ if auth_manager_config.OAUTH_SUB_CLAIM:
|
|
|
+ sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
|
|
+ else:
|
|
|
+ # Fallback to the default sub claim if not configured
|
|
|
+ sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
|
|
+
|
|
|
+ if not sub:
|
|
|
+ log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+
|
|
|
+ provider_sub = f"{provider}@{sub}"
|
|
|
+
|
|
|
+ email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
|
|
+ email = user_data.get(email_claim, "")
|
|
|
+ # We currently mandate that email addresses are provided
|
|
|
+ if not email:
|
|
|
+ # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
|
|
|
+ if provider == "github":
|
|
|
+ try:
|
|
|
+ access_token = token.get("access_token")
|
|
|
+ headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session:
|
|
|
+ async with session.get(
|
|
|
+ "https://api.github.com/user/emails",
|
|
|
+ headers=headers,
|
|
|
+ ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
+ ) as resp:
|
|
|
+ if resp.ok:
|
|
|
+ emails = await resp.json()
|
|
|
+ # use the primary email as the user's email
|
|
|
+ primary_email = next(
|
|
|
+ (
|
|
|
+ e["email"]
|
|
|
+ for e in emails
|
|
|
+ if e.get("primary")
|
|
|
+ ),
|
|
|
+ None,
|
|
|
)
|
|
|
+ if primary_email:
|
|
|
+ email = primary_email
|
|
|
+ else:
|
|
|
+ log.warning(
|
|
|
+ "No primary email found in GitHub response"
|
|
|
+ )
|
|
|
+ raise HTTPException(
|
|
|
+ 400, detail=ERROR_MESSAGES.INVALID_CRED
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ log.warning("Failed to fetch GitHub email")
|
|
|
raise HTTPException(
|
|
|
400, detail=ERROR_MESSAGES.INVALID_CRED
|
|
|
)
|
|
|
- else:
|
|
|
- log.warning("Failed to fetch GitHub email")
|
|
|
- raise HTTPException(
|
|
|
- 400, detail=ERROR_MESSAGES.INVALID_CRED
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- log.warning(f"Error fetching GitHub email: {e}")
|
|
|
+ except Exception as e:
|
|
|
+ log.warning(f"Error fetching GitHub email: {e}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+ else:
|
|
|
+ log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
- else:
|
|
|
- log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
|
|
+ email = email.lower()
|
|
|
+ if (
|
|
|
+ "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
+ and email.split("@")[-1]
|
|
|
+ not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
+ ):
|
|
|
+ log.warning(
|
|
|
+ f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
|
|
|
+ )
|
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
- email = email.lower()
|
|
|
- if (
|
|
|
- "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
- and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
- ):
|
|
|
- log.warning(
|
|
|
- f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
|
|
|
- )
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
-
|
|
|
- # Check if the user exists
|
|
|
- user = Users.get_user_by_oauth_sub(provider_sub)
|
|
|
-
|
|
|
- if not user:
|
|
|
- # If the user does not exist, check if merging is enabled
|
|
|
- if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
|
|
- # Check if the user exists by email
|
|
|
- user = Users.get_user_by_email(email)
|
|
|
- if user:
|
|
|
- # Update the user with the new oauth sub
|
|
|
- Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
|
-
|
|
|
- if user:
|
|
|
- determined_role = self.get_user_role(user, user_data)
|
|
|
- if user.role != determined_role:
|
|
|
- Users.update_user_role_by_id(user.id, determined_role)
|
|
|
-
|
|
|
- # Update profile picture if enabled and different from current
|
|
|
- if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
|
|
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
- if picture_claim:
|
|
|
- new_picture_url = user_data.get(
|
|
|
- picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
|
|
|
- )
|
|
|
- processed_picture_url = await self._process_picture_url(
|
|
|
- new_picture_url, token.get("access_token")
|
|
|
- )
|
|
|
- if processed_picture_url != user.profile_image_url:
|
|
|
- Users.update_user_profile_image_url_by_id(
|
|
|
- user.id, processed_picture_url
|
|
|
+
|
|
|
+ # Check if the user exists
|
|
|
+ user = Users.get_user_by_oauth_sub(provider_sub)
|
|
|
+
|
|
|
+ if not user:
|
|
|
+ # If the user does not exist, check if merging is enabled
|
|
|
+ if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
|
|
+ # Check if the user exists by email
|
|
|
+ user = Users.get_user_by_email(email)
|
|
|
+ if user:
|
|
|
+ # Update the user with the new oauth sub
|
|
|
+ Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
|
+
|
|
|
+ if user:
|
|
|
+ determined_role = self.get_user_role(user, user_data)
|
|
|
+ if user.role != determined_role:
|
|
|
+ Users.update_user_role_by_id(user.id, determined_role)
|
|
|
+
|
|
|
+ # Update profile picture if enabled and different from current
|
|
|
+ if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
|
|
+ picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
+ if picture_claim:
|
|
|
+ new_picture_url = user_data.get(
|
|
|
+ picture_claim,
|
|
|
+ OAUTH_PROVIDERS[provider].get("picture_url", ""),
|
|
|
)
|
|
|
- log.debug(f"Updated profile picture for user {user.email}")
|
|
|
-
|
|
|
- if not user:
|
|
|
- # If the user does not exist, check if signups are enabled
|
|
|
- if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
|
|
- # Check if an existing user with the same email already exists
|
|
|
- existing_user = Users.get_user_by_email(email)
|
|
|
- if existing_user:
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
-
|
|
|
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
- if picture_claim:
|
|
|
- picture_url = user_data.get(
|
|
|
- picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
|
|
|
- )
|
|
|
- picture_url = await self._process_picture_url(
|
|
|
- picture_url, token.get("access_token")
|
|
|
+ processed_picture_url = await self._process_picture_url(
|
|
|
+ new_picture_url, token.get("access_token")
|
|
|
+ )
|
|
|
+ if processed_picture_url != user.profile_image_url:
|
|
|
+ Users.update_user_profile_image_url_by_id(
|
|
|
+ user.id, processed_picture_url
|
|
|
+ )
|
|
|
+ log.debug(f"Updated profile picture for user {user.email}")
|
|
|
+
|
|
|
+ if not user:
|
|
|
+ # If the user does not exist, check if signups are enabled
|
|
|
+ if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
|
|
+ # Check if an existing user with the same email already exists
|
|
|
+ existing_user = Users.get_user_by_email(email)
|
|
|
+ if existing_user:
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
+
|
|
|
+ picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
+ if picture_claim:
|
|
|
+ picture_url = user_data.get(
|
|
|
+ picture_claim,
|
|
|
+ OAUTH_PROVIDERS[provider].get("picture_url", ""),
|
|
|
+ )
|
|
|
+ picture_url = await self._process_picture_url(
|
|
|
+ picture_url, token.get("access_token")
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ picture_url = "/user.png"
|
|
|
+
|
|
|
+ username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
|
|
+
|
|
|
+ name = user_data.get(username_claim)
|
|
|
+ if not name:
|
|
|
+ log.warning("Username claim is missing, using email as name")
|
|
|
+ name = email
|
|
|
+
|
|
|
+ role = self.get_user_role(None, user_data)
|
|
|
+
|
|
|
+ user = Auths.insert_new_auth(
|
|
|
+ email=email,
|
|
|
+ password=get_password_hash(
|
|
|
+ str(uuid.uuid4())
|
|
|
+ ), # Random password, not used
|
|
|
+ name=name,
|
|
|
+ profile_image_url=picture_url,
|
|
|
+ role=role,
|
|
|
+ oauth_sub=provider_sub,
|
|
|
)
|
|
|
- else:
|
|
|
- picture_url = "/user.png"
|
|
|
-
|
|
|
- username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
|
|
-
|
|
|
- name = user_data.get(username_claim)
|
|
|
- if not name:
|
|
|
- log.warning("Username claim is missing, using email as name")
|
|
|
- name = email
|
|
|
-
|
|
|
- role = self.get_user_role(None, user_data)
|
|
|
-
|
|
|
- user = Auths.insert_new_auth(
|
|
|
- email=email,
|
|
|
- password=get_password_hash(
|
|
|
- str(uuid.uuid4())
|
|
|
- ), # Random password, not used
|
|
|
- name=name,
|
|
|
- profile_image_url=picture_url,
|
|
|
- role=role,
|
|
|
- oauth_sub=provider_sub,
|
|
|
- )
|
|
|
|
|
|
- if auth_manager_config.WEBHOOK_URL:
|
|
|
- await post_webhook(
|
|
|
- WEBUI_NAME,
|
|
|
- auth_manager_config.WEBHOOK_URL,
|
|
|
- WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
- {
|
|
|
- "action": "signup",
|
|
|
- "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
- "user": user.model_dump_json(exclude_none=True),
|
|
|
- },
|
|
|
+ if auth_manager_config.WEBHOOK_URL:
|
|
|
+ await post_webhook(
|
|
|
+ WEBUI_NAME,
|
|
|
+ auth_manager_config.WEBHOOK_URL,
|
|
|
+ WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
+ {
|
|
|
+ "action": "signup",
|
|
|
+ "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
+ "user": user.model_dump_json(exclude_none=True),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise HTTPException(
|
|
|
+ status.HTTP_403_FORBIDDEN,
|
|
|
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
|
)
|
|
|
- else:
|
|
|
- raise HTTPException(
|
|
|
- status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
|
|
- )
|
|
|
|
|
|
- jwt_token = create_token(
|
|
|
- data={"id": user.id},
|
|
|
- expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
|
|
- )
|
|
|
+ jwt_token = create_token(
|
|
|
+ data={"id": user.id},
|
|
|
+ expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
|
|
+ )
|
|
|
|
|
|
- if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin":
|
|
|
- self.update_user_groups(
|
|
|
- user=user,
|
|
|
- user_data=user_data,
|
|
|
- default_permissions=request.app.state.config.USER_PERMISSIONS,
|
|
|
+ if (
|
|
|
+ auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
|
|
|
+ and user.role != "admin"
|
|
|
+ ):
|
|
|
+ self.update_user_groups(
|
|
|
+ user=user,
|
|
|
+ user_data=user_data,
|
|
|
+ default_permissions=request.app.state.config.USER_PERMISSIONS,
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error during OAuth process: {e}")
|
|
|
+ error_message = (
|
|
|
+ e.detail
|
|
|
+ if isinstance(e, HTTPException) and e.detail
|
|
|
+ else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
|
|
|
)
|
|
|
|
|
|
redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
@@ -587,6 +609,10 @@ class OAuthManager:
|
|
|
redirect_base_url = redirect_base_url[:-1]
|
|
|
redirect_url = f"{redirect_base_url}/auth"
|
|
|
|
|
|
+ if error_message:
|
|
|
+ redirect_url = f"{redirect_url}?error={error_message}"
|
|
|
+ return RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
+
|
|
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
|
|
|
# Set the cookie token
|