1.7
This commit is contained in:
@@ -0,0 +1,161 @@
|
||||
"""聊天气候服务:把对话节奏翻译成季节/温度/天气"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.chat_climate import ChatClimate
|
||||
from app.models.message import Message
|
||||
from app.models.conversation_member import ConversationMember
|
||||
|
||||
|
||||
SEASON_EMOJI = {
|
||||
"spring": "🌸", "summer": "☀️", "autumn": "🍁", "winter": "❄️",
|
||||
}
|
||||
WEATHER_EMOJI = {
|
||||
"sunny": "晴", "cloudy": "多云", "rainy": "雨", "windy": "风", "snowy": "雪",
|
||||
}
|
||||
|
||||
|
||||
class ClimateService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def _verify_member(self, conversation_id: str, user_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(ConversationMember).where(
|
||||
ConversationMember.conversation_id == conversation_id,
|
||||
ConversationMember.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalars().first() is not None
|
||||
|
||||
async def compute(self, conversation_id: str, user_id: str) -> dict:
|
||||
"""计算/更新会话气候"""
|
||||
if not await self._verify_member(conversation_id, user_id):
|
||||
raise ValueError("无权访问该会话")
|
||||
|
||||
# 取近 14 天消息
|
||||
since = datetime.utcnow() - timedelta(days=14)
|
||||
result = await self.db.execute(
|
||||
select(Message).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.is_deleted == False,
|
||||
Message.created_at >= since,
|
||||
).order_by(Message.created_at.asc())
|
||||
)
|
||||
msgs = result.scalars().all()
|
||||
|
||||
season, temperature, weather = self._analyze(msgs)
|
||||
emoji = SEASON_EMOJI[season]
|
||||
|
||||
# 更新气候记录
|
||||
cc_result = await self.db.execute(
|
||||
select(ChatClimate).where(ChatClimate.conversation_id == conversation_id)
|
||||
)
|
||||
climate = cc_result.scalars().first()
|
||||
|
||||
# 日历历史
|
||||
history = []
|
||||
if climate and climate.daily_history:
|
||||
try:
|
||||
history = json.loads(climate.daily_history)
|
||||
except Exception:
|
||||
history = []
|
||||
|
||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
history = [h for h in history if h.get("date") != today]
|
||||
history.append({"date": today, "season": season, "temp": temperature, "emoji": emoji})
|
||||
history = history[-30:] # 保留 30 天
|
||||
|
||||
if climate:
|
||||
climate.season = season
|
||||
climate.temperature = temperature
|
||||
climate.weather = weather
|
||||
climate.emoji = emoji
|
||||
climate.daily_history = json.dumps(history, ensure_ascii=False)
|
||||
else:
|
||||
climate = ChatClimate(
|
||||
conversation_id=conversation_id,
|
||||
season=season, temperature=temperature, weather=weather,
|
||||
emoji=emoji, daily_history=json.dumps(history, ensure_ascii=False),
|
||||
)
|
||||
self.db.add(climate)
|
||||
await self.db.flush()
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"season": season,
|
||||
"temperature": temperature,
|
||||
"weather": weather,
|
||||
"emoji": emoji,
|
||||
"weather_label": WEATHER_EMOJI[weather],
|
||||
"message_count_14d": len(msgs),
|
||||
}
|
||||
|
||||
def _analyze(self, msgs: list[Message]) -> tuple[str, int, str]:
|
||||
"""根据消息列表分析季节/温度/天气"""
|
||||
n = len(msgs)
|
||||
if n == 0:
|
||||
return "winter", -8, "snowy"
|
||||
|
||||
# 14 天内的消息分布 → 温度(活跃度)
|
||||
# 温度 = 消息密度映射到 -10..40
|
||||
density = min(n / 14, 1) # 平均每天消息数(封顶 1 表示满)
|
||||
temperature = round(-8 + density * 46) # -8 .. 38
|
||||
|
||||
# 季节:按温度分段
|
||||
if temperature >= 28:
|
||||
season = "summer"
|
||||
elif temperature >= 15:
|
||||
season = "spring"
|
||||
elif temperature >= 5:
|
||||
season = "autumn"
|
||||
else:
|
||||
season = "winter"
|
||||
|
||||
# 天气:根据回复间隔/连续性 + 平均字数
|
||||
if n >= 2:
|
||||
gaps = []
|
||||
for i in range(1, n):
|
||||
gap = (msgs[i].created_at - msgs[i - 1].created_at).total_seconds()
|
||||
gaps.append(gap)
|
||||
avg_gap = sum(gaps) / len(gaps)
|
||||
avg_len = sum(len(m.content or "") for m in msgs) / n
|
||||
|
||||
# 连续性(gap 小 = 连续)
|
||||
if avg_gap < 120: # 2 分钟内,热烈
|
||||
weather = "sunny"
|
||||
elif avg_gap < 1800: # 半小时内,正常
|
||||
weather = "cloudy"
|
||||
elif avg_gap < 21600: # 6 小时内,稀疏
|
||||
weather = "rainy"
|
||||
elif avg_gap < 86400: # 一天内
|
||||
weather = "windy"
|
||||
else:
|
||||
weather = "snowy"
|
||||
|
||||
# 字数长 = 有深度对话,偏向"雨"(绵绵)
|
||||
if avg_len > 50 and weather in ("sunny", "cloudy"):
|
||||
weather = "rainy"
|
||||
else:
|
||||
weather = "cloudy"
|
||||
|
||||
return season, temperature, weather
|
||||
|
||||
async def get_calendar(self, conversation_id: str, user_id: str) -> list[dict]:
|
||||
"""获取 30 天气候日历"""
|
||||
if not await self._verify_member(conversation_id, user_id):
|
||||
raise ValueError("无权访问该会话")
|
||||
result = await self.db.execute(
|
||||
select(ChatClimate).where(ChatClimate.conversation_id == conversation_id)
|
||||
)
|
||||
climate = result.scalars().first()
|
||||
if not climate or not climate.daily_history:
|
||||
return []
|
||||
try:
|
||||
return json.loads(climate.daily_history)
|
||||
except Exception:
|
||||
return []
|
||||
Reference in New Issue
Block a user