335 lines
13 KiB
Python
335 lines
13 KiB
Python
"""会话服务"""
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.conversation import Conversation
|
|
from app.models.conversation_member import ConversationMember
|
|
from app.models.user import User
|
|
|
|
|
|
class ConversationService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def get_or_create_private(self, user1_id: str, user2_id: str) -> Conversation:
|
|
"""获取或创建私聊会话"""
|
|
result = await self.db.execute(
|
|
select(Conversation).join(ConversationMember)
|
|
.where(
|
|
Conversation.type == "private",
|
|
ConversationMember.user_id == user1_id,
|
|
)
|
|
)
|
|
for conv in result.scalars().all():
|
|
member_result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv.id,
|
|
ConversationMember.user_id == user2_id,
|
|
)
|
|
)
|
|
if member_result.scalars().first():
|
|
return conv
|
|
|
|
conv = Conversation(id=str(uuid.uuid4()), type="private")
|
|
self.db.add(conv)
|
|
await self.db.flush()
|
|
|
|
self.db.add(ConversationMember(
|
|
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user1_id, role="member"
|
|
))
|
|
self.db.add(ConversationMember(
|
|
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user2_id, role="member"
|
|
))
|
|
return conv
|
|
|
|
async def create_group(self, creator_id: str, name: str, member_ids: list[str],
|
|
description: str | None = None) -> Conversation:
|
|
"""创建群聊"""
|
|
conv = Conversation(
|
|
id=str(uuid.uuid4()),
|
|
type="group",
|
|
name=name,
|
|
description=description,
|
|
creator_id=creator_id,
|
|
)
|
|
self.db.add(conv)
|
|
await self.db.flush()
|
|
|
|
self.db.add(ConversationMember(
|
|
id=str(uuid.uuid4()), conversation_id=conv.id,
|
|
user_id=creator_id, role="owner"
|
|
))
|
|
for mid in member_ids:
|
|
if mid != creator_id:
|
|
self.db.add(ConversationMember(
|
|
id=str(uuid.uuid4()), conversation_id=conv.id,
|
|
user_id=mid, role="member"
|
|
))
|
|
return conv
|
|
|
|
async def update_group(self, conv_id: str, user_id: str, **kwargs):
|
|
"""更新群聊信息(仅群主/管理员)"""
|
|
conv = await self._get_conv_if_admin(conv_id, user_id)
|
|
for key, value in kwargs.items():
|
|
if value is not None and hasattr(conv, key):
|
|
setattr(conv, key, value)
|
|
|
|
async def add_members(self, conv_id: str, user_id: str, new_member_ids: list[str]):
|
|
"""添加群成员(仅群主/管理员)"""
|
|
await self._get_conv_if_admin(conv_id, user_id)
|
|
for mid in new_member_ids:
|
|
# 检查是否已在群中
|
|
existing = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv_id,
|
|
ConversationMember.user_id == mid,
|
|
ConversationMember.left_at.is_(None),
|
|
)
|
|
)
|
|
if not existing.scalars().first():
|
|
self.db.add(ConversationMember(
|
|
id=str(uuid.uuid4()), conversation_id=conv_id,
|
|
user_id=mid, role="member"
|
|
))
|
|
|
|
async def remove_member(self, conv_id: str, user_id: str, target_user_id: str):
|
|
"""移除群成员(仅群主/管理员,不能移除群主)"""
|
|
await self._get_conv_if_admin(conv_id, user_id)
|
|
member = await self._get_member(conv_id, target_user_id)
|
|
if not member:
|
|
raise ValueError("该用户不在群中")
|
|
if member.role == "owner":
|
|
raise ValueError("不能移除群主")
|
|
member.left_at = datetime.utcnow()
|
|
|
|
async def leave_group(self, conv_id: str, user_id: str):
|
|
"""退出群聊"""
|
|
member = await self._get_member(conv_id, user_id)
|
|
if not member:
|
|
raise ValueError("你不在该群中")
|
|
if member.role == "owner":
|
|
raise ValueError("群主不能退出,请先转让群主身份")
|
|
member.left_at = datetime.utcnow()
|
|
|
|
async def dissolve_group(self, conv_id: str, user_id: str):
|
|
"""解散群聊(仅群主)"""
|
|
member = await self._get_member(conv_id, user_id)
|
|
if not member:
|
|
raise ValueError("你不在该群中")
|
|
if member.role != "owner":
|
|
raise ValueError("只有群主可以解散群聊")
|
|
|
|
# 验证会话存在且为群聊
|
|
conv_result = await self.db.execute(
|
|
select(Conversation).where(Conversation.id == conv_id)
|
|
)
|
|
conv = conv_result.scalars().first()
|
|
if not conv or conv.type != "group":
|
|
raise ValueError("群聊不存在")
|
|
|
|
# 软删除所有成员(设置 left_at)
|
|
members_result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv_id,
|
|
ConversationMember.left_at.is_(None),
|
|
)
|
|
)
|
|
now = datetime.utcnow()
|
|
for m in members_result.scalars().all():
|
|
m.left_at = now
|
|
|
|
async def update_member_role(self, conv_id: str, user_id: str, target_user_id: str, role: str):
|
|
"""修改成员角色(仅群主)"""
|
|
member = await self._get_member(conv_id, user_id)
|
|
if not member or member.role != "owner":
|
|
raise ValueError("只有群主可以修改角色")
|
|
target = await self._get_member(conv_id, target_user_id)
|
|
if not target:
|
|
raise ValueError("目标用户不在群中")
|
|
target.role = role
|
|
|
|
async def _get_conv_if_admin(self, conv_id: str, user_id: str) -> Conversation:
|
|
"""获取会话并验证管理员权限"""
|
|
conv_result = await self.db.execute(
|
|
select(Conversation).where(Conversation.id == conv_id)
|
|
)
|
|
conv = conv_result.scalars().first()
|
|
if not conv:
|
|
raise ValueError("会话不存在")
|
|
if conv.type != "group":
|
|
raise ValueError("仅群聊支持此操作")
|
|
member = await self._get_member(conv_id, user_id)
|
|
if not member or member.role not in ("owner", "admin"):
|
|
raise ValueError("仅群主或管理员可执行此操作")
|
|
return conv
|
|
|
|
async def _get_member(self, conv_id: str, user_id: str) -> ConversationMember | None:
|
|
"""获取成员记录"""
|
|
result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv_id,
|
|
ConversationMember.user_id == user_id,
|
|
ConversationMember.left_at.is_(None),
|
|
)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
async def update_prefs(self, conv_id: str, user_id: str,
|
|
is_pinned: bool | None = None,
|
|
is_muted: bool | None = None) -> dict:
|
|
"""更新会话个人偏好(置顶/免打扰)"""
|
|
member = await self._get_member(conv_id, user_id)
|
|
if not member:
|
|
raise ValueError("无权访问该会话")
|
|
if is_pinned is not None:
|
|
member.is_pinned = is_pinned
|
|
member.pinned_at = datetime.utcnow() if is_pinned else None
|
|
if is_muted is not None:
|
|
member.is_muted = is_muted
|
|
await self.db.flush()
|
|
return {"is_pinned": member.is_pinned, "is_muted": member.is_muted}
|
|
|
|
async def get_user_conversations(self, user_id: str) -> list[dict]:
|
|
"""获取用户的会话列表"""
|
|
result = await self.db.execute(
|
|
select(ConversationMember)
|
|
.where(ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None))
|
|
.order_by(ConversationMember.joined_at.desc())
|
|
)
|
|
members = result.scalars().all()
|
|
|
|
conversations = []
|
|
for member in members:
|
|
conv_result = await self.db.execute(
|
|
select(Conversation).where(Conversation.id == member.conversation_id)
|
|
)
|
|
conv = conv_result.scalars().first()
|
|
if not conv:
|
|
continue
|
|
|
|
unread = await self._get_unread_count(conv.id, member.last_read_message_id)
|
|
|
|
display_name = conv.name
|
|
display_avatar = conv.avatar_url
|
|
|
|
if conv.type == "private":
|
|
other = await self._get_other_member(conv.id, user_id)
|
|
if other:
|
|
display_name = other.nickname or other.username
|
|
display_avatar = other.avatar_url
|
|
|
|
conversations.append({
|
|
"id": conv.id,
|
|
"type": conv.type,
|
|
"name": display_name,
|
|
"avatar_url": display_avatar,
|
|
"description": conv.description,
|
|
"last_message_preview": conv.last_message_preview,
|
|
"last_message_at": conv.last_message_at,
|
|
"unread_count": unread,
|
|
"created_at": conv.created_at,
|
|
"is_pinned": member.is_pinned,
|
|
"is_muted": member.is_muted,
|
|
})
|
|
|
|
# 排序:置顶优先,组内按最后消息时间倒序
|
|
conversations.sort(key=lambda x: (
|
|
x["is_pinned"],
|
|
x["last_message_at"] or x["created_at"],
|
|
), reverse=True)
|
|
return conversations
|
|
|
|
async def get_conversation_detail(self, conv_id: str, user_id: str) -> dict | None:
|
|
"""获取会话详情"""
|
|
conv_result = await self.db.execute(
|
|
select(Conversation).where(Conversation.id == conv_id)
|
|
)
|
|
conv = conv_result.scalars().first()
|
|
if not conv:
|
|
return None
|
|
|
|
member_result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv_id,
|
|
ConversationMember.user_id == user_id,
|
|
ConversationMember.left_at.is_(None),
|
|
)
|
|
)
|
|
if not member_result.scalars().first():
|
|
return None
|
|
|
|
members_result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv_id,
|
|
ConversationMember.left_at.is_(None),
|
|
)
|
|
)
|
|
member_rows = members_result.scalars().all()
|
|
|
|
# 批量获取所有成员用户信息
|
|
member_user_ids = [m.user_id for m in member_rows]
|
|
users_result = await self.db.execute(
|
|
select(User).where(User.id.in_(member_user_ids))
|
|
)
|
|
users_map = {u.id: u for u in users_result.scalars().all()}
|
|
|
|
members = []
|
|
for m in member_rows:
|
|
user = users_map.get(m.user_id)
|
|
if user:
|
|
members.append({
|
|
"id": m.id,
|
|
"user_id": user.id,
|
|
"username": user.username,
|
|
"nickname": user.nickname,
|
|
"avatar_url": user.avatar_url,
|
|
"role": m.role,
|
|
"joined_at": m.joined_at,
|
|
})
|
|
|
|
return {
|
|
"id": conv.id,
|
|
"type": conv.type,
|
|
"name": conv.name,
|
|
"avatar_url": conv.avatar_url,
|
|
"description": conv.description,
|
|
"last_message_preview": conv.last_message_preview,
|
|
"last_message_at": conv.last_message_at,
|
|
"created_at": conv.created_at,
|
|
"members": members,
|
|
"unread_count": await self._get_unread_count(conv.id, None),
|
|
}
|
|
|
|
async def _get_other_member(self, conv_id: str, user_id: str) -> User | None:
|
|
"""获取私聊的对方用户"""
|
|
result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv_id,
|
|
ConversationMember.user_id != user_id,
|
|
)
|
|
)
|
|
member = result.scalars().first()
|
|
if member:
|
|
user_result = await self.db.execute(select(User).where(User.id == member.user_id))
|
|
return user_result.scalars().first()
|
|
return None
|
|
|
|
async def _get_unread_count(self, conv_id: str, last_read_id: str | None) -> int:
|
|
"""计算未读消息数"""
|
|
from app.models.message import Message
|
|
query = select(func.count(Message.id)).where(
|
|
Message.conversation_id == conv_id,
|
|
Message.is_deleted == False,
|
|
)
|
|
if last_read_id:
|
|
lr = await self.db.execute(select(Message).where(Message.id == last_read_id))
|
|
lr_msg = lr.scalars().first()
|
|
if lr_msg:
|
|
query = query.where(Message.created_at > lr_msg.created_at)
|
|
result = await self.db.execute(query)
|
|
return result.scalar() or 0
|