Files
2026-06-12 23:14:12 +08:00

97 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""FastAPI 公共依赖"""
from typing import AsyncGenerator
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
import redis.asyncio as redis
from app.config import settings
from app.database import async_session
from app.utils.security import decode_access_token
# HTTP Bearer 认证
security = HTTPBearer()
# Redis 连接池
_redis_pool: redis.Redis | None = None
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话"""
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def get_redis() -> redis.Redis:
"""获取 Redis 连接"""
global _redis_pool
if _redis_pool is None:
_redis_pool = redis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis_pool
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
):
"""获取当前认证用户"""
token = credentials.credentials
payload = decode_access_token(token)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效或过期的 Token",
)
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的 Token 格式",
)
from app.services.user_service import UserService
user_service = UserService(db)
user = await user_service.get_by_id(user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
)
if user.is_banned:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号已被封禁",
)
return user
async def get_admin_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
):
"""获取管理员用户(需要 is_admin=True"""
token = credentials.credentials
payload = decode_access_token(token)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效或过期的 Token",
)
if not payload.get("is_admin"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限",
)
return payload