98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
"""用户服务"""
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select, or_, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.user import User
|
|
from app.utils.security import hash_password, verify_password
|
|
|
|
|
|
class UserService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def get_by_id(self, user_id: str) -> User | None:
|
|
"""根据 ID 获取用户"""
|
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
|
return result.scalars().first()
|
|
|
|
async def get_by_username(self, username: str) -> User | None:
|
|
"""根据用户名获取用户"""
|
|
result = await self.db.execute(select(User).where(User.username == username))
|
|
return result.scalars().first()
|
|
|
|
async def search_users(self, query: str, current_user_id: str, limit: int = 20) -> list[User]:
|
|
"""搜索用户(支持用户名、昵称、邮箱)"""
|
|
result = await self.db.execute(
|
|
select(User).where(
|
|
or_(
|
|
User.username.ilike(f"%{query}%"),
|
|
User.nickname.ilike(f"%{query}%"),
|
|
User.email.ilike(f"%{query}%"),
|
|
),
|
|
User.id != current_user_id,
|
|
User.is_banned == False,
|
|
).limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_profile(self, user_id: str, **kwargs) -> User:
|
|
"""更新用户资料"""
|
|
user = await self.get_by_id(user_id)
|
|
if not user:
|
|
raise ValueError("用户不存在")
|
|
|
|
for key, value in kwargs.items():
|
|
if value is not None and hasattr(user, key):
|
|
setattr(user, key, value)
|
|
|
|
user.updated_at = datetime.utcnow()
|
|
return user
|
|
|
|
async def change_password(self, user_id: str, old_password: str, new_password: str):
|
|
"""修改密码"""
|
|
user = await self.get_by_id(user_id)
|
|
if not user:
|
|
raise ValueError("用户不存在")
|
|
if not verify_password(old_password, user.password_hash):
|
|
raise ValueError("原密码错误")
|
|
user.password_hash = hash_password(new_password)
|
|
user.updated_at = datetime.utcnow()
|
|
|
|
async def change_email(self, user_id: str, new_email: str, password: str):
|
|
"""更换绑定邮箱"""
|
|
user = await self.get_by_id(user_id)
|
|
if not user:
|
|
raise ValueError("用户不存在")
|
|
if not verify_password(password, user.password_hash):
|
|
raise ValueError("密码错误")
|
|
# 检查邮箱是否已被使用
|
|
result = await self.db.execute(
|
|
select(User).where(User.email == new_email, User.id != user_id)
|
|
)
|
|
if result.scalars().first():
|
|
raise ValueError("该邮箱已被其他账号使用")
|
|
user.email = new_email
|
|
user.updated_at = datetime.utcnow()
|
|
|
|
async def update_status(self, user_id: str, status: str):
|
|
"""更新用户在线状态"""
|
|
user = await self.get_by_id(user_id)
|
|
if user:
|
|
user.status = status
|
|
user.last_seen_at = datetime.utcnow()
|
|
|
|
async def get_total_count(self) -> int:
|
|
"""获取用户总数"""
|
|
result = await self.db.execute(select(func.count(User.id)))
|
|
return result.scalar() or 0
|
|
|
|
async def get_online_count(self) -> int:
|
|
"""获取在线用户数"""
|
|
result = await self.db.execute(
|
|
select(func.count(User.id)).where(User.status == "online")
|
|
)
|
|
return result.scalar() or 0
|