"""FastAPI 公共依赖""" from typing import AsyncGenerator from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.ext.asyncio import AsyncSession import redis.asyncio as redis from app.config import settings from app.database import async_session from app.utils.security import decode_access_token # HTTP Bearer 认证 security = HTTPBearer() # Redis 连接池 _redis_pool: redis.Redis | None = None async def get_db() -> AsyncGenerator[AsyncSession, None]: """获取数据库会话""" async with async_session() as session: try: yield session await session.commit() except Exception: await session.rollback() raise async def get_redis() -> redis.Redis: """获取 Redis 连接""" global _redis_pool if _redis_pool is None: _redis_pool = redis.from_url(settings.REDIS_URL, decode_responses=True) return _redis_pool async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db), ): """获取当前认证用户""" token = credentials.credentials payload = decode_access_token(token) if payload is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效或过期的 Token", ) user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的 Token 格式", ) from app.services.user_service import UserService user_service = UserService(db) user = await user_service.get_by_id(user_id) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", ) if user.is_banned: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="账号已被封禁", ) return user async def get_admin_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db), ): """获取管理员用户(需要 is_admin=True)""" token = credentials.credentials payload = decode_access_token(token) if payload is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效或过期的 Token", ) if not payload.get("is_admin"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限", ) return payload