Timothy Jaeryang Baek 3 月之前
父節點
當前提交
22af53f60c
共有 1 個文件被更改,包括 6 次插入4 次删除
  1. 6 4
      backend/open_webui/config.py

+ 6 - 4
backend/open_webui/config.py

@@ -13,6 +13,8 @@ from urllib.parse import urlparse
 import requests
 import requests
 from pydantic import BaseModel
 from pydantic import BaseModel
 from sqlalchemy import JSON, Column, DateTime, Integer, func
 from sqlalchemy import JSON, Column, DateTime, Integer, func
+from authlib.integrations.starlette_client import OAuth
+
 
 
 from open_webui.env import (
 from open_webui.env import (
     DATA_DIR,
     DATA_DIR,
@@ -546,7 +548,7 @@ def load_oauth_providers():
     OAUTH_PROVIDERS.clear()
     OAUTH_PROVIDERS.clear()
     if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
     if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
 
 
-        def google_oauth_register(client):
+        def google_oauth_register(client: OAuth):
             client.register(
             client.register(
                 name="google",
                 name="google",
                 client_id=GOOGLE_CLIENT_ID.value,
                 client_id=GOOGLE_CLIENT_ID.value,
@@ -574,7 +576,7 @@ def load_oauth_providers():
         and MICROSOFT_CLIENT_TENANT_ID.value
         and MICROSOFT_CLIENT_TENANT_ID.value
     ):
     ):
 
 
-        def microsoft_oauth_register(client):
+        def microsoft_oauth_register(client: OAuth):
             client.register(
             client.register(
                 name="microsoft",
                 name="microsoft",
                 client_id=MICROSOFT_CLIENT_ID.value,
                 client_id=MICROSOFT_CLIENT_ID.value,
@@ -599,7 +601,7 @@ def load_oauth_providers():
 
 
     if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value:
     if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value:
 
 
-        def github_oauth_register(client):
+        def github_oauth_register(client: OAuth):
             client.register(
             client.register(
                 name="github",
                 name="github",
                 client_id=GITHUB_CLIENT_ID.value,
                 client_id=GITHUB_CLIENT_ID.value,
@@ -631,7 +633,7 @@ def load_oauth_providers():
         and OPENID_PROVIDER_URL.value
         and OPENID_PROVIDER_URL.value
     ):
     ):
 
 
-        def oidc_oauth_register(client):
+        def oidc_oauth_register(client: OAuth):
             client_kwargs = {
             client_kwargs = {
                 "scope": OAUTH_SCOPES.value,
                 "scope": OAUTH_SCOPES.value,
                 **(
                 **(