首个可运行的版本
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""服务层包"""
|
||||
@@ -0,0 +1,189 @@
|
||||
"""管理后台服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.user import User
|
||||
from app.models.message import Message
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.utils.security import verify_password, hash_password, create_access_token
|
||||
|
||||
|
||||
class AdminService:
|
||||
def __init__(self, db: AsyncSession = None):
|
||||
self.db = db
|
||||
|
||||
async def init_system_config(self):
|
||||
"""初始化系统默认配置"""
|
||||
if not self.db:
|
||||
return
|
||||
defaults = {
|
||||
"platform_name": "青叶",
|
||||
"announcement": "",
|
||||
"max_upload_size_mb": "10",
|
||||
"allow_registration": "true",
|
||||
"admin_password_hash": hash_password(settings.ADMIN_PASSWORD),
|
||||
}
|
||||
for key, value in defaults.items():
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == key)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
self.db.add(SystemConfig(id=str(uuid.uuid4()), key=key, value=value))
|
||||
|
||||
async def login(self, password: str) -> str | None:
|
||||
"""管理员登录(仅密码)"""
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == "admin_password_hash")
|
||||
)
|
||||
config = result.scalars().first()
|
||||
if not config:
|
||||
return None
|
||||
|
||||
if not verify_password(password, config.value):
|
||||
return None
|
||||
|
||||
# 生成管理员 Token
|
||||
token = create_access_token({
|
||||
"sub": "admin",
|
||||
"username": "admin",
|
||||
"is_admin": True,
|
||||
})
|
||||
return token
|
||||
|
||||
async def get_dashboard_stats(self) -> dict:
|
||||
"""获取仪表盘统计数据"""
|
||||
total_users = await self.db.execute(select(func.count(User.id)))
|
||||
online_users = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.status == "online")
|
||||
)
|
||||
total_messages = await self.db.execute(select(func.count(Message.id)))
|
||||
total_conversations = await self.db.execute(select(func.count(Conversation.id)))
|
||||
|
||||
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_messages = await self.db.execute(
|
||||
select(func.count(Message.id)).where(Message.created_at >= today)
|
||||
)
|
||||
seven_days_ago = datetime.utcnow() - __import__("datetime").timedelta(days=7)
|
||||
new_users_7d = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.created_at >= seven_days_ago)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_users": total_users.scalar() or 0,
|
||||
"online_users": online_users.scalar() or 0,
|
||||
"total_messages": total_messages.scalar() or 0,
|
||||
"today_messages": today_messages.scalar() or 0,
|
||||
"total_conversations": total_conversations.scalar() or 0,
|
||||
"new_users_7d": new_users_7d.scalar() or 0,
|
||||
}
|
||||
|
||||
async def get_trend_data(self, metric: str, days: int = 7) -> list[dict]:
|
||||
"""获取趋势数据"""
|
||||
from sqlalchemy import cast, Date
|
||||
trends = []
|
||||
for i in range(days - 1, -1, -1):
|
||||
day = (datetime.utcnow() - __import__("datetime").timedelta(days=i)).date()
|
||||
day_start = datetime.combine(day, __import__("datetime").time.min)
|
||||
day_end = datetime.combine(day, __import__("datetime").time.max)
|
||||
|
||||
if metric == "online":
|
||||
# 简化:使用当前在线数
|
||||
count = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.status == "online")
|
||||
)
|
||||
value = count.scalar() or 0
|
||||
elif metric == "messages":
|
||||
count = await self.db.execute(
|
||||
select(func.count(Message.id)).where(
|
||||
Message.created_at >= day_start,
|
||||
Message.created_at <= day_end,
|
||||
)
|
||||
)
|
||||
value = count.scalar() or 0
|
||||
elif metric == "registrations":
|
||||
count = await self.db.execute(
|
||||
select(func.count(User.id)).where(
|
||||
User.created_at >= day_start,
|
||||
User.created_at <= day_end,
|
||||
)
|
||||
)
|
||||
value = count.scalar() or 0
|
||||
else:
|
||||
value = 0
|
||||
|
||||
trends.append({"date": day.isoformat(), "value": value})
|
||||
return trends
|
||||
|
||||
async def get_users_list(self, page: int = 1, page_size: int = 20,
|
||||
search: str | None = None, status: str | None = None) -> dict:
|
||||
"""获取用户列表(管理后台)"""
|
||||
query = select(User)
|
||||
count_query = select(func.count(User.id))
|
||||
|
||||
if search:
|
||||
query = query.where(User.username.ilike(f"%{search}%"))
|
||||
count_query = count_query.where(User.username.ilike(f"%{search}%"))
|
||||
if status == "online":
|
||||
query = query.where(User.status == "online")
|
||||
count_query = count_query.where(User.status == "online")
|
||||
elif status == "banned":
|
||||
query = query.where(User.is_banned == True)
|
||||
count_query = count_query.where(User.is_banned == True)
|
||||
|
||||
total = (await self.db.execute(count_query)).scalar() or 0
|
||||
result = await self.db.execute(
|
||||
query.order_by(User.created_at.desc())
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
|
||||
users = []
|
||||
for u in result.scalars().all():
|
||||
users.append({
|
||||
"id": u.id,
|
||||
"username": u.username,
|
||||
"email": u.email,
|
||||
"avatar_url": u.avatar_url,
|
||||
"status": u.status,
|
||||
"is_banned": u.is_banned,
|
||||
"banned_reason": u.banned_reason,
|
||||
"last_seen_at": u.last_seen_at,
|
||||
"created_at": u.created_at,
|
||||
})
|
||||
|
||||
return {"items": users, "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
async def ban_user(self, user_id: str, is_banned: bool, reason: str | None = None):
|
||||
"""封禁/解封用户"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
user.is_banned = is_banned
|
||||
user.banned_reason = reason if is_banned else None
|
||||
|
||||
async def delete_user(self, user_id: str):
|
||||
"""删除用户"""
|
||||
await self.db.execute(delete(User).where(User.id == user_id))
|
||||
|
||||
async def get_all_configs(self) -> list[dict]:
|
||||
"""获取所有系统配置"""
|
||||
result = await self.db.execute(select(SystemConfig))
|
||||
return [{"key": c.key, "value": c.value} for c in result.scalars().all()]
|
||||
|
||||
async def update_configs(self, configs: dict[str, str]):
|
||||
"""更新系统配置"""
|
||||
for key, value in configs.items():
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == key)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
if config:
|
||||
config.value = value
|
||||
config.updated_at = datetime.utcnow()
|
||||
@@ -0,0 +1,79 @@
|
||||
"""认证服务"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.utils.security import hash_password, verify_password, create_access_token, create_refresh_token
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def register(self, username: str, email: str, password: str) -> dict:
|
||||
"""用户注册"""
|
||||
# 检查用户名是否已存在
|
||||
result = await self.db.execute(select(User).where(User.username == username))
|
||||
if result.scalars().first():
|
||||
raise ValueError("用户名已存在")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
result = await self.db.execute(select(User).where(User.email == email))
|
||||
if result.scalars().first():
|
||||
raise ValueError("邮箱已被注册")
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
username=username,
|
||||
email=email,
|
||||
password_hash=hash_password(password),
|
||||
)
|
||||
self.db.add(user)
|
||||
await self.db.flush()
|
||||
|
||||
# 生成 Token
|
||||
tokens = self._generate_tokens(user)
|
||||
return {**tokens, "user": user}
|
||||
|
||||
async def login(self, username: str, password: str) -> dict:
|
||||
"""用户登录"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(
|
||||
(User.username == username) | (User.email == username)
|
||||
)
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user or not verify_password(password, user.password_hash):
|
||||
raise ValueError("用户名或密码错误")
|
||||
|
||||
if user.is_banned:
|
||||
raise ValueError("账号已被封禁")
|
||||
|
||||
# 更新在线状态
|
||||
user.status = "online"
|
||||
from datetime import datetime, timezone
|
||||
user.last_seen_at = datetime.utcnow()
|
||||
|
||||
tokens = self._generate_tokens(user)
|
||||
return {**tokens, "user": user}
|
||||
|
||||
async def refresh_token(self, user_id: str) -> dict:
|
||||
"""刷新 Token"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
return self._generate_tokens(user)
|
||||
|
||||
def _generate_tokens(self, user: User) -> dict:
|
||||
"""生成 JWT Token 对"""
|
||||
data = {"sub": user.id, "username": user.username}
|
||||
return {
|
||||
"access_token": create_access_token(data),
|
||||
"refresh_token": create_refresh_token(data),
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
"""会话服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation_member import ConversationMember
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_or_create_private(self, user1_id: str, user2_id: str) -> Conversation:
|
||||
"""获取或创建私聊会话"""
|
||||
# 查找已有的私聊
|
||||
result = await self.db.execute(
|
||||
select(Conversation).join(ConversationMember)
|
||||
.where(
|
||||
Conversation.type == "private",
|
||||
ConversationMember.user_id == user1_id,
|
||||
)
|
||||
)
|
||||
for conv in result.scalars().all():
|
||||
member_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv.id,
|
||||
ConversationMember.user_id == user2_id,
|
||||
)
|
||||
)
|
||||
if member_result.scalars().first():
|
||||
return conv
|
||||
|
||||
# 创建新私聊
|
||||
conv = Conversation(id=str(uuid.uuid4()), type="private")
|
||||
self.db.add(conv)
|
||||
await self.db.flush()
|
||||
|
||||
# 添加两个成员
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user1_id, role="member"
|
||||
))
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user2_id, role="member"
|
||||
))
|
||||
return conv
|
||||
|
||||
async def create_group(self, creator_id: str, name: str, member_ids: list[str],
|
||||
description: str | None = None) -> Conversation:
|
||||
"""创建群聊"""
|
||||
conv = Conversation(
|
||||
id=str(uuid.uuid4()),
|
||||
type="group",
|
||||
name=name,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
self.db.add(conv)
|
||||
await self.db.flush()
|
||||
|
||||
# 创建者为 owner
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id,
|
||||
user_id=creator_id, role="owner"
|
||||
))
|
||||
# 其他成员
|
||||
for mid in member_ids:
|
||||
if mid != creator_id:
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id,
|
||||
user_id=mid, role="member"
|
||||
))
|
||||
return conv
|
||||
|
||||
async def get_user_conversations(self, user_id: str) -> list[dict]:
|
||||
"""获取用户的会话列表"""
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember)
|
||||
.where(ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None))
|
||||
.order_by(ConversationMember.joined_at.desc())
|
||||
)
|
||||
members = result.scalars().all()
|
||||
|
||||
conversations = []
|
||||
for member in members:
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == member.conversation_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if not conv:
|
||||
continue
|
||||
|
||||
# 获取未读数
|
||||
unread = await self._get_unread_count(conv.id, member.last_read_message_id)
|
||||
|
||||
# 获取显示信息
|
||||
display_name = conv.name
|
||||
display_avatar = conv.avatar_url
|
||||
|
||||
if conv.type == "private":
|
||||
other = await self._get_other_member(conv.id, user_id)
|
||||
if other:
|
||||
display_name = other.username
|
||||
display_avatar = other.avatar_url
|
||||
|
||||
conversations.append({
|
||||
"id": conv.id,
|
||||
"type": conv.type,
|
||||
"name": display_name,
|
||||
"avatar_url": display_avatar,
|
||||
"description": conv.description,
|
||||
"last_message_preview": conv.last_message_preview,
|
||||
"last_message_at": conv.last_message_at,
|
||||
"unread_count": unread,
|
||||
"created_at": conv.created_at,
|
||||
})
|
||||
|
||||
# 按最后消息时间排序
|
||||
conversations.sort(key=lambda x: x["last_message_at"] or x["created_at"], reverse=True)
|
||||
return conversations
|
||||
|
||||
async def get_conversation_detail(self, conv_id: str, user_id: str) -> dict | None:
|
||||
"""获取会话详情"""
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conv_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if not conv:
|
||||
return None
|
||||
|
||||
# 验证成员身份
|
||||
member_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
ConversationMember.left_at.is_(None),
|
||||
)
|
||||
)
|
||||
if not member_result.scalars().first():
|
||||
return None
|
||||
|
||||
# 获取所有成员
|
||||
members_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.left_at.is_(None),
|
||||
)
|
||||
)
|
||||
members = []
|
||||
for m in members_result.scalars().all():
|
||||
user_result = await self.db.execute(select(User).where(User.id == m.user_id))
|
||||
user = user_result.scalars().first()
|
||||
if user:
|
||||
members.append({
|
||||
"id": m.id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"nickname": user.bio,
|
||||
"avatar_url": user.avatar_url,
|
||||
"role": m.role,
|
||||
"joined_at": m.joined_at,
|
||||
})
|
||||
|
||||
return {
|
||||
"id": conv.id,
|
||||
"type": conv.type,
|
||||
"name": conv.name,
|
||||
"avatar_url": conv.avatar_url,
|
||||
"description": conv.description,
|
||||
"last_message_preview": conv.last_message_preview,
|
||||
"last_message_at": conv.last_message_at,
|
||||
"created_at": conv.created_at,
|
||||
"members": members,
|
||||
"unread_count": await self._get_unread_count(conv.id, None),
|
||||
}
|
||||
|
||||
async def _get_other_member(self, conv_id: str, user_id: str) -> User | None:
|
||||
"""获取私聊的对方用户"""
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.user_id != user_id,
|
||||
)
|
||||
)
|
||||
member = result.scalars().first()
|
||||
if member:
|
||||
user_result = await self.db.execute(select(User).where(User.id == member.user_id))
|
||||
return user_result.scalars().first()
|
||||
return None
|
||||
|
||||
async def _get_unread_count(self, conv_id: str, last_read_id: str | None) -> int:
|
||||
"""计算未读消息数"""
|
||||
from app.models.message import Message
|
||||
query = select(func := __import__("sqlalchemy").func).count(Message.id).where(
|
||||
Message.conversation_id == conv_id,
|
||||
Message.is_deleted == False,
|
||||
)
|
||||
if last_read_id:
|
||||
# 获取 last_read 消息的时间
|
||||
lr = await self.db.execute(select(Message).where(Message.id == last_read_id))
|
||||
lr_msg = lr.scalars().first()
|
||||
if lr_msg:
|
||||
query = query.where(Message.created_at > lr_msg.created_at)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar() or 0
|
||||
@@ -0,0 +1,162 @@
|
||||
"""好友服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.friend import Friend
|
||||
from app.models.friend_request import FriendRequest
|
||||
|
||||
|
||||
class FriendService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def send_request(self, from_user_id: str, to_user_id: str,
|
||||
message: str | None = None) -> FriendRequest:
|
||||
"""发送好友请求"""
|
||||
if from_user_id == to_user_id:
|
||||
raise ValueError("不能添加自己为好友")
|
||||
|
||||
# 检查目标用户是否存在
|
||||
target = await self.db.execute(select(User).where(User.id == to_user_id))
|
||||
if not target.scalars().first():
|
||||
raise ValueError("目标用户不存在")
|
||||
|
||||
# 检查是否已是好友
|
||||
existing = await self.db.execute(
|
||||
select(Friend).where(
|
||||
Friend.user_id == from_user_id,
|
||||
Friend.friend_user_id == to_user_id,
|
||||
)
|
||||
)
|
||||
if existing.scalars().first():
|
||||
raise ValueError("已经是好友了")
|
||||
|
||||
# 检查是否有待处理的请求
|
||||
pending = await self.db.execute(
|
||||
select(FriendRequest).where(
|
||||
FriendRequest.from_user_id == from_user_id,
|
||||
FriendRequest.to_user_id == to_user_id,
|
||||
FriendRequest.status == "pending",
|
||||
)
|
||||
)
|
||||
if pending.scalars().first():
|
||||
raise ValueError("已发送过好友请求")
|
||||
|
||||
request = FriendRequest(
|
||||
id=str(uuid.uuid4()),
|
||||
from_user_id=from_user_id,
|
||||
to_user_id=to_user_id,
|
||||
message=message,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(request)
|
||||
return request
|
||||
|
||||
async def accept_request(self, request_id: str, user_id: str):
|
||||
"""接受好友请求"""
|
||||
result = await self.db.execute(
|
||||
select(FriendRequest).where(FriendRequest.id == request_id)
|
||||
)
|
||||
request = result.scalars().first()
|
||||
if not request:
|
||||
raise ValueError("请求不存在")
|
||||
if request.to_user_id != user_id:
|
||||
raise ValueError("无权操作此请求")
|
||||
if request.status != "pending":
|
||||
raise ValueError("该请求已处理")
|
||||
|
||||
request.status = "accepted"
|
||||
request.responded_at = datetime.utcnow()
|
||||
|
||||
# 创建双向好友关系
|
||||
self.db.add(Friend(
|
||||
id=str(uuid.uuid4()), user_id=request.from_user_id,
|
||||
friend_user_id=request.to_user_id,
|
||||
))
|
||||
self.db.add(Friend(
|
||||
id=str(uuid.uuid4()), user_id=request.to_user_id,
|
||||
friend_user_id=request.from_user_id,
|
||||
))
|
||||
|
||||
async def reject_request(self, request_id: str, user_id: str):
|
||||
"""拒绝好友请求"""
|
||||
result = await self.db.execute(
|
||||
select(FriendRequest).where(FriendRequest.id == request_id)
|
||||
)
|
||||
request = result.scalars().first()
|
||||
if not request:
|
||||
raise ValueError("请求不存在")
|
||||
if request.to_user_id != user_id:
|
||||
raise ValueError("无权操作此请求")
|
||||
|
||||
request.status = "rejected"
|
||||
request.responded_at = datetime.utcnow()
|
||||
|
||||
async def get_friends(self, user_id: str) -> list[dict]:
|
||||
"""获取好友列表"""
|
||||
result = await self.db.execute(
|
||||
select(Friend).where(Friend.user_id == user_id)
|
||||
)
|
||||
friends = []
|
||||
for friendship in result.scalars().all():
|
||||
user_result = await self.db.execute(
|
||||
select(User).where(User.id == friendship.friend_user_id)
|
||||
)
|
||||
user = user_result.scalars().first()
|
||||
if user:
|
||||
friends.append({
|
||||
"id": friendship.id,
|
||||
"friend_user_id": user.id,
|
||||
"username": user.username,
|
||||
"nickname": user.bio,
|
||||
"avatar_url": user.avatar_url,
|
||||
"remark": friendship.remark,
|
||||
"status": user.status,
|
||||
})
|
||||
return friends
|
||||
|
||||
async def get_pending_requests(self, user_id: str) -> list[dict]:
|
||||
"""获取待处理的好友请求"""
|
||||
result = await self.db.execute(
|
||||
select(FriendRequest).where(
|
||||
FriendRequest.to_user_id == user_id,
|
||||
FriendRequest.status == "pending",
|
||||
).order_by(FriendRequest.created_at.desc())
|
||||
)
|
||||
requests = []
|
||||
for req in result.scalars().all():
|
||||
from_user = await self.db.execute(select(User).where(User.id == req.from_user_id))
|
||||
fu = from_user.scalars().first()
|
||||
requests.append({
|
||||
"id": req.id,
|
||||
"from_user_id": req.from_user_id,
|
||||
"from_username": fu.username if fu else "未知",
|
||||
"from_avatar": fu.avatar_url if fu else None,
|
||||
"to_user_id": req.to_user_id,
|
||||
"message": req.message,
|
||||
"status": req.status,
|
||||
"created_at": req.created_at,
|
||||
})
|
||||
return requests
|
||||
|
||||
async def remove_friend(self, user_id: str, friend_id: str):
|
||||
"""删除好友"""
|
||||
await self.db.execute(
|
||||
select(Friend).where(
|
||||
Friend.user_id == user_id,
|
||||
Friend.friend_user_id == friend_id,
|
||||
)
|
||||
)
|
||||
# 删除双向关系
|
||||
from sqlalchemy import delete
|
||||
await self.db.execute(
|
||||
delete(Friend).where(
|
||||
(Friend.user_id == user_id) & (Friend.friend_user_id == friend_id) |
|
||||
(Friend.user_id == friend_id) & (Friend.friend_user_id == user_id)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,179 @@
|
||||
"""消息服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.message import Message
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation_member import ConversationMember
|
||||
|
||||
|
||||
class MessageService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def send_message(self, conversation_id: str, sender_id: str,
|
||||
content: str, msg_type: str = "text",
|
||||
reply_to_id: str | None = None) -> Message:
|
||||
"""发送消息"""
|
||||
message = Message(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
sender_id=sender_id,
|
||||
type=msg_type,
|
||||
content=content,
|
||||
reply_to_id=reply_to_id,
|
||||
)
|
||||
self.db.add(message)
|
||||
|
||||
# 更新会话的最后消息
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if conv:
|
||||
conv.last_message_at = datetime.utcnow()
|
||||
preview = content[:200] if len(content) > 200 else content
|
||||
conv.last_message_preview = preview
|
||||
conv.updated_at = datetime.utcnow()
|
||||
|
||||
return message
|
||||
|
||||
async def get_messages(self, conversation_id: str, user_id: str,
|
||||
before: str | None = None, limit: int = 50) -> dict:
|
||||
"""获取消息列表(游标分页)"""
|
||||
# 验证成员身份
|
||||
member_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conversation_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
)
|
||||
)
|
||||
if not member_result.scalars().first():
|
||||
raise ValueError("无权访问该会话")
|
||||
|
||||
query = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.is_deleted == False)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
# 游标分页
|
||||
if before:
|
||||
before_msg = await self.db.execute(
|
||||
select(Message).where(Message.id == before)
|
||||
)
|
||||
before_msg_obj = before_msg.scalars().first()
|
||||
if before_msg_obj:
|
||||
query = query.where(Message.created_at < before_msg_obj.created_at)
|
||||
|
||||
query = query.limit(limit + 1)
|
||||
result = await self.db.execute(query)
|
||||
messages = list(result.scalars().all())
|
||||
|
||||
has_more = len(messages) > limit
|
||||
messages = messages[:limit]
|
||||
|
||||
# 获取发送者信息
|
||||
from app.models.user import User
|
||||
message_list = []
|
||||
for msg in reversed(messages):
|
||||
sender_result = await self.db.execute(
|
||||
select(User).where(User.id == msg.sender_id)
|
||||
)
|
||||
sender = sender_result.scalars().first()
|
||||
message_list.append({
|
||||
"id": msg.id,
|
||||
"conversation_id": msg.conversation_id,
|
||||
"sender_id": msg.sender_id,
|
||||
"sender_name": sender.username if sender else "未知",
|
||||
"sender_avatar": sender.avatar_url if sender else None,
|
||||
"type": msg.type,
|
||||
"content": msg.content,
|
||||
"reply_to_id": msg.reply_to_id,
|
||||
"is_deleted": msg.is_deleted,
|
||||
"created_at": msg.created_at,
|
||||
})
|
||||
|
||||
return {
|
||||
"messages": message_list,
|
||||
"has_more": has_more,
|
||||
"next_cursor": messages[-1].id if has_more and messages else None,
|
||||
}
|
||||
|
||||
async def mark_as_read(self, conversation_id: str, user_id: str, message_id: str):
|
||||
"""标记消息已读"""
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conversation_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
)
|
||||
)
|
||||
member = result.scalars().first()
|
||||
if member:
|
||||
member.last_read_message_id = message_id
|
||||
|
||||
async def soft_delete(self, message_id: str, user_id: str):
|
||||
"""软删除消息(仅能删除自己的)"""
|
||||
result = await self.db.execute(
|
||||
select(Message).where(Message.id == message_id, Message.sender_id == user_id)
|
||||
)
|
||||
message = result.scalars().first()
|
||||
if not message:
|
||||
raise ValueError("消息不存在或无权删除")
|
||||
message.is_deleted = True
|
||||
|
||||
async def get_total_count(self) -> int:
|
||||
"""获取消息总数"""
|
||||
result = await self.db.execute(select(func.count(Message.id)))
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_today_count(self) -> int:
|
||||
"""获取今日消息数"""
|
||||
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
result = await self.db.execute(
|
||||
select(func.count(Message.id)).where(Message.created_at >= today)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def search_messages(self, user_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
keyword: str | None = None,
|
||||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
limit: int = 50) -> list[dict]:
|
||||
"""管理后台搜索消息"""
|
||||
query = select(Message).where(Message.is_deleted == False)
|
||||
|
||||
if user_id:
|
||||
query = query.where(Message.sender_id == user_id)
|
||||
if conversation_id:
|
||||
query = query.where(Message.conversation_id == conversation_id)
|
||||
if keyword:
|
||||
query = query.where(Message.content.ilike(f"%{keyword}%"))
|
||||
if date_from:
|
||||
query = query.where(Message.created_at >= date_from)
|
||||
if date_to:
|
||||
query = query.where(Message.created_at <= date_to)
|
||||
|
||||
query = query.order_by(Message.created_at.desc()).limit(limit)
|
||||
result = await self.db.execute(query)
|
||||
|
||||
from app.models.user import User
|
||||
messages = []
|
||||
for msg in result.scalars().all():
|
||||
sender = await self.db.execute(select(User).where(User.id == msg.sender_id))
|
||||
s = sender.scalars().first()
|
||||
messages.append({
|
||||
"id": msg.id,
|
||||
"conversation_id": msg.conversation_id,
|
||||
"sender_id": msg.sender_id,
|
||||
"sender_name": s.username if s else "未知",
|
||||
"type": msg.type,
|
||||
"content": msg.content[:200],
|
||||
"created_at": msg.created_at,
|
||||
})
|
||||
return messages
|
||||
@@ -0,0 +1,69 @@
|
||||
"""用户服务"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, or_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class UserService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_id(self, user_id: str) -> User | None:
|
||||
"""根据 ID 获取用户"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_by_username(self, username: str) -> User | None:
|
||||
"""根据用户名获取用户"""
|
||||
result = await self.db.execute(select(User).where(User.username == username))
|
||||
return result.scalars().first()
|
||||
|
||||
async def search_users(self, query: str, current_user_id: str, limit: int = 20) -> list[User]:
|
||||
"""搜索用户"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(
|
||||
or_(
|
||||
User.username.ilike(f"%{query}%"),
|
||||
User.email.ilike(f"%{query}%"),
|
||||
),
|
||||
User.id != current_user_id,
|
||||
User.is_banned == False,
|
||||
).limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_profile(self, user_id: str, **kwargs) -> User:
|
||||
"""更新用户资料"""
|
||||
user = await self.get_by_id(user_id)
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if value is not None and hasattr(user, key):
|
||||
setattr(user, key, value)
|
||||
|
||||
user.updated_at = datetime.utcnow()
|
||||
return user
|
||||
|
||||
async def update_status(self, user_id: str, status: str):
|
||||
"""更新用户在线状态"""
|
||||
user = await self.get_by_id(user_id)
|
||||
if user:
|
||||
user.status = status
|
||||
user.last_seen_at = datetime.utcnow()
|
||||
|
||||
async def get_total_count(self) -> int:
|
||||
"""获取用户总数"""
|
||||
result = await self.db.execute(select(func.count(User.id)))
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_online_count(self) -> int:
|
||||
"""获取在线用户数"""
|
||||
result = await self.db.execute(
|
||||
select(func.count(User.id)).where(User.status == "online")
|
||||
)
|
||||
return result.scalar() or 0
|
||||
Reference in New Issue
Block a user