1.9
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
"""群公告服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.group_announcement import GroupAnnouncement
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class AnnouncementService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get(self, conversation_id: str) -> dict | None:
|
||||
result = await self.db.execute(
|
||||
select(GroupAnnouncement).where(
|
||||
GroupAnnouncement.conversation_id == conversation_id
|
||||
).order_by(GroupAnnouncement.updated_at.desc())
|
||||
)
|
||||
ann = result.scalars().first()
|
||||
if not ann:
|
||||
return None
|
||||
return self._to_dict(ann)
|
||||
|
||||
async def upsert(self, conversation_id: str, author_id: str, content: str) -> dict:
|
||||
# 校验管理员
|
||||
conv_service = ConversationService(self.db)
|
||||
await conv_service._get_conv_if_admin(conversation_id, author_id)
|
||||
|
||||
result = await self.db.execute(
|
||||
select(GroupAnnouncement).where(GroupAnnouncement.conversation_id == conversation_id)
|
||||
)
|
||||
ann = result.scalars().first()
|
||||
if ann:
|
||||
ann.content = content
|
||||
ann.author_id = author_id
|
||||
ann.updated_at = datetime.utcnow()
|
||||
else:
|
||||
ann = GroupAnnouncement(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
author_id=author_id, content=content,
|
||||
)
|
||||
self.db.add(ann)
|
||||
await self.db.flush()
|
||||
return self._to_dict(ann)
|
||||
|
||||
def _to_dict(self, ann: GroupAnnouncement) -> dict:
|
||||
return {
|
||||
"id": ann.id,
|
||||
"conversation_id": ann.conversation_id,
|
||||
"author_id": ann.author_id,
|
||||
"content": ann.content,
|
||||
"updated_at": ann.updated_at.isoformat() if ann.updated_at else None,
|
||||
}
|
||||
@@ -178,6 +178,21 @@ class ConversationService:
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_prefs(self, conv_id: str, user_id: str,
|
||||
is_pinned: bool | None = None,
|
||||
is_muted: bool | None = None) -> dict:
|
||||
"""更新会话个人偏好(置顶/免打扰)"""
|
||||
member = await self._get_member(conv_id, user_id)
|
||||
if not member:
|
||||
raise ValueError("无权访问该会话")
|
||||
if is_pinned is not None:
|
||||
member.is_pinned = is_pinned
|
||||
member.pinned_at = datetime.utcnow() if is_pinned else None
|
||||
if is_muted is not None:
|
||||
member.is_muted = is_muted
|
||||
await self.db.flush()
|
||||
return {"is_pinned": member.is_pinned, "is_muted": member.is_muted}
|
||||
|
||||
async def get_user_conversations(self, user_id: str) -> list[dict]:
|
||||
"""获取用户的会话列表"""
|
||||
result = await self.db.execute(
|
||||
@@ -217,9 +232,15 @@ class ConversationService:
|
||||
"last_message_at": conv.last_message_at,
|
||||
"unread_count": unread,
|
||||
"created_at": conv.created_at,
|
||||
"is_pinned": member.is_pinned,
|
||||
"is_muted": member.is_muted,
|
||||
})
|
||||
|
||||
conversations.sort(key=lambda x: x["last_message_at"] or x["created_at"], reverse=True)
|
||||
# 排序:置顶优先,组内按最后消息时间倒序
|
||||
conversations.sort(key=lambda x: (
|
||||
x["is_pinned"],
|
||||
x["last_message_at"] or x["created_at"],
|
||||
), reverse=True)
|
||||
return conversations
|
||||
|
||||
async def get_conversation_detail(self, conv_id: str, user_id: str) -> dict | None:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""草稿服务(Redis 存储,零迁移)"""
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class DraftService:
|
||||
def __init__(self):
|
||||
self._redis = None
|
||||
|
||||
async def _get_redis(self):
|
||||
if self._redis is None:
|
||||
self._redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return self._redis
|
||||
|
||||
async def get(self, user_id: str, conv_id: str) -> str:
|
||||
r = await self._get_redis()
|
||||
return await r.get(f"draft:{user_id}:{conv_id}") or ""
|
||||
|
||||
async def set(self, user_id: str, conv_id: str, text: str):
|
||||
r = await self._get_redis()
|
||||
key = f"draft:{user_id}:{conv_id}"
|
||||
if text:
|
||||
await r.set(key, text, ex=30 * 86400) # 30 天 TTL
|
||||
else:
|
||||
await r.delete(key)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""邮件服务(假实现:验证码打印到日志;以后替换为 SMTP)"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
|
||||
def generate_code() -> str:
|
||||
"""生成 6 位数字验证码"""
|
||||
return f"{secrets.randbelow(1000000):06d}"
|
||||
|
||||
|
||||
def hash_code(code: str) -> str:
|
||||
return hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
|
||||
async def send_verification_email(to_email: str, code: str, purpose: str = "验证"):
|
||||
"""发送验证邮件(开发期假实现:打印到日志)
|
||||
|
||||
生产环境替换此函数为真实 SMTP 发送即可,调用方无需改动。
|
||||
"""
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"📧 [邮件服务-开发模式] 收件人: {to_email}")
|
||||
print(f"📋 用途: {purpose}")
|
||||
print(f"🔢 验证码: {code}")
|
||||
print(f"{'=' * 50}\n")
|
||||
return True
|
||||
@@ -1,5 +1,6 @@
|
||||
"""消息服务"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -9,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.message import Message
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation_member import ConversationMember
|
||||
from app.models.message_reaction import MessageReaction
|
||||
|
||||
|
||||
class MessageService:
|
||||
@@ -17,8 +19,24 @@ class MessageService:
|
||||
|
||||
async def send_message(self, conversation_id: str, sender_id: str,
|
||||
content: str, msg_type: str = "text",
|
||||
reply_to_id: str | None = None) -> Message:
|
||||
reply_to_id: str | None = None,
|
||||
mentioned_user_ids: list[str] | None = None) -> Message:
|
||||
"""发送消息"""
|
||||
# 全员禁言检查(群聊,非管理员不能发)
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if conv and conv.type == "group" and conv.mute_all:
|
||||
from app.models.conversation_member import ConversationMember as CM
|
||||
m_result = await self.db.execute(
|
||||
select(CM).where(CM.conversation_id == conversation_id, CM.user_id == sender_id)
|
||||
)
|
||||
m = m_result.scalars().first()
|
||||
if m and m.role == "member":
|
||||
raise ValueError("群已开启全员禁言")
|
||||
|
||||
import json
|
||||
message = Message(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
@@ -26,6 +44,7 @@ class MessageService:
|
||||
type=msg_type,
|
||||
content=content,
|
||||
reply_to_id=reply_to_id,
|
||||
mentions=json.dumps(mentioned_user_ids) if mentioned_user_ids else None,
|
||||
)
|
||||
self.db.add(message)
|
||||
|
||||
@@ -101,6 +120,27 @@ class MessageService:
|
||||
)
|
||||
reply_senders_map = {u.id: u for u in reply_senders_result.scalars().all()}
|
||||
|
||||
# 批量预加载这些消息的表情回应
|
||||
msg_ids = [m.id for m in messages]
|
||||
reaction_result = await self.db.execute(
|
||||
select(MessageReaction).where(MessageReaction.message_id.in_(msg_ids))
|
||||
)
|
||||
from app.models.user import User as U
|
||||
reactions_raw = reaction_result.scalars().all()
|
||||
reaction_user_ids = list(set(r.user_id for r in reactions_raw))
|
||||
reaction_users_map: dict[str, U] = {}
|
||||
if reaction_user_ids:
|
||||
ru_result = await self.db.execute(select(U).where(U.id.in_(reaction_user_ids)))
|
||||
reaction_users_map = {u.id: u for u in ru_result.scalars().all()}
|
||||
reaction_map: dict[str, list[dict]] = {}
|
||||
for r in reactions_raw:
|
||||
u = reaction_users_map.get(r.user_id)
|
||||
reaction_map.setdefault(r.message_id, []).append({
|
||||
"emoji": r.emoji,
|
||||
"user_id": r.user_id,
|
||||
"username": u.username if u else None,
|
||||
})
|
||||
|
||||
message_list = []
|
||||
for msg in reversed(messages):
|
||||
sender = senders_map.get(msg.sender_id)
|
||||
@@ -126,7 +166,10 @@ class MessageService:
|
||||
"reply_to_id": msg.reply_to_id,
|
||||
"reply_to_content": reply_to_content,
|
||||
"reply_to_sender_name": reply_to_sender_name,
|
||||
"mentions": json.loads(msg.mentions) if msg.mentions else None,
|
||||
"is_deleted": msg.is_deleted,
|
||||
"is_recalled": msg.is_recalled,
|
||||
"reactions": reaction_map.get(msg.id, []),
|
||||
"created_at": msg.created_at,
|
||||
})
|
||||
|
||||
@@ -158,6 +201,43 @@ class MessageService:
|
||||
raise ValueError("消息不存在或无权删除")
|
||||
message.is_deleted = True
|
||||
|
||||
async def recall_message(self, message_id: str, user_id: str):
|
||||
"""撤回消息(仅本人,2 分钟内)"""
|
||||
result = await self.db.execute(
|
||||
select(Message).where(Message.id == message_id, Message.sender_id == user_id)
|
||||
)
|
||||
message = result.scalars().first()
|
||||
if not message:
|
||||
raise ValueError("消息不存在或无权撤回")
|
||||
if message.is_recalled:
|
||||
raise ValueError("消息已撤回")
|
||||
elapsed = (datetime.utcnow() - message.created_at).total_seconds()
|
||||
if elapsed > 120:
|
||||
raise ValueError("超过 2 分钟,无法撤回")
|
||||
message.is_recalled = True
|
||||
message.recalled_at = datetime.utcnow()
|
||||
|
||||
async def react(self, message_id: str, user_id: str, emoji: str):
|
||||
"""添加表情回应(toggle:已存在则取消)"""
|
||||
result = await self.db.execute(
|
||||
select(MessageReaction).where(
|
||||
MessageReaction.message_id == message_id,
|
||||
MessageReaction.user_id == user_id,
|
||||
MessageReaction.emoji == emoji,
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
if existing:
|
||||
await self.db.delete(existing)
|
||||
return {"action": "removed", "emoji": emoji}
|
||||
reaction = MessageReaction(
|
||||
id=str(uuid.uuid4()),
|
||||
message_id=message_id, user_id=user_id, emoji=emoji,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
await self.db.flush()
|
||||
return {"action": "added", "emoji": emoji}
|
||||
|
||||
async def get_total_count(self) -> int:
|
||||
"""获取消息总数"""
|
||||
result = await self.db.execute(select(func.count(Message.id)))
|
||||
|
||||
Reference in New Issue
Block a user