"""WebSocket 连接管理器""" import json from datetime import datetime, timezone from typing import Dict, Set from fastapi import WebSocket class ConnectionManager: """管理所有 WebSocket 连接""" def __init__(self): # user_id -> set of WebSocket connections (一个用户可能有多个标签页) self.active_connections: Dict[str, Set[WebSocket]] = {} async def connect(self, websocket: WebSocket, user_id: str): """接受新连接""" await websocket.accept() if user_id not in self.active_connections: self.active_connections[user_id] = set() self.active_connections[user_id].add(websocket) def disconnect(self, websocket: WebSocket, user_id: str): """断开连接""" if user_id in self.active_connections: self.active_connections[user_id].discard(websocket) if not self.active_connections[user_id]: del self.active_connections[user_id] async def send_to_user(self, user_id: str, event_type: str, data: dict): """向指定用户发送事件""" if user_id not in self.active_connections: return message = json.dumps({ "type": event_type, "data": data, "timestamp": datetime.utcnow().isoformat(), }, ensure_ascii=False, default=str) disconnected = set() for ws in self.active_connections[user_id]: try: await ws.send_text(message) except Exception: disconnected.add(ws) # 清理断开的连接 for ws in disconnected: self.active_connections[user_id].discard(ws) async def broadcast_to_conversation(self, user_ids: list[str], event_type: str, data: dict, exclude_user: str | None = None): """向会话中所有用户广播事件""" for uid in user_ids: if uid != exclude_user: await self.send_to_user(uid, event_type, data) async def broadcast(self, event_type: str, data: dict): """向所有在线用户广播""" for user_id in list(self.active_connections.keys()): await self.send_to_user(user_id, event_type, data) def is_online(self, user_id: str) -> bool: """检查用户是否在线""" return user_id in self.active_connections and len(self.active_connections[user_id]) > 0 def get_online_user_ids(self) -> list[str]: """获取所有在线用户 ID""" return list(self.active_connections.keys()) # 全局单例 manager = ConnectionManager()