|
|
@@ -1218,6 +1218,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
|
|
|
|
|
app.state.MODELS = {}
|
|
|
|
|
|
+# Add the middleware to the app
|
|
|
+if ENABLE_COMPRESSION_MIDDLEWARE:
|
|
|
+ app.add_middleware(CompressMiddleware)
|
|
|
+
|
|
|
|
|
|
class RedirectMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
@@ -1259,14 +1263,47 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
|
|
return response
|
|
|
|
|
|
|
|
|
-# Add the middleware to the app
|
|
|
-if ENABLE_COMPRESSION_MIDDLEWARE:
|
|
|
- app.add_middleware(CompressMiddleware)
|
|
|
-
|
|
|
app.add_middleware(RedirectMiddleware)
|
|
|
app.add_middleware(SecurityHeadersMiddleware)
|
|
|
|
|
|
|
|
|
+class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
|
|
|
+ async def dispatch(self, request: Request, call_next):
|
|
|
+ auth_header = request.headers.get("Authorization")
|
|
|
+
|
|
|
+ # Only apply restrictions if an sk- API key is used
|
|
|
+ if auth_header and auth_header.startswith("sk-"):
|
|
|
+ # Check if restrictions are enabled
|
|
|
+ if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
|
|
|
+ allowed_paths = [
|
|
|
+ path.strip()
|
|
|
+ for path in str(
|
|
|
+ request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
|
|
|
+ ).split(",")
|
|
|
+ if path.strip()
|
|
|
+ ]
|
|
|
+
|
|
|
+ request_path = request.url.path
|
|
|
+
|
|
|
+ # Match exact path or prefix path
|
|
|
+ is_allowed = any(
|
|
|
+ request_path == allowed or request_path.startswith(allowed + "/")
|
|
|
+ for allowed in allowed_paths
|
|
|
+ )
|
|
|
+
|
|
|
+ if not is_allowed:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_403_FORBIDDEN,
|
|
|
+ detail="API key not allowed to access this endpoint.",
|
|
|
+ )
|
|
|
+
|
|
|
+ response = await call_next(request)
|
|
|
+ return response
|
|
|
+
|
|
|
+
|
|
|
+app.add_middleware(APIKeyRestrictionMiddleware)
|
|
|
+
|
|
|
+
|
|
|
@app.middleware("http")
|
|
|
async def commit_session_after_request(request: Request, call_next):
|
|
|
response = await call_next(request)
|