327 lines
12 KiB
Python
327 lines
12 KiB
Python
"""朋友圈服务"""
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import select, func, and_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.moment import Moment, MomentLike, MomentComment
|
|
from app.models.friend import Friend
|
|
from app.models.user import User
|
|
|
|
|
|
class MomentService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def create_moment(self, user_id: str, content: str,
|
|
images: list[str] | None = None,
|
|
visibility: str = "friends") -> Moment:
|
|
"""发布动态"""
|
|
moment = Moment(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
content=content,
|
|
images=json.dumps(images) if images else None,
|
|
visibility=visibility,
|
|
)
|
|
self.db.add(moment)
|
|
await self.db.flush()
|
|
return moment
|
|
|
|
async def get_feed(self, user_id: str, cursor: str | None = None,
|
|
limit: int = 20) -> list[dict]:
|
|
"""获取朋友圈 feed(自己 + 好友的动态)"""
|
|
# 获取好友ID列表
|
|
friend_ids = await self._get_friend_ids(user_id)
|
|
# 可以看到的人:自己 + 好友
|
|
visible_user_ids = [user_id] + friend_ids
|
|
|
|
query = (
|
|
select(Moment)
|
|
.where(
|
|
Moment.user_id.in_(visible_user_ids),
|
|
Moment.visibility != "private", # 私密动态只有自己能看到(下面单独处理)
|
|
)
|
|
.order_by(Moment.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
|
|
# 也获取自己的私密动态
|
|
own_private_query = (
|
|
select(Moment)
|
|
.where(Moment.user_id == user_id, Moment.visibility == "private")
|
|
.order_by(Moment.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
|
|
if cursor:
|
|
cursor_result = await self.db.execute(
|
|
select(Moment.created_at).where(Moment.id == cursor)
|
|
)
|
|
cursor_time = cursor_result.scalar()
|
|
if cursor_time:
|
|
query = query.where(Moment.created_at < cursor_time)
|
|
own_private_query = own_private_query.where(Moment.created_at < cursor_time)
|
|
|
|
result = await self.db.execute(query)
|
|
own_private_result = await self.db.execute(own_private_query)
|
|
|
|
moments = list(result.scalars().all())
|
|
moments.extend(own_private_result.scalars().all())
|
|
|
|
# 合并并去重、排序
|
|
seen_ids = set()
|
|
unique = []
|
|
for m in moments:
|
|
if m.id not in seen_ids:
|
|
seen_ids.add(m.id)
|
|
unique.append(m)
|
|
unique.sort(key=lambda x: x.created_at, reverse=True)
|
|
unique = unique[:limit]
|
|
|
|
return await self._moments_to_dicts(unique, user_id)
|
|
|
|
async def get_user_moments(self, user_id: str, viewer_id: str | None = None,
|
|
cursor: str | None = None, limit: int = 20) -> list[dict]:
|
|
"""获取指定用户的动态"""
|
|
query = (
|
|
select(Moment)
|
|
.where(Moment.user_id == user_id)
|
|
.order_by(Moment.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
if cursor:
|
|
cursor_result = await self.db.execute(
|
|
select(Moment.created_at).where(Moment.id == cursor)
|
|
)
|
|
cursor_time = cursor_result.scalar()
|
|
if cursor_time:
|
|
query = query.where(Moment.created_at < cursor_time)
|
|
|
|
result = await self.db.execute(query)
|
|
moments = list(result.scalars().all())
|
|
|
|
# 过滤可见性
|
|
filtered = []
|
|
for m in moments:
|
|
if m.visibility == "public":
|
|
filtered.append(m)
|
|
elif m.visibility == "friends":
|
|
if viewer_id and (viewer_id == user_id or await self._are_friends(viewer_id, user_id)):
|
|
filtered.append(m)
|
|
elif m.visibility == "private":
|
|
if viewer_id == user_id:
|
|
filtered.append(m)
|
|
|
|
return await self._moments_to_dicts(filtered, viewer_id)
|
|
|
|
async def delete_moment(self, moment_id: str, user_id: str):
|
|
"""删除动态(仅作者)"""
|
|
result = await self.db.execute(select(Moment).where(Moment.id == moment_id))
|
|
moment = result.scalars().first()
|
|
if not moment:
|
|
raise ValueError("动态不存在")
|
|
if moment.user_id != user_id:
|
|
raise ValueError("只能删除自己的动态")
|
|
await self.db.delete(moment)
|
|
|
|
async def toggle_like(self, moment_id: str, user_id: str) -> bool:
|
|
"""点赞/取消点赞,返回是否已点赞"""
|
|
result = await self.db.execute(
|
|
select(MomentLike).where(
|
|
MomentLike.moment_id == moment_id,
|
|
MomentLike.user_id == user_id,
|
|
)
|
|
)
|
|
existing = result.scalars().first()
|
|
if existing:
|
|
await self.db.delete(existing)
|
|
return False
|
|
else:
|
|
self.db.add(MomentLike(
|
|
id=str(uuid.uuid4()),
|
|
moment_id=moment_id,
|
|
user_id=user_id,
|
|
))
|
|
return True
|
|
|
|
async def add_comment(self, moment_id: str, user_id: str, content: str,
|
|
reply_to_id: str | None = None) -> dict:
|
|
"""添加评论"""
|
|
# 验证动态存在
|
|
moment_result = await self.db.execute(select(Moment).where(Moment.id == moment_id))
|
|
if not moment_result.scalars().first():
|
|
raise ValueError("动态不存在")
|
|
|
|
comment = MomentComment(
|
|
id=str(uuid.uuid4()),
|
|
moment_id=moment_id,
|
|
user_id=user_id,
|
|
content=content,
|
|
reply_to_id=reply_to_id,
|
|
)
|
|
self.db.add(comment)
|
|
await self.db.flush()
|
|
|
|
# 返回带用户信息的评论
|
|
user_result = await self.db.execute(select(User).where(User.id == user_id))
|
|
user = user_result.scalars().first()
|
|
|
|
reply_to_username = None
|
|
if reply_to_id:
|
|
rt_result = await self.db.execute(select(User).where(User.id == comment.reply_to_id))
|
|
# reply_to_id 是评论 ID,需要找到评论者的 user
|
|
rt_comment = await self.db.execute(
|
|
select(MomentComment).where(MomentComment.id == reply_to_id)
|
|
)
|
|
rt_c = rt_comment.scalars().first()
|
|
if rt_c:
|
|
rt_user = await self.db.execute(select(User).where(User.id == rt_c.user_id))
|
|
rt_u = rt_user.scalars().first()
|
|
reply_to_username = rt_u.username if rt_u else None
|
|
|
|
return {
|
|
"id": comment.id,
|
|
"moment_id": moment_id,
|
|
"user_id": user_id,
|
|
"username": user.username if user else "未知",
|
|
"nickname": user.nickname if user else None,
|
|
"avatar_url": user.avatar_url if user else None,
|
|
"content": content,
|
|
"reply_to_id": reply_to_id,
|
|
"reply_to_username": reply_to_username,
|
|
"created_at": comment.created_at,
|
|
}
|
|
|
|
async def get_comments(self, moment_id: str) -> list[dict]:
|
|
"""获取评论列表"""
|
|
result = await self.db.execute(
|
|
select(MomentComment).where(
|
|
MomentComment.moment_id == moment_id
|
|
).order_by(MomentComment.created_at.asc())
|
|
)
|
|
comments = []
|
|
for c in result.scalars().all():
|
|
user_result = await self.db.execute(select(User).where(User.id == c.user_id))
|
|
user = user_result.scalars().first()
|
|
|
|
reply_to_username = None
|
|
if c.reply_to_id:
|
|
rt_comment = await self.db.execute(
|
|
select(MomentComment).where(MomentComment.id == c.reply_to_id)
|
|
)
|
|
rt_c = rt_comment.scalars().first()
|
|
if rt_c:
|
|
rt_user = await self.db.execute(select(User).where(User.id == rt_c.user_id))
|
|
rt_u = rt_user.scalars().first()
|
|
reply_to_username = rt_u.username if rt_u else None
|
|
|
|
comments.append({
|
|
"id": c.id,
|
|
"moment_id": moment_id,
|
|
"user_id": c.user_id,
|
|
"username": user.username if user else "未知",
|
|
"nickname": user.nickname if user else None,
|
|
"avatar_url": user.avatar_url if user else None,
|
|
"content": c.content,
|
|
"reply_to_id": c.reply_to_id,
|
|
"reply_to_username": reply_to_username,
|
|
"created_at": c.created_at,
|
|
})
|
|
return comments
|
|
|
|
async def delete_comment(self, comment_id: str, user_id: str):
|
|
"""删除评论(仅作者)"""
|
|
result = await self.db.execute(select(MomentComment).where(MomentComment.id == comment_id))
|
|
comment = result.scalars().first()
|
|
if not comment:
|
|
raise ValueError("评论不存在")
|
|
if comment.user_id != user_id:
|
|
raise ValueError("只能删除自己的评论")
|
|
await self.db.delete(comment)
|
|
|
|
async def _moments_to_dicts(self, moments: list[Moment], viewer_id: str | None) -> list[dict]:
|
|
"""批量将 Moment ORM 对象转为前端需要的字典(优化 N+1 查询)"""
|
|
if not moments:
|
|
return []
|
|
|
|
moment_ids = [m.id for m in moments]
|
|
user_ids = list(set(m.user_id for m in moments))
|
|
|
|
# 批量获取所有作者
|
|
users_result = await self.db.execute(select(User).where(User.id.in_(user_ids)))
|
|
users_map = {u.id: u for u in users_result.scalars().all()}
|
|
|
|
# 批量获取点赞数
|
|
like_counts_result = await self.db.execute(
|
|
select(MomentLike.moment_id, func.count(MomentLike.id))
|
|
.where(MomentLike.moment_id.in_(moment_ids))
|
|
.group_by(MomentLike.moment_id)
|
|
)
|
|
like_counts_map = dict(like_counts_result.all())
|
|
|
|
# 批量获取评论数
|
|
comment_counts_result = await self.db.execute(
|
|
select(MomentComment.moment_id, func.count(MomentComment.id))
|
|
.where(MomentComment.moment_id.in_(moment_ids))
|
|
.group_by(MomentComment.moment_id)
|
|
)
|
|
comment_counts_map = dict(comment_counts_result.all())
|
|
|
|
# 批量获取当前用户的点赞状态
|
|
liked_moment_ids = set()
|
|
if viewer_id:
|
|
liked_result = await self.db.execute(
|
|
select(MomentLike.moment_id).where(
|
|
MomentLike.moment_id.in_(moment_ids),
|
|
MomentLike.user_id == viewer_id,
|
|
)
|
|
)
|
|
liked_moment_ids = {r[0] for r in liked_result.all()}
|
|
|
|
result = []
|
|
for moment in moments:
|
|
user = users_map.get(moment.user_id)
|
|
images = []
|
|
if moment.images:
|
|
try:
|
|
images = json.loads(moment.images)
|
|
except Exception:
|
|
pass
|
|
|
|
result.append({
|
|
"id": moment.id,
|
|
"user_id": moment.user_id,
|
|
"username": user.username if user else "未知",
|
|
"nickname": user.nickname if user else None,
|
|
"avatar_url": user.avatar_url if user else None,
|
|
"content": moment.content,
|
|
"images": images,
|
|
"visibility": moment.visibility,
|
|
"like_count": like_counts_map.get(moment.id, 0),
|
|
"is_liked": moment.id in liked_moment_ids,
|
|
"comment_count": comment_counts_map.get(moment.id, 0),
|
|
"created_at": moment.created_at,
|
|
})
|
|
return result
|
|
|
|
async def _get_friend_ids(self, user_id: str) -> list[str]:
|
|
"""获取好友ID列表"""
|
|
result = await self.db.execute(
|
|
select(Friend.friend_user_id).where(Friend.user_id == user_id)
|
|
)
|
|
return [r[0] for r in result.all()]
|
|
|
|
async def _are_friends(self, user1_id: str, user2_id: str) -> bool:
|
|
"""检查两人是否是好友"""
|
|
result = await self.db.execute(
|
|
select(Friend).where(
|
|
Friend.user_id == user1_id,
|
|
Friend.friend_user_id == user2_id,
|
|
)
|
|
)
|
|
return result.scalars().first() is not None
|