"""好友之树服务""" import uuid from datetime import datetime from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.friendship_tree import FriendshipTree from app.models.conversation import Conversation from app.models.conversation_member import ConversationMember from app.models.message import Message from app.models.friend import Friend from app.models.user import User from app.websocket.manager import manager # 阶段定义:分数 -> (阶段索引, 名称, emoji) STAGES = [ (0, "种子", "🌱"), (11, "萌芽", "🌿"), (41, "幼苗", "🪴"), (151, "小树", "🌲"), (401, "大树", "🌳"), (1001, "古树", "🌲"), ] def stage_for_score(score: int) -> tuple[int, str, str]: """根据分数返回 (阶段索引, 名称, emoji)""" idx = 0 for i, (threshold, name, emoji) in enumerate(STAGES): if score >= threshold: idx = i return idx, STAGES[idx][1], STAGES[idx][2] class TreeService: def __init__(self, db: AsyncSession): self.db = db async def _get_or_create_tree_row(self, user_id: str, friend_id: str) -> FriendshipTree: """规范化(a < b)后查/建树行""" a, b = (user_id, friend_id) if user_id < friend_id else (friend_id, user_id) result = await self.db.execute( select(FriendshipTree).where( FriendshipTree.user_a_id == a, FriendshipTree.user_b_id == b, ) ) tree = result.scalars().first() if not tree: tree = FriendshipTree( id=str(uuid.uuid4()), user_a_id=a, user_b_id=b, ) self.db.add(tree) await self.db.flush() return tree async def _count_messages_between(self, user_id: str, friend_id: str) -> int: """统计两人私聊会话中的消息数(复用 get_or_create_private 的查找逻辑)""" # 找到两人共有的私聊会话 result = await self.db.execute( select(Conversation).join(ConversationMember) .where( Conversation.type == "private", ConversationMember.user_id == user_id, ) ) conv_id = None for conv in result.scalars().all(): member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv.id, ConversationMember.user_id == friend_id, ) ) if member_result.scalars().first(): conv_id = conv.id break if not conv_id: return 0 count_result = await self.db.execute( select(func.count(Message.id)).where( Message.conversation_id == conv_id, Message.is_deleted == False, ) ) return count_result.scalar() or 0 async def _count_messages_in_days(self, user_id: str, friend_id: str, days: int = 7) -> int: """统计近 N 天两人私聊消息数(心跳 BPM 用)""" result = await self.db.execute( select(Conversation).join(ConversationMember) .where( Conversation.type == "private", ConversationMember.user_id == user_id, ) ) conv_id = None for conv in result.scalars().all(): member_result = await self.db.execute( select(ConversationMember).where( ConversationMember.conversation_id == conv.id, ConversationMember.user_id == friend_id, ) ) if member_result.scalars().first(): conv_id = conv.id break if not conv_id: return 0 from datetime import datetime, timedelta since = datetime.utcnow() - timedelta(days=days) count_result = await self.db.execute( select(func.count(Message.id)).where( Message.conversation_id == conv_id, Message.is_deleted == False, Message.created_at >= since, ) ) return count_result.scalar() or 0 async def get_heartbeat(self, user_id: str, friend_id: str) -> dict: """获取心跳同步数据:BPM 由近7天消息数决定""" from app.models.user import User msg_7d = await self._count_messages_in_days(user_id, friend_id, 7) # BPM 映射 if msg_7d == 0: bpm = 42 # 沉睡 elif msg_7d < 30: bpm = 54 # 平静 elif msg_7d < 100: bpm = 66 # 正常 elif msg_7d < 300: bpm = 78 # 活跃 else: bpm = 90 # 热烈 # 对方信息 friend_result = await self.db.execute(select(User).where(User.id == friend_id)) friend = friend_result.scalars().first() is_online = manager.is_online(friend_id) if friend else False return { "friend_id": friend_id, "friend_name": friend.nickname or friend.username if friend else "未知", "friend_avatar": friend.avatar_url if friend else None, "bpm": bpm, "msg_7d": msg_7d, "is_online": is_online, # leaf_seed 用于渲染对方的迷你叶(确定性) "friend_leaf_seed": (friend_id or "0")[:16].ljust(16, '0'), } async def get_tree(self, user_id: str, friend_id: str) -> dict: """获取好友之树""" tree = await self._get_or_create_tree_row(user_id, friend_id) msg_count = await self._count_messages_between(user_id, friend_id) total_score = msg_count + tree.water_count * 5 stage_idx, stage_name, stage_emoji = stage_for_score(total_score) # 下个阶段的分数门槛 next_threshold = STAGES[stage_idx + 1][0] if stage_idx + 1 < len(STAGES) else None # 好友信息 friend_result = await self.db.execute(select(User).where(User.id == friend_id)) friend = friend_result.scalars().first() return { "tree_id": tree.id, "friend_id": friend_id, "friend_name": friend.nickname or friend.username if friend else "未知", "friend_avatar": friend.avatar_url if friend else None, "message_count": msg_count, "water_count": tree.water_count, "total_score": total_score, "stage_index": stage_idx, "stage_name": stage_name, "stage_emoji": stage_emoji, "next_threshold": next_threshold, "last_watered_at": tree.last_watered_at, "seed": tree.id[:16], # 用于程序化树形态 } async def water(self, user_id: str, friend_id: str) -> dict: """浇水,返回带 leveled_up 标志""" tree = await self._get_or_create_tree_row(user_id, friend_id) old_score = (await self._count_messages_between(user_id, friend_id)) + tree.water_count * 5 old_stage_idx = stage_for_score(old_score)[0] tree.water_count += 1 tree.last_watered_at = datetime.utcnow() await self.db.flush() result = await self.get_tree(user_id, friend_id) result["leveled_up"] = result["stage_index"] > old_stage_idx return result async def get_all_trees(self, user_id: str) -> list[dict]: """获取所有好友的树(花园概览用)""" # 获取好友列表 friends_result = await self.db.execute( select(Friend.friend_user_id).where(Friend.user_id == user_id) ) friend_ids = [r[0] for r in friends_result.all()] trees = [] for fid in friend_ids: trees.append(await self.get_tree(user_id, fid)) # 按分数排序 trees.sort(key=lambda t: t["total_score"], reverse=True) return trees