102 lines
3.4 KiB
Python
102 lines
3.4 KiB
Python
|
|
"""
|
||
|
|
Audit logging middleware.
|
||
|
|
Records POST/PUT/DELETE requests to the audit_logs table.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
import logging
|
||
|
|
|
||
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||
|
|
from starlette.requests import Request
|
||
|
|
from starlette.responses import Response
|
||
|
|
|
||
|
|
from app.database import async_session
|
||
|
|
from app.models import AuditLog
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Methods to audit
|
||
|
|
_AUDIT_METHODS = {"POST", "PUT", "DELETE"}
|
||
|
|
|
||
|
|
# Paths to skip (noisy or non-business endpoints)
|
||
|
|
_SKIP_PREFIXES = ("/ws", "/health", "/docs", "/redoc", "/openapi.json")
|
||
|
|
|
||
|
|
# Max request body size to store (bytes)
|
||
|
|
_MAX_BODY_SIZE = 4096
|
||
|
|
|
||
|
|
|
||
|
|
def _get_client_ip(request: Request) -> str:
|
||
|
|
"""Extract real client IP from proxy headers."""
|
||
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
||
|
|
if forwarded:
|
||
|
|
return forwarded.split(",")[0].strip()
|
||
|
|
cf_ip = request.headers.get("CF-Connecting-IP")
|
||
|
|
if cf_ip:
|
||
|
|
return cf_ip
|
||
|
|
return request.client.host if request.client else "unknown"
|
||
|
|
|
||
|
|
|
||
|
|
def _get_operator(request: Request) -> str | None:
|
||
|
|
"""Extract operator name from request state (set by verify_api_key)."""
|
||
|
|
key_info = getattr(request.state, "key_info", None)
|
||
|
|
if key_info and isinstance(key_info, dict):
|
||
|
|
return key_info.get("name")
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
class AuditMiddleware(BaseHTTPMiddleware):
|
||
|
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||
|
|
if request.method not in _AUDIT_METHODS:
|
||
|
|
return await call_next(request)
|
||
|
|
|
||
|
|
path = request.url.path
|
||
|
|
if any(path.startswith(p) for p in _SKIP_PREFIXES):
|
||
|
|
return await call_next(request)
|
||
|
|
|
||
|
|
# Only audit /api/ routes
|
||
|
|
if not path.startswith("/api/"):
|
||
|
|
return await call_next(request)
|
||
|
|
|
||
|
|
# Read request body for audit (cache it for downstream)
|
||
|
|
body_bytes = await request.body()
|
||
|
|
request_body = None
|
||
|
|
if body_bytes and len(body_bytes) <= _MAX_BODY_SIZE:
|
||
|
|
try:
|
||
|
|
request_body = json.loads(body_bytes)
|
||
|
|
# Redact sensitive fields
|
||
|
|
if isinstance(request_body, dict):
|
||
|
|
for key in ("password", "api_key", "key", "secret", "token"):
|
||
|
|
if key in request_body:
|
||
|
|
request_body[key] = "***REDACTED***"
|
||
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||
|
|
request_body = None
|
||
|
|
|
||
|
|
start = time.monotonic()
|
||
|
|
response = await call_next(request)
|
||
|
|
duration_ms = int((time.monotonic() - start) * 1000)
|
||
|
|
|
||
|
|
# Extract operator from dependency injection result
|
||
|
|
operator = _get_operator(request)
|
||
|
|
|
||
|
|
# Build response summary
|
||
|
|
response_summary = f"HTTP {response.status_code}"
|
||
|
|
|
||
|
|
try:
|
||
|
|
async with async_session() as session:
|
||
|
|
async with session.begin():
|
||
|
|
session.add(AuditLog(
|
||
|
|
method=request.method,
|
||
|
|
path=path,
|
||
|
|
status_code=response.status_code,
|
||
|
|
operator=operator,
|
||
|
|
client_ip=_get_client_ip(request),
|
||
|
|
request_body=request_body,
|
||
|
|
response_summary=response_summary,
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
))
|
||
|
|
except Exception:
|
||
|
|
logger.debug("Failed to write audit log for %s %s", request.method, path)
|
||
|
|
|
||
|
|
return response
|