Просмотр исходного кода

refac: decouple api key restrictions from get user

Timothy Jaeryang Baek 3 месяцев назад
Родитель
Сommit
b160eef7eb
2 измененных файлов с 41 добавлено и 23 удалено
  1. 41 4
      backend/open_webui/main.py
  2. 0 19
      backend/open_webui/utils/auth.py

+ 41 - 4
backend/open_webui/main.py

@@ -1218,6 +1218,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
 
 app.state.MODELS = {}
 
+# Add the middleware to the app
+if ENABLE_COMPRESSION_MIDDLEWARE:
+    app.add_middleware(CompressMiddleware)
+
 
 class RedirectMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
@@ -1259,14 +1263,47 @@ class RedirectMiddleware(BaseHTTPMiddleware):
         return response
 
 
-# Add the middleware to the app
-if ENABLE_COMPRESSION_MIDDLEWARE:
-    app.add_middleware(CompressMiddleware)
-
 app.add_middleware(RedirectMiddleware)
 app.add_middleware(SecurityHeadersMiddleware)
 
 
+class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        auth_header = request.headers.get("Authorization")
+
+        # Only apply restrictions if an sk- API key is used
+        if auth_header and auth_header.startswith("sk-"):
+            # Check if restrictions are enabled
+            if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
+                allowed_paths = [
+                    path.strip()
+                    for path in str(
+                        request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
+                    ).split(",")
+                    if path.strip()
+                ]
+
+                request_path = request.url.path
+
+                # Match exact path or prefix path
+                is_allowed = any(
+                    request_path == allowed or request_path.startswith(allowed + "/")
+                    for allowed in allowed_paths
+                )
+
+                if not is_allowed:
+                    raise HTTPException(
+                        status_code=status.HTTP_403_FORBIDDEN,
+                        detail="API key not allowed to access this endpoint.",
+                    )
+
+        response = await call_next(request)
+        return response
+
+
+app.add_middleware(APIKeyRestrictionMiddleware)
+
+
 @app.middleware("http")
 async def commit_session_after_request(request: Request, call_next):
     response = await call_next(request)

+ 0 - 19
backend/open_webui/utils/auth.py

@@ -233,24 +233,6 @@ def get_current_user(
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
             )
 
-        if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
-            allowed_paths = [
-                path.strip()
-                for path in str(
-                    request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
-                ).split(",")
-            ]
-
-            # Check if the request path matches any allowed endpoint.
-            if not any(
-                request.url.path == allowed
-                or request.url.path.startswith(allowed + "/")
-                for allowed in allowed_paths
-            ):
-                raise HTTPException(
-                    status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
-                )
-
         user = get_current_user_by_api_key(token)
 
         # Add user info to current span
@@ -260,7 +242,6 @@ def get_current_user(
             current_span.set_attribute("client.user.email", user.email)
             current_span.set_attribute("client.user.role", user.role)
             current_span.set_attribute("client.auth.type", "api_key")
-
         return user
 
     # auth by jwt token