Files
disknext/middleware/auth.py
于小丘 a99091ea7a feat: embed permission claims in JWT and add captcha verification
- Add GroupClaims model for JWT permission snapshots
- Add JWTPayload model for typed JWT decoding
- Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB)
- Add UserBanStore for instant ban enforcement via Redis + memory fallback
- Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE
- Shorten access_token expiry from 3h to 1h
- Add CaptchaScene enum and verify_captcha_if_needed service
- Add require_captcha dependency injection factory
- Add CLA document and new default settings
- Update all tests for new JWT API

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 19:07:48 +08:00

99 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Annotated
from uuid import UUID
from fastapi import Depends
import jwt
from sqlmodels.user import JWTPayload, User, UserStatus
from utils import JWT
from .dependencies import SessionDep
from utils import http_exceptions
from service.redis import RedisManager
from service.redis.user_ban_store import UserBanStore
async def jwt_required(
session: SessionDep,
token: Annotated[str, Depends(JWT.oauth2_scheme)],
) -> JWTPayload:
"""
验证 JWT 并返回 claims。
封禁检查策略:
1. JWT 内嵌 status 检查(签发时快照)
2. Redis 黑名单检查(即时封禁,如果 Redis 可用)
3. Redis 不可用时查库检查 status降级方案
"""
try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
claims = JWTPayload(
sub=payload["sub"],
jti=payload["jti"],
status=payload["status"],
group=payload["group"],
)
except (jwt.InvalidTokenError, KeyError, ValueError):
http_exceptions.raise_unauthorized("凭据过期或无效")
# 1. JWT 内嵌 status 检查
if claims.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
# 2. 即时封禁检查
user_id_str = str(claims.sub)
if RedisManager.is_available():
# Redis 可用:查黑名单
if await UserBanStore.is_banned(user_id_str):
http_exceptions.raise_forbidden("账户已被禁用")
else:
# Redis 不可用:查库(仅 status 字段,不加载关系)
user = await User.get(session, User.id == claims.sub)
if not user or user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
return claims
async def admin_required(
claims: Annotated[JWTPayload, Depends(jwt_required)],
) -> JWTPayload:
"""
验证管理员权限(仅读取 JWT claims不查库
使用方法:
>>> APIRouter(dependencies=[Depends(admin_required)])
"""
if not claims.group.admin:
http_exceptions.raise_forbidden("Admin Required")
return claims
async def auth_required(
session: SessionDep,
claims: Annotated[JWTPayload, Depends(jwt_required)],
) -> User:
"""验证 JWT + 从数据库加载完整 User含 group 关系)"""
user = await User.get(session, User.id == claims.sub, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
return user
def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
"""
验证下载令牌并返回 (jti, file_id, owner_id)。
:param token: JWT 令牌字符串
:return: (jti, file_id, owner_id) 或 None验证失败
"""
try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
if payload.get("type") != "download":
http_exceptions.raise_unauthorized("Download token required")
jti = payload.get("jti")
if not jti:
http_exceptions.raise_unauthorized("Download token required")
return jti, UUID(payload["file_id"]), UUID(payload["owner_id"])
except jwt.InvalidTokenError:
http_exceptions.raise_unauthorized("Download token required")