"""认证路由""" from datetime import datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from pydantic import BaseModel from app.dependencies import get_db from app.models.user import User from app.models.password_reset import PasswordResetToken from app.schemas.auth import RegisterRequest, LoginRequest, TokenResponse, RefreshRequest from app.services.auth_service import AuthService from app.services.email_service import generate_code, hash_code, send_verification_email from app.utils.security import decode_refresh_token, hash_password router = APIRouter() class ForgotRequest(BaseModel): email: str class ResetRequest(BaseModel): email: str code: str new_password: str class VerifyEmailRequest(BaseModel): email: str code: str @router.post("/register", response_model=TokenResponse) async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)): """用户注册""" service = AuthService(db) try: result = await service.register(req.username, req.email, req.password) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], user=result["user"], ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @router.post("/login", response_model=TokenResponse) async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)): """用户登录""" service = AuthService(db) try: result = await service.login(req.username, req.password) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], user=result["user"], ) except ValueError as e: raise HTTPException(status_code=401, detail=str(e)) @router.post("/refresh", response_model=dict) async def refresh_token(req: RefreshRequest, db: AsyncSession = Depends(get_db)): """刷新 Token""" payload = decode_refresh_token(req.refresh_token) if not payload: raise HTTPException(status_code=401, detail="无效的 Refresh Token") service = AuthService(db) try: result = await service.refresh_token(payload.get("sub")) return result except ValueError as e: raise HTTPException(status_code=401, detail=str(e)) @router.post("/forgot") async def forgot_password(req: ForgotRequest, db: AsyncSession = Depends(get_db)): """找回密码:生成验证码(开发期打印到日志)""" result = await db.execute(select(User).where(User.email == req.email)) user = result.scalars().first() # 出于安全,无论用户是否存在都返回成功 if user: import uuid code = generate_code() token = PasswordResetToken( id=str(uuid.uuid4()), user_id=user.id, token_hash=hash_code(code), expires_at=datetime.utcnow() + timedelta(minutes=15), ) db.add(token) await db.flush() await send_verification_email(req.email, code, "找回密码") return {"success": True, "message": "若该邮箱已注册,验证码已发送(开发期见后端日志)"} @router.post("/reset") async def reset_password(req: ResetRequest, db: AsyncSession = Depends(get_db)): """用验证码重置密码""" result = await db.execute(select(User).where(User.email == req.email)) user = result.scalars().first() if not user: raise HTTPException(status_code=400, detail="用户不存在") # 找匹配且未过期的 token tokens = await db.execute( select(PasswordResetToken).where( PasswordResetToken.user_id == user.id, PasswordResetToken.used == False, ).order_by(PasswordResetToken.created_at.desc()) ) matched = None for t in tokens.scalars().all(): if t.token_hash == hash_code(req.code) and t.expires_at > datetime.utcnow(): matched = t break if not matched: raise HTTPException(status_code=400, detail="验证码无效或已过期") user.password_hash = hash_password(req.new_password) matched.used = True await db.flush() return {"success": True, "message": "密码已重置"} @router.post("/send-verify-email") async def send_verify_email(req: ForgotRequest, db: AsyncSession = Depends(get_db)): """发送邮箱验证码""" result = await db.execute(select(User).where(User.email == req.email)) user = result.scalars().first() if user: code = generate_code() import uuid token = PasswordResetToken( id=str(uuid.uuid4()), user_id=user.id, token_hash=hash_code(code), expires_at=datetime.utcnow() + timedelta(minutes=15), ) db.add(token) await db.flush() await send_verification_email(req.email, code, "邮箱验证") return {"success": True, "message": "验证码已发送(开发期见后端日志)"} @router.post("/verify-email") async def verify_email(req: VerifyEmailRequest, db: AsyncSession = Depends(get_db)): """验证邮箱""" result = await db.execute(select(User).where(User.email == req.email)) user = result.scalars().first() if not user: raise HTTPException(status_code=400, detail="用户不存在") tokens = await db.execute( select(PasswordResetToken).where( PasswordResetToken.user_id == user.id, PasswordResetToken.used == False, ).order_by(PasswordResetToken.created_at.desc()) ) for t in tokens.scalars().all(): if t.token_hash == hash_code(req.code) and t.expires_at > datetime.utcnow(): t.used = True user.email_verified = True await db.flush() return {"success": True, "email_verified": True} raise HTTPException(status_code=400, detail="验证码无效或已过期")