1.5
This commit is contained in:
@@ -7,6 +7,7 @@ from app.dependencies import get_db, get_current_user
|
||||
from app.models.user import User
|
||||
from app.schemas.message import MessageSend, MessagePage, MarkReadRequest
|
||||
from app.services.message_service import MessageService
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -27,6 +28,30 @@ async def get_messages(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/messages/search")
|
||||
async def search_messages(
|
||||
conversation_id: str,
|
||||
q: str = Query(..., min_length=1, max_length=200),
|
||||
limit: int = Query(20, ge=1, le=50),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""在会话内搜索消息(用户级)"""
|
||||
# 验证成员身份
|
||||
conv_service = ConversationService(db)
|
||||
detail = await conv_service.get_conversation_detail(conversation_id, user.id)
|
||||
if not detail:
|
||||
raise HTTPException(status_code=403, detail="无权访问该会话")
|
||||
|
||||
service = MessageService(db)
|
||||
results = await service.search_messages(
|
||||
conversation_id=conversation_id,
|
||||
keyword=q,
|
||||
limit=limit,
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
@router.put("/{conversation_id}/messages/{message_id}/read")
|
||||
async def mark_as_read(
|
||||
conversation_id: str,
|
||||
|
||||
@@ -247,10 +247,18 @@ class ConversationService:
|
||||
ConversationMember.left_at.is_(None),
|
||||
)
|
||||
)
|
||||
member_rows = members_result.scalars().all()
|
||||
|
||||
# 批量获取所有成员用户信息
|
||||
member_user_ids = [m.user_id for m in member_rows]
|
||||
users_result = await self.db.execute(
|
||||
select(User).where(User.id.in_(member_user_ids))
|
||||
)
|
||||
users_map = {u.id: u for u in users_result.scalars().all()}
|
||||
|
||||
members = []
|
||||
for m in members_result.scalars().all():
|
||||
user_result = await self.db.execute(select(User).where(User.id == m.user_id))
|
||||
user = user_result.scalars().first()
|
||||
for m in member_rows:
|
||||
user = users_map.get(m.user_id)
|
||||
if user:
|
||||
members.append({
|
||||
"id": m.id,
|
||||
|
||||
@@ -102,12 +102,20 @@ class FriendService:
|
||||
result = await self.db.execute(
|
||||
select(Friend).where(Friend.user_id == user_id)
|
||||
)
|
||||
friendships = result.scalars().all()
|
||||
if not friendships:
|
||||
return []
|
||||
|
||||
# 批量获取所有好友用户
|
||||
friend_user_ids = [f.friend_user_id for f in friendships]
|
||||
users_result = await self.db.execute(
|
||||
select(User).where(User.id.in_(friend_user_ids))
|
||||
)
|
||||
users_map = {u.id: u for u in users_result.scalars().all()}
|
||||
|
||||
friends = []
|
||||
for friendship in result.scalars().all():
|
||||
user_result = await self.db.execute(
|
||||
select(User).where(User.id == friendship.friend_user_id)
|
||||
)
|
||||
user = user_result.scalars().first()
|
||||
for friendship in friendships:
|
||||
user = users_map.get(friendship.friend_user_id)
|
||||
if user:
|
||||
friends.append({
|
||||
"id": friendship.id,
|
||||
@@ -128,10 +136,20 @@ class FriendService:
|
||||
FriendRequest.status == "pending",
|
||||
).order_by(FriendRequest.created_at.desc())
|
||||
)
|
||||
requests_list = result.scalars().all()
|
||||
if not requests_list:
|
||||
return []
|
||||
|
||||
# 批量获取所有发送者
|
||||
from_user_ids = list(set(r.from_user_id for r in requests_list))
|
||||
users_result = await self.db.execute(
|
||||
select(User).where(User.id.in_(from_user_ids))
|
||||
)
|
||||
users_map = {u.id: u for u in users_result.scalars().all()}
|
||||
|
||||
requests = []
|
||||
for req in result.scalars().all():
|
||||
from_user = await self.db.execute(select(User).where(User.id == req.from_user_id))
|
||||
fu = from_user.scalars().first()
|
||||
for req in requests_list:
|
||||
fu = users_map.get(req.from_user_id)
|
||||
requests.append({
|
||||
"id": req.id,
|
||||
"from_user_id": req.from_user_id,
|
||||
|
||||
@@ -77,29 +77,42 @@ class MessageService:
|
||||
has_more = len(messages) > limit
|
||||
messages = messages[:limit]
|
||||
|
||||
# 获取发送者信息
|
||||
# 批量预加载发送者信息
|
||||
from app.models.user import User
|
||||
sender_ids = list(set(m.sender_id for m in messages))
|
||||
senders_result = await self.db.execute(
|
||||
select(User).where(User.id.in_(sender_ids))
|
||||
)
|
||||
senders_map = {u.id: u for u in senders_result.scalars().all()}
|
||||
|
||||
# 批量预加载被引用消息
|
||||
reply_to_ids = list(set(m.reply_to_id for m in messages if m.reply_to_id))
|
||||
reply_msgs_map: dict[str, Message] = {}
|
||||
if reply_to_ids:
|
||||
reply_result = await self.db.execute(
|
||||
select(Message).where(Message.id.in_(reply_to_ids))
|
||||
)
|
||||
reply_msgs_map = {m.id: m for m in reply_result.scalars().all()}
|
||||
|
||||
# 批量预加载被引用消息的发送者
|
||||
reply_sender_ids = list(set(rm.sender_id for rm in reply_msgs_map.values()))
|
||||
reply_senders_result = await self.db.execute(
|
||||
select(User).where(User.id.in_(reply_sender_ids))
|
||||
)
|
||||
reply_senders_map = {u.id: u for u in reply_senders_result.scalars().all()}
|
||||
|
||||
message_list = []
|
||||
for msg in reversed(messages):
|
||||
sender_result = await self.db.execute(
|
||||
select(User).where(User.id == msg.sender_id)
|
||||
)
|
||||
sender = sender_result.scalars().first()
|
||||
sender = senders_map.get(msg.sender_id)
|
||||
|
||||
# 获取被引用消息的信息
|
||||
reply_to_content = None
|
||||
reply_to_sender_name = None
|
||||
if msg.reply_to_id:
|
||||
reply_msg_result = await self.db.execute(
|
||||
select(Message).where(Message.id == msg.reply_to_id)
|
||||
)
|
||||
reply_msg = reply_msg_result.scalars().first()
|
||||
reply_msg = reply_msgs_map.get(msg.reply_to_id)
|
||||
if reply_msg:
|
||||
reply_to_content = reply_msg.content[:200] if reply_msg.content else None
|
||||
reply_sender_result = await self.db.execute(
|
||||
select(User).where(User.id == reply_msg.sender_id)
|
||||
)
|
||||
reply_sender = reply_sender_result.scalars().first()
|
||||
reply_sender = reply_senders_map.get(reply_msg.sender_id)
|
||||
reply_to_sender_name = reply_sender.username if reply_sender else None
|
||||
|
||||
message_list.append({
|
||||
|
||||
@@ -82,7 +82,7 @@ class MomentService:
|
||||
unique.sort(key=lambda x: x.created_at, reverse=True)
|
||||
unique = unique[:limit]
|
||||
|
||||
return [await self._moment_to_dict(m, user_id) for m in unique]
|
||||
return await self._moments_to_dicts(unique, user_id)
|
||||
|
||||
async def get_user_moments(self, user_id: str, viewer_id: str | None = None,
|
||||
cursor: str | None = None, limit: int = 20) -> list[dict]:
|
||||
@@ -116,7 +116,7 @@ class MomentService:
|
||||
if viewer_id == user_id:
|
||||
filtered.append(m)
|
||||
|
||||
return [await self._moment_to_dict(m, viewer_id) for m in filtered]
|
||||
return await self._moments_to_dicts(filtered, viewer_id)
|
||||
|
||||
async def delete_moment(self, moment_id: str, user_id: str):
|
||||
"""删除动态(仅作者)"""
|
||||
@@ -243,56 +243,70 @@ class MomentService:
|
||||
raise ValueError("只能删除自己的评论")
|
||||
await self.db.delete(comment)
|
||||
|
||||
async def _moment_to_dict(self, moment: Moment, viewer_id: str | None) -> dict:
|
||||
"""将 Moment ORM 对象转为前端需要的字典"""
|
||||
user_result = await self.db.execute(select(User).where(User.id == moment.user_id))
|
||||
user = user_result.scalars().first()
|
||||
async def _moments_to_dicts(self, moments: list[Moment], viewer_id: str | None) -> list[dict]:
|
||||
"""批量将 Moment ORM 对象转为前端需要的字典(优化 N+1 查询)"""
|
||||
if not moments:
|
||||
return []
|
||||
|
||||
# 点赞数
|
||||
like_count_result = await self.db.execute(
|
||||
select(func.count(MomentLike.id)).where(MomentLike.moment_id == moment.id)
|
||||
moment_ids = [m.id for m in moments]
|
||||
user_ids = list(set(m.user_id for m in moments))
|
||||
|
||||
# 批量获取所有作者
|
||||
users_result = await self.db.execute(select(User).where(User.id.in_(user_ids)))
|
||||
users_map = {u.id: u for u in users_result.scalars().all()}
|
||||
|
||||
# 批量获取点赞数
|
||||
like_counts_result = await self.db.execute(
|
||||
select(MomentLike.moment_id, func.count(MomentLike.id))
|
||||
.where(MomentLike.moment_id.in_(moment_ids))
|
||||
.group_by(MomentLike.moment_id)
|
||||
)
|
||||
like_count = like_count_result.scalar() or 0
|
||||
like_counts_map = dict(like_counts_result.all())
|
||||
|
||||
# 是否已点赞
|
||||
is_liked = False
|
||||
# 批量获取评论数
|
||||
comment_counts_result = await self.db.execute(
|
||||
select(MomentComment.moment_id, func.count(MomentComment.id))
|
||||
.where(MomentComment.moment_id.in_(moment_ids))
|
||||
.group_by(MomentComment.moment_id)
|
||||
)
|
||||
comment_counts_map = dict(comment_counts_result.all())
|
||||
|
||||
# 批量获取当前用户的点赞状态
|
||||
liked_moment_ids = set()
|
||||
if viewer_id:
|
||||
like_result = await self.db.execute(
|
||||
select(MomentLike).where(
|
||||
MomentLike.moment_id == moment.id,
|
||||
liked_result = await self.db.execute(
|
||||
select(MomentLike.moment_id).where(
|
||||
MomentLike.moment_id.in_(moment_ids),
|
||||
MomentLike.user_id == viewer_id,
|
||||
)
|
||||
)
|
||||
is_liked = like_result.scalars().first() is not None
|
||||
liked_moment_ids = {r[0] for r in liked_result.all()}
|
||||
|
||||
# 评论数
|
||||
comment_count_result = await self.db.execute(
|
||||
select(func.count(MomentComment.id)).where(MomentComment.moment_id == moment.id)
|
||||
)
|
||||
comment_count = comment_count_result.scalar() or 0
|
||||
result = []
|
||||
for moment in moments:
|
||||
user = users_map.get(moment.user_id)
|
||||
images = []
|
||||
if moment.images:
|
||||
try:
|
||||
images = json.loads(moment.images)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 解析图片
|
||||
images = []
|
||||
if moment.images:
|
||||
try:
|
||||
images = json.loads(moment.images)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"id": moment.id,
|
||||
"user_id": moment.user_id,
|
||||
"username": user.username if user else "未知",
|
||||
"nickname": user.nickname if user else None,
|
||||
"avatar_url": user.avatar_url if user else None,
|
||||
"content": moment.content,
|
||||
"images": images,
|
||||
"visibility": moment.visibility,
|
||||
"like_count": like_count,
|
||||
"is_liked": is_liked,
|
||||
"comment_count": comment_count,
|
||||
"created_at": moment.created_at,
|
||||
}
|
||||
result.append({
|
||||
"id": moment.id,
|
||||
"user_id": moment.user_id,
|
||||
"username": user.username if user else "未知",
|
||||
"nickname": user.nickname if user else None,
|
||||
"avatar_url": user.avatar_url if user else None,
|
||||
"content": moment.content,
|
||||
"images": images,
|
||||
"visibility": moment.visibility,
|
||||
"like_count": like_counts_map.get(moment.id, 0),
|
||||
"is_liked": moment.id in liked_moment_ids,
|
||||
"comment_count": comment_counts_map.get(moment.id, 0),
|
||||
"created_at": moment.created_at,
|
||||
})
|
||||
return result
|
||||
|
||||
async def _get_friend_ids(self, user_id: str) -> list[str]:
|
||||
"""获取好友ID列表"""
|
||||
|
||||
Reference in New Issue
Block a user