"""消息服务""" import uuid from datetime import datetime, timezone from sqlalchemy import select, func, and_ from sqlalchemy.ext.asyncio import AsyncSession from app.models.message import Message from app.models.conversation import Conversation from app.models.conversation_member import ConversationMember class MessageService: def __init__(self, db: AsyncSession): self.db = db async def send_message(self, conversation_id: str, sender_id: str, content: str, msg_type: str = "text", reply_to_id: str | None = None) -> Message: """发送消息""" message = Message( id=str(uuid.uuid4()), conversation_id=conversation_id, sender_id=sender_id, type=msg_type, content=content, reply_to_id=reply_to_id, ) self.db.add(message) # 更新会话的最后消息 conv_result = await self.db.execute( select(Conversation).where(Conversation.id == conversation_id) ) conv = conv_result.scalars().first() if conv: conv.last_message_at = datetime.utcnow() preview = content[:200] if len(content) > 200 else content conv.last_message_preview = preview conv.updated_at = datetime.utcnow() return message async def get_messages(self, conversation_id: str, user_id: str, before: str | None = None, limit: int = 50) -> dict: """获取消息列表(游标分页)""" # 验证成员身份 member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conversation_id, ConversationMember.user_id == user_id, ) ) if not member_result.scalars().first(): raise ValueError("无权访问该会话") query = ( select(Message) .where(Message.conversation_id == conversation_id, Message.is_deleted == False) .order_by(Message.created_at.desc()) ) # 游标分页 if before: before_msg = await self.db.execute( select(Message).where(Message.id == before) ) before_msg_obj = before_msg.scalars().first() if before_msg_obj: query = query.where(Message.created_at < before_msg_obj.created_at) query = query.limit(limit + 1) result = await self.db.execute(query) messages = list(result.scalars().all()) has_more = len(messages) > limit messages = messages[:limit] # 批量预加载发送者信息 from app.models.user import User sender_ids = list(set(m.sender_id for m in messages)) senders_result = await self.db.execute( select(User).where(User.id.in_(sender_ids)) ) senders_map = {u.id: u for u in senders_result.scalars().all()} # 批量预加载被引用消息 reply_to_ids = list(set(m.reply_to_id for m in messages if m.reply_to_id)) reply_msgs_map: dict[str, Message] = {} if reply_to_ids: reply_result = await self.db.execute( select(Message).where(Message.id.in_(reply_to_ids)) ) reply_msgs_map = {m.id: m for m in reply_result.scalars().all()} # 批量预加载被引用消息的发送者 reply_sender_ids = list(set(rm.sender_id for rm in reply_msgs_map.values())) reply_senders_result = await self.db.execute( select(User).where(User.id.in_(reply_sender_ids)) ) reply_senders_map = {u.id: u for u in reply_senders_result.scalars().all()} message_list = [] for msg in reversed(messages): sender = senders_map.get(msg.sender_id) # 获取被引用消息的信息 reply_to_content = None reply_to_sender_name = None if msg.reply_to_id: reply_msg = reply_msgs_map.get(msg.reply_to_id) if reply_msg: reply_to_content = reply_msg.content[:200] if reply_msg.content else None reply_sender = reply_senders_map.get(reply_msg.sender_id) reply_to_sender_name = reply_sender.username if reply_sender else None message_list.append({ "id": msg.id, "conversation_id": msg.conversation_id, "sender_id": msg.sender_id, "sender_name": sender.username if sender else "未知", "sender_avatar": sender.avatar_url if sender else None, "type": msg.type, "content": msg.content, "reply_to_id": msg.reply_to_id, "reply_to_content": reply_to_content, "reply_to_sender_name": reply_to_sender_name, "is_deleted": msg.is_deleted, "created_at": msg.created_at, }) return { "messages": message_list, "has_more": has_more, "next_cursor": messages[-1].id if has_more and messages else None, } async def mark_as_read(self, conversation_id: str, user_id: str, message_id: str): """标记消息已读""" result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conversation_id, ConversationMember.user_id == user_id, ) ) member = result.scalars().first() if member: member.last_read_message_id = message_id async def soft_delete(self, message_id: str, user_id: str): """软删除消息(仅能删除自己的)""" result = await self.db.execute( select(Message).where(Message.id == message_id, Message.sender_id == user_id) ) message = result.scalars().first() if not message: raise ValueError("消息不存在或无权删除") message.is_deleted = True async def get_total_count(self) -> int: """获取消息总数""" result = await self.db.execute(select(func.count(Message.id))) return result.scalar() or 0 async def get_today_count(self) -> int: """获取今日消息数""" today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) result = await self.db.execute( select(func.count(Message.id)).where(Message.created_at >= today) ) return result.scalar() or 0 async def search_messages(self, user_id: str | None = None, conversation_id: str | None = None, keyword: str | None = None, date_from: str | None = None, date_to: str | None = None, limit: int = 50) -> list[dict]: """管理后台搜索消息""" query = select(Message).where(Message.is_deleted == False) if user_id: query = query.where(Message.sender_id == user_id) if conversation_id: query = query.where(Message.conversation_id == conversation_id) if keyword: query = query.where(Message.content.ilike(f"%{keyword}%")) if date_from: query = query.where(Message.created_at >= date_from) if date_to: query = query.where(Message.created_at <= date_to) query = query.order_by(Message.created_at.desc()).limit(limit) result = await self.db.execute(query) from app.models.user import User messages = [] for msg in result.scalars().all(): sender = await self.db.execute(select(User).where(User.id == msg.sender_id)) s = sender.scalars().first() messages.append({ "id": msg.id, "conversation_id": msg.conversation_id, "sender_id": msg.sender_id, "sender_name": s.username if s else "未知", "type": msg.type, "content": msg.content[:200], "created_at": msg.created_at, }) return messages