1.0
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
"""会话服务"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation_member import ConversationMember
|
||||
@@ -18,7 +17,6 @@ class ConversationService:
|
||||
|
||||
async def get_or_create_private(self, user1_id: str, user2_id: str) -> Conversation:
|
||||
"""获取或创建私聊会话"""
|
||||
# 查找已有的私聊
|
||||
result = await self.db.execute(
|
||||
select(Conversation).join(ConversationMember)
|
||||
.where(
|
||||
@@ -36,12 +34,10 @@ class ConversationService:
|
||||
if member_result.scalars().first():
|
||||
return conv
|
||||
|
||||
# 创建新私聊
|
||||
conv = Conversation(id=str(uuid.uuid4()), type="private")
|
||||
self.db.add(conv)
|
||||
await self.db.flush()
|
||||
|
||||
# 添加两个成员
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id, user_id=user1_id, role="member"
|
||||
))
|
||||
@@ -63,12 +59,10 @@ class ConversationService:
|
||||
self.db.add(conv)
|
||||
await self.db.flush()
|
||||
|
||||
# 创建者为 owner
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id,
|
||||
user_id=creator_id, role="owner"
|
||||
))
|
||||
# 其他成员
|
||||
for mid in member_ids:
|
||||
if mid != creator_id:
|
||||
self.db.add(ConversationMember(
|
||||
@@ -77,6 +71,86 @@ class ConversationService:
|
||||
))
|
||||
return conv
|
||||
|
||||
async def update_group(self, conv_id: str, user_id: str, **kwargs):
|
||||
"""更新群聊信息(仅群主/管理员)"""
|
||||
conv = await self._get_conv_if_admin(conv_id, user_id)
|
||||
for key, value in kwargs.items():
|
||||
if value is not None and hasattr(conv, key):
|
||||
setattr(conv, key, value)
|
||||
|
||||
async def add_members(self, conv_id: str, user_id: str, new_member_ids: list[str]):
|
||||
"""添加群成员(仅群主/管理员)"""
|
||||
await self._get_conv_if_admin(conv_id, user_id)
|
||||
for mid in new_member_ids:
|
||||
# 检查是否已在群中
|
||||
existing = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.user_id == mid,
|
||||
ConversationMember.left_at.is_(None),
|
||||
)
|
||||
)
|
||||
if not existing.scalars().first():
|
||||
self.db.add(ConversationMember(
|
||||
id=str(uuid.uuid4()), conversation_id=conv_id,
|
||||
user_id=mid, role="member"
|
||||
))
|
||||
|
||||
async def remove_member(self, conv_id: str, user_id: str, target_user_id: str):
|
||||
"""移除群成员(仅群主/管理员,不能移除群主)"""
|
||||
await self._get_conv_if_admin(conv_id, user_id)
|
||||
member = await self._get_member(conv_id, target_user_id)
|
||||
if not member:
|
||||
raise ValueError("该用户不在群中")
|
||||
if member.role == "owner":
|
||||
raise ValueError("不能移除群主")
|
||||
member.left_at = datetime.utcnow()
|
||||
|
||||
async def leave_group(self, conv_id: str, user_id: str):
|
||||
"""退出群聊"""
|
||||
member = await self._get_member(conv_id, user_id)
|
||||
if not member:
|
||||
raise ValueError("你不在该群中")
|
||||
if member.role == "owner":
|
||||
raise ValueError("群主不能退出,请先转让群主身份")
|
||||
member.left_at = datetime.utcnow()
|
||||
|
||||
async def update_member_role(self, conv_id: str, user_id: str, target_user_id: str, role: str):
|
||||
"""修改成员角色(仅群主)"""
|
||||
member = await self._get_member(conv_id, user_id)
|
||||
if not member or member.role != "owner":
|
||||
raise ValueError("只有群主可以修改角色")
|
||||
target = await self._get_member(conv_id, target_user_id)
|
||||
if not target:
|
||||
raise ValueError("目标用户不在群中")
|
||||
target.role = role
|
||||
|
||||
async def _get_conv_if_admin(self, conv_id: str, user_id: str) -> Conversation:
|
||||
"""获取会话并验证管理员权限"""
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conv_id)
|
||||
)
|
||||
conv = conv_result.scalars().first()
|
||||
if not conv:
|
||||
raise ValueError("会话不存在")
|
||||
if conv.type != "group":
|
||||
raise ValueError("仅群聊支持此操作")
|
||||
member = await self._get_member(conv_id, user_id)
|
||||
if not member or member.role not in ("owner", "admin"):
|
||||
raise ValueError("仅群主或管理员可执行此操作")
|
||||
return conv
|
||||
|
||||
async def _get_member(self, conv_id: str, user_id: str) -> ConversationMember | None:
|
||||
"""获取成员记录"""
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
ConversationMember.left_at.is_(None),
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_conversations(self, user_id: str) -> list[dict]:
|
||||
"""获取用户的会话列表"""
|
||||
result = await self.db.execute(
|
||||
@@ -95,17 +169,15 @@ class ConversationService:
|
||||
if not conv:
|
||||
continue
|
||||
|
||||
# 获取未读数
|
||||
unread = await self._get_unread_count(conv.id, member.last_read_message_id)
|
||||
|
||||
# 获取显示信息
|
||||
display_name = conv.name
|
||||
display_avatar = conv.avatar_url
|
||||
|
||||
if conv.type == "private":
|
||||
other = await self._get_other_member(conv.id, user_id)
|
||||
if other:
|
||||
display_name = other.username
|
||||
display_name = other.nickname or other.username
|
||||
display_avatar = other.avatar_url
|
||||
|
||||
conversations.append({
|
||||
@@ -120,7 +192,6 @@ class ConversationService:
|
||||
"created_at": conv.created_at,
|
||||
})
|
||||
|
||||
# 按最后消息时间排序
|
||||
conversations.sort(key=lambda x: x["last_message_at"] or x["created_at"], reverse=True)
|
||||
return conversations
|
||||
|
||||
@@ -133,7 +204,6 @@ class ConversationService:
|
||||
if not conv:
|
||||
return None
|
||||
|
||||
# 验证成员身份
|
||||
member_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
@@ -144,7 +214,6 @@ class ConversationService:
|
||||
if not member_result.scalars().first():
|
||||
return None
|
||||
|
||||
# 获取所有成员
|
||||
members_result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conv_id,
|
||||
@@ -160,7 +229,7 @@ class ConversationService:
|
||||
"id": m.id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"nickname": user.bio,
|
||||
"nickname": user.nickname,
|
||||
"avatar_url": user.avatar_url,
|
||||
"role": m.role,
|
||||
"joined_at": m.joined_at,
|
||||
@@ -196,12 +265,11 @@ class ConversationService:
|
||||
async def _get_unread_count(self, conv_id: str, last_read_id: str | None) -> int:
|
||||
"""计算未读消息数"""
|
||||
from app.models.message import Message
|
||||
query = select(func := __import__("sqlalchemy").func).count(Message.id).where(
|
||||
query = select(func.count(Message.id)).where(
|
||||
Message.conversation_id == conv_id,
|
||||
Message.is_deleted == False,
|
||||
)
|
||||
if last_read_id:
|
||||
# 获取 last_read 消息的时间
|
||||
lr = await self.db.execute(select(Message).where(Message.id == last_read_id))
|
||||
lr_msg = lr.scalars().first()
|
||||
if lr_msg:
|
||||
|
||||
Reference in New Issue
Block a user