"""消息服务""" import json import uuid from datetime import datetime, timezone from sqlalchemy import select, func, and_ from sqlalchemy.ext.asyncio import AsyncSession from app.models.message import Message from app.models.conversation import Conversation from app.models.conversation_member import ConversationMember from app.models.message_reaction import MessageReaction class MessageService: def __init__(self, db: AsyncSession): self.db = db async def send_message(self, conversation_id: str, sender_id: str, content: str, msg_type: str = "text", reply_to_id: str | None = None, mentioned_user_ids: list[str] | None = None) -> Message: """发送消息""" # 全员禁言检查(群聊,非管理员不能发) conv_result = await self.db.execute( select(Conversation).where(Conversation.id == conversation_id) ) conv = conv_result.scalars().first() if conv and conv.type == "group" and conv.mute_all: from app.models.conversation_member import ConversationMember as CM m_result = await self.db.execute( select(CM).where(CM.conversation_id == conversation_id, CM.user_id == sender_id) ) m = m_result.scalars().first() if m and m.role == "member": raise ValueError("群已开启全员禁言") import json message = Message( id=str(uuid.uuid4()), conversation_id=conversation_id, sender_id=sender_id, type=msg_type, content=content, reply_to_id=reply_to_id, mentions=json.dumps(mentioned_user_ids) if mentioned_user_ids else None, ) self.db.add(message) # 更新会话的最后消息 conv_result = await self.db.execute( select(Conversation).where(Conversation.id == conversation_id) ) conv = conv_result.scalars().first() if conv: conv.last_message_at = datetime.utcnow() preview = content[:200] if len(content) > 200 else content conv.last_message_preview = preview conv.updated_at = datetime.utcnow() return message async def get_messages(self, conversation_id: str, user_id: str, before: str | None = None, limit: int = 50) -> dict: """获取消息列表(游标分页)""" # 验证成员身份 member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conversation_id, ConversationMember.user_id == user_id, ) ) if not member_result.scalars().first(): raise ValueError("无权访问该会话") query = ( select(Message) .where(Message.conversation_id == conversation_id, Message.is_deleted == False) .order_by(Message.created_at.desc()) ) # 游标分页 if before: before_msg = await self.db.execute( select(Message).where(Message.id == before) ) before_msg_obj = before_msg.scalars().first() if before_msg_obj: query = query.where(Message.created_at < before_msg_obj.created_at) query = query.limit(limit + 1) result = await self.db.execute(query) messages = list(result.scalars().all()) has_more = len(messages) > limit messages = messages[:limit] # 批量预加载发送者信息 from app.models.user import User sender_ids = list(set(m.sender_id for m in messages)) senders_result = await self.db.execute( select(User).where(User.id.in_(sender_ids)) ) senders_map = {u.id: u for u in senders_result.scalars().all()} # 批量预加载被引用消息 reply_to_ids = list(set(m.reply_to_id for m in messages if m.reply_to_id)) reply_msgs_map: dict[str, Message] = {} if reply_to_ids: reply_result = await self.db.execute( select(Message).where(Message.id.in_(reply_to_ids)) ) reply_msgs_map = {m.id: m for m in reply_result.scalars().all()} # 批量预加载被引用消息的发送者 reply_sender_ids = list(set(rm.sender_id for rm in reply_msgs_map.values())) reply_senders_result = await self.db.execute( select(User).where(User.id.in_(reply_sender_ids)) ) reply_senders_map = {u.id: u for u in reply_senders_result.scalars().all()} # 批量预加载这些消息的表情回应 msg_ids = [m.id for m in messages] reaction_result = await self.db.execute( select(MessageReaction).where(MessageReaction.message_id.in_(msg_ids)) ) from app.models.user import User as U reactions_raw = reaction_result.scalars().all() reaction_user_ids = list(set(r.user_id for r in reactions_raw)) reaction_users_map: dict[str, U] = {} if reaction_user_ids: ru_result = await self.db.execute(select(U).where(U.id.in_(reaction_user_ids))) reaction_users_map = {u.id: u for u in ru_result.scalars().all()} reaction_map: dict[str, list[dict]] = {} for r in reactions_raw: u = reaction_users_map.get(r.user_id) reaction_map.setdefault(r.message_id, []).append({ "emoji": r.emoji, "user_id": r.user_id, "username": u.username if u else None, }) message_list = [] for msg in reversed(messages): sender = senders_map.get(msg.sender_id) # 获取被引用消息的信息 reply_to_content = None reply_to_sender_name = None if msg.reply_to_id: reply_msg = reply_msgs_map.get(msg.reply_to_id) if reply_msg: reply_to_content = reply_msg.content[:200] if reply_msg.content else None reply_sender = reply_senders_map.get(reply_msg.sender_id) reply_to_sender_name = reply_sender.username if reply_sender else None message_list.append({ "id": msg.id, "conversation_id": msg.conversation_id, "sender_id": msg.sender_id, "sender_name": sender.username if sender else "未知", "sender_avatar": sender.avatar_url if sender else None, "type": msg.type, "content": msg.content, "reply_to_id": msg.reply_to_id, "reply_to_content": reply_to_content, "reply_to_sender_name": reply_to_sender_name, "mentions": json.loads(msg.mentions) if msg.mentions else None, "is_deleted": msg.is_deleted, "is_recalled": msg.is_recalled, "reactions": reaction_map.get(msg.id, []), "created_at": msg.created_at, }) return { "messages": message_list, "has_more": has_more, "next_cursor": messages[-1].id if has_more and messages else None, } async def mark_as_read(self, conversation_id: str, user_id: str, message_id: str): """标记消息已读""" result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conversation_id, ConversationMember.user_id == user_id, ) ) member = result.scalars().first() if member: member.last_read_message_id = message_id async def soft_delete(self, message_id: str, user_id: str): """软删除消息(仅能删除自己的)""" result = await self.db.execute( select(Message).where(Message.id == message_id, Message.sender_id == user_id) ) message = result.scalars().first() if not message: raise ValueError("消息不存在或无权删除") message.is_deleted = True async def recall_message(self, message_id: str, user_id: str): """撤回消息(仅本人,2 分钟内)""" result = await self.db.execute( select(Message).where(Message.id == message_id, Message.sender_id == user_id) ) message = result.scalars().first() if not message: raise ValueError("消息不存在或无权撤回") if message.is_recalled: raise ValueError("消息已撤回") elapsed = (datetime.utcnow() - message.created_at).total_seconds() if elapsed > 120: raise ValueError("超过 2 分钟,无法撤回") message.is_recalled = True message.recalled_at = datetime.utcnow() async def react(self, message_id: str, user_id: str, emoji: str): """添加表情回应(toggle:已存在则取消)""" result = await self.db.execute( select(MessageReaction).where( MessageReaction.message_id == message_id, MessageReaction.user_id == user_id, MessageReaction.emoji == emoji, ) ) existing = result.scalars().first() if existing: await self.db.delete(existing) return {"action": "removed", "emoji": emoji} reaction = MessageReaction( id=str(uuid.uuid4()), message_id=message_id, user_id=user_id, emoji=emoji, ) self.db.add(reaction) await self.db.flush() return {"action": "added", "emoji": emoji} async def get_total_count(self) -> int: """获取消息总数""" result = await self.db.execute(select(func.count(Message.id))) return result.scalar() or 0 async def get_today_count(self) -> int: """获取今日消息数""" today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) result = await self.db.execute( select(func.count(Message.id)).where(Message.created_at >= today) ) return result.scalar() or 0 async def search_messages(self, user_id: str | None = None, conversation_id: str | None = None, keyword: str | None = None, date_from: str | None = None, date_to: str | None = None, limit: int = 50) -> list[dict]: """管理后台搜索消息""" query = select(Message).where(Message.is_deleted == False) if user_id: query = query.where(Message.sender_id == user_id) if conversation_id: query = query.where(Message.conversation_id == conversation_id) if keyword: query = query.where(Message.content.ilike(f"%{keyword}%")) if date_from: query = query.where(Message.created_at >= date_from) if date_to: query = query.where(Message.created_at <= date_to) query = query.order_by(Message.created_at.desc()).limit(limit) result = await self.db.execute(query) from app.models.user import User messages = [] for msg in result.scalars().all(): sender = await self.db.execute(select(User).where(User.id == msg.sender_id)) s = sender.scalars().first() messages.append({ "id": msg.id, "conversation_id": msg.conversation_id, "sender_id": msg.sender_id, "sender_name": s.username if s else "未知", "type": msg.type, "content": msg.content[:200], "created_at": msg.created_at, }) return messages