169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
"""认证路由"""
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from pydantic import BaseModel
|
|
|
|
from app.dependencies import get_db
|
|
from app.models.user import User
|
|
from app.models.password_reset import PasswordResetToken
|
|
from app.schemas.auth import RegisterRequest, LoginRequest, TokenResponse, RefreshRequest
|
|
from app.services.auth_service import AuthService
|
|
from app.services.email_service import generate_code, hash_code, send_verification_email
|
|
from app.utils.security import decode_refresh_token, hash_password
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class ForgotRequest(BaseModel):
|
|
email: str
|
|
|
|
|
|
class ResetRequest(BaseModel):
|
|
email: str
|
|
code: str
|
|
new_password: str
|
|
|
|
|
|
class VerifyEmailRequest(BaseModel):
|
|
email: str
|
|
code: str
|
|
|
|
|
|
@router.post("/register", response_model=TokenResponse)
|
|
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
|
"""用户注册"""
|
|
service = AuthService(db)
|
|
try:
|
|
result = await service.register(req.username, req.email, req.password)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
user=result["user"],
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.post("/login", response_model=TokenResponse)
|
|
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|
"""用户登录"""
|
|
service = AuthService(db)
|
|
try:
|
|
result = await service.login(req.username, req.password)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
user=result["user"],
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=401, detail=str(e))
|
|
|
|
|
|
@router.post("/refresh", response_model=dict)
|
|
async def refresh_token(req: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
|
"""刷新 Token"""
|
|
payload = decode_refresh_token(req.refresh_token)
|
|
if not payload:
|
|
raise HTTPException(status_code=401, detail="无效的 Refresh Token")
|
|
|
|
service = AuthService(db)
|
|
try:
|
|
result = await service.refresh_token(payload.get("sub"))
|
|
return result
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=401, detail=str(e))
|
|
|
|
|
|
@router.post("/forgot")
|
|
async def forgot_password(req: ForgotRequest, db: AsyncSession = Depends(get_db)):
|
|
"""找回密码:生成验证码(开发期打印到日志)"""
|
|
result = await db.execute(select(User).where(User.email == req.email))
|
|
user = result.scalars().first()
|
|
# 出于安全,无论用户是否存在都返回成功
|
|
if user:
|
|
import uuid
|
|
code = generate_code()
|
|
token = PasswordResetToken(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user.id,
|
|
token_hash=hash_code(code),
|
|
expires_at=datetime.utcnow() + timedelta(minutes=15),
|
|
)
|
|
db.add(token)
|
|
await db.flush()
|
|
await send_verification_email(req.email, code, "找回密码")
|
|
return {"success": True, "message": "若该邮箱已注册,验证码已发送(开发期见后端日志)"}
|
|
|
|
|
|
@router.post("/reset")
|
|
async def reset_password(req: ResetRequest, db: AsyncSession = Depends(get_db)):
|
|
"""用验证码重置密码"""
|
|
result = await db.execute(select(User).where(User.email == req.email))
|
|
user = result.scalars().first()
|
|
if not user:
|
|
raise HTTPException(status_code=400, detail="用户不存在")
|
|
# 找匹配且未过期的 token
|
|
tokens = await db.execute(
|
|
select(PasswordResetToken).where(
|
|
PasswordResetToken.user_id == user.id,
|
|
PasswordResetToken.used == False,
|
|
).order_by(PasswordResetToken.created_at.desc())
|
|
)
|
|
matched = None
|
|
for t in tokens.scalars().all():
|
|
if t.token_hash == hash_code(req.code) and t.expires_at > datetime.utcnow():
|
|
matched = t
|
|
break
|
|
if not matched:
|
|
raise HTTPException(status_code=400, detail="验证码无效或已过期")
|
|
user.password_hash = hash_password(req.new_password)
|
|
matched.used = True
|
|
await db.flush()
|
|
return {"success": True, "message": "密码已重置"}
|
|
|
|
|
|
@router.post("/send-verify-email")
|
|
async def send_verify_email(req: ForgotRequest, db: AsyncSession = Depends(get_db)):
|
|
"""发送邮箱验证码"""
|
|
result = await db.execute(select(User).where(User.email == req.email))
|
|
user = result.scalars().first()
|
|
if user:
|
|
code = generate_code()
|
|
import uuid
|
|
token = PasswordResetToken(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user.id,
|
|
token_hash=hash_code(code),
|
|
expires_at=datetime.utcnow() + timedelta(minutes=15),
|
|
)
|
|
db.add(token)
|
|
await db.flush()
|
|
await send_verification_email(req.email, code, "邮箱验证")
|
|
return {"success": True, "message": "验证码已发送(开发期见后端日志)"}
|
|
|
|
|
|
@router.post("/verify-email")
|
|
async def verify_email(req: VerifyEmailRequest, db: AsyncSession = Depends(get_db)):
|
|
"""验证邮箱"""
|
|
result = await db.execute(select(User).where(User.email == req.email))
|
|
user = result.scalars().first()
|
|
if not user:
|
|
raise HTTPException(status_code=400, detail="用户不存在")
|
|
tokens = await db.execute(
|
|
select(PasswordResetToken).where(
|
|
PasswordResetToken.user_id == user.id,
|
|
PasswordResetToken.used == False,
|
|
).order_by(PasswordResetToken.created_at.desc())
|
|
)
|
|
for t in tokens.scalars().all():
|
|
if t.token_hash == hash_code(req.code) and t.expires_at > datetime.utcnow():
|
|
t.used = True
|
|
user.email_verified = True
|
|
await db.flush()
|
|
return {"success": True, "email_verified": True}
|
|
raise HTTPException(status_code=400, detail="验证码无效或已过期")
|