首个可运行的版本
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""WebSocket 包"""
|
||||
@@ -0,0 +1,58 @@
|
||||
"""WebSocket 事件类型定义"""
|
||||
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
# 客户端 -> 服务端
|
||||
CHAT_SEND = "chat.send"
|
||||
CHAT_TYPING = "chat.typing"
|
||||
CHAT_READ = "chat.read"
|
||||
PRESENCE_UPDATE = "presence.update"
|
||||
|
||||
# 服务端 -> 客户端
|
||||
CHAT_MESSAGE = "chat.message"
|
||||
CHAT_TYPING_INDICATOR = "chat.typing"
|
||||
CHAT_READ_RECEIPT = "chat.read"
|
||||
CHAT_MESSAGE_DELETED = "chat.message_deleted"
|
||||
CONVERSATION_UPDATED = "conversation.updated"
|
||||
CONVERSATION_MEMBER_ADDED = "conversation.member_added"
|
||||
CONVERSATION_MEMBER_REMOVED = "conversation.member_removed"
|
||||
FRIEND_REQUEST = "friend.request"
|
||||
FRIEND_ACCEPTED = "friend.accepted"
|
||||
PRESENCE_ONLINE = "presence.online"
|
||||
PRESENCE_OFFLINE = "presence.offline"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class WSEvent(BaseModel):
|
||||
"""WebSocket 事件信封"""
|
||||
type: str
|
||||
data: dict
|
||||
timestamp: str | None = None
|
||||
|
||||
|
||||
class ChatSendData(BaseModel):
|
||||
"""发送消息数据"""
|
||||
conversation_id: str
|
||||
content: str
|
||||
type: str = "text"
|
||||
reply_to_id: str | None = None
|
||||
|
||||
|
||||
class ChatTypingData(BaseModel):
|
||||
"""输入中数据"""
|
||||
conversation_id: str
|
||||
|
||||
|
||||
class ChatReadData(BaseModel):
|
||||
"""已读数据"""
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
class PresenceUpdateData(BaseModel):
|
||||
"""在线状态更新"""
|
||||
status: str # online / offline / away
|
||||
@@ -0,0 +1,106 @@
|
||||
"""WebSocket 事件处理器"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.websocket.events import EventType
|
||||
from app.websocket.manager import manager
|
||||
|
||||
|
||||
async def handle_chat_send(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理发送消息事件"""
|
||||
from app.services.message_service import MessageService
|
||||
service = MessageService(db)
|
||||
try:
|
||||
message = await service.send_message(
|
||||
conversation_id=data["conversation_id"],
|
||||
sender_id=user_id,
|
||||
content=data["content"],
|
||||
msg_type=data.get("type", "text"),
|
||||
reply_to_id=data.get("reply_to_id"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# 获取发送者信息
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
sender = await user_service.get_by_id(user_id)
|
||||
|
||||
# 获取会话成员列表
|
||||
from app.services.conversation_service import ConversationService
|
||||
conv_service = ConversationService(db)
|
||||
detail = await conv_service.get_conversation_detail(data["conversation_id"], user_id)
|
||||
|
||||
msg_data = {
|
||||
"id": message.id,
|
||||
"conversation_id": message.conversation_id,
|
||||
"sender_id": user_id,
|
||||
"sender_name": sender.username if sender else "未知",
|
||||
"sender_avatar": sender.avatar_url if sender else None,
|
||||
"type": message.type,
|
||||
"content": message.content,
|
||||
"created_at": message.created_at.isoformat(),
|
||||
}
|
||||
|
||||
# 广播给会话中的所有成员
|
||||
if detail and "members" in detail:
|
||||
member_ids = [m["user_id"] for m in detail["members"]]
|
||||
await manager.broadcast_to_conversation(
|
||||
member_ids, EventType.CHAT_MESSAGE, msg_data
|
||||
)
|
||||
except Exception as e:
|
||||
await manager.send_to_user(user_id, EventType.ERROR, {"message": str(e)})
|
||||
|
||||
|
||||
async def handle_chat_typing(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理输入中事件"""
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
user = await user_service.get_by_id(user_id)
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
conv_service = ConversationService(db)
|
||||
detail = await conv_service.get_conversation_detail(data["conversation_id"], user_id)
|
||||
|
||||
if detail and "members" in detail:
|
||||
member_ids = [m["user_id"] for m in detail["members"]]
|
||||
await manager.broadcast_to_conversation(
|
||||
member_ids, EventType.CHAT_TYPING_INDICATOR,
|
||||
{"conversation_id": data["conversation_id"], "user_id": user_id,
|
||||
"username": user.username if user else "未知"},
|
||||
exclude_user=user_id,
|
||||
)
|
||||
|
||||
|
||||
async def handle_chat_read(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理已读事件"""
|
||||
from app.services.message_service import MessageService
|
||||
service = MessageService(db)
|
||||
await service.mark_as_read(data["conversation_id"], user_id, data["message_id"])
|
||||
await db.commit()
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
conv_service = ConversationService(db)
|
||||
detail = await conv_service.get_conversation_detail(data["conversation_id"], user_id)
|
||||
|
||||
if detail and "members" in detail:
|
||||
member_ids = [m["user_id"] for m in detail["members"]]
|
||||
await manager.broadcast_to_conversation(
|
||||
member_ids, EventType.CHAT_READ_RECEIPT,
|
||||
{"conversation_id": data["conversation_id"], "user_id": user_id,
|
||||
"read_up_to": data["message_id"]},
|
||||
)
|
||||
|
||||
|
||||
async def handle_presence_update(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理在线状态更新"""
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
await user_service.update_status(user_id, data["status"])
|
||||
await db.commit()
|
||||
|
||||
event = EventType.PRESENCE_ONLINE if data["status"] == "online" else EventType.PRESENCE_OFFLINE
|
||||
await manager.broadcast(event, {"user_id": user_id})
|
||||
@@ -0,0 +1,76 @@
|
||||
"""WebSocket 连接管理器"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""管理所有 WebSocket 连接"""
|
||||
|
||||
def __init__(self):
|
||||
# user_id -> set of WebSocket connections (一个用户可能有多个标签页)
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, user_id: str):
|
||||
"""接受新连接"""
|
||||
await websocket.accept()
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = set()
|
||||
self.active_connections[user_id].add(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket, user_id: str):
|
||||
"""断开连接"""
|
||||
if user_id in self.active_connections:
|
||||
self.active_connections[user_id].discard(websocket)
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
async def send_to_user(self, user_id: str, event_type: str, data: dict):
|
||||
"""向指定用户发送事件"""
|
||||
if user_id not in self.active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps({
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}, ensure_ascii=False, default=str)
|
||||
|
||||
disconnected = set()
|
||||
for ws in self.active_connections[user_id]:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
disconnected.add(ws)
|
||||
|
||||
# 清理断开的连接
|
||||
for ws in disconnected:
|
||||
self.active_connections[user_id].discard(ws)
|
||||
|
||||
async def broadcast_to_conversation(self, user_ids: list[str],
|
||||
event_type: str, data: dict,
|
||||
exclude_user: str | None = None):
|
||||
"""向会话中所有用户广播事件"""
|
||||
for uid in user_ids:
|
||||
if uid != exclude_user:
|
||||
await self.send_to_user(uid, event_type, data)
|
||||
|
||||
async def broadcast(self, event_type: str, data: dict):
|
||||
"""向所有在线用户广播"""
|
||||
for user_id in list(self.active_connections.keys()):
|
||||
await self.send_to_user(user_id, event_type, data)
|
||||
|
||||
def is_online(self, user_id: str) -> bool:
|
||||
"""检查用户是否在线"""
|
||||
return user_id in self.active_connections and len(self.active_connections[user_id]) > 0
|
||||
|
||||
def get_online_user_ids(self) -> list[str]:
|
||||
"""获取所有在线用户 ID"""
|
||||
return list(self.active_connections.keys())
|
||||
|
||||
|
||||
# 全局单例
|
||||
manager = ConnectionManager()
|
||||
@@ -0,0 +1,101 @@
|
||||
"""WebSocket 路由"""
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import async_session
|
||||
from app.utils.security import decode_access_token
|
||||
from app.websocket.manager import manager
|
||||
from app.websocket.events import EventType
|
||||
from app.websocket.handlers import (
|
||||
handle_chat_send, handle_chat_typing,
|
||||
handle_chat_read, handle_presence_update,
|
||||
)
|
||||
|
||||
websocket_router = APIRouter()
|
||||
|
||||
|
||||
@websocket_router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: str = Query(None)):
|
||||
"""WebSocket 连接端点"""
|
||||
|
||||
# 验证 Token
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Missing token")
|
||||
return
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if not payload:
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=4001, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
# 接受连接
|
||||
await manager.connect(websocket, user_id)
|
||||
|
||||
# 更新在线状态
|
||||
async with async_session() as db:
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
await user_service.update_status(user_id, "online")
|
||||
await db.commit()
|
||||
|
||||
# 广播上线通知
|
||||
await manager.broadcast(EventType.PRESENCE_ONLINE, {"user_id": user_id})
|
||||
print(f"🌿 用户 {user_id} 已连接 WebSocket")
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await websocket.receive_text()
|
||||
try:
|
||||
event = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
await manager.send_to_user(user_id, EventType.ERROR, {"message": "无效的 JSON"})
|
||||
continue
|
||||
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
# 创建新的数据库会话处理事件
|
||||
async with async_session() as db:
|
||||
handler_map = {
|
||||
EventType.CHAT_SEND: handle_chat_send,
|
||||
EventType.CHAT_TYPING: handle_chat_typing,
|
||||
EventType.CHAT_READ: handle_chat_read,
|
||||
EventType.PRESENCE_UPDATE: handle_presence_update,
|
||||
}
|
||||
|
||||
handler = handler_map.get(event_type)
|
||||
if handler:
|
||||
try:
|
||||
await handler(websocket, user_id, data, db)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
await manager.send_to_user(
|
||||
user_id, EventType.ERROR, {"message": str(e)}
|
||||
)
|
||||
await db.rollback()
|
||||
else:
|
||||
await manager.send_to_user(
|
||||
user_id, EventType.ERROR,
|
||||
{"message": f"未知事件类型: {event_type}"}
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, user_id)
|
||||
|
||||
# 更新离线状态
|
||||
async with async_session() as db:
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
await user_service.update_status(user_id, "offline")
|
||||
await db.commit()
|
||||
|
||||
await manager.broadcast(EventType.PRESENCE_OFFLINE, {"user_id": user_id})
|
||||
print(f"🌿 用户 {user_id} 已断开 WebSocket")
|
||||
Reference in New Issue
Block a user