Browse Source

fix: oauth token

Timothy Jaeryang Baek 2 weeks ago
parent
commit
e4c4ba0979

+ 1 - 1
backend/open_webui/functions.py

@@ -239,7 +239,7 @@ async def generate_function_chat_completion(
     oauth_token = None
     try:
         if request.cookies.get("oauth_session_id", None):
-            oauth_token = request.app.state.oauth_manager.get_oauth_token(
+            oauth_token = await request.app.state.oauth_manager.get_oauth_token(
                 user.id,
                 request.cookies.get("oauth_session_id", None),
             )

+ 10 - 8
backend/open_webui/routers/openai.py

@@ -121,7 +121,7 @@ def openai_reasoning_model_handler(payload):
     return payload
 
 
-def get_headers_and_cookies(
+async def get_headers_and_cookies(
     request: Request,
     url,
     key=None,
@@ -174,7 +174,7 @@ def get_headers_and_cookies(
         oauth_token = None
         try:
             if request.cookies.get("oauth_session_id", None):
-                oauth_token = request.app.state.oauth_manager.get_oauth_token(
+                oauth_token = await request.app.state.oauth_manager.get_oauth_token(
                     user.id,
                     request.cookies.get("oauth_session_id", None),
                 )
@@ -305,7 +305,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             request.app.state.config.OPENAI_API_CONFIGS.get(url, {}),  # Legacy support
         )
 
-        headers, cookies = get_headers_and_cookies(
+        headers, cookies = await get_headers_and_cookies(
             request, url, key, api_config, user=user
         )
 
@@ -570,7 +570,7 @@ async def get_models(
             timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
         ) as session:
             try:
-                headers, cookies = get_headers_and_cookies(
+                headers, cookies = await get_headers_and_cookies(
                     request, url, key, api_config, user=user
                 )
 
@@ -656,7 +656,7 @@ async def verify_connection(
         timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
     ) as session:
         try:
-            headers, cookies = get_headers_and_cookies(
+            headers, cookies = await get_headers_and_cookies(
                 request, url, key, api_config, user=user
             )
 
@@ -901,7 +901,7 @@ async def generate_chat_completion(
             convert_logit_bias_input_to_json(payload["logit_bias"])
         )
 
-    headers, cookies = get_headers_and_cookies(
+    headers, cookies = await get_headers_and_cookies(
         request, url, key, api_config, metadata, user=user
     )
 
@@ -1010,7 +1010,9 @@ async def embeddings(request: Request, form_data: dict, user):
     session = None
     streaming = False
 
-    headers, cookies = get_headers_and_cookies(request, url, key, api_config, user=user)
+    headers, cookies = await get_headers_and_cookies(
+        request, url, key, api_config, user=user
+    )
     try:
         session = aiohttp.ClientSession(trust_env=True)
         r = await session.request(
@@ -1080,7 +1082,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     streaming = False
 
     try:
-        headers, cookies = get_headers_and_cookies(
+        headers, cookies = await get_headers_and_cookies(
             request, url, key, api_config, user=user
         )
 

+ 2 - 2
backend/open_webui/utils/middleware.py

@@ -818,7 +818,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     oauth_token = None
     try:
         if request.cookies.get("oauth_session_id", None):
-            oauth_token = request.app.state.oauth_manager.get_oauth_token(
+            oauth_token = await request.app.state.oauth_manager.get_oauth_token(
                 user.id,
                 request.cookies.get("oauth_session_id", None),
             )
@@ -1498,7 +1498,7 @@ async def process_chat_response(
     oauth_token = None
     try:
         if request.cookies.get("oauth_session_id", None):
-            oauth_token = request.app.state.oauth_manager.get_oauth_token(
+            oauth_token = await request.app.state.oauth_manager.get_oauth_token(
                 user.id,
                 request.cookies.get("oauth_session_id", None),
             )

+ 2 - 2
backend/open_webui/utils/oauth.py

@@ -157,7 +157,7 @@ class OAuthManager:
             )
         return None
 
-    def get_oauth_token(
+    async def get_oauth_token(
         self, user_id: str, session_id: str, force_refresh: bool = False
     ):
         """
@@ -186,7 +186,7 @@ class OAuthManager:
                 log.debug(
                     f"Token refresh needed for user {user_id}, provider {session.provider}"
                 )
-                refreshed_token = self._refresh_token(session)
+                refreshed_token = await self._refresh_token(session)
                 if refreshed_token:
                     return refreshed_token
                 else: