Files
chat/backend/app/services/conversation_service.py
T
2026-06-12 23:14:12 +08:00

211 lines
7.6 KiB
Python

"""会话服务"""
import uuid
from datetime import datetime, timezone
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.conversation import Conversation
from app.models.conversation_member import ConversationMember
from app.models.user import User
class ConversationService:
def __init__(self, db: AsyncSession):
self.db = db
async def get_or_create_private(self, user1_id: str, user2_id: str) -> Conversation:
"""获取或创建私聊会话"""
# 查找已有的私聊
result = await self.db.execute(
select(Conversation).join(ConversationMember)
.where(
Conversation.type == "private",
ConversationMember.user_id == user1_id,
)
)
for conv in result.scalars().all():
member_result = await self.db.execute(
select(ConversationMember).where(
ConversationMember.conversation_id == conv.id,
ConversationMember.user_id == user2_id,
)
)
if member_result.scalars().first():
return conv
# 创建新私聊
conv = Conversation(id=str(uuid.uuid4()), type="private")
self.db.add(conv)
await self.db.flush()
# 添加两个成员
self.db.add(ConversationMember(
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user1_id, role="member"
))
self.db.add(ConversationMember(
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user2_id, role="member"
))
return conv
async def create_group(self, creator_id: str, name: str, member_ids: list[str],
description: str | None = None) -> Conversation:
"""创建群聊"""
conv = Conversation(
id=str(uuid.uuid4()),
type="group",
name=name,
description=description,
creator_id=creator_id,
)
self.db.add(conv)
await self.db.flush()
# 创建者为 owner
self.db.add(ConversationMember(
id=str(uuid.uuid4()), conversation_id=conv.id,
user_id=creator_id, role="owner"
))
# 其他成员
for mid in member_ids:
if mid != creator_id:
self.db.add(ConversationMember(
id=str(uuid.uuid4()), conversation_id=conv.id,
user_id=mid, role="member"
))
return conv
async def get_user_conversations(self, user_id: str) -> list[dict]:
"""获取用户的会话列表"""
result = await self.db.execute(
select(ConversationMember)
.where(ConversationMember.user_id == user_id, ConversationMember.left_at.is_(None))
.order_by(ConversationMember.joined_at.desc())
)
members = result.scalars().all()
conversations = []
for member in members:
conv_result = await self.db.execute(
select(Conversation).where(Conversation.id == member.conversation_id)
)
conv = conv_result.scalars().first()
if not conv:
continue
# 获取未读数
unread = await self._get_unread_count(conv.id, member.last_read_message_id)
# 获取显示信息
display_name = conv.name
display_avatar = conv.avatar_url
if conv.type == "private":
other = await self._get_other_member(conv.id, user_id)
if other:
display_name = other.username
display_avatar = other.avatar_url
conversations.append({
"id": conv.id,
"type": conv.type,
"name": display_name,
"avatar_url": display_avatar,
"description": conv.description,
"last_message_preview": conv.last_message_preview,
"last_message_at": conv.last_message_at,
"unread_count": unread,
"created_at": conv.created_at,
})
# 按最后消息时间排序
conversations.sort(key=lambda x: x["last_message_at"] or x["created_at"], reverse=True)
return conversations
async def get_conversation_detail(self, conv_id: str, user_id: str) -> dict | None:
"""获取会话详情"""
conv_result = await self.db.execute(
select(Conversation).where(Conversation.id == conv_id)
)
conv = conv_result.scalars().first()
if not conv:
return None
# 验证成员身份
member_result = await self.db.execute(
select(ConversationMember).where(
ConversationMember.conversation_id == conv_id,
ConversationMember.user_id == user_id,
ConversationMember.left_at.is_(None),
)
)
if not member_result.scalars().first():
return None
# 获取所有成员
members_result = await self.db.execute(
select(ConversationMember).where(
ConversationMember.conversation_id == conv_id,
ConversationMember.left_at.is_(None),
)
)
members = []
for m in members_result.scalars().all():
user_result = await self.db.execute(select(User).where(User.id == m.user_id))
user = user_result.scalars().first()
if user:
members.append({
"id": m.id,
"user_id": user.id,
"username": user.username,
"nickname": user.bio,
"avatar_url": user.avatar_url,
"role": m.role,
"joined_at": m.joined_at,
})
return {
"id": conv.id,
"type": conv.type,
"name": conv.name,
"avatar_url": conv.avatar_url,
"description": conv.description,
"last_message_preview": conv.last_message_preview,
"last_message_at": conv.last_message_at,
"created_at": conv.created_at,
"members": members,
"unread_count": await self._get_unread_count(conv.id, None),
}
async def _get_other_member(self, conv_id: str, user_id: str) -> User | None:
"""获取私聊的对方用户"""
result = await self.db.execute(
select(ConversationMember).where(
ConversationMember.conversation_id == conv_id,
ConversationMember.user_id != user_id,
)
)
member = result.scalars().first()
if member:
user_result = await self.db.execute(select(User).where(User.id == member.user_id))
return user_result.scalars().first()
return None
async def _get_unread_count(self, conv_id: str, last_read_id: str | None) -> int:
"""计算未读消息数"""
from app.models.message import Message
query = select(func := __import__("sqlalchemy").func).count(Message.id).where(
Message.conversation_id == conv_id,
Message.is_deleted == False,
)
if last_read_id:
# 获取 last_read 消息的时间
lr = await self.db.execute(select(Message).where(Message.id == last_read_id))
lr_msg = lr.scalars().first()
if lr_msg:
query = query.where(Message.created_at > lr_msg.created_at)
result = await self.db.execute(query)
return result.scalar() or 0