292 lines
12 KiB
Python
292 lines
12 KiB
Python
"""消息服务"""
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select, func, and_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.message import Message
|
|
from app.models.conversation import Conversation
|
|
from app.models.conversation_member import ConversationMember
|
|
from app.models.message_reaction import MessageReaction
|
|
|
|
|
|
class MessageService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def send_message(self, conversation_id: str, sender_id: str,
|
|
content: str, msg_type: str = "text",
|
|
reply_to_id: str | None = None,
|
|
mentioned_user_ids: list[str] | None = None) -> Message:
|
|
"""发送消息"""
|
|
# 全员禁言检查(群聊,非管理员不能发)
|
|
conv_result = await self.db.execute(
|
|
select(Conversation).where(Conversation.id == conversation_id)
|
|
)
|
|
conv = conv_result.scalars().first()
|
|
if conv and conv.type == "group" and conv.mute_all:
|
|
from app.models.conversation_member import ConversationMember as CM
|
|
m_result = await self.db.execute(
|
|
select(CM).where(CM.conversation_id == conversation_id, CM.user_id == sender_id)
|
|
)
|
|
m = m_result.scalars().first()
|
|
if m and m.role == "member":
|
|
raise ValueError("群已开启全员禁言")
|
|
|
|
import json
|
|
message = Message(
|
|
id=str(uuid.uuid4()),
|
|
conversation_id=conversation_id,
|
|
sender_id=sender_id,
|
|
type=msg_type,
|
|
content=content,
|
|
reply_to_id=reply_to_id,
|
|
mentions=json.dumps(mentioned_user_ids) if mentioned_user_ids else None,
|
|
)
|
|
self.db.add(message)
|
|
|
|
# 更新会话的最后消息
|
|
conv_result = await self.db.execute(
|
|
select(Conversation).where(Conversation.id == conversation_id)
|
|
)
|
|
conv = conv_result.scalars().first()
|
|
if conv:
|
|
conv.last_message_at = datetime.utcnow()
|
|
preview = content[:200] if len(content) > 200 else content
|
|
conv.last_message_preview = preview
|
|
conv.updated_at = datetime.utcnow()
|
|
|
|
return message
|
|
|
|
async def get_messages(self, conversation_id: str, user_id: str,
|
|
before: str | None = None, limit: int = 50) -> dict:
|
|
"""获取消息列表(游标分页)"""
|
|
# 验证成员身份
|
|
member_result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conversation_id,
|
|
ConversationMember.user_id == user_id,
|
|
)
|
|
)
|
|
if not member_result.scalars().first():
|
|
raise ValueError("无权访问该会话")
|
|
|
|
query = (
|
|
select(Message)
|
|
.where(Message.conversation_id == conversation_id, Message.is_deleted == False)
|
|
.order_by(Message.created_at.desc())
|
|
)
|
|
|
|
# 游标分页
|
|
if before:
|
|
before_msg = await self.db.execute(
|
|
select(Message).where(Message.id == before)
|
|
)
|
|
before_msg_obj = before_msg.scalars().first()
|
|
if before_msg_obj:
|
|
query = query.where(Message.created_at < before_msg_obj.created_at)
|
|
|
|
query = query.limit(limit + 1)
|
|
result = await self.db.execute(query)
|
|
messages = list(result.scalars().all())
|
|
|
|
has_more = len(messages) > limit
|
|
messages = messages[:limit]
|
|
|
|
# 批量预加载发送者信息
|
|
from app.models.user import User
|
|
sender_ids = list(set(m.sender_id for m in messages))
|
|
senders_result = await self.db.execute(
|
|
select(User).where(User.id.in_(sender_ids))
|
|
)
|
|
senders_map = {u.id: u for u in senders_result.scalars().all()}
|
|
|
|
# 批量预加载被引用消息
|
|
reply_to_ids = list(set(m.reply_to_id for m in messages if m.reply_to_id))
|
|
reply_msgs_map: dict[str, Message] = {}
|
|
if reply_to_ids:
|
|
reply_result = await self.db.execute(
|
|
select(Message).where(Message.id.in_(reply_to_ids))
|
|
)
|
|
reply_msgs_map = {m.id: m for m in reply_result.scalars().all()}
|
|
|
|
# 批量预加载被引用消息的发送者
|
|
reply_sender_ids = list(set(rm.sender_id for rm in reply_msgs_map.values()))
|
|
reply_senders_result = await self.db.execute(
|
|
select(User).where(User.id.in_(reply_sender_ids))
|
|
)
|
|
reply_senders_map = {u.id: u for u in reply_senders_result.scalars().all()}
|
|
|
|
# 批量预加载这些消息的表情回应
|
|
msg_ids = [m.id for m in messages]
|
|
reaction_result = await self.db.execute(
|
|
select(MessageReaction).where(MessageReaction.message_id.in_(msg_ids))
|
|
)
|
|
from app.models.user import User as U
|
|
reactions_raw = reaction_result.scalars().all()
|
|
reaction_user_ids = list(set(r.user_id for r in reactions_raw))
|
|
reaction_users_map: dict[str, U] = {}
|
|
if reaction_user_ids:
|
|
ru_result = await self.db.execute(select(U).where(U.id.in_(reaction_user_ids)))
|
|
reaction_users_map = {u.id: u for u in ru_result.scalars().all()}
|
|
reaction_map: dict[str, list[dict]] = {}
|
|
for r in reactions_raw:
|
|
u = reaction_users_map.get(r.user_id)
|
|
reaction_map.setdefault(r.message_id, []).append({
|
|
"emoji": r.emoji,
|
|
"user_id": r.user_id,
|
|
"username": u.username if u else None,
|
|
})
|
|
|
|
message_list = []
|
|
for msg in reversed(messages):
|
|
sender = senders_map.get(msg.sender_id)
|
|
|
|
# 获取被引用消息的信息
|
|
reply_to_content = None
|
|
reply_to_sender_name = None
|
|
if msg.reply_to_id:
|
|
reply_msg = reply_msgs_map.get(msg.reply_to_id)
|
|
if reply_msg:
|
|
reply_to_content = reply_msg.content[:200] if reply_msg.content else None
|
|
reply_sender = reply_senders_map.get(reply_msg.sender_id)
|
|
reply_to_sender_name = reply_sender.username if reply_sender else None
|
|
|
|
message_list.append({
|
|
"id": msg.id,
|
|
"conversation_id": msg.conversation_id,
|
|
"sender_id": msg.sender_id,
|
|
"sender_name": sender.username if sender else "未知",
|
|
"sender_avatar": sender.avatar_url if sender else None,
|
|
"type": msg.type,
|
|
"content": msg.content,
|
|
"reply_to_id": msg.reply_to_id,
|
|
"reply_to_content": reply_to_content,
|
|
"reply_to_sender_name": reply_to_sender_name,
|
|
"mentions": json.loads(msg.mentions) if msg.mentions else None,
|
|
"is_deleted": msg.is_deleted,
|
|
"is_recalled": msg.is_recalled,
|
|
"reactions": reaction_map.get(msg.id, []),
|
|
"created_at": msg.created_at,
|
|
})
|
|
|
|
return {
|
|
"messages": message_list,
|
|
"has_more": has_more,
|
|
"next_cursor": messages[-1].id if has_more and messages else None,
|
|
}
|
|
|
|
async def mark_as_read(self, conversation_id: str, user_id: str, message_id: str):
|
|
"""标记消息已读"""
|
|
result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conversation_id,
|
|
ConversationMember.user_id == user_id,
|
|
)
|
|
)
|
|
member = result.scalars().first()
|
|
if member:
|
|
member.last_read_message_id = message_id
|
|
|
|
async def soft_delete(self, message_id: str, user_id: str):
|
|
"""软删除消息(仅能删除自己的)"""
|
|
result = await self.db.execute(
|
|
select(Message).where(Message.id == message_id, Message.sender_id == user_id)
|
|
)
|
|
message = result.scalars().first()
|
|
if not message:
|
|
raise ValueError("消息不存在或无权删除")
|
|
message.is_deleted = True
|
|
|
|
async def recall_message(self, message_id: str, user_id: str):
|
|
"""撤回消息(仅本人,2 分钟内)"""
|
|
result = await self.db.execute(
|
|
select(Message).where(Message.id == message_id, Message.sender_id == user_id)
|
|
)
|
|
message = result.scalars().first()
|
|
if not message:
|
|
raise ValueError("消息不存在或无权撤回")
|
|
if message.is_recalled:
|
|
raise ValueError("消息已撤回")
|
|
elapsed = (datetime.utcnow() - message.created_at).total_seconds()
|
|
if elapsed > 120:
|
|
raise ValueError("超过 2 分钟,无法撤回")
|
|
message.is_recalled = True
|
|
message.recalled_at = datetime.utcnow()
|
|
|
|
async def react(self, message_id: str, user_id: str, emoji: str):
|
|
"""添加表情回应(toggle:已存在则取消)"""
|
|
result = await self.db.execute(
|
|
select(MessageReaction).where(
|
|
MessageReaction.message_id == message_id,
|
|
MessageReaction.user_id == user_id,
|
|
MessageReaction.emoji == emoji,
|
|
)
|
|
)
|
|
existing = result.scalars().first()
|
|
if existing:
|
|
await self.db.delete(existing)
|
|
return {"action": "removed", "emoji": emoji}
|
|
reaction = MessageReaction(
|
|
id=str(uuid.uuid4()),
|
|
message_id=message_id, user_id=user_id, emoji=emoji,
|
|
)
|
|
self.db.add(reaction)
|
|
await self.db.flush()
|
|
return {"action": "added", "emoji": emoji}
|
|
|
|
async def get_total_count(self) -> int:
|
|
"""获取消息总数"""
|
|
result = await self.db.execute(select(func.count(Message.id)))
|
|
return result.scalar() or 0
|
|
|
|
async def get_today_count(self) -> int:
|
|
"""获取今日消息数"""
|
|
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
result = await self.db.execute(
|
|
select(func.count(Message.id)).where(Message.created_at >= today)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
async def search_messages(self, user_id: str | None = None,
|
|
conversation_id: str | None = None,
|
|
keyword: str | None = None,
|
|
date_from: str | None = None,
|
|
date_to: str | None = None,
|
|
limit: int = 50) -> list[dict]:
|
|
"""管理后台搜索消息"""
|
|
query = select(Message).where(Message.is_deleted == False)
|
|
|
|
if user_id:
|
|
query = query.where(Message.sender_id == user_id)
|
|
if conversation_id:
|
|
query = query.where(Message.conversation_id == conversation_id)
|
|
if keyword:
|
|
query = query.where(Message.content.ilike(f"%{keyword}%"))
|
|
if date_from:
|
|
query = query.where(Message.created_at >= date_from)
|
|
if date_to:
|
|
query = query.where(Message.created_at <= date_to)
|
|
|
|
query = query.order_by(Message.created_at.desc()).limit(limit)
|
|
result = await self.db.execute(query)
|
|
|
|
from app.models.user import User
|
|
messages = []
|
|
for msg in result.scalars().all():
|
|
sender = await self.db.execute(select(User).where(User.id == msg.sender_id))
|
|
s = sender.scalars().first()
|
|
messages.append({
|
|
"id": msg.id,
|
|
"conversation_id": msg.conversation_id,
|
|
"sender_id": msg.sender_id,
|
|
"sender_name": s.username if s else "未知",
|
|
"type": msg.type,
|
|
"content": msg.content[:200],
|
|
"created_at": msg.created_at,
|
|
})
|
|
return messages
|