audit.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. from contextlib import asynccontextmanager
  2. from dataclasses import asdict, dataclass
  3. from enum import Enum
  4. import re
  5. from typing import (
  6. TYPE_CHECKING,
  7. Any,
  8. AsyncGenerator,
  9. Dict,
  10. MutableMapping,
  11. Optional,
  12. cast,
  13. )
  14. import uuid
  15. from asgiref.typing import (
  16. ASGI3Application,
  17. ASGIReceiveCallable,
  18. ASGIReceiveEvent,
  19. ASGISendCallable,
  20. ASGISendEvent,
  21. Scope as ASGIScope,
  22. )
  23. from loguru import logger
  24. from starlette.requests import Request
  25. from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
  26. from open_webui.utils.auth import get_current_user, get_http_authorization_cred
  27. from open_webui.models.users import UserModel
  28. if TYPE_CHECKING:
  29. from loguru import Logger
  30. @dataclass(frozen=True)
  31. class AuditLogEntry:
  32. # `Metadata` audit level properties
  33. id: str
  34. user: Optional[dict[str, Any]]
  35. audit_level: str
  36. verb: str
  37. request_uri: str
  38. user_agent: Optional[str] = None
  39. source_ip: Optional[str] = None
  40. # `Request` audit level properties
  41. request_object: Any = None
  42. # `Request Response` level
  43. response_object: Any = None
  44. response_status_code: Optional[int] = None
  45. class AuditLevel(str, Enum):
  46. NONE = "NONE"
  47. METADATA = "METADATA"
  48. REQUEST = "REQUEST"
  49. REQUEST_RESPONSE = "REQUEST_RESPONSE"
  50. class AuditLogger:
  51. """
  52. A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
  53. Parameters:
  54. logger (Logger): An instance of Loguru’s logger.
  55. """
  56. def __init__(self, logger: "Logger"):
  57. self.logger = logger.bind(auditable=True)
  58. def write(
  59. self,
  60. audit_entry: AuditLogEntry,
  61. *,
  62. log_level: str = "INFO",
  63. extra: Optional[dict] = None,
  64. ):
  65. entry = asdict(audit_entry)
  66. if extra:
  67. entry["extra"] = extra
  68. self.logger.log(
  69. log_level,
  70. "",
  71. **entry,
  72. )
  73. class AuditContext:
  74. """
  75. Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
  76. Attributes:
  77. request_body (bytearray): Accumulated request payload.
  78. response_body (bytearray): Accumulated response payload.
  79. max_body_size (int): Maximum number of bytes to capture.
  80. metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
  81. """
  82. def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
  83. self.request_body = bytearray()
  84. self.response_body = bytearray()
  85. self.max_body_size = max_body_size
  86. self.metadata: Dict[str, Any] = {}
  87. def add_request_chunk(self, chunk: bytes):
  88. if len(self.request_body) < self.max_body_size:
  89. self.request_body.extend(
  90. chunk[: self.max_body_size - len(self.request_body)]
  91. )
  92. def add_response_chunk(self, chunk: bytes):
  93. if len(self.response_body) < self.max_body_size:
  94. self.response_body.extend(
  95. chunk[: self.max_body_size - len(self.response_body)]
  96. )
  97. class AuditLoggingMiddleware:
  98. """
  99. ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
  100. """
  101. AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
  102. def __init__(
  103. self,
  104. app: ASGI3Application,
  105. *,
  106. excluded_paths: Optional[list[str]] = None,
  107. max_body_size: int = MAX_BODY_LOG_SIZE,
  108. audit_level: AuditLevel = AuditLevel.NONE,
  109. ) -> None:
  110. self.app = app
  111. self.audit_logger = AuditLogger(logger)
  112. self.excluded_paths = excluded_paths or []
  113. self.max_body_size = max_body_size
  114. self.audit_level = audit_level
  115. async def __call__(
  116. self,
  117. scope: ASGIScope,
  118. receive: ASGIReceiveCallable,
  119. send: ASGISendCallable,
  120. ) -> None:
  121. if scope["type"] != "http":
  122. return await self.app(scope, receive, send)
  123. request = Request(scope=cast(MutableMapping, scope))
  124. if self._should_skip_auditing(request):
  125. return await self.app(scope, receive, send)
  126. async with self._audit_context(request) as context:
  127. async def send_wrapper(message: ASGISendEvent) -> None:
  128. if self.audit_level == AuditLevel.REQUEST_RESPONSE:
  129. await self._capture_response(message, context)
  130. await send(message)
  131. original_receive = receive
  132. async def receive_wrapper() -> ASGIReceiveEvent:
  133. nonlocal original_receive
  134. message = await original_receive()
  135. if self.audit_level in (
  136. AuditLevel.REQUEST,
  137. AuditLevel.REQUEST_RESPONSE,
  138. ):
  139. await self._capture_request(message, context)
  140. return message
  141. await self.app(scope, receive_wrapper, send_wrapper)
  142. @asynccontextmanager
  143. async def _audit_context(
  144. self, request: Request
  145. ) -> AsyncGenerator[AuditContext, None]:
  146. """
  147. async context manager that ensures that an audit log entry is recorded after the request is processed.
  148. """
  149. context = AuditContext()
  150. try:
  151. yield context
  152. finally:
  153. await self._log_audit_entry(request, context)
  154. async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]:
  155. auth_header = request.headers.get("Authorization")
  156. try:
  157. user = get_current_user(
  158. request, None, get_http_authorization_cred(auth_header)
  159. )
  160. return user
  161. except Exception as e:
  162. logger.debug(f"Failed to get authenticated user: {str(e)}")
  163. return None
  164. def _should_skip_auditing(self, request: Request) -> bool:
  165. if (
  166. request.method not in {"POST", "PUT", "PATCH", "DELETE"}
  167. or AUDIT_LOG_LEVEL == "NONE"
  168. ):
  169. return True
  170. ALWAYS_LOG_ENDPOINTS = {
  171. "/api/v1/auths/signin",
  172. "/api/v1/auths/signout",
  173. "/api/v1/auths/signup",
  174. }
  175. path = request.url.path.lower()
  176. for endpoint in ALWAYS_LOG_ENDPOINTS:
  177. if path.startswith(endpoint):
  178. return False # Do NOT skip logging for auth endpoints
  179. # Skip logging if the request is not authenticated
  180. if not request.headers.get("authorization"):
  181. return True
  182. # match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
  183. pattern = re.compile(
  184. r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
  185. )
  186. if pattern.match(request.url.path):
  187. return True
  188. return False
  189. async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
  190. if message["type"] == "http.request":
  191. body = message.get("body", b"")
  192. context.add_request_chunk(body)
  193. async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
  194. if message["type"] == "http.response.start":
  195. context.metadata["response_status_code"] = message["status"]
  196. elif message["type"] == "http.response.body":
  197. body = message.get("body", b"")
  198. context.add_response_chunk(body)
  199. async def _log_audit_entry(self, request: Request, context: AuditContext):
  200. try:
  201. user = await self._get_authenticated_user(request)
  202. user = (
  203. user.model_dump(include={"id", "name", "email", "role"}) if user else {}
  204. )
  205. request_body = context.request_body.decode("utf-8", errors="replace")
  206. response_body = context.response_body.decode("utf-8", errors="replace")
  207. # Redact sensitive information
  208. if "password" in request_body:
  209. request_body = re.sub(
  210. r'"password":\s*"(.*?)"',
  211. '"password": "********"',
  212. request_body,
  213. )
  214. entry = AuditLogEntry(
  215. id=str(uuid.uuid4()),
  216. user=user,
  217. audit_level=self.audit_level.value,
  218. verb=request.method,
  219. request_uri=str(request.url),
  220. response_status_code=context.metadata.get("response_status_code", None),
  221. source_ip=request.client.host if request.client else None,
  222. user_agent=request.headers.get("user-agent"),
  223. request_object=request_body,
  224. response_object=response_body,
  225. )
  226. self.audit_logger.write(entry)
  227. except Exception as e:
  228. logger.error(f"Failed to log audit entry: {str(e)}")