|
@@ -126,24 +126,24 @@ except Exception as e:
|
|
|
raise
|
|
|
|
|
|
|
|
|
-def encrypt_token(token) -> str:
|
|
|
- """Encrypt OAuth tokens for storage"""
|
|
|
+def encrypt_data(data) -> str:
|
|
|
+ """Encrypt data for storage"""
|
|
|
try:
|
|
|
- token_json = json.dumps(token)
|
|
|
- encrypted = FERNET.encrypt(token_json.encode()).decode()
|
|
|
+ data_json = json.dumps(data)
|
|
|
+ encrypted = FERNET.encrypt(data_json.encode()).decode()
|
|
|
return encrypted
|
|
|
except Exception as e:
|
|
|
- log.error(f"Error encrypting tokens: {e}")
|
|
|
+ log.error(f"Error encrypting data: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
-def decrypt_token(token: str):
|
|
|
- """Decrypt OAuth tokens from storage"""
|
|
|
+def decrypt_data(data: str):
|
|
|
+ """Decrypt data from storage"""
|
|
|
try:
|
|
|
- decrypted = FERNET.decrypt(token.encode()).decode()
|
|
|
+ decrypted = FERNET.decrypt(data.encode()).decode()
|
|
|
return json.loads(decrypted)
|
|
|
except Exception as e:
|
|
|
- log.error(f"Error decrypting tokens: {e}")
|
|
|
+ log.error(f"Error decrypting data: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
@@ -212,7 +212,10 @@ def get_discovery_urls(server_url) -> list[str]:
|
|
|
# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
|
|
|
# This is not currently supported.
|
|
|
async def get_oauth_client_info_with_dynamic_client_registration(
|
|
|
- request, oauth_server_url, oauth_server_key: Optional[str] = None
|
|
|
+ request,
|
|
|
+ client_id: str,
|
|
|
+ oauth_server_url: str,
|
|
|
+ oauth_server_key: Optional[str] = None,
|
|
|
) -> OAuthClientInformationFull:
|
|
|
try:
|
|
|
oauth_server_metadata = None
|
|
@@ -221,9 +224,10 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
|
|
redirect_base_url = (
|
|
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
|
).rstrip("/")
|
|
|
+
|
|
|
oauth_client_metadata = OAuthClientMetadata(
|
|
|
client_name="Open WebUI",
|
|
|
- redirect_uris=[f"{redirect_base_url}/oauth/callback"],
|
|
|
+ redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
|
|
|
grant_types=["authorization_code", "refresh_token"],
|
|
|
response_types=["code"],
|
|
|
token_endpoint_auth_method="client_secret_post",
|
|
@@ -315,23 +319,22 @@ class OAuthClientManager:
|
|
|
self.clients = {}
|
|
|
|
|
|
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
|
|
|
- if client_id not in self.clients:
|
|
|
- self.clients[client_id] = {
|
|
|
- "client": self.oauth.register(
|
|
|
- name=client_id,
|
|
|
- client_id=oauth_client_info.client_id,
|
|
|
- client_secret=oauth_client_info.client_secret,
|
|
|
- client_kwargs=(
|
|
|
- {"scope": oauth_client_info.scope}
|
|
|
- if oauth_client_info.scope
|
|
|
- else {}
|
|
|
- ),
|
|
|
- server_metadata_url=(
|
|
|
- oauth_client_info.issuer if oauth_client_info.issuer else None
|
|
|
- ),
|
|
|
+ self.clients[client_id] = {
|
|
|
+ "client": self.oauth.register(
|
|
|
+ name=client_id,
|
|
|
+ client_id=oauth_client_info.client_id,
|
|
|
+ client_secret=oauth_client_info.client_secret,
|
|
|
+ client_kwargs=(
|
|
|
+ {"scope": oauth_client_info.scope}
|
|
|
+ if oauth_client_info.scope
|
|
|
+ else {}
|
|
|
),
|
|
|
- "client_info": oauth_client_info,
|
|
|
- }
|
|
|
+ server_metadata_url=(
|
|
|
+ oauth_client_info.issuer if oauth_client_info.issuer else None
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ "client_info": oauth_client_info,
|
|
|
+ }
|
|
|
return self.clients[client_id]
|
|
|
|
|
|
def remove_client(self, client_id):
|
|
@@ -359,7 +362,7 @@ class OAuthClientManager:
|
|
|
return None
|
|
|
|
|
|
async def get_oauth_token(
|
|
|
- self, user_id: str, session_id: str, force_refresh: bool = False
|
|
|
+ self, user_id: str, client_id: str, force_refresh: bool = False
|
|
|
):
|
|
|
"""
|
|
|
Get a valid OAuth token for the user, automatically refreshing if needed.
|
|
@@ -374,10 +377,12 @@ class OAuthClientManager:
|
|
|
"""
|
|
|
try:
|
|
|
# Get the OAuth session
|
|
|
- session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
|
|
|
+ session = OAuthSessions.get_session_by_provider_and_user_id(
|
|
|
+ client_id, user_id
|
|
|
+ )
|
|
|
if not session:
|
|
|
log.warning(
|
|
|
- f"No OAuth session found for user {user_id}, session {session_id}"
|
|
|
+ f"No OAuth session found for user {user_id}, client_id {client_id}"
|
|
|
)
|
|
|
return None
|
|
|
|
|
@@ -392,8 +397,9 @@ class OAuthClientManager:
|
|
|
return refreshed_token
|
|
|
else:
|
|
|
log.warning(
|
|
|
- f"Token refresh failed for user {user_id}, client_id {session.provider}"
|
|
|
+ f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
|
|
|
)
|
|
|
+ OAuthSessions.delete_session_by_id(session.id)
|
|
|
return None
|
|
|
return session.token
|
|
|
|
|
@@ -533,7 +539,7 @@ class OAuthClientManager:
|
|
|
redirect_uri = (
|
|
|
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
|
|
)
|
|
|
- return await client.authorize_redirect(request, redirect_uri)
|
|
|
+ return await client.authorize_redirect(request, str(redirect_uri))
|
|
|
|
|
|
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
|
|
client = self.get_client(client_id)
|
|
@@ -565,7 +571,6 @@ class OAuthClientManager:
|
|
|
provider=client_id,
|
|
|
token=token,
|
|
|
)
|
|
|
-
|
|
|
log.info(
|
|
|
f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
|
|
|
)
|
|
@@ -579,16 +584,17 @@ class OAuthClientManager:
|
|
|
error_message = "OAuth callback error"
|
|
|
log.warning(f"OAuth callback error: {e}")
|
|
|
|
|
|
- redirect_base_url = (
|
|
|
+ redirect_url = (
|
|
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
|
).rstrip("/")
|
|
|
- redirect_url = f"{redirect_base_url}/auth"
|
|
|
|
|
|
if error_message:
|
|
|
- redirect_url = f"{redirect_url}?error={error_message}"
|
|
|
+ log.debug(error_message)
|
|
|
+ redirect_url = f"{redirect_url}/?error={error_message}"
|
|
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
|
|
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
+ return response
|
|
|
|
|
|
|
|
|
class OAuthManager:
|
|
@@ -649,8 +655,10 @@ class OAuthManager:
|
|
|
return refreshed_token
|
|
|
else:
|
|
|
log.warning(
|
|
|
- f"Token refresh failed for user {user_id}, provider {session.provider}"
|
|
|
+ f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
|
|
|
)
|
|
|
+ OAuthSessions.delete_session_by_id(session.id)
|
|
|
+
|
|
|
return None
|
|
|
return session.token
|
|
|
|