"""会话服务""" import uuid from datetime import datetime, timezone from sqlalchemy import select, and_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.conversation import Conversation from app.models.conversation_member import ConversationMember from app.models.user import User class ConversationService: def __init__(self, db: AsyncSession): self.db = db async def get_or_create_private(self, user1_id: str, user2_id: str) -> Conversation: """获取或创建私聊会话""" # 查找已有的私聊 result = await self.db.execute( select(Conversation).join(ConversationMember) .where( Conversation.type == "private", ConversationMember.user_id == user1_id, ) ) for conv in result.scalars().all(): member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv.id, ConversationMember.user_id == user2_id, ) ) if member_result.scalars().first(): return conv # 创建新私聊 conv = Conversation(id=str(uuid.uuid4()), type="private") self.db.add(conv) await self.db.flush() # 添加两个成员 self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user1_id, role="member" )) self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user2_id, role="member" )) return conv async def create_group(self, creator_id: str, name: str, member_ids: list[str], description: str | None = None) -> Conversation: """创建群聊""" conv = Conversation( id=str(uuid.uuid4()), type="group", name=name, description=description, creator_id=creator_id, ) self.db.add(conv) await self.db.flush() # 创建者为 owner self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=creator_id, role="owner" )) # 其他成员 for mid in member_ids: if mid != creator_id: self.db.add(ConversationMember( id=str(uuid.uuid4()), conversation_id=conv.id, user_id=mid, role="member" )) return conv async def get_user_conversations(self, user_id: str) -> list[dict]: """获取用户的会话列表""" result = await self.db.execute( select(ConversationMember) .where(ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None)) .order_by(ConversationMember.joined_at.desc()) ) members = result.scalars().all() conversations = [] for member in members: conv_result = await self.db.execute( select(Conversation).where(Conversation.id == member.conversation_id) ) conv = conv_result.scalars().first() if not conv: continue # 获取未读数 unread = await self._get_unread_count(conv.id, member.last_read_message_id) # 获取显示信息 display_name = conv.name display_avatar = conv.avatar_url if conv.type == "private": other = await self._get_other_member(conv.id, user_id) if other: display_name = other.username display_avatar = other.avatar_url conversations.append({ "id": conv.id, "type": conv.type, "name": display_name, "avatar_url": display_avatar, "description": conv.description, "last_message_preview": conv.last_message_preview, "last_message_at": conv.last_message_at, "unread_count": unread, "created_at": conv.created_at, }) # 按最后消息时间排序 conversations.sort(key=lambda x: x["last_message_at"] or x["created_at"], reverse=True) return conversations async def get_conversation_detail(self, conv_id: str, user_id: str) -> dict | None: """获取会话详情""" conv_result = await self.db.execute( select(Conversation).where(Conversation.id == conv_id) ) conv = conv_result.scalars().first() if not conv: return None # 验证成员身份 member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None), ) ) if not member_result.scalars().first(): return None # 获取所有成员 members_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.left_at.is_(None), ) ) members = [] for m in members_result.scalars().all(): user_result = await self.db.execute(select(User).where(User.id == m.user_id)) user = user_result.scalars().first() if user: members.append({ "id": m.id, "user_id": user.id, "username": user.username, "nickname": user.bio, "avatar_url": user.avatar_url, "role": m.role, "joined_at": m.joined_at, }) return { "id": conv.id, "type": conv.type, "name": conv.name, "avatar_url": conv.avatar_url, "description": conv.description, "last_message_preview": conv.last_message_preview, "last_message_at": conv.last_message_at, "created_at": conv.created_at, "members": members, "unread_count": await self._get_unread_count(conv.id, None), } async def _get_other_member(self, conv_id: str, user_id: str) -> User | None: """获取私聊的对方用户""" result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv_id, ConversationMember.user_id != user_id, ) ) member = result.scalars().first() if member: user_result = await self.db.execute(select(User).where(User.id == member.user_id)) return user_result.scalars().first() return None async def _get_unread_count(self, conv_id: str, last_read_id: str | None) -> int: """计算未读消息数""" from app.models.message import Message query = select(func := __import__("sqlalchemy").func).count(Message.id).where( Message.conversation_id == conv_id, Message.is_deleted == False, ) if last_read_id: # 获取 last_read 消息的时间 lr = await self.db.execute(select(Message).where(Message.id == last_read_id)) lr_msg = lr.scalars().first() if lr_msg: query = query.where(Message.created_at > lr_msg.created_at) result = await self.db.execute(query) return result.scalar() or 0