217 lines
7.9 KiB
Python
217 lines
7.9 KiB
Python
"""好友之树服务"""
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.friendship_tree import FriendshipTree
|
|
from app.models.conversation import Conversation
|
|
from app.models.conversation_member import ConversationMember
|
|
from app.models.message import Message
|
|
from app.models.friend import Friend
|
|
from app.models.user import User
|
|
from app.websocket.manager import manager
|
|
|
|
|
|
# 阶段定义:分数 -> (阶段索引, 名称, emoji)
|
|
STAGES = [
|
|
(0, "种子", "🌱"),
|
|
(11, "萌芽", "🌿"),
|
|
(41, "幼苗", "🪴"),
|
|
(151, "小树", "🌲"),
|
|
(401, "大树", "🌳"),
|
|
(1001, "古树", "🌲"),
|
|
]
|
|
|
|
|
|
def stage_for_score(score: int) -> tuple[int, str, str]:
|
|
"""根据分数返回 (阶段索引, 名称, emoji)"""
|
|
idx = 0
|
|
for i, (threshold, name, emoji) in enumerate(STAGES):
|
|
if score >= threshold:
|
|
idx = i
|
|
return idx, STAGES[idx][1], STAGES[idx][2]
|
|
|
|
|
|
class TreeService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def _get_or_create_tree_row(self, user_id: str, friend_id: str) -> FriendshipTree:
|
|
"""规范化(a < b)后查/建树行"""
|
|
a, b = (user_id, friend_id) if user_id < friend_id else (friend_id, user_id)
|
|
result = await self.db.execute(
|
|
select(FriendshipTree).where(
|
|
FriendshipTree.user_a_id == a,
|
|
FriendshipTree.user_b_id == b,
|
|
)
|
|
)
|
|
tree = result.scalars().first()
|
|
if not tree:
|
|
tree = FriendshipTree(
|
|
id=str(uuid.uuid4()),
|
|
user_a_id=a,
|
|
user_b_id=b,
|
|
)
|
|
self.db.add(tree)
|
|
await self.db.flush()
|
|
return tree
|
|
|
|
async def _count_messages_between(self, user_id: str, friend_id: str) -> int:
|
|
"""统计两人私聊会话中的消息数(复用 get_or_create_private 的查找逻辑)"""
|
|
# 找到两人共有的私聊会话
|
|
result = await self.db.execute(
|
|
select(Conversation).join(ConversationMember)
|
|
.where(
|
|
Conversation.type == "private",
|
|
ConversationMember.user_id == user_id,
|
|
)
|
|
)
|
|
conv_id = None
|
|
for conv in result.scalars().all():
|
|
member_result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv.id,
|
|
ConversationMember.user_id == friend_id,
|
|
)
|
|
)
|
|
if member_result.scalars().first():
|
|
conv_id = conv.id
|
|
break
|
|
if not conv_id:
|
|
return 0
|
|
count_result = await self.db.execute(
|
|
select(func.count(Message.id)).where(
|
|
Message.conversation_id == conv_id,
|
|
Message.is_deleted == False,
|
|
)
|
|
)
|
|
return count_result.scalar() or 0
|
|
|
|
async def _count_messages_in_days(self, user_id: str, friend_id: str, days: int = 7) -> int:
|
|
"""统计近 N 天两人私聊消息数(心跳 BPM 用)"""
|
|
result = await self.db.execute(
|
|
select(Conversation).join(ConversationMember)
|
|
.where(
|
|
Conversation.type == "private",
|
|
ConversationMember.user_id == user_id,
|
|
)
|
|
)
|
|
conv_id = None
|
|
for conv in result.scalars().all():
|
|
member_result = await self.db.execute(
|
|
select(ConversationMember).where(
|
|
ConversationMember.conversation_id == conv.id,
|
|
ConversationMember.user_id == friend_id,
|
|
)
|
|
)
|
|
if member_result.scalars().first():
|
|
conv_id = conv.id
|
|
break
|
|
if not conv_id:
|
|
return 0
|
|
from datetime import datetime, timedelta
|
|
since = datetime.utcnow() - timedelta(days=days)
|
|
count_result = await self.db.execute(
|
|
select(func.count(Message.id)).where(
|
|
Message.conversation_id == conv_id,
|
|
Message.is_deleted == False,
|
|
Message.created_at >= since,
|
|
)
|
|
)
|
|
return count_result.scalar() or 0
|
|
|
|
async def get_heartbeat(self, user_id: str, friend_id: str) -> dict:
|
|
"""获取心跳同步数据:BPM 由近7天消息数决定"""
|
|
from app.models.user import User
|
|
msg_7d = await self._count_messages_in_days(user_id, friend_id, 7)
|
|
|
|
# BPM 映射
|
|
if msg_7d == 0:
|
|
bpm = 42 # 沉睡
|
|
elif msg_7d < 30:
|
|
bpm = 54 # 平静
|
|
elif msg_7d < 100:
|
|
bpm = 66 # 正常
|
|
elif msg_7d < 300:
|
|
bpm = 78 # 活跃
|
|
else:
|
|
bpm = 90 # 热烈
|
|
|
|
# 对方信息
|
|
friend_result = await self.db.execute(select(User).where(User.id == friend_id))
|
|
friend = friend_result.scalars().first()
|
|
is_online = manager.is_online(friend_id) if friend else False
|
|
|
|
return {
|
|
"friend_id": friend_id,
|
|
"friend_name": friend.nickname or friend.username if friend else "未知",
|
|
"friend_avatar": friend.avatar_url if friend else None,
|
|
"bpm": bpm,
|
|
"msg_7d": msg_7d,
|
|
"is_online": is_online,
|
|
# leaf_seed 用于渲染对方的迷你叶(确定性)
|
|
"friend_leaf_seed": (friend_id or "0")[:16].ljust(16, '0'),
|
|
}
|
|
|
|
async def get_tree(self, user_id: str, friend_id: str) -> dict:
|
|
"""获取好友之树"""
|
|
tree = await self._get_or_create_tree_row(user_id, friend_id)
|
|
msg_count = await self._count_messages_between(user_id, friend_id)
|
|
total_score = msg_count + tree.water_count * 5
|
|
stage_idx, stage_name, stage_emoji = stage_for_score(total_score)
|
|
|
|
# 下个阶段的分数门槛
|
|
next_threshold = STAGES[stage_idx + 1][0] if stage_idx + 1 < len(STAGES) else None
|
|
|
|
# 好友信息
|
|
friend_result = await self.db.execute(select(User).where(User.id == friend_id))
|
|
friend = friend_result.scalars().first()
|
|
|
|
return {
|
|
"tree_id": tree.id,
|
|
"friend_id": friend_id,
|
|
"friend_name": friend.nickname or friend.username if friend else "未知",
|
|
"friend_avatar": friend.avatar_url if friend else None,
|
|
"message_count": msg_count,
|
|
"water_count": tree.water_count,
|
|
"total_score": total_score,
|
|
"stage_index": stage_idx,
|
|
"stage_name": stage_name,
|
|
"stage_emoji": stage_emoji,
|
|
"next_threshold": next_threshold,
|
|
"last_watered_at": tree.last_watered_at,
|
|
"seed": tree.id[:16], # 用于程序化树形态
|
|
}
|
|
|
|
async def water(self, user_id: str, friend_id: str) -> dict:
|
|
"""浇水,返回带 leveled_up 标志"""
|
|
tree = await self._get_or_create_tree_row(user_id, friend_id)
|
|
old_score = (await self._count_messages_between(user_id, friend_id)) + tree.water_count * 5
|
|
old_stage_idx = stage_for_score(old_score)[0]
|
|
|
|
tree.water_count += 1
|
|
tree.last_watered_at = datetime.utcnow()
|
|
await self.db.flush()
|
|
|
|
result = await self.get_tree(user_id, friend_id)
|
|
result["leveled_up"] = result["stage_index"] > old_stage_idx
|
|
return result
|
|
|
|
async def get_all_trees(self, user_id: str) -> list[dict]:
|
|
"""获取所有好友的树(花园概览用)"""
|
|
# 获取好友列表
|
|
friends_result = await self.db.execute(
|
|
select(Friend.friend_user_id).where(Friend.user_id == user_id)
|
|
)
|
|
friend_ids = [r[0] for r in friends_result.all()]
|
|
|
|
trees = []
|
|
for fid in friend_ids:
|
|
trees.append(await self.get_tree(user_id, fid))
|
|
# 按分数排序
|
|
trees.sort(key=lambda t: t["total_score"], reverse=True)
|
|
return trees
|