Files
desungongpai/app/middleware.py

102 lines
3.4 KiB
Python
Raw Permalink Normal View History

"""
Audit logging middleware.
Records POST/PUT/DELETE requests to the audit_logs table.
"""
import json
import time
import logging
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from app.database import async_session
from app.models import AuditLog
logger = logging.getLogger(__name__)
# Methods to audit
_AUDIT_METHODS = {"POST", "PUT", "DELETE"}
# Paths to skip (noisy or non-business endpoints)
_SKIP_PREFIXES = ("/ws", "/health", "/docs", "/redoc", "/openapi.json")
# Max request body size to store (bytes)
_MAX_BODY_SIZE = 4096
def _get_client_ip(request: Request) -> str:
"""Extract real client IP from proxy headers."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
cf_ip = request.headers.get("CF-Connecting-IP")
if cf_ip:
return cf_ip
return request.client.host if request.client else "unknown"
def _get_operator(request: Request) -> str | None:
"""Extract operator name from request state (set by verify_api_key)."""
key_info = getattr(request.state, "key_info", None)
if key_info and isinstance(key_info, dict):
return key_info.get("name")
return None
class AuditMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
if request.method not in _AUDIT_METHODS:
return await call_next(request)
path = request.url.path
if any(path.startswith(p) for p in _SKIP_PREFIXES):
return await call_next(request)
# Only audit /api/ routes
if not path.startswith("/api/"):
return await call_next(request)
# Read request body for audit (cache it for downstream)
body_bytes = await request.body()
request_body = None
if body_bytes and len(body_bytes) <= _MAX_BODY_SIZE:
try:
request_body = json.loads(body_bytes)
# Redact sensitive fields
if isinstance(request_body, dict):
for key in ("password", "api_key", "key", "secret", "token"):
if key in request_body:
request_body[key] = "***REDACTED***"
except (json.JSONDecodeError, UnicodeDecodeError):
request_body = None
start = time.monotonic()
response = await call_next(request)
duration_ms = int((time.monotonic() - start) * 1000)
# Extract operator from dependency injection result
operator = _get_operator(request)
# Build response summary
response_summary = f"HTTP {response.status_code}"
try:
async with async_session() as session:
async with session.begin():
session.add(AuditLog(
method=request.method,
path=path,
status_code=response.status_code,
operator=operator,
client_ip=_get_client_ip(request),
request_body=request_body,
response_summary=response_summary,
duration_ms=duration_ms,
))
except Exception:
logger.debug("Failed to write audit log for %s %s", request.method, path)
return response