This commit is contained in:
2026-06-14 11:16:42 +08:00
parent ca39190ad7
commit c9fc87cd89
35 changed files with 1480 additions and 18 deletions
+81 -1
View File
@@ -1,5 +1,6 @@
"""消息服务"""
import json
import uuid
from datetime import datetime, timezone
@@ -9,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.message import Message
from app.models.conversation import Conversation
from app.models.conversation_member import ConversationMember
from app.models.message_reaction import MessageReaction
class MessageService:
@@ -17,8 +19,24 @@ class MessageService:
async def send_message(self, conversation_id: str, sender_id: str,
content: str, msg_type: str = "text",
reply_to_id: str | None = None) -> Message:
reply_to_id: str | None = None,
mentioned_user_ids: list[str] | None = None) -> Message:
"""发送消息"""
# 全员禁言检查(群聊,非管理员不能发)
conv_result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = conv_result.scalars().first()
if conv and conv.type == "group" and conv.mute_all:
from app.models.conversation_member import ConversationMember as CM
m_result = await self.db.execute(
select(CM).where(CM.conversation_id == conversation_id, CM.user_id == sender_id)
)
m = m_result.scalars().first()
if m and m.role == "member":
raise ValueError("群已开启全员禁言")
import json
message = Message(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
@@ -26,6 +44,7 @@ class MessageService:
type=msg_type,
content=content,
reply_to_id=reply_to_id,
mentions=json.dumps(mentioned_user_ids) if mentioned_user_ids else None,
)
self.db.add(message)
@@ -101,6 +120,27 @@ class MessageService:
)
reply_senders_map = {u.id: u for u in reply_senders_result.scalars().all()}
# 批量预加载这些消息的表情回应
msg_ids = [m.id for m in messages]
reaction_result = await self.db.execute(
select(MessageReaction).where(MessageReaction.message_id.in_(msg_ids))
)
from app.models.user import User as U
reactions_raw = reaction_result.scalars().all()
reaction_user_ids = list(set(r.user_id for r in reactions_raw))
reaction_users_map: dict[str, U] = {}
if reaction_user_ids:
ru_result = await self.db.execute(select(U).where(U.id.in_(reaction_user_ids)))
reaction_users_map = {u.id: u for u in ru_result.scalars().all()}
reaction_map: dict[str, list[dict]] = {}
for r in reactions_raw:
u = reaction_users_map.get(r.user_id)
reaction_map.setdefault(r.message_id, []).append({
"emoji": r.emoji,
"user_id": r.user_id,
"username": u.username if u else None,
})
message_list = []
for msg in reversed(messages):
sender = senders_map.get(msg.sender_id)
@@ -126,7 +166,10 @@ class MessageService:
"reply_to_id": msg.reply_to_id,
"reply_to_content": reply_to_content,
"reply_to_sender_name": reply_to_sender_name,
"mentions": json.loads(msg.mentions) if msg.mentions else None,
"is_deleted": msg.is_deleted,
"is_recalled": msg.is_recalled,
"reactions": reaction_map.get(msg.id, []),
"created_at": msg.created_at,
})
@@ -158,6 +201,43 @@ class MessageService:
raise ValueError("消息不存在或无权删除")
message.is_deleted = True
async def recall_message(self, message_id: str, user_id: str):
"""撤回消息(仅本人,2 分钟内)"""
result = await self.db.execute(
select(Message).where(Message.id == message_id, Message.sender_id == user_id)
)
message = result.scalars().first()
if not message:
raise ValueError("消息不存在或无权撤回")
if message.is_recalled:
raise ValueError("消息已撤回")
elapsed = (datetime.utcnow() - message.created_at).total_seconds()
if elapsed > 120:
raise ValueError("超过 2 分钟,无法撤回")
message.is_recalled = True
message.recalled_at = datetime.utcnow()
async def react(self, message_id: str, user_id: str, emoji: str):
"""添加表情回应(toggle:已存在则取消)"""
result = await self.db.execute(
select(MessageReaction).where(
MessageReaction.message_id == message_id,
MessageReaction.user_id == user_id,
MessageReaction.emoji == emoji,
)
)
existing = result.scalars().first()
if existing:
await self.db.delete(existing)
return {"action": "removed", "emoji": emoji}
reaction = MessageReaction(
id=str(uuid.uuid4()),
message_id=message_id, user_id=user_id, emoji=emoji,
)
self.db.add(reaction)
await self.db.flush()
return {"action": "added", "emoji": emoji}
async def get_total_count(self) -> int:
"""获取消息总数"""
result = await self.db.execute(select(func.count(Message.id)))