""" Audit logging middleware. Records POST/PUT/DELETE requests to the audit_logs table. """ import json import time import logging from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response from app.database import async_session from app.models import AuditLog logger = logging.getLogger(__name__) # Methods to audit _AUDIT_METHODS = {"POST", "PUT", "DELETE"} # Paths to skip (noisy or non-business endpoints) _SKIP_PREFIXES = ("/ws", "/health", "/docs", "/redoc", "/openapi.json") # Max request body size to store (bytes) _MAX_BODY_SIZE = 4096 def _get_client_ip(request: Request) -> str: """Extract real client IP from proxy headers.""" forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() cf_ip = request.headers.get("CF-Connecting-IP") if cf_ip: return cf_ip return request.client.host if request.client else "unknown" def _get_operator(request: Request) -> str | None: """Extract operator name from request state (set by verify_api_key).""" key_info = getattr(request.state, "key_info", None) if key_info and isinstance(key_info, dict): return key_info.get("name") return None class AuditMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: if request.method not in _AUDIT_METHODS: return await call_next(request) path = request.url.path if any(path.startswith(p) for p in _SKIP_PREFIXES): return await call_next(request) # Only audit /api/ routes if not path.startswith("/api/"): return await call_next(request) # Read request body for audit (cache it for downstream) body_bytes = await request.body() request_body = None if body_bytes and len(body_bytes) <= _MAX_BODY_SIZE: try: request_body = json.loads(body_bytes) # Redact sensitive fields if isinstance(request_body, dict): for key in ("password", "api_key", "key", "secret", "token"): if key in request_body: request_body[key] = "***REDACTED***" except (json.JSONDecodeError, UnicodeDecodeError): request_body = None start = time.monotonic() response = await call_next(request) duration_ms = int((time.monotonic() - start) * 1000) # Extract operator from dependency injection result operator = _get_operator(request) # Build response summary response_summary = f"HTTP {response.status_code}" try: async with async_session() as session: async with session.begin(): session.add(AuditLog( method=request.method, path=path, status_code=response.status_code, operator=operator, client_ip=_get_client_ip(request), request_body=request_body, response_summary=response_summary, duration_ms=duration_ms, )) except Exception: logger.debug("Failed to write audit log for %s %s", request.method, path) return response