Files
chat/backend/app/websocket/manager.py
T
2026-06-15 21:21:20 +08:00

82 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""WebSocket 连接管理器"""
import json
from datetime import datetime, timezone
from typing import Dict, Set
from fastapi import WebSocket
class ConnectionManager:
"""管理所有 WebSocket 连接"""
def __init__(self):
# user_id -> set of WebSocket connections (一个用户可能有多个标签页)
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
"""接受新连接"""
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
def disconnect(self, websocket: WebSocket, user_id: str):
"""断开连接"""
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: str, event_type: str, data: dict):
"""向指定用户发送事件"""
if user_id not in self.active_connections:
return
message = json.dumps({
"type": event_type,
"data": data,
"timestamp": datetime.utcnow().isoformat(),
}, ensure_ascii=False, default=str)
disconnected = set()
for ws in self.active_connections[user_id]:
try:
await ws.send_text(message)
except Exception:
disconnected.add(ws)
# 清理断开的连接
for ws in disconnected:
self.active_connections[user_id].discard(ws)
async def broadcast_to_conversation(self, user_ids: list[str],
event_type: str, data: dict,
exclude_user: str | None = None):
"""向会话中所有用户广播事件"""
for uid in user_ids:
if uid != exclude_user:
await self.send_to_user(uid, event_type, data)
async def broadcast(self, event_type: str, data: dict):
"""向所有在线用户广播"""
for user_id in list(self.active_connections.keys()):
await self.send_to_user(user_id, event_type, data)
def is_online(self, user_id: str) -> bool:
"""检查用户是否在线"""
return user_id in self.active_connections and len(self.active_connections[user_id]) > 0
def get_online_user_ids(self) -> list[str]:
"""获取所有在线用户 ID"""
return list(self.active_connections.keys())
# 全局单例
# ⚠️ 重要约束:本管理器为「进程内存」单例 —— active_connections 只存在于
# 当前进程。因此后端必须以「单进程」方式运行(生产镜像 Dockerfile.prod
# 已固定为单 worker uvicorn)。若使用 --workers N 或 gunicorn 多 worker
# 连接会分散到不同进程,跨用户 / 跨标签页的实时消息(聊天、撤回、好友请求、
# 互动通知等)将无法投递。水平扩展前需先将其迁移到基于 Redis 的 pub/sub。
manager = ConnectionManager()