浏览代码

Merge pull request #13581 from kaytwo/refreshOauthPfp

feat: refresh oauth profile picture
Tim Jaeryang Baek 5 月之前
父节点
当前提交
410af53eca
共有 2 个文件被更改,包括 70 次插入34 次删除
  1. 6 0
      backend/open_webui/config.py
  2. 64 34
      backend/open_webui/utils/oauth.py

+ 6 - 0
backend/open_webui/config.py

@@ -552,6 +552,12 @@ OAUTH_ALLOWED_DOMAINS = PersistentConfig(
     ],
 )
 
+OAUTH_UPDATE_PICTURE_ON_LOGIN = PersistentConfig(
+    "OAUTH_UPDATE_PICTURE_ON_LOGIN",
+    "oauth.update_picture_on_login",
+    os.environ.get("OAUTH_UPDATE_PICTURE_ON_LOGIN", "False").lower() == "true",
+)
+
 
 def load_oauth_providers():
     OAUTH_PROVIDERS.clear()

+ 64 - 34
backend/open_webui/utils/oauth.py

@@ -34,6 +34,7 @@ from open_webui.config import (
     OAUTH_ALLOWED_ROLES,
     OAUTH_ADMIN_ROLES,
     OAUTH_ALLOWED_DOMAINS,
+    OAUTH_UPDATE_PICTURE_ON_LOGIN,
     WEBHOOK_URL,
     JWT_EXPIRES_IN,
     AppConfig,
@@ -72,6 +73,7 @@ auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
 auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS
 auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
 auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
+auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
 
 
 class OAuthManager:
@@ -282,6 +284,49 @@ class OAuthManager:
                     id=group_model.id, form_data=update_form, overwrite=False
                 )
 
+    async def _process_picture_url(
+        self, picture_url: str, access_token: str = None
+    ) -> str:
+        """Process a picture URL and return a base64 encoded data URL.
+
+        Args:
+            picture_url: The URL of the picture to process
+            access_token: Optional OAuth access token for authenticated requests
+
+        Returns:
+            A data URL containing the base64 encoded picture, or "/user.png" if processing fails
+        """
+        if not picture_url:
+            return "/user.png"
+
+        try:
+            get_kwargs = {}
+            if access_token:
+                get_kwargs["headers"] = {
+                    "Authorization": f"Bearer {access_token}",
+                }
+            async with aiohttp.ClientSession() as session:
+                async with session.get(picture_url, **get_kwargs) as resp:
+                    if resp.ok:
+                        picture = await resp.read()
+                        base64_encoded_picture = base64.b64encode(picture).decode(
+                            "utf-8"
+                        )
+                        guessed_mime_type = mimetypes.guess_type(picture_url)[0]
+                        if guessed_mime_type is None:
+                            guessed_mime_type = "image/jpeg"
+                        return (
+                            f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
+                        )
+                    else:
+                        log.warning(
+                            f"Failed to fetch profile picture from {picture_url}"
+                        )
+                        return "/user.png"
+        except Exception as e:
+            log.error(f"Error processing profile picture '{picture_url}': {e}")
+            return "/user.png"
+
     async def handle_login(self, request, provider):
         if provider not in OAUTH_PROVIDERS:
             raise HTTPException(404)
@@ -382,6 +427,22 @@ class OAuthManager:
             if user.role != determined_role:
                 Users.update_user_role_by_id(user.id, determined_role)
 
+            # Update profile picture if enabled and different from current
+            if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
+                picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
+                if picture_claim:
+                    new_picture_url = user_data.get(
+                        picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
+                    )
+                    processed_picture_url = await self._process_picture_url(
+                        new_picture_url, token.get("access_token")
+                    )
+                    if processed_picture_url != user.profile_image_url:
+                        Users.update_user_profile_image_url_by_id(
+                            user.id, processed_picture_url
+                        )
+                        log.debug(f"Updated profile picture for user {user.email}")
+
         if not user:
             user_count = Users.get_num_users()
 
@@ -397,40 +458,9 @@ class OAuthManager:
                     picture_url = user_data.get(
                         picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
                     )
-                    if picture_url:
-                        # Download the profile image into a base64 string
-                        try:
-                            access_token = token.get("access_token")
-                            get_kwargs = {}
-                            if access_token:
-                                get_kwargs["headers"] = {
-                                    "Authorization": f"Bearer {access_token}",
-                                }
-                            async with aiohttp.ClientSession(trust_env=True) as session:
-                                async with session.get(
-                                    picture_url, **get_kwargs
-                                ) as resp:
-                                    if resp.ok:
-                                        picture = await resp.read()
-                                        base64_encoded_picture = base64.b64encode(
-                                            picture
-                                        ).decode("utf-8")
-                                        guessed_mime_type = mimetypes.guess_type(
-                                            picture_url
-                                        )[0]
-                                        if guessed_mime_type is None:
-                                            # assume JPG, browsers are tolerant enough of image formats
-                                            guessed_mime_type = "image/jpeg"
-                                        picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
-                                    else:
-                                        picture_url = "/user.png"
-                        except Exception as e:
-                            log.error(
-                                f"Error downloading profile image '{picture_url}': {e}"
-                            )
-                            picture_url = "/user.png"
-                    if not picture_url:
-                        picture_url = "/user.png"
+                    picture_url = await self._process_picture_url(
+                        picture_url, token.get("access_token")
+                    )
                 else:
                     picture_url = "/user.png"