171 lines
5.0 KiB
Python
171 lines
5.0 KiB
Python
"""用户路由"""
|
|
|
|
from datetime import datetime
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from pydantic import BaseModel
|
|
|
|
from app.dependencies import get_db, get_current_user
|
|
from app.models.user import User
|
|
from app.models.user_block import UserBlock
|
|
from app.schemas.user import (
|
|
UserRead, UserProfile, UserUpdate, UserSearchResult,
|
|
PasswordChange, EmailChange,
|
|
)
|
|
from app.services.user_service import UserService
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class StatusUpdate(BaseModel):
|
|
custom_status: str | None = None
|
|
status_emoji: str | None = None
|
|
expires_hours: int | None = None
|
|
|
|
|
|
@router.get("/me", response_model=UserRead)
|
|
async def get_me(user: User = Depends(get_current_user)):
|
|
"""获取当前用户信息"""
|
|
return user
|
|
|
|
|
|
@router.put("/me", response_model=UserRead)
|
|
async def update_me(
|
|
req: UserUpdate,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""更新当前用户信息"""
|
|
service = UserService(db)
|
|
updated = await service.update_profile(user.id, **req.model_dump(exclude_none=True))
|
|
return updated
|
|
|
|
|
|
@router.put("/me/password")
|
|
async def change_password(
|
|
req: PasswordChange,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""修改密码"""
|
|
service = UserService(db)
|
|
try:
|
|
await service.change_password(user.id, req.old_password, req.new_password)
|
|
return {"success": True, "message": "密码修改成功"}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.put("/me/email")
|
|
async def change_email(
|
|
req: EmailChange,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""更换绑定邮箱"""
|
|
service = UserService(db)
|
|
try:
|
|
await service.change_email(user.id, req.email, req.password)
|
|
return {"success": True, "message": "邮箱已更新"}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.get("/search", response_model=list[UserSearchResult])
|
|
async def search_users(
|
|
q: str = Query(..., min_length=1),
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""搜索用户"""
|
|
service = UserService(db)
|
|
users = await service.search_users(q, user.id)
|
|
return users
|
|
|
|
|
|
@router.get("/{user_id}", response_model=UserProfile)
|
|
async def get_user(
|
|
user_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""获取用户公开信息"""
|
|
service = UserService(db)
|
|
user = await service.get_by_id(user_id)
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
return user
|
|
|
|
|
|
@router.put("/me/status")
|
|
async def update_status(
|
|
req: StatusUpdate,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""设置个人心情状态(对外可见,可设过期)"""
|
|
user.custom_status = req.custom_status or None
|
|
user.status_emoji = req.status_emoji or None
|
|
if req.expires_hours:
|
|
from datetime import timedelta
|
|
user.status_expires_at = datetime.utcnow() + timedelta(hours=req.expires_hours)
|
|
else:
|
|
user.status_expires_at = None
|
|
await db.flush()
|
|
return {
|
|
"custom_status": user.custom_status,
|
|
"status_emoji": user.status_emoji,
|
|
"status_expires_at": user.status_expires_at.isoformat() if user.status_expires_at else None,
|
|
}
|
|
|
|
|
|
@router.post("/{user_id}/block")
|
|
async def block_user(
|
|
user_id: str,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""拉黑用户"""
|
|
if user_id == user.id:
|
|
raise HTTPException(status_code=400, detail="不能拉黑自己")
|
|
existing = await db.execute(
|
|
select(UserBlock).where(UserBlock.blocker_id == user.id, UserBlock.blocked_id == user_id)
|
|
)
|
|
if not existing.scalars().first():
|
|
import uuid
|
|
db.add(UserBlock(id=str(uuid.uuid4()), blocker_id=user.id, blocked_id=user_id))
|
|
return {"success": True}
|
|
|
|
|
|
@router.delete("/{user_id}/block")
|
|
async def unblock_user(
|
|
user_id: str,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""取消拉黑"""
|
|
from sqlalchemy import delete as sql_delete
|
|
await db.execute(
|
|
sql_delete(UserBlock).where(UserBlock.blocker_id == user.id, UserBlock.blocked_id == user_id)
|
|
)
|
|
return {"success": True}
|
|
|
|
|
|
@router.get("/me/blocks")
|
|
async def list_blocks(
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""我的黑名单"""
|
|
result = await db.execute(
|
|
select(UserBlock).where(UserBlock.blocker_id == user.id)
|
|
)
|
|
blocks = []
|
|
for b in result.scalars().all():
|
|
u = await db.execute(select(User).where(User.id == b.blocked_id))
|
|
bu = u.scalars().first()
|
|
if bu:
|
|
blocks.append({"user_id": bu.id, "username": bu.username, "nickname": bu.nickname})
|
|
return blocks
|