""" WebSocket Manager - WebSocket 连接管理器 Manages client connections, topic subscriptions, and broadcasting. """ import asyncio import json import logging from datetime import datetime, timezone from fastapi import WebSocket logger = logging.getLogger(__name__) # Maximum concurrent WebSocket connections MAX_CONNECTIONS = 100 # Valid topics VALID_TOPICS = {"location", "alarm", "device_status", "attendance", "bluetooth"} class WebSocketManager: """Manages WebSocket connections with topic-based subscriptions.""" def __init__(self): # {websocket: set_of_topics} self.active_connections: dict[WebSocket, set[str]] = {} @property def connection_count(self) -> int: return len(self.active_connections) async def connect(self, websocket: WebSocket, topics: set[str]) -> bool: """Accept and register a WebSocket connection. Returns False if limit reached.""" if self.connection_count >= MAX_CONNECTIONS: await websocket.close(code=1013, reason="Max connections reached") return False await websocket.accept() filtered = topics & VALID_TOPICS self.active_connections[websocket] = filtered if filtered else VALID_TOPICS logger.info( "WebSocket connected (%d total), topics: %s", self.connection_count, self.active_connections[websocket], ) return True def disconnect(self, websocket: WebSocket): """Remove a WebSocket connection.""" self.active_connections.pop(websocket, None) logger.info("WebSocket disconnected (%d remaining)", self.connection_count) async def broadcast(self, topic: str, data: dict): """Broadcast a message to all subscribers of the given topic.""" if topic not in VALID_TOPICS: return message = json.dumps( {"topic": topic, "data": data, "timestamp": datetime.now(timezone.utc).isoformat()}, default=str, ensure_ascii=False, ) disconnected = [] # Snapshot dict to avoid RuntimeError from concurrent modification for ws, topics in list(self.active_connections.items()): if topic in topics: try: await ws.send_text(message) except Exception: disconnected.append(ws) for ws in disconnected: self.active_connections.pop(ws, None) def broadcast_nonblocking(self, topic: str, data: dict): """Fire-and-forget broadcast (used from TCP handler context).""" asyncio.create_task(self._safe_broadcast(topic, data)) async def _safe_broadcast(self, topic: str, data: dict): try: await self.broadcast(topic, data) except Exception: logger.exception("WebSocket broadcast error for topic %s", topic) # Singleton instance ws_manager = WebSocketManager()