"""会话服务""" import uuid from datetime import datetime from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.conversation import Conversation from app.models.conversation_member import ConversationMember from app.models.user import User class ConversationService: def __init__(self, db: AsyncSession): self.db = db async def get_or_create_private(self, user1_id: str, user2_id: str) -> Conversation: """获取或创建私聊会话""" result = await self.db.execute( select(Conversation).join(ConversationMember) .where( Conversation.type == "private", ConversationMember.user_id == user1_id, ) ) for conv in result.scalars().all(): member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv.id, ConversationMember.user_id == user2_id, ) ) if member_result.scalars().first(): return conv conv = Conversation(id=str(uuid.uuid4()), type="private") self.db.add(conv) await self.db.flush() self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user1_id, role="member" )) self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user2_id, role="member" )) return conv async def create_group(self, creator_id: str, name: str, member_ids: list[str], description: str | None = None) -> Conversation: """创建群聊""" conv = Conversation( id=str(uuid.uuid4()), type="group", name=name, description=description, creator_id=creator_id, ) self.db.add(conv) await self.db.flush() self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=creator_id, role="owner" )) for mid in member_ids: if mid != creator_id: self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=mid, role="member" )) return conv async def update_group(self, conv_id: str, user_id: str, **kwargs): """更新群聊信息(仅群主/管理员)""" conv = await self._get_conv_if_admin(conv_id, user_id) for key, value in kwargs.items(): if value is not None and hasattr(conv, key): setattr(conv, key, value) async def add_members(self, conv_id: str, user_id: str, new_member_ids: list[str]): """添加群成员(仅群主/管理员)""" await self._get_conv_if_admin(conv_id, user_id) for mid in new_member_ids: # 检查是否已在群中 existing = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.user_id == mid, ConversationMember.left_at.is_(None), ) ) if not existing.scalars().first(): self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv_id, user_id=mid, role="member" )) async def remove_member(self, conv_id: str, user_id: str, target_user_id: str): """移除群成员(仅群主/管理员,不能移除群主)""" await self._get_conv_if_admin(conv_id, user_id) member = await self._get_member(conv_id, target_user_id) if not member: raise ValueError("该用户不在群中") if member.role == "owner": raise ValueError("不能移除群主") member.left_at = datetime.utcnow() async def leave_group(self, conv_id: str, user_id: str): """退出群聊""" member = await self._get_member(conv_id, user_id) if not member: raise ValueError("你不在该群中") if member.role == "owner": raise ValueError("群主不能退出,请先转让群主身份") member.left_at = datetime.utcnow() async def dissolve_group(self, conv_id: str, user_id: str): """解散群聊(仅群主)""" member = await self._get_member(conv_id, user_id) if not member: raise ValueError("你不在该群中") if member.role != "owner": raise ValueError("只有群主可以解散群聊") # 验证会话存在且为群聊 conv_result = await self.db.execute( select(Conversation).where(Conversation.id == conv_id) ) conv = conv_result.scalars().first() if not conv or conv.type != "group": raise ValueError("群聊不存在") # 软删除所有成员(设置 left_at) members_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.left_at.is_(None), ) ) now = datetime.utcnow() for m in members_result.scalars().all(): m.left_at = now async def update_member_role(self, conv_id: str, user_id: str, target_user_id: str, role: str): """修改成员角色(仅群主)""" member = await self._get_member(conv_id, user_id) if not member or member.role != "owner": raise ValueError("只有群主可以修改角色") target = await self._get_member(conv_id, target_user_id) if not target: raise ValueError("目标用户不在群中") target.role = role async def _get_conv_if_admin(self, conv_id: str, user_id: str) -> Conversation: """获取会话并验证管理员权限""" conv_result = await self.db.execute( select(Conversation).where(Conversation.id == conv_id) ) conv = conv_result.scalars().first() if not conv: raise ValueError("会话不存在") if conv.type != "group": raise ValueError("仅群聊支持此操作") member = await self._get_member(conv_id, user_id) if not member or member.role not in ("owner", "admin"): raise ValueError("仅群主或管理员可执行此操作") return conv async def _get_member(self, conv_id: str, user_id: str) -> ConversationMember | None: """获取成员记录""" result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None), ) ) return result.scalars().first() async def update_prefs(self, conv_id: str, user_id: str, is_pinned: bool | None = None, is_muted: bool | None = None) -> dict: """更新会话个人偏好(置顶/免打扰)""" member = await self._get_member(conv_id, user_id) if not member: raise ValueError("无权访问该会话") if is_pinned is not None: member.is_pinned = is_pinned member.pinned_at = datetime.utcnow() if is_pinned else None if is_muted is not None: member.is_muted = is_muted await self.db.flush() return {"is_pinned": member.is_pinned, "is_muted": member.is_muted} async def get_user_conversations(self, user_id: str) -> list[dict]: """获取用户的会话列表""" result = await self.db.execute( select(ConversationMember) .where(ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None)) .order_by(ConversationMember.joined_at.desc()) ) members = result.scalars().all() conversations = [] for member in members: conv_result = await self.db.execute( select(Conversation).where(Conversation.id == member.conversation_id) ) conv = conv_result.scalars().first() if not conv: continue unread = await self._get_unread_count(conv.id, member.last_read_message_id) display_name = conv.name display_avatar = conv.avatar_url if conv.type == "private": other = await self._get_other_member(conv.id, user_id) if other: display_name = other.nickname or other.username display_avatar = other.avatar_url conversations.append({ "id": conv.id, "type": conv.type, "name": display_name, "avatar_url": display_avatar, "description": conv.description, "last_message_preview": conv.last_message_preview, "last_message_at": conv.last_message_at, "unread_count": unread, "created_at": conv.created_at, "is_pinned": member.is_pinned, "is_muted": member.is_muted, }) # 排序:置顶优先,组内按最后消息时间倒序 conversations.sort(key=lambda x: ( x["is_pinned"], x["last_message_at"] or x["created_at"], ), reverse=True) return conversations async def get_conversation_detail(self, conv_id: str, user_id: str) -> dict | None: """获取会话详情""" conv_result = await self.db.execute( select(Conversation).where(Conversation.id == conv_id) ) conv = conv_result.scalars().first() if not conv: return None member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None), ) ) if not member_result.scalars().first(): return None members_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.left_at.is_(None), ) ) member_rows = members_result.scalars().all() # 批量获取所有成员用户信息 member_user_ids = [m.user_id for m in member_rows] users_result = await self.db.execute( select(User).where(User.id.in_(member_user_ids)) ) users_map = {u.id: u for u in users_result.scalars().all()} members = [] for m in member_rows: user = users_map.get(m.user_id) if user: members.append({ "id": m.id, "user_id": user.id, "username": user.username, "nickname": user.nickname, "avatar_url": user.avatar_url, "role": m.role, "joined_at": m.joined_at, }) return { "id": conv.id, "type": conv.type, "name": conv.name, "avatar_url": conv.avatar_url, "description": conv.description, "last_message_preview": conv.last_message_preview, "last_message_at": conv.last_message_at, "created_at": conv.created_at, "members": members, "unread_count": await self._get_unread_count(conv.id, None), } async def _get_other_member(self, conv_id: str, user_id: str) -> User | None: """获取私聊的对方用户""" result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.user_id != user_id, ) ) member = result.scalars().first() if member: user_result = await self.db.execute(select(User).where(User.id == member.user_id)) return user_result.scalars().first() return None async def _get_unread_count(self, conv_id: str, last_read_id: str | None) -> int: """计算未读消息数""" from app.models.message import Message query = select(func.count(Message.id)).where( Message.conversation_id == conv_id, Message.is_deleted == False, ) if last_read_id: lr = await self.db.execute(select(Message).where(Message.id == last_read_id)) lr_msg = lr.scalars().first() if lr_msg: query = query.where(Message.created_at > lr_msg.created_at) result = await self.db.execute(query) return result.scalar() or 0