This commit is contained in:
2026-06-13 11:02:47 +08:00
parent 318ddd85a5
commit 68678304ff
15 changed files with 659 additions and 78 deletions
+25
View File
@@ -7,6 +7,7 @@ from app.dependencies import get_db, get_current_user
from app.models.user import User
from app.schemas.message import MessageSend, MessagePage, MarkReadRequest
from app.services.message_service import MessageService
from app.services.conversation_service import ConversationService
router = APIRouter()
@@ -27,6 +28,30 @@ async def get_messages(
raise HTTPException(status_code=403, detail=str(e))
@router.get("/{conversation_id}/messages/search")
async def search_messages(
conversation_id: str,
q: str = Query(..., min_length=1, max_length=200),
limit: int = Query(20, ge=1, le=50),
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""在会话内搜索消息(用户级)"""
# 验证成员身份
conv_service = ConversationService(db)
detail = await conv_service.get_conversation_detail(conversation_id, user.id)
if not detail:
raise HTTPException(status_code=403, detail="无权访问该会话")
service = MessageService(db)
results = await service.search_messages(
conversation_id=conversation_id,
keyword=q,
limit=limit,
)
return {"results": results}
@router.put("/{conversation_id}/messages/{message_id}/read")
async def mark_as_read(
conversation_id: str,
+11 -3
View File
@@ -247,10 +247,18 @@ class ConversationService:
ConversationMember.left_at.is_(None),
)
)
member_rows = members_result.scalars().all()
# 批量获取所有成员用户信息
member_user_ids = [m.user_id for m in member_rows]
users_result = await self.db.execute(
select(User).where(User.id.in_(member_user_ids))
)
users_map = {u.id: u for u in users_result.scalars().all()}
members = []
for m in members_result.scalars().all():
user_result = await self.db.execute(select(User).where(User.id == m.user_id))
user = user_result.scalars().first()
for m in member_rows:
user = users_map.get(m.user_id)
if user:
members.append({
"id": m.id,
+26 -8
View File
@@ -102,12 +102,20 @@ class FriendService:
result = await self.db.execute(
select(Friend).where(Friend.user_id == user_id)
)
friendships = result.scalars().all()
if not friendships:
return []
# 批量获取所有好友用户
friend_user_ids = [f.friend_user_id for f in friendships]
users_result = await self.db.execute(
select(User).where(User.id.in_(friend_user_ids))
)
users_map = {u.id: u for u in users_result.scalars().all()}
friends = []
for friendship in result.scalars().all():
user_result = await self.db.execute(
select(User).where(User.id == friendship.friend_user_id)
)
user = user_result.scalars().first()
for friendship in friendships:
user = users_map.get(friendship.friend_user_id)
if user:
friends.append({
"id": friendship.id,
@@ -128,10 +136,20 @@ class FriendService:
FriendRequest.status == "pending",
).order_by(FriendRequest.created_at.desc())
)
requests_list = result.scalars().all()
if not requests_list:
return []
# 批量获取所有发送者
from_user_ids = list(set(r.from_user_id for r in requests_list))
users_result = await self.db.execute(
select(User).where(User.id.in_(from_user_ids))
)
users_map = {u.id: u for u in users_result.scalars().all()}
requests = []
for req in result.scalars().all():
from_user = await self.db.execute(select(User).where(User.id == req.from_user_id))
fu = from_user.scalars().first()
for req in requests_list:
fu = users_map.get(req.from_user_id)
requests.append({
"id": req.id,
"from_user_id": req.from_user_id,
+26 -13
View File
@@ -77,29 +77,42 @@ class MessageService:
has_more = len(messages) > limit
messages = messages[:limit]
# 获取发送者信息
# 批量预加载发送者信息
from app.models.user import User
sender_ids = list(set(m.sender_id for m in messages))
senders_result = await self.db.execute(
select(User).where(User.id.in_(sender_ids))
)
senders_map = {u.id: u for u in senders_result.scalars().all()}
# 批量预加载被引用消息
reply_to_ids = list(set(m.reply_to_id for m in messages if m.reply_to_id))
reply_msgs_map: dict[str, Message] = {}
if reply_to_ids:
reply_result = await self.db.execute(
select(Message).where(Message.id.in_(reply_to_ids))
)
reply_msgs_map = {m.id: m for m in reply_result.scalars().all()}
# 批量预加载被引用消息的发送者
reply_sender_ids = list(set(rm.sender_id for rm in reply_msgs_map.values()))
reply_senders_result = await self.db.execute(
select(User).where(User.id.in_(reply_sender_ids))
)
reply_senders_map = {u.id: u for u in reply_senders_result.scalars().all()}
message_list = []
for msg in reversed(messages):
sender_result = await self.db.execute(
select(User).where(User.id == msg.sender_id)
)
sender = sender_result.scalars().first()
sender = senders_map.get(msg.sender_id)
# 获取被引用消息的信息
reply_to_content = None
reply_to_sender_name = None
if msg.reply_to_id:
reply_msg_result = await self.db.execute(
select(Message).where(Message.id == msg.reply_to_id)
)
reply_msg = reply_msg_result.scalars().first()
reply_msg = reply_msgs_map.get(msg.reply_to_id)
if reply_msg:
reply_to_content = reply_msg.content[:200] if reply_msg.content else None
reply_sender_result = await self.db.execute(
select(User).where(User.id == reply_msg.sender_id)
)
reply_sender = reply_sender_result.scalars().first()
reply_sender = reply_senders_map.get(reply_msg.sender_id)
reply_to_sender_name = reply_sender.username if reply_sender else None
message_list.append({
+57 -43
View File
@@ -82,7 +82,7 @@ class MomentService:
unique.sort(key=lambda x: x.created_at, reverse=True)
unique = unique[:limit]
return [await self._moment_to_dict(m, user_id) for m in unique]
return await self._moments_to_dicts(unique, user_id)
async def get_user_moments(self, user_id: str, viewer_id: str | None = None,
cursor: str | None = None, limit: int = 20) -> list[dict]:
@@ -116,7 +116,7 @@ class MomentService:
if viewer_id == user_id:
filtered.append(m)
return [await self._moment_to_dict(m, viewer_id) for m in filtered]
return await self._moments_to_dicts(filtered, viewer_id)
async def delete_moment(self, moment_id: str, user_id: str):
"""删除动态(仅作者)"""
@@ -243,56 +243,70 @@ class MomentService:
raise ValueError("只能删除自己的评论")
await self.db.delete(comment)
async def _moment_to_dict(self, moment: Moment, viewer_id: str | None) -> dict:
"""将 Moment ORM 对象转为前端需要的字典"""
user_result = await self.db.execute(select(User).where(User.id == moment.user_id))
user = user_result.scalars().first()
async def _moments_to_dicts(self, moments: list[Moment], viewer_id: str | None) -> list[dict]:
"""批量将 Moment ORM 对象转为前端需要的字典(优化 N+1 查询)"""
if not moments:
return []
# 点赞数
like_count_result = await self.db.execute(
select(func.count(MomentLike.id)).where(MomentLike.moment_id == moment.id)
moment_ids = [m.id for m in moments]
user_ids = list(set(m.user_id for m in moments))
# 批量获取所有作者
users_result = await self.db.execute(select(User).where(User.id.in_(user_ids)))
users_map = {u.id: u for u in users_result.scalars().all()}
# 批量获取点赞数
like_counts_result = await self.db.execute(
select(MomentLike.moment_id, func.count(MomentLike.id))
.where(MomentLike.moment_id.in_(moment_ids))
.group_by(MomentLike.moment_id)
)
like_count = like_count_result.scalar() or 0
like_counts_map = dict(like_counts_result.all())
# 是否已点赞
is_liked = False
# 批量获取评论数
comment_counts_result = await self.db.execute(
select(MomentComment.moment_id, func.count(MomentComment.id))
.where(MomentComment.moment_id.in_(moment_ids))
.group_by(MomentComment.moment_id)
)
comment_counts_map = dict(comment_counts_result.all())
# 批量获取当前用户的点赞状态
liked_moment_ids = set()
if viewer_id:
like_result = await self.db.execute(
select(MomentLike).where(
MomentLike.moment_id == moment.id,
liked_result = await self.db.execute(
select(MomentLike.moment_id).where(
MomentLike.moment_id.in_(moment_ids),
MomentLike.user_id == viewer_id,
)
)
is_liked = like_result.scalars().first() is not None
liked_moment_ids = {r[0] for r in liked_result.all()}
# 评论数
comment_count_result = await self.db.execute(
select(func.count(MomentComment.id)).where(MomentComment.moment_id == moment.id)
)
comment_count = comment_count_result.scalar() or 0
result = []
for moment in moments:
user = users_map.get(moment.user_id)
images = []
if moment.images:
try:
images = json.loads(moment.images)
except Exception:
pass
# 解析图片
images = []
if moment.images:
try:
images = json.loads(moment.images)
except:
pass
return {
"id": moment.id,
"user_id": moment.user_id,
"username": user.username if user else "未知",
"nickname": user.nickname if user else None,
"avatar_url": user.avatar_url if user else None,
"content": moment.content,
"images": images,
"visibility": moment.visibility,
"like_count": like_count,
"is_liked": is_liked,
"comment_count": comment_count,
"created_at": moment.created_at,
}
result.append({
"id": moment.id,
"user_id": moment.user_id,
"username": user.username if user else "未知",
"nickname": user.nickname if user else None,
"avatar_url": user.avatar_url if user else None,
"content": moment.content,
"images": images,
"visibility": moment.visibility,
"like_count": like_counts_map.get(moment.id, 0),
"is_liked": moment.id in liked_moment_ids,
"comment_count": comment_counts_map.get(moment.id, 0),
"created_at": moment.created_at,
})
return result
async def _get_friend_ids(self, user_id: str) -> list[str]:
"""获取好友ID列表"""