首个可运行的版本
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
"""应用配置 - 通过环境变量读取"""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 数据库
|
||||
DATABASE_URL: str = "postgresql+asyncpg://qingye:qingye_secret@postgres:5432/qingye"
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
|
||||
# JWT
|
||||
JWT_SECRET_KEY: str = "qingye-jwt-secret-change-in-production"
|
||||
JWT_REFRESH_SECRET_KEY: str = "qingye-refresh-secret-change-in-production"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS: str = "http://localhost:5173"
|
||||
|
||||
# 管理员默认密码
|
||||
ADMIN_PASSWORD: str = "admin123"
|
||||
|
||||
# 上传
|
||||
MAX_UPLOAD_SIZE_MB: int = 10
|
||||
UPLOAD_DIR: str = "/app/uploads"
|
||||
|
||||
model_config = {"env_file": ".env", "extra": "ignore"}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,26 @@
|
||||
"""异步数据库引擎和会话管理"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
# 异步引擎
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_size=20,
|
||||
max_overflow=10,
|
||||
)
|
||||
|
||||
# 异步会话工厂
|
||||
async_session = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
# 声明式基类
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
@@ -0,0 +1,96 @@
|
||||
"""FastAPI 公共依赖"""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.config import settings
|
||||
from app.database import async_session
|
||||
from app.utils.security import decode_access_token
|
||||
|
||||
# HTTP Bearer 认证
|
||||
security = HTTPBearer()
|
||||
|
||||
# Redis 连接池
|
||||
_redis_pool: redis.Redis | None = None
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话"""
|
||||
async with async_session() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def get_redis() -> redis.Redis:
|
||||
"""获取 Redis 连接"""
|
||||
global _redis_pool
|
||||
if _redis_pool is None:
|
||||
_redis_pool = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis_pool
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取当前认证用户"""
|
||||
token = credentials.credentials
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效或过期的 Token",
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的 Token 格式",
|
||||
)
|
||||
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
user = await user_service.get_by_id(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在",
|
||||
)
|
||||
if user.is_banned:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="账号已被封禁",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_admin_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取管理员用户(需要 is_admin=True)"""
|
||||
token = credentials.credentials
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效或过期的 Token",
|
||||
)
|
||||
|
||||
if not payload.get("is_admin"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="需要管理员权限",
|
||||
)
|
||||
|
||||
return payload
|
||||
@@ -0,0 +1,84 @@
|
||||
"""青叶 - FastAPI 应用入口"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.config import settings
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期:启动和关闭"""
|
||||
# 启动时
|
||||
print("🌿 青叶后端启动中...")
|
||||
|
||||
# 确保上传目录存在
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
# 初始化数据库表(开发阶段用,生产用 Alembic 迁移)
|
||||
from app.database import engine, Base
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
print("📦 数据库表已创建")
|
||||
|
||||
# 初始化系统配置
|
||||
from app.database import async_session
|
||||
async with async_session() as db:
|
||||
from app.services.admin_service import AdminService
|
||||
admin_service = AdminService(db)
|
||||
await admin_service.init_system_config()
|
||||
await db.commit()
|
||||
print("⚙️ 系统配置已初始化")
|
||||
|
||||
print("🚀 青叶后端启动完成!")
|
||||
yield
|
||||
|
||||
# 关闭时
|
||||
print("🌿 青叶后端关闭")
|
||||
from app.database import engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="青叶 - QingYe",
|
||||
description="青叶社交聊天应用后端 API",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS.split(","),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 静态文件(上传的文件)
|
||||
app.mount("/uploads", StaticFiles(directory=settings.UPLOAD_DIR), name="uploads")
|
||||
|
||||
# 注册路由
|
||||
from app.routers import auth, users, conversations, messages, friends, admin, uploads
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["认证"])
|
||||
app.include_router(users.router, prefix="/api/v1/users", tags=["用户"])
|
||||
app.include_router(conversations.router, prefix="/api/v1/conversations", tags=["会话"])
|
||||
app.include_router(messages.router, prefix="/api/v1/conversations", tags=["消息"])
|
||||
app.include_router(friends.router, prefix="/api/v1/friends", tags=["好友"])
|
||||
app.include_router(admin.router, prefix="/api/v1/admin", tags=["管理"])
|
||||
app.include_router(uploads.router, prefix="/api/v1/uploads", tags=["上传"])
|
||||
|
||||
# WebSocket
|
||||
from app.websocket.router import websocket_router
|
||||
app.include_router(websocket_router)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"name": "青叶 QingYe", "version": "0.1.0", "status": "running"}
|
||||
@@ -0,0 +1,19 @@
|
||||
"""SQLAlchemy 模型包 - 导入所有模型供 Alembic 自动检测"""
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation_member import ConversationMember
|
||||
from app.models.message import Message
|
||||
from app.models.friend import Friend
|
||||
from app.models.friend_request import FriendRequest
|
||||
from app.models.system_config import SystemConfig
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"Conversation",
|
||||
"ConversationMember",
|
||||
"Message",
|
||||
"Friend",
|
||||
"FriendRequest",
|
||||
"SystemConfig",
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""会话模型(私聊 + 群聊)"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False) # private / group
|
||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True) # 群聊名称
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(500), nullable=True) # 群头像
|
||||
description: Mapped[str | None] = mapped_column(String(500), nullable=True) # 群描述
|
||||
creator_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
last_message_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
last_message_preview: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.utcnow())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=lambda: datetime.utcnow(),
|
||||
onupdate=lambda: datetime.utcnow(),
|
||||
)
|
||||
|
||||
# 关系
|
||||
members = relationship("ConversationMember", back_populates="conversation", cascade="all, delete-orphan")
|
||||
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
||||
creator = relationship("User", foreign_keys=[creator_id])
|
||||
@@ -0,0 +1,32 @@
|
||||
"""会话成员模型"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, DateTime, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ConversationMember(Base):
|
||||
__tablename__ = "conversation_members"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("conversation_id", "user_id", name="uq_conv_user"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
conversation_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(20), default="member") # owner / admin / member
|
||||
nickname: Mapped[str | None] = mapped_column(String(50), nullable=True) # 群内昵称
|
||||
last_read_message_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||
joined_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.utcnow())
|
||||
left_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# 关系
|
||||
conversation = relationship("Conversation", back_populates="members")
|
||||
user = relationship("User", back_populates="conversations")
|
||||
@@ -0,0 +1,29 @@
|
||||
"""好友关系模型"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, DateTime, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Friend(Base):
|
||||
__tablename__ = "friends"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "friend_user_id", name="uq_friendship"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
friend_user_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
remark: Mapped[str | None] = mapped_column(String(50), nullable=True) # 好友备注
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.utcnow())
|
||||
|
||||
# 关系
|
||||
user = relationship("User", foreign_keys=[user_id], back_populates="friends")
|
||||
friend_user = relationship("User", foreign_keys=[friend_user_id], back_populates="friend_of")
|
||||
@@ -0,0 +1,28 @@
|
||||
"""好友请求模型"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class FriendRequest(Base):
|
||||
__tablename__ = "friend_requests"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
from_user_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
to_user_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
message: Mapped[str | None] = mapped_column(String(200), nullable=True) # 验证消息
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending") # pending / accepted / rejected
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.utcnow())
|
||||
responded_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# 关系
|
||||
from_user = relationship("User", foreign_keys=[from_user_id], back_populates="sent_requests")
|
||||
to_user = relationship("User", foreign_keys=[to_user_id], back_populates="received_requests")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""消息模型"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
__table_args__ = (
|
||||
Index("ix_messages_conv_created", "conversation_id", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
conversation_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
sender_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
type: Mapped[str] = mapped_column(String(20), default="text") # text / image / file / system
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
reply_to_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("messages.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=lambda: datetime.utcnow(), index=True
|
||||
)
|
||||
|
||||
# 关系
|
||||
conversation = relationship("Conversation", back_populates="messages")
|
||||
sender = relationship("User", back_populates="sent_messages", foreign_keys=[sender_id])
|
||||
reply_to = relationship("Message", remote_side="Message.id")
|
||||
@@ -0,0 +1,24 @@
|
||||
"""系统配置模型"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
__tablename__ = "system_configs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
key: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
|
||||
value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=lambda: datetime.utcnow(),
|
||||
onupdate=lambda: datetime.utcnow(),
|
||||
)
|
||||
updated_by: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
"""用户模型"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, Boolean, DateTime, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
bio: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="offline") # online/offline/away
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_banned: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
banned_reason: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
last_seen_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.utcnow())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=lambda: datetime.utcnow(),
|
||||
onupdate=lambda: datetime.utcnow(),
|
||||
)
|
||||
|
||||
# 关系
|
||||
sent_messages = relationship("Message", back_populates="sender", foreign_keys="Message.sender_id")
|
||||
conversations = relationship("ConversationMember", back_populates="user")
|
||||
friends = relationship("Friend", foreign_keys="Friend.user_id", back_populates="user")
|
||||
friend_of = relationship("Friend", foreign_keys="Friend.friend_user_id", back_populates="friend_user")
|
||||
sent_requests = relationship("FriendRequest", foreign_keys="FriendRequest.from_user_id", back_populates="from_user")
|
||||
received_requests = relationship("FriendRequest", foreign_keys="FriendRequest.to_user_id", back_populates="to_user")
|
||||
@@ -0,0 +1 @@
|
||||
"""路由包"""
|
||||
@@ -0,0 +1,130 @@
|
||||
"""管理后台路由"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, get_admin_user
|
||||
from app.schemas.admin import (
|
||||
AdminLoginRequest, AdminLoginResponse, DashboardStats,
|
||||
UserBanRequest, SystemConfigUpdate,
|
||||
)
|
||||
from app.services.admin_service import AdminService
|
||||
from app.services.message_service import MessageService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/login", response_model=AdminLoginResponse)
|
||||
async def admin_login(
|
||||
req: AdminLoginRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""管理员登录(仅密码)"""
|
||||
service = AdminService(db)
|
||||
token = await service.login(req.password)
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="管理员密码错误")
|
||||
return AdminLoginResponse(access_token=token)
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardStats)
|
||||
async def admin_dashboard(
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""仪表盘统计数据"""
|
||||
service = AdminService(db)
|
||||
return await service.get_dashboard_stats()
|
||||
|
||||
|
||||
@router.get("/stats/{metric}")
|
||||
async def admin_stats(
|
||||
metric: str,
|
||||
days: int = Query(7, ge=1, le=90),
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取趋势数据 (online/messages/registrations)"""
|
||||
if metric not in ("online", "messages", "registrations"):
|
||||
raise HTTPException(status_code=400, detail="无效的指标类型")
|
||||
service = AdminService(db)
|
||||
return await service.get_trend_data(metric, days)
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
async def admin_users(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: str | None = Query(None),
|
||||
status: str | None = Query(None),
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""用户管理列表"""
|
||||
service = AdminService(db)
|
||||
return await service.get_users_list(page, page_size, search, status)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/ban")
|
||||
async def admin_ban_user(
|
||||
user_id: str,
|
||||
req: UserBanRequest,
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""封禁/解封用户"""
|
||||
service = AdminService(db)
|
||||
try:
|
||||
await service.ban_user(user_id, req.is_banned, req.reason)
|
||||
return {"success": True}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
async def admin_delete_user(
|
||||
user_id: str,
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除用户"""
|
||||
service = AdminService(db)
|
||||
await service.delete_user(user_id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.get("/messages")
|
||||
async def admin_messages(
|
||||
user_id: str | None = Query(None),
|
||||
conversation_id: str | None = Query(None),
|
||||
keyword: str | None = Query(None),
|
||||
date_from: str | None = Query(None),
|
||||
date_to: str | None = Query(None),
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""搜索消息(管理审查)"""
|
||||
service = MessageService(db)
|
||||
return await service.search_messages(user_id, conversation_id, keyword, date_from, date_to)
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def admin_get_config(
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取系统配置"""
|
||||
service = AdminService(db)
|
||||
return await service.get_all_configs()
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
async def admin_update_config(
|
||||
req: SystemConfigUpdate,
|
||||
_=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新系统配置"""
|
||||
service = AdminService(db)
|
||||
await service.update_configs(req.configs)
|
||||
return {"success": True}
|
||||
@@ -0,0 +1,56 @@
|
||||
"""认证路由"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db
|
||||
from app.schemas.auth import RegisterRequest, LoginRequest, TokenResponse, RefreshRequest
|
||||
from app.services.auth_service import AuthService
|
||||
from app.utils.security import decode_refresh_token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse)
|
||||
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""用户注册"""
|
||||
service = AuthService(db)
|
||||
try:
|
||||
result = await service.register(req.username, req.email, req.password)
|
||||
return TokenResponse(
|
||||
access_token=result["access_token"],
|
||||
refresh_token=result["refresh_token"],
|
||||
user=result["user"],
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""用户登录"""
|
||||
service = AuthService(db)
|
||||
try:
|
||||
result = await service.login(req.username, req.password)
|
||||
return TokenResponse(
|
||||
access_token=result["access_token"],
|
||||
refresh_token=result["refresh_token"],
|
||||
user=result["user"],
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=dict)
|
||||
async def refresh_token(req: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""刷新 Token"""
|
||||
payload = decode_refresh_token(req.refresh_token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="无效的 Refresh Token")
|
||||
|
||||
service = AuthService(db)
|
||||
try:
|
||||
result = await service.refresh_token(payload.get("sub"))
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
@@ -0,0 +1,67 @@
|
||||
"""会话路由"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, get_current_user
|
||||
from app.models.user import User
|
||||
from app.schemas.conversation import ConversationCreate, ConversationRead, ConversationDetail, GroupCreate
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[dict])
|
||||
async def list_conversations(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取会话列表"""
|
||||
service = ConversationService(db)
|
||||
return await service.get_user_conversations(user.id)
|
||||
|
||||
|
||||
@router.post("/", response_model=dict)
|
||||
async def create_conversation(
|
||||
req: ConversationCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建会话(私聊或群聊)"""
|
||||
service = ConversationService(db)
|
||||
if req.type == "private":
|
||||
if len(req.member_ids) != 1:
|
||||
raise HTTPException(status_code=400, detail="私聊只能选择一个用户")
|
||||
conv = await service.get_or_create_private(user.id, req.member_ids[0])
|
||||
else:
|
||||
conv = await service.create_group(user.id, req.name or "群聊", req.member_ids)
|
||||
|
||||
detail = await service.get_conversation_detail(conv.id, user.id)
|
||||
return detail
|
||||
|
||||
|
||||
@router.post("/group", response_model=dict)
|
||||
async def create_group(
|
||||
req: GroupCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建群聊"""
|
||||
service = ConversationService(db)
|
||||
conv = await service.create_group(user.id, req.name, req.member_ids, req.description)
|
||||
detail = await service.get_conversation_detail(conv.id, user.id)
|
||||
return detail
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=dict)
|
||||
async def get_conversation(
|
||||
conversation_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取会话详情"""
|
||||
service = ConversationService(db)
|
||||
detail = await service.get_conversation_detail(conversation_id, user.id)
|
||||
if not detail:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或无权访问")
|
||||
return detail
|
||||
@@ -0,0 +1,88 @@
|
||||
"""好友路由"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, get_current_user
|
||||
from app.models.user import User
|
||||
from app.schemas.friend import FriendRequestCreate, FriendRead, FriendRequestRead
|
||||
from app.services.friend_service import FriendService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[dict])
|
||||
async def list_friends(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取好友列表"""
|
||||
service = FriendService(db)
|
||||
return await service.get_friends(user.id)
|
||||
|
||||
|
||||
@router.get("/requests", response_model=list[dict])
|
||||
async def list_requests(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取待处理的好友请求"""
|
||||
service = FriendService(db)
|
||||
return await service.get_pending_requests(user.id)
|
||||
|
||||
|
||||
@router.post("/request")
|
||||
async def send_friend_request(
|
||||
req: FriendRequestCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""发送好友请求"""
|
||||
service = FriendService(db)
|
||||
try:
|
||||
await service.send_request(user.id, req.to_user_id, req.message)
|
||||
return {"success": True, "message": "好友请求已发送"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/request/{request_id}/accept")
|
||||
async def accept_friend_request(
|
||||
request_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""接受好友请求"""
|
||||
service = FriendService(db)
|
||||
try:
|
||||
await service.accept_request(request_id, user.id)
|
||||
return {"success": True, "message": "已添加好友"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/request/{request_id}/reject")
|
||||
async def reject_friend_request(
|
||||
request_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""拒绝好友请求"""
|
||||
service = FriendService(db)
|
||||
try:
|
||||
await service.reject_request(request_id, user.id)
|
||||
return {"success": True, "message": "已拒绝"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{friend_id}")
|
||||
async def remove_friend(
|
||||
friend_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除好友"""
|
||||
service = FriendService(db)
|
||||
await service.remove_friend(user.id, friend_id)
|
||||
return {"success": True, "message": "已删除好友"}
|
||||
@@ -0,0 +1,56 @@
|
||||
"""消息路由"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, get_current_user
|
||||
from app.models.user import User
|
||||
from app.schemas.message import MessageSend, MessagePage, MarkReadRequest
|
||||
from app.services.message_service import MessageService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/messages", response_model=dict)
|
||||
async def get_messages(
|
||||
conversation_id: str,
|
||||
before: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取消息列表(游标分页)"""
|
||||
service = MessageService(db)
|
||||
try:
|
||||
return await service.get_messages(conversation_id, user.id, before, limit)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{conversation_id}/messages/{message_id}/read")
|
||||
async def mark_as_read(
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""标记消息已读"""
|
||||
service = MessageService(db)
|
||||
await service.mark_as_read(conversation_id, user.id, message_id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}/messages/{message_id}")
|
||||
async def delete_message(
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除消息"""
|
||||
service = MessageService(db)
|
||||
try:
|
||||
await service.soft_delete(message_id, user.id)
|
||||
return {"success": True}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
@@ -0,0 +1,63 @@
|
||||
"""文件上传路由"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from app.config import settings
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/avatar")
|
||||
async def upload_avatar(
|
||||
file: UploadFile = File(...),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""上传头像"""
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail="只能上传图片文件")
|
||||
|
||||
# 检查文件大小
|
||||
contents = await file.read()
|
||||
max_size = settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
if len(contents) > max_size:
|
||||
raise HTTPException(status_code=400, detail=f"文件大小超过 {settings.MAX_UPLOAD_SIZE_MB}MB")
|
||||
|
||||
# 保存文件
|
||||
ext = os.path.splitext(file.filename or "image.jpg")[1]
|
||||
filename = f"avatar_{user.id}{ext}"
|
||||
filepath = os.path.join(settings.UPLOAD_DIR, filename)
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
return {"url": f"/uploads/{filename}"}
|
||||
|
||||
|
||||
@router.post("/file")
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""上传文件(聊天中使用)"""
|
||||
contents = await file.read()
|
||||
max_size = settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
if len(contents) > max_size:
|
||||
raise HTTPException(status_code=400, detail=f"文件大小超过 {settings.MAX_UPLOAD_SIZE_MB}MB")
|
||||
|
||||
ext = os.path.splitext(file.filename or "file")[1]
|
||||
filename = f"{uuid.uuid4().hex}{ext}"
|
||||
filepath = os.path.join(settings.UPLOAD_DIR, filename)
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
return {
|
||||
"url": f"/uploads/{filename}",
|
||||
"filename": file.filename,
|
||||
"size": len(contents),
|
||||
"content_type": file.content_type,
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
"""用户路由"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, get_current_user
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserRead, UserProfile, UserUpdate, UserSearchResult
|
||||
from app.services.user_service import UserService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserRead)
|
||||
async def get_me(user: User = Depends(get_current_user)):
|
||||
"""获取当前用户信息"""
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/me", response_model=UserRead)
|
||||
async def update_me(
|
||||
req: UserUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新当前用户信息"""
|
||||
service = UserService(db)
|
||||
updated = await service.update_profile(user.id, **req.model_dump(exclude_none=True))
|
||||
return updated
|
||||
|
||||
|
||||
@router.get("/search", response_model=list[UserSearchResult])
|
||||
async def search_users(
|
||||
q: str = Query(..., min_length=1),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""搜索用户"""
|
||||
service = UserService(db)
|
||||
users = await service.search_users(q, user.id)
|
||||
return users
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserProfile)
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取用户公开信息"""
|
||||
service = UserService(db)
|
||||
user = await service.get_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
return user
|
||||
@@ -0,0 +1 @@
|
||||
"""Pydantic Schema 包"""
|
||||
@@ -0,0 +1,66 @@
|
||||
"""管理后台 Schema"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AdminLoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class AdminLoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class DashboardStats(BaseModel):
|
||||
total_users: int
|
||||
online_users: int
|
||||
total_messages: int
|
||||
today_messages: int
|
||||
total_conversations: int
|
||||
new_users_7d: int
|
||||
|
||||
|
||||
class TrendDataPoint(BaseModel):
|
||||
date: str
|
||||
value: int
|
||||
|
||||
|
||||
class UserAdminRead(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
avatar_url: str | None = None
|
||||
status: str
|
||||
is_banned: bool
|
||||
banned_reason: str | None = None
|
||||
last_seen_at: datetime | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserBanRequest(BaseModel):
|
||||
is_banned: bool
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SystemConfigRead(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SystemConfigUpdate(BaseModel):
|
||||
configs: dict[str, str]
|
||||
|
||||
|
||||
class AdminMessageFilter(BaseModel):
|
||||
user_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
keyword: str | None = None
|
||||
date_from: str | None = None
|
||||
date_to: str | None = None
|
||||
@@ -0,0 +1,34 @@
|
||||
"""认证相关 Schema"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str = Field(..., min_length=2, max_length=50)
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=6, max_length=100)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
user: "UserBrief"
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserBrief(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
avatar_url: str | None = None
|
||||
is_admin: bool = False
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -0,0 +1,21 @@
|
||||
"""通用 Schema"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
success: bool = True
|
||||
message: str = "操作成功"
|
||||
|
||||
|
||||
class PageParams(BaseModel):
|
||||
page: int = 1
|
||||
page_size: int = 20
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
@@ -0,0 +1,53 @@
|
||||
"""会话相关 Schema"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
type: str = Field(..., pattern="^(private|group)$")
|
||||
name: str | None = Field(None, max_length=100)
|
||||
member_ids: list[str] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class ConversationRead(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
name: str | None = None
|
||||
avatar_url: str | None = None
|
||||
description: str | None = None
|
||||
last_message_preview: str | None = None
|
||||
last_message_at: datetime | None = None
|
||||
created_at: datetime
|
||||
unread_count: int = 0
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ConversationDetail(ConversationRead):
|
||||
members: list["MemberRead"] = []
|
||||
|
||||
|
||||
class MemberRead(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
username: str
|
||||
nickname: str | None = None
|
||||
avatar_url: str | None = None
|
||||
role: str = "member"
|
||||
joined_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ConversationUpdate(BaseModel):
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
member_ids: list[str] = Field(..., min_length=1)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""好友相关 Schema"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FriendRequestCreate(BaseModel):
|
||||
to_user_id: str
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class FriendRequestRead(BaseModel):
|
||||
id: str
|
||||
from_user_id: str
|
||||
from_username: str | None = None
|
||||
from_avatar: str | None = None
|
||||
to_user_id: str
|
||||
to_username: str | None = None
|
||||
message: str | None = None
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class FriendRead(BaseModel):
|
||||
id: str
|
||||
friend_user_id: str
|
||||
username: str
|
||||
nickname: str | None = None
|
||||
avatar_url: str | None = None
|
||||
remark: str | None = None
|
||||
status: str = "offline"
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -0,0 +1,37 @@
|
||||
"""消息相关 Schema"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageSend(BaseModel):
|
||||
conversation_id: str
|
||||
content: str = Field(..., min_length=1, max_length=5000)
|
||||
type: str = Field(default="text", pattern="^(text|image|file)$")
|
||||
reply_to_id: str | None = None
|
||||
|
||||
|
||||
class MessageRead(BaseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
sender_id: str
|
||||
sender_name: str | None = None
|
||||
sender_avatar: str | None = None
|
||||
type: str
|
||||
content: str
|
||||
reply_to_id: str | None = None
|
||||
is_deleted: bool = False
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class MessagePage(BaseModel):
|
||||
messages: list[MessageRead]
|
||||
has_more: bool = False
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class MarkReadRequest(BaseModel):
|
||||
message_id: str
|
||||
@@ -0,0 +1,49 @@
|
||||
"""用户相关 Schema"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UserRead(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
status: str = "offline"
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserProfile(BaseModel):
|
||||
"""他人可见的公开信息"""
|
||||
id: str
|
||||
username: str
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
status: str = "offline"
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
username: str | None = Field(None, min_length=2, max_length=50)
|
||||
bio: str | None = Field(None, max_length=200)
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
old_password: str
|
||||
new_password: str = Field(..., min_length=6, max_length=100)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
status: str = "offline"
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -0,0 +1 @@
|
||||
"""服务层包"""
|
||||
@@ -0,0 +1,189 @@
|
||||
"""管理后台服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.user import User
|
||||
from app.models.message import Message
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.utils.security import verify_password, hash_password, create_access_token
|
||||
|
||||
|
||||
class AdminService:
|
||||
def __init__(self, db: AsyncSession = None):
|
||||
self.db = db
|
||||
|
||||
async def init_system_config(self):
|
||||
"""初始化系统默认配置"""
|
||||
if not self.db:
|
||||
return
|
||||
defaults = {
|
||||
"platform_name": "青叶",
|
||||
"announcement": "",
|
||||
"max_upload_size_mb": "10",
|
||||
"allow_registration": "true",
|
||||
"admin_password_hash": hash_password(settings.ADMIN_PASSWORD),
|
||||
}
|
||||
for key, value in defaults.items():
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == key)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
self.db.add(SystemConfig(id=str(uuid.uuid4()), key=key, value=value))
|
||||
|
||||
async def login(self, password: str) -> str | None:
|
||||
"""管理员登录(仅密码)"""
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == "admin_password_hash")
|
||||
)
|
||||
config = result.scalars().first()
|
||||
if not config:
|
||||
return None
|
||||
|
||||
if not verify_password(password, config.value):
|
||||
return None
|
||||
|
||||
# 生成管理员 Token
|
||||
token = create_access_token({
|
||||
"sub": "admin",
|
||||
"username": "admin",
|
||||
"is_admin": True,
|
||||
})
|
||||
return token
|
||||
|
||||
async def get_dashboard_stats(self) -> dict:
|
||||
"""获取仪表盘统计数据"""
|
||||
total_users = await self.db.execute(select(func.count(User.id)))
|
||||
online_users = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.status == "online")
|
||||
)
|
||||
total_messages = await self.db.execute(select(func.count(Message.id)))
|
||||
total_conversations = await self.db.execute(select(func.count(Conversation.id)))
|
||||
|
||||
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_messages = await self.db.execute(
|
||||
select(func.count(Message.id)).where(Message.created_at >= today)
|
||||
)
|
||||
seven_days_ago = datetime.utcnow() - __import__("datetime").timedelta(days=7)
|
||||
new_users_7d = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.created_at >= seven_days_ago)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_users": total_users.scalar() or 0,
|
||||
"online_users": online_users.scalar() or 0,
|
||||
"total_messages": total_messages.scalar() or 0,
|
||||
"today_messages": today_messages.scalar() or 0,
|
||||
"total_conversations": total_conversations.scalar() or 0,
|
||||
"new_users_7d": new_users_7d.scalar() or 0,
|
||||
}
|
||||
|
||||
async def get_trend_data(self, metric: str, days: int = 7) -> list[dict]:
|
||||
"""获取趋势数据"""
|
||||
from sqlalchemy import cast, Date
|
||||
trends = []
|
||||
for i in range(days - 1, -1, -1):
|
||||
day = (datetime.utcnow() - __import__("datetime").timedelta(days=i)).date()
|
||||
day_start = datetime.combine(day, __import__("datetime").time.min)
|
||||
day_end = datetime.combine(day, __import__("datetime").time.max)
|
||||
|
||||
if metric == "online":
|
||||
# 简化:使用当前在线数
|
||||
count = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.status == "online")
|
||||
)
|
||||
value = count.scalar() or 0
|
||||
elif metric == "messages":
|
||||
count = await self.db.execute(
|
||||
select(func.count(Message.id)).where(
|
||||
Message.created_at >= day_start,
|
||||
Message.created_at <= day_end,
|
||||
)
|
||||
)
|
||||
value = count.scalar() or 0
|
||||
elif metric == "registrations":
|
||||
count = await self.db.execute(
|
||||
select(func.count(User.id)).where(
|
||||
User.created_at >= day_start,
|
||||
User.created_at <= day_end,
|
||||
)
|
||||
)
|
||||
value = count.scalar() or 0
|
||||
else:
|
||||
value = 0
|
||||
|
||||
trends.append({"date": day.isoformat(), "value": value})
|
||||
return trends
|
||||
|
||||
async def get_users_list(self, page: int = 1, page_size: int = 20,
|
||||
search: str | None = None, status: str | None = None) -> dict:
|
||||
"""获取用户列表(管理后台)"""
|
||||
query = select(User)
|
||||
count_query = select(func.count(User.id))
|
||||
|
||||
if search:
|
||||
query = query.where(User.username.ilike(f"%{search}%"))
|
||||
count_query = count_query.where(User.username.ilike(f"%{search}%"))
|
||||
if status == "online":
|
||||
query = query.where(User.status == "online")
|
||||
count_query = count_query.where(User.status == "online")
|
||||
elif status == "banned":
|
||||
query = query.where(User.is_banned == True)
|
||||
count_query = count_query.where(User.is_banned == True)
|
||||
|
||||
total = (await self.db.execute(count_query)).scalar() or 0
|
||||
result = await self.db.execute(
|
||||
query.order_by(User.created_at.desc())
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
|
||||
users = []
|
||||
for u in result.scalars().all():
|
||||
users.append({
|
||||
"id": u.id,
|
||||
"username": u.username,
|
||||
"email": u.email,
|
||||
"avatar_url": u.avatar_url,
|
||||
"status": u.status,
|
||||
"is_banned": u.is_banned,
|
||||
"banned_reason": u.banned_reason,
|
||||
"last_seen_at": u.last_seen_at,
|
||||
"created_at": u.created_at,
|
||||
})
|
||||
|
||||
return {"items": users, "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
async def ban_user(self, user_id: str, is_banned: bool, reason: str | None = None):
|
||||
"""封禁/解封用户"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
user.is_banned = is_banned
|
||||
user.banned_reason = reason if is_banned else None
|
||||
|
||||
async def delete_user(self, user_id: str):
|
||||
"""删除用户"""
|
||||
await self.db.execute(delete(User).where(User.id == user_id))
|
||||
|
||||
async def get_all_configs(self) -> list[dict]:
|
||||
"""获取所有系统配置"""
|
||||
result = await self.db.execute(select(SystemConfig))
|
||||
return [{"key": c.key, "value": c.value} for c in result.scalars().all()]
|
||||
|
||||
async def update_configs(self, configs: dict[str, str]):
|
||||
"""更新系统配置"""
|
||||
for key, value in configs.items():
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == key)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
if config:
|
||||
config.value = value
|
||||
config.updated_at = datetime.utcnow()
|
||||
@@ -0,0 +1,79 @@
|
||||
"""认证服务"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.utils.security import hash_password, verify_password, create_access_token, create_refresh_token
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def register(self, username: str, email: str, password: str) -> dict:
|
||||
"""用户注册"""
|
||||
# 检查用户名是否已存在
|
||||
result = await self.db.execute(select(User).where(User.username == username))
|
||||
if result.scalars().first():
|
||||
raise ValueError("用户名已存在")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
result = await self.db.execute(select(User).where(User.email == email))
|
||||
if result.scalars().first():
|
||||
raise ValueError("邮箱已被注册")
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
username=username,
|
||||
email=email,
|
||||
password_hash=hash_password(password),
|
||||
)
|
||||
self.db.add(user)
|
||||
await self.db.flush()
|
||||
|
||||
# 生成 Token
|
||||
tokens = self._generate_tokens(user)
|
||||
return {**tokens, "user": user}
|
||||
|
||||
async def login(self, username: str, password: str) -> dict:
|
||||
"""用户登录"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(
|
||||
(User.username == username) | (User.email == username)
|
||||
)
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user or not verify_password(password, user.password_hash):
|
||||
raise ValueError("用户名或密码错误")
|
||||
|
||||
if user.is_banned:
|
||||
raise ValueError("账号已被封禁")
|
||||
|
||||
# 更新在线状态
|
||||
user.status = "online"
|
||||
from datetime import datetime, timezone
|
||||
user.last_seen_at = datetime.utcnow()
|
||||
|
||||
tokens = self._generate_tokens(user)
|
||||
return {**tokens, "user": user}
|
||||
|
||||
async def refresh_token(self, user_id: str) -> dict:
|
||||
"""刷新 Token"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
return self._generate_tokens(user)
|
||||
|
||||
def _generate_tokens(self, user: User) -> dict:
|
||||
"""生成 JWT Token 对"""
|
||||
data = {"sub": user.id, "username": user.username}
|
||||
return {
|
||||
"access_token": create_access_token(data),
|
||||
"refresh_token": create_refresh_token(data),
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
"""会话服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation_member import ConversationMember
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_or_create_private(self, user1_id: str, user2_id: str) -> Conversation:
|
||||
"""获取或创建私聊会话"""
|
||||
# 查找已有的私聊
|
||||
result = await self.db.execute(
|
||||
select(Conversation).join(ConversationMember)
|
||||
.where(
|
||||
Conversation.type == "private",
|
||||
ConversationMember.user_id == user1_id,
|
||||
)
|
||||
)
|
||||
for conv in result.scalars().all():
|
||||
member_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv.id,
|
||||
ConversationMember.user_id == user2_id,
|
||||
)
|
||||
)
|
||||
if member_result.scalars().first():
|
||||
return conv
|
||||
|
||||
# 创建新私聊
|
||||
conv = Conversation(id=str(uuid.uuid4()), type="private")
|
||||
self.db.add(conv)
|
||||
await self.db.flush()
|
||||
|
||||
# 添加两个成员
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user1_id, role="member"
|
||||
))
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user2_id, role="member"
|
||||
))
|
||||
return conv
|
||||
|
||||
async def create_group(self, creator_id: str, name: str, member_ids: list[str],
|
||||
description: str | None = None) -> Conversation:
|
||||
"""创建群聊"""
|
||||
conv = Conversation(
|
||||
id=str(uuid.uuid4()),
|
||||
type="group",
|
||||
name=name,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
self.db.add(conv)
|
||||
await self.db.flush()
|
||||
|
||||
# 创建者为 owner
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id,
|
||||
user_id=creator_id, role="owner"
|
||||
))
|
||||
# 其他成员
|
||||
for mid in member_ids:
|
||||
if mid != creator_id:
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id,
|
||||
user_id=mid, role="member"
|
||||
))
|
||||
return conv
|
||||
|
||||
async def get_user_conversations(self, user_id: str) -> list[dict]:
|
||||
"""获取用户的会话列表"""
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember)
|
||||
.where(ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None))
|
||||
.order_by(ConversationMember.joined_at.desc())
|
||||
)
|
||||
members = result.scalars().all()
|
||||
|
||||
conversations = []
|
||||
for member in members:
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == member.conversation_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if not conv:
|
||||
continue
|
||||
|
||||
# 获取未读数
|
||||
unread = await self._get_unread_count(conv.id, member.last_read_message_id)
|
||||
|
||||
# 获取显示信息
|
||||
display_name = conv.name
|
||||
display_avatar = conv.avatar_url
|
||||
|
||||
if conv.type == "private":
|
||||
other = await self._get_other_member(conv.id, user_id)
|
||||
if other:
|
||||
display_name = other.username
|
||||
display_avatar = other.avatar_url
|
||||
|
||||
conversations.append({
|
||||
"id": conv.id,
|
||||
"type": conv.type,
|
||||
"name": display_name,
|
||||
"avatar_url": display_avatar,
|
||||
"description": conv.description,
|
||||
"last_message_preview": conv.last_message_preview,
|
||||
"last_message_at": conv.last_message_at,
|
||||
"unread_count": unread,
|
||||
"created_at": conv.created_at,
|
||||
})
|
||||
|
||||
# 按最后消息时间排序
|
||||
conversations.sort(key=lambda x: x["last_message_at"] or x["created_at"], reverse=True)
|
||||
return conversations
|
||||
|
||||
async def get_conversation_detail(self, conv_id: str, user_id: str) -> dict | None:
|
||||
"""获取会话详情"""
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conv_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if not conv:
|
||||
return None
|
||||
|
||||
# 验证成员身份
|
||||
member_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
ConversationMember.left_at.is_(None),
|
||||
)
|
||||
)
|
||||
if not member_result.scalars().first():
|
||||
return None
|
||||
|
||||
# 获取所有成员
|
||||
members_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.left_at.is_(None),
|
||||
)
|
||||
)
|
||||
members = []
|
||||
for m in members_result.scalars().all():
|
||||
user_result = await self.db.execute(select(User).where(User.id == m.user_id))
|
||||
user = user_result.scalars().first()
|
||||
if user:
|
||||
members.append({
|
||||
"id": m.id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"nickname": user.bio,
|
||||
"avatar_url": user.avatar_url,
|
||||
"role": m.role,
|
||||
"joined_at": m.joined_at,
|
||||
})
|
||||
|
||||
return {
|
||||
"id": conv.id,
|
||||
"type": conv.type,
|
||||
"name": conv.name,
|
||||
"avatar_url": conv.avatar_url,
|
||||
"description": conv.description,
|
||||
"last_message_preview": conv.last_message_preview,
|
||||
"last_message_at": conv.last_message_at,
|
||||
"created_at": conv.created_at,
|
||||
"members": members,
|
||||
"unread_count": await self._get_unread_count(conv.id, None),
|
||||
}
|
||||
|
||||
async def _get_other_member(self, conv_id: str, user_id: str) -> User | None:
|
||||
"""获取私聊的对方用户"""
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.user_id != user_id,
|
||||
)
|
||||
)
|
||||
member = result.scalars().first()
|
||||
if member:
|
||||
user_result = await self.db.execute(select(User).where(User.id == member.user_id))
|
||||
return user_result.scalars().first()
|
||||
return None
|
||||
|
||||
async def _get_unread_count(self, conv_id: str, last_read_id: str | None) -> int:
|
||||
"""计算未读消息数"""
|
||||
from app.models.message import Message
|
||||
query = select(func := __import__("sqlalchemy").func).count(Message.id).where(
|
||||
Message.conversation_id == conv_id,
|
||||
Message.is_deleted == False,
|
||||
)
|
||||
if last_read_id:
|
||||
# 获取 last_read 消息的时间
|
||||
lr = await self.db.execute(select(Message).where(Message.id == last_read_id))
|
||||
lr_msg = lr.scalars().first()
|
||||
if lr_msg:
|
||||
query = query.where(Message.created_at > lr_msg.created_at)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar() or 0
|
||||
@@ -0,0 +1,162 @@
|
||||
"""好友服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.friend import Friend
|
||||
from app.models.friend_request import FriendRequest
|
||||
|
||||
|
||||
class FriendService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def send_request(self, from_user_id: str, to_user_id: str,
|
||||
message: str | None = None) -> FriendRequest:
|
||||
"""发送好友请求"""
|
||||
if from_user_id == to_user_id:
|
||||
raise ValueError("不能添加自己为好友")
|
||||
|
||||
# 检查目标用户是否存在
|
||||
target = await self.db.execute(select(User).where(User.id == to_user_id))
|
||||
if not target.scalars().first():
|
||||
raise ValueError("目标用户不存在")
|
||||
|
||||
# 检查是否已是好友
|
||||
existing = await self.db.execute(
|
||||
select(Friend).where(
|
||||
Friend.user_id == from_user_id,
|
||||
Friend.friend_user_id == to_user_id,
|
||||
)
|
||||
)
|
||||
if existing.scalars().first():
|
||||
raise ValueError("已经是好友了")
|
||||
|
||||
# 检查是否有待处理的请求
|
||||
pending = await self.db.execute(
|
||||
select(FriendRequest).where(
|
||||
FriendRequest.from_user_id == from_user_id,
|
||||
FriendRequest.to_user_id == to_user_id,
|
||||
FriendRequest.status == "pending",
|
||||
)
|
||||
)
|
||||
if pending.scalars().first():
|
||||
raise ValueError("已发送过好友请求")
|
||||
|
||||
request = FriendRequest(
|
||||
id=str(uuid.uuid4()),
|
||||
from_user_id=from_user_id,
|
||||
to_user_id=to_user_id,
|
||||
message=message,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(request)
|
||||
return request
|
||||
|
||||
async def accept_request(self, request_id: str, user_id: str):
|
||||
"""接受好友请求"""
|
||||
result = await self.db.execute(
|
||||
select(FriendRequest).where(FriendRequest.id == request_id)
|
||||
)
|
||||
request = result.scalars().first()
|
||||
if not request:
|
||||
raise ValueError("请求不存在")
|
||||
if request.to_user_id != user_id:
|
||||
raise ValueError("无权操作此请求")
|
||||
if request.status != "pending":
|
||||
raise ValueError("该请求已处理")
|
||||
|
||||
request.status = "accepted"
|
||||
request.responded_at = datetime.utcnow()
|
||||
|
||||
# 创建双向好友关系
|
||||
self.db.add(Friend(
|
||||
id=str(uuid.uuid4()), user_id=request.from_user_id,
|
||||
friend_user_id=request.to_user_id,
|
||||
))
|
||||
self.db.add(Friend(
|
||||
id=str(uuid.uuid4()), user_id=request.to_user_id,
|
||||
friend_user_id=request.from_user_id,
|
||||
))
|
||||
|
||||
async def reject_request(self, request_id: str, user_id: str):
|
||||
"""拒绝好友请求"""
|
||||
result = await self.db.execute(
|
||||
select(FriendRequest).where(FriendRequest.id == request_id)
|
||||
)
|
||||
request = result.scalars().first()
|
||||
if not request:
|
||||
raise ValueError("请求不存在")
|
||||
if request.to_user_id != user_id:
|
||||
raise ValueError("无权操作此请求")
|
||||
|
||||
request.status = "rejected"
|
||||
request.responded_at = datetime.utcnow()
|
||||
|
||||
async def get_friends(self, user_id: str) -> list[dict]:
|
||||
"""获取好友列表"""
|
||||
result = await self.db.execute(
|
||||
select(Friend).where(Friend.user_id == user_id)
|
||||
)
|
||||
friends = []
|
||||
for friendship in result.scalars().all():
|
||||
user_result = await self.db.execute(
|
||||
select(User).where(User.id == friendship.friend_user_id)
|
||||
)
|
||||
user = user_result.scalars().first()
|
||||
if user:
|
||||
friends.append({
|
||||
"id": friendship.id,
|
||||
"friend_user_id": user.id,
|
||||
"username": user.username,
|
||||
"nickname": user.bio,
|
||||
"avatar_url": user.avatar_url,
|
||||
"remark": friendship.remark,
|
||||
"status": user.status,
|
||||
})
|
||||
return friends
|
||||
|
||||
async def get_pending_requests(self, user_id: str) -> list[dict]:
|
||||
"""获取待处理的好友请求"""
|
||||
result = await self.db.execute(
|
||||
select(FriendRequest).where(
|
||||
FriendRequest.to_user_id == user_id,
|
||||
FriendRequest.status == "pending",
|
||||
).order_by(FriendRequest.created_at.desc())
|
||||
)
|
||||
requests = []
|
||||
for req in result.scalars().all():
|
||||
from_user = await self.db.execute(select(User).where(User.id == req.from_user_id))
|
||||
fu = from_user.scalars().first()
|
||||
requests.append({
|
||||
"id": req.id,
|
||||
"from_user_id": req.from_user_id,
|
||||
"from_username": fu.username if fu else "未知",
|
||||
"from_avatar": fu.avatar_url if fu else None,
|
||||
"to_user_id": req.to_user_id,
|
||||
"message": req.message,
|
||||
"status": req.status,
|
||||
"created_at": req.created_at,
|
||||
})
|
||||
return requests
|
||||
|
||||
async def remove_friend(self, user_id: str, friend_id: str):
|
||||
"""删除好友"""
|
||||
await self.db.execute(
|
||||
select(Friend).where(
|
||||
Friend.user_id == user_id,
|
||||
Friend.friend_user_id == friend_id,
|
||||
)
|
||||
)
|
||||
# 删除双向关系
|
||||
from sqlalchemy import delete
|
||||
await self.db.execute(
|
||||
delete(Friend).where(
|
||||
(Friend.user_id == user_id) & (Friend.friend_user_id == friend_id) |
|
||||
(Friend.user_id == friend_id) & (Friend.friend_user_id == user_id)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,179 @@
|
||||
"""消息服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.message import Message
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation_member import ConversationMember
|
||||
|
||||
|
||||
class MessageService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def send_message(self, conversation_id: str, sender_id: str,
|
||||
content: str, msg_type: str = "text",
|
||||
reply_to_id: str | None = None) -> Message:
|
||||
"""发送消息"""
|
||||
message = Message(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
sender_id=sender_id,
|
||||
type=msg_type,
|
||||
content=content,
|
||||
reply_to_id=reply_to_id,
|
||||
)
|
||||
self.db.add(message)
|
||||
|
||||
# 更新会话的最后消息
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if conv:
|
||||
conv.last_message_at = datetime.utcnow()
|
||||
preview = content[:200] if len(content) > 200 else content
|
||||
conv.last_message_preview = preview
|
||||
conv.updated_at = datetime.utcnow()
|
||||
|
||||
return message
|
||||
|
||||
async def get_messages(self, conversation_id: str, user_id: str,
|
||||
before: str | None = None, limit: int = 50) -> dict:
|
||||
"""获取消息列表(游标分页)"""
|
||||
# 验证成员身份
|
||||
member_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conversation_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
)
|
||||
)
|
||||
if not member_result.scalars().first():
|
||||
raise ValueError("无权访问该会话")
|
||||
|
||||
query = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.is_deleted == False)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
# 游标分页
|
||||
if before:
|
||||
before_msg = await self.db.execute(
|
||||
select(Message).where(Message.id == before)
|
||||
)
|
||||
before_msg_obj = before_msg.scalars().first()
|
||||
if before_msg_obj:
|
||||
query = query.where(Message.created_at < before_msg_obj.created_at)
|
||||
|
||||
query = query.limit(limit + 1)
|
||||
result = await self.db.execute(query)
|
||||
messages = list(result.scalars().all())
|
||||
|
||||
has_more = len(messages) > limit
|
||||
messages = messages[:limit]
|
||||
|
||||
# 获取发送者信息
|
||||
from app.models.user import User
|
||||
message_list = []
|
||||
for msg in reversed(messages):
|
||||
sender_result = await self.db.execute(
|
||||
select(User).where(User.id == msg.sender_id)
|
||||
)
|
||||
sender = sender_result.scalars().first()
|
||||
message_list.append({
|
||||
"id": msg.id,
|
||||
"conversation_id": msg.conversation_id,
|
||||
"sender_id": msg.sender_id,
|
||||
"sender_name": sender.username if sender else "未知",
|
||||
"sender_avatar": sender.avatar_url if sender else None,
|
||||
"type": msg.type,
|
||||
"content": msg.content,
|
||||
"reply_to_id": msg.reply_to_id,
|
||||
"is_deleted": msg.is_deleted,
|
||||
"created_at": msg.created_at,
|
||||
})
|
||||
|
||||
return {
|
||||
"messages": message_list,
|
||||
"has_more": has_more,
|
||||
"next_cursor": messages[-1].id if has_more and messages else None,
|
||||
}
|
||||
|
||||
async def mark_as_read(self, conversation_id: str, user_id: str, message_id: str):
|
||||
"""标记消息已读"""
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conversation_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
)
|
||||
)
|
||||
member = result.scalars().first()
|
||||
if member:
|
||||
member.last_read_message_id = message_id
|
||||
|
||||
async def soft_delete(self, message_id: str, user_id: str):
|
||||
"""软删除消息(仅能删除自己的)"""
|
||||
result = await self.db.execute(
|
||||
select(Message).where(Message.id == message_id, Message.sender_id == user_id)
|
||||
)
|
||||
message = result.scalars().first()
|
||||
if not message:
|
||||
raise ValueError("消息不存在或无权删除")
|
||||
message.is_deleted = True
|
||||
|
||||
async def get_total_count(self) -> int:
|
||||
"""获取消息总数"""
|
||||
result = await self.db.execute(select(func.count(Message.id)))
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_today_count(self) -> int:
|
||||
"""获取今日消息数"""
|
||||
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
result = await self.db.execute(
|
||||
select(func.count(Message.id)).where(Message.created_at >= today)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def search_messages(self, user_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
keyword: str | None = None,
|
||||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
limit: int = 50) -> list[dict]:
|
||||
"""管理后台搜索消息"""
|
||||
query = select(Message).where(Message.is_deleted == False)
|
||||
|
||||
if user_id:
|
||||
query = query.where(Message.sender_id == user_id)
|
||||
if conversation_id:
|
||||
query = query.where(Message.conversation_id == conversation_id)
|
||||
if keyword:
|
||||
query = query.where(Message.content.ilike(f"%{keyword}%"))
|
||||
if date_from:
|
||||
query = query.where(Message.created_at >= date_from)
|
||||
if date_to:
|
||||
query = query.where(Message.created_at <= date_to)
|
||||
|
||||
query = query.order_by(Message.created_at.desc()).limit(limit)
|
||||
result = await self.db.execute(query)
|
||||
|
||||
from app.models.user import User
|
||||
messages = []
|
||||
for msg in result.scalars().all():
|
||||
sender = await self.db.execute(select(User).where(User.id == msg.sender_id))
|
||||
s = sender.scalars().first()
|
||||
messages.append({
|
||||
"id": msg.id,
|
||||
"conversation_id": msg.conversation_id,
|
||||
"sender_id": msg.sender_id,
|
||||
"sender_name": s.username if s else "未知",
|
||||
"type": msg.type,
|
||||
"content": msg.content[:200],
|
||||
"created_at": msg.created_at,
|
||||
})
|
||||
return messages
|
||||
@@ -0,0 +1,69 @@
|
||||
"""用户服务"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, or_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class UserService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_id(self, user_id: str) -> User | None:
|
||||
"""根据 ID 获取用户"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_by_username(self, username: str) -> User | None:
|
||||
"""根据用户名获取用户"""
|
||||
result = await self.db.execute(select(User).where(User.username == username))
|
||||
return result.scalars().first()
|
||||
|
||||
async def search_users(self, query: str, current_user_id: str, limit: int = 20) -> list[User]:
|
||||
"""搜索用户"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(
|
||||
or_(
|
||||
User.username.ilike(f"%{query}%"),
|
||||
User.email.ilike(f"%{query}%"),
|
||||
),
|
||||
User.id != current_user_id,
|
||||
User.is_banned == False,
|
||||
).limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_profile(self, user_id: str, **kwargs) -> User:
|
||||
"""更新用户资料"""
|
||||
user = await self.get_by_id(user_id)
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if value is not None and hasattr(user, key):
|
||||
setattr(user, key, value)
|
||||
|
||||
user.updated_at = datetime.utcnow()
|
||||
return user
|
||||
|
||||
async def update_status(self, user_id: str, status: str):
|
||||
"""更新用户在线状态"""
|
||||
user = await self.get_by_id(user_id)
|
||||
if user:
|
||||
user.status = status
|
||||
user.last_seen_at = datetime.utcnow()
|
||||
|
||||
async def get_total_count(self) -> int:
|
||||
"""获取用户总数"""
|
||||
result = await self.db.execute(select(func.count(User.id)))
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_online_count(self) -> int:
|
||||
"""获取在线用户数"""
|
||||
result = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.status == "online")
|
||||
)
|
||||
return result.scalar() or 0
|
||||
@@ -0,0 +1,14 @@
|
||||
"""通用工具函数"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""生成 UUID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def now_utc() -> datetime:
|
||||
"""获取当前 UTC 时间"""
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""安全工具:JWT Token 和密码哈希"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config import settings
|
||||
|
||||
# 密码哈希上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""密码加密"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""创建 Access Token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
return jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm="HS256")
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
"""创建 Refresh Token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
return jwt.encode(
|
||||
to_encode, settings.JWT_REFRESH_SECRET_KEY, algorithm="HS256"
|
||||
)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict | None:
|
||||
"""解码 Access Token"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=["HS256"])
|
||||
if payload.get("type") != "access":
|
||||
return None
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def decode_refresh_token(token: str) -> dict | None:
|
||||
"""解码 Refresh Token"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_REFRESH_SECRET_KEY, algorithms=["HS256"]
|
||||
)
|
||||
if payload.get("type") != "refresh":
|
||||
return None
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
@@ -0,0 +1 @@
|
||||
"""WebSocket 包"""
|
||||
@@ -0,0 +1,58 @@
|
||||
"""WebSocket 事件类型定义"""
|
||||
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
# 客户端 -> 服务端
|
||||
CHAT_SEND = "chat.send"
|
||||
CHAT_TYPING = "chat.typing"
|
||||
CHAT_READ = "chat.read"
|
||||
PRESENCE_UPDATE = "presence.update"
|
||||
|
||||
# 服务端 -> 客户端
|
||||
CHAT_MESSAGE = "chat.message"
|
||||
CHAT_TYPING_INDICATOR = "chat.typing"
|
||||
CHAT_READ_RECEIPT = "chat.read"
|
||||
CHAT_MESSAGE_DELETED = "chat.message_deleted"
|
||||
CONVERSATION_UPDATED = "conversation.updated"
|
||||
CONVERSATION_MEMBER_ADDED = "conversation.member_added"
|
||||
CONVERSATION_MEMBER_REMOVED = "conversation.member_removed"
|
||||
FRIEND_REQUEST = "friend.request"
|
||||
FRIEND_ACCEPTED = "friend.accepted"
|
||||
PRESENCE_ONLINE = "presence.online"
|
||||
PRESENCE_OFFLINE = "presence.offline"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class WSEvent(BaseModel):
|
||||
"""WebSocket 事件信封"""
|
||||
type: str
|
||||
data: dict
|
||||
timestamp: str | None = None
|
||||
|
||||
|
||||
class ChatSendData(BaseModel):
|
||||
"""发送消息数据"""
|
||||
conversation_id: str
|
||||
content: str
|
||||
type: str = "text"
|
||||
reply_to_id: str | None = None
|
||||
|
||||
|
||||
class ChatTypingData(BaseModel):
|
||||
"""输入中数据"""
|
||||
conversation_id: str
|
||||
|
||||
|
||||
class ChatReadData(BaseModel):
|
||||
"""已读数据"""
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
class PresenceUpdateData(BaseModel):
|
||||
"""在线状态更新"""
|
||||
status: str # online / offline / away
|
||||
@@ -0,0 +1,106 @@
|
||||
"""WebSocket 事件处理器"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.websocket.events import EventType
|
||||
from app.websocket.manager import manager
|
||||
|
||||
|
||||
async def handle_chat_send(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理发送消息事件"""
|
||||
from app.services.message_service import MessageService
|
||||
service = MessageService(db)
|
||||
try:
|
||||
message = await service.send_message(
|
||||
conversation_id=data["conversation_id"],
|
||||
sender_id=user_id,
|
||||
content=data["content"],
|
||||
msg_type=data.get("type", "text"),
|
||||
reply_to_id=data.get("reply_to_id"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# 获取发送者信息
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
sender = await user_service.get_by_id(user_id)
|
||||
|
||||
# 获取会话成员列表
|
||||
from app.services.conversation_service import ConversationService
|
||||
conv_service = ConversationService(db)
|
||||
detail = await conv_service.get_conversation_detail(data["conversation_id"], user_id)
|
||||
|
||||
msg_data = {
|
||||
"id": message.id,
|
||||
"conversation_id": message.conversation_id,
|
||||
"sender_id": user_id,
|
||||
"sender_name": sender.username if sender else "未知",
|
||||
"sender_avatar": sender.avatar_url if sender else None,
|
||||
"type": message.type,
|
||||
"content": message.content,
|
||||
"created_at": message.created_at.isoformat(),
|
||||
}
|
||||
|
||||
# 广播给会话中的所有成员
|
||||
if detail and "members" in detail:
|
||||
member_ids = [m["user_id"] for m in detail["members"]]
|
||||
await manager.broadcast_to_conversation(
|
||||
member_ids, EventType.CHAT_MESSAGE, msg_data
|
||||
)
|
||||
except Exception as e:
|
||||
await manager.send_to_user(user_id, EventType.ERROR, {"message": str(e)})
|
||||
|
||||
|
||||
async def handle_chat_typing(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理输入中事件"""
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
user = await user_service.get_by_id(user_id)
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
conv_service = ConversationService(db)
|
||||
detail = await conv_service.get_conversation_detail(data["conversation_id"], user_id)
|
||||
|
||||
if detail and "members" in detail:
|
||||
member_ids = [m["user_id"] for m in detail["members"]]
|
||||
await manager.broadcast_to_conversation(
|
||||
member_ids, EventType.CHAT_TYPING_INDICATOR,
|
||||
{"conversation_id": data["conversation_id"], "user_id": user_id,
|
||||
"username": user.username if user else "未知"},
|
||||
exclude_user=user_id,
|
||||
)
|
||||
|
||||
|
||||
async def handle_chat_read(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理已读事件"""
|
||||
from app.services.message_service import MessageService
|
||||
service = MessageService(db)
|
||||
await service.mark_as_read(data["conversation_id"], user_id, data["message_id"])
|
||||
await db.commit()
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
conv_service = ConversationService(db)
|
||||
detail = await conv_service.get_conversation_detail(data["conversation_id"], user_id)
|
||||
|
||||
if detail and "members" in detail:
|
||||
member_ids = [m["user_id"] for m in detail["members"]]
|
||||
await manager.broadcast_to_conversation(
|
||||
member_ids, EventType.CHAT_READ_RECEIPT,
|
||||
{"conversation_id": data["conversation_id"], "user_id": user_id,
|
||||
"read_up_to": data["message_id"]},
|
||||
)
|
||||
|
||||
|
||||
async def handle_presence_update(ws: WebSocket, user_id: str, data: dict, db: AsyncSession):
|
||||
"""处理在线状态更新"""
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
await user_service.update_status(user_id, data["status"])
|
||||
await db.commit()
|
||||
|
||||
event = EventType.PRESENCE_ONLINE if data["status"] == "online" else EventType.PRESENCE_OFFLINE
|
||||
await manager.broadcast(event, {"user_id": user_id})
|
||||
@@ -0,0 +1,76 @@
|
||||
"""WebSocket 连接管理器"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""管理所有 WebSocket 连接"""
|
||||
|
||||
def __init__(self):
|
||||
# user_id -> set of WebSocket connections (一个用户可能有多个标签页)
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, user_id: str):
|
||||
"""接受新连接"""
|
||||
await websocket.accept()
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = set()
|
||||
self.active_connections[user_id].add(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket, user_id: str):
|
||||
"""断开连接"""
|
||||
if user_id in self.active_connections:
|
||||
self.active_connections[user_id].discard(websocket)
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
async def send_to_user(self, user_id: str, event_type: str, data: dict):
|
||||
"""向指定用户发送事件"""
|
||||
if user_id not in self.active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps({
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}, ensure_ascii=False, default=str)
|
||||
|
||||
disconnected = set()
|
||||
for ws in self.active_connections[user_id]:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
disconnected.add(ws)
|
||||
|
||||
# 清理断开的连接
|
||||
for ws in disconnected:
|
||||
self.active_connections[user_id].discard(ws)
|
||||
|
||||
async def broadcast_to_conversation(self, user_ids: list[str],
|
||||
event_type: str, data: dict,
|
||||
exclude_user: str | None = None):
|
||||
"""向会话中所有用户广播事件"""
|
||||
for uid in user_ids:
|
||||
if uid != exclude_user:
|
||||
await self.send_to_user(uid, event_type, data)
|
||||
|
||||
async def broadcast(self, event_type: str, data: dict):
|
||||
"""向所有在线用户广播"""
|
||||
for user_id in list(self.active_connections.keys()):
|
||||
await self.send_to_user(user_id, event_type, data)
|
||||
|
||||
def is_online(self, user_id: str) -> bool:
|
||||
"""检查用户是否在线"""
|
||||
return user_id in self.active_connections and len(self.active_connections[user_id]) > 0
|
||||
|
||||
def get_online_user_ids(self) -> list[str]:
|
||||
"""获取所有在线用户 ID"""
|
||||
return list(self.active_connections.keys())
|
||||
|
||||
|
||||
# 全局单例
|
||||
manager = ConnectionManager()
|
||||
@@ -0,0 +1,101 @@
|
||||
"""WebSocket 路由"""
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import async_session
|
||||
from app.utils.security import decode_access_token
|
||||
from app.websocket.manager import manager
|
||||
from app.websocket.events import EventType
|
||||
from app.websocket.handlers import (
|
||||
handle_chat_send, handle_chat_typing,
|
||||
handle_chat_read, handle_presence_update,
|
||||
)
|
||||
|
||||
websocket_router = APIRouter()
|
||||
|
||||
|
||||
@websocket_router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: str = Query(None)):
|
||||
"""WebSocket 连接端点"""
|
||||
|
||||
# 验证 Token
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Missing token")
|
||||
return
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if not payload:
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=4001, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
# 接受连接
|
||||
await manager.connect(websocket, user_id)
|
||||
|
||||
# 更新在线状态
|
||||
async with async_session() as db:
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
await user_service.update_status(user_id, "online")
|
||||
await db.commit()
|
||||
|
||||
# 广播上线通知
|
||||
await manager.broadcast(EventType.PRESENCE_ONLINE, {"user_id": user_id})
|
||||
print(f"🌿 用户 {user_id} 已连接 WebSocket")
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await websocket.receive_text()
|
||||
try:
|
||||
event = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
await manager.send_to_user(user_id, EventType.ERROR, {"message": "无效的 JSON"})
|
||||
continue
|
||||
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
# 创建新的数据库会话处理事件
|
||||
async with async_session() as db:
|
||||
handler_map = {
|
||||
EventType.CHAT_SEND: handle_chat_send,
|
||||
EventType.CHAT_TYPING: handle_chat_typing,
|
||||
EventType.CHAT_READ: handle_chat_read,
|
||||
EventType.PRESENCE_UPDATE: handle_presence_update,
|
||||
}
|
||||
|
||||
handler = handler_map.get(event_type)
|
||||
if handler:
|
||||
try:
|
||||
await handler(websocket, user_id, data, db)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
await manager.send_to_user(
|
||||
user_id, EventType.ERROR, {"message": str(e)}
|
||||
)
|
||||
await db.rollback()
|
||||
else:
|
||||
await manager.send_to_user(
|
||||
user_id, EventType.ERROR,
|
||||
{"message": f"未知事件类型: {event_type}"}
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, user_id)
|
||||
|
||||
# 更新离线状态
|
||||
async with async_session() as db:
|
||||
from app.services.user_service import UserService
|
||||
user_service = UserService(db)
|
||||
await user_service.update_status(user_id, "offline")
|
||||
await db.commit()
|
||||
|
||||
await manager.broadcast(EventType.PRESENCE_OFFLINE, {"user_id": user_id})
|
||||
print(f"🌿 用户 {user_id} 已断开 WebSocket")
|
||||
Reference in New Issue
Block a user