Browse Source

enh: failed login attempts audit log

Timothy Jaeryang Baek 3 months ago
parent
commit
f2314596ba
1 changed files with 44 additions and 10 deletions
  1. 44 10
      backend/open_webui/utils/audit.py

+ 44 - 10
backend/open_webui/utils/audit.py

@@ -37,7 +37,7 @@ if TYPE_CHECKING:
 class AuditLogEntry:
     # `Metadata` audit level properties
     id: str
-    user: dict[str, Any]
+    user: Optional[dict[str, Any]]
     audit_level: str
     verb: str
     request_uri: str
@@ -190,21 +190,40 @@ class AuditLoggingMiddleware:
         finally:
             await self._log_audit_entry(request, context)
 
-    async def _get_authenticated_user(self, request: Request) -> UserModel:
-
+    async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]:
         auth_header = request.headers.get("Authorization")
-        assert auth_header
-        user = get_current_user(request, None, get_http_authorization_cred(auth_header))
 
-        return user
+        try:
+            user = get_current_user(
+                request, None, get_http_authorization_cred(auth_header)
+            )
+            return user
+        except Exception as e:
+            logger.debug(f"Failed to get authenticated user: {str(e)}")
+
+        return None
 
     def _should_skip_auditing(self, request: Request) -> bool:
         if (
             request.method not in {"POST", "PUT", "PATCH", "DELETE"}
             or AUDIT_LOG_LEVEL == "NONE"
-            or not request.headers.get("authorization")
         ):
             return True
+
+        ALWAYS_LOG_ENDPOINTS = {
+            "/api/v1/auths/signin",
+            "/api/v1/auths/signout",
+            "/api/v1/auths/signup",
+        }
+        path = request.url.path.lower()
+        for endpoint in ALWAYS_LOG_ENDPOINTS:
+            if path.startswith(endpoint):
+                return False  # Do NOT skip logging for auth endpoints
+
+        # Skip logging if the request is not authenticated
+        if not request.headers.get("authorization"):
+            return True
+
         # match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
         pattern = re.compile(
             r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
@@ -231,17 +250,32 @@ class AuditLoggingMiddleware:
         try:
             user = await self._get_authenticated_user(request)
 
+            user = (
+                user.model_dump(include={"id", "name", "email", "role"}) if user else {}
+            )
+
+            request_body = context.request_body.decode("utf-8", errors="replace")
+            response_body = context.response_body.decode("utf-8", errors="replace")
+
+            # Redact sensitive information
+            if "password" in request_body:
+                request_body = re.sub(
+                    r'"password":\s*"(.*?)"',
+                    '"password": "********"',
+                    request_body,
+                )
+
             entry = AuditLogEntry(
                 id=str(uuid.uuid4()),
-                user=user.model_dump(include={"id", "name", "email", "role"}),
+                user=user,
                 audit_level=self.audit_level.value,
                 verb=request.method,
                 request_uri=str(request.url),
                 response_status_code=context.metadata.get("response_status_code", None),
                 source_ip=request.client.host if request.client else None,
                 user_agent=request.headers.get("user-agent"),
-                request_object=context.request_body.decode("utf-8", errors="replace"),
-                response_object=context.response_body.decode("utf-8", errors="replace"),
+                request_object=request_body,
+                response_object=response_body,
             )
 
             self.audit_logger.write(entry)