feat: embed permission claims in JWT and add captcha verification

- Add GroupClaims model for JWT permission snapshots
- Add JWTPayload model for typed JWT decoding
- Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB)
- Add UserBanStore for instant ban enforcement via Redis + memory fallback
- Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE
- Shorten access_token expiry from 3h to 1h
- Add CaptchaScene enum and verify_captcha_if_needed service
- Add require_captcha dependency injection factory
- Add CLA document and new default settings
- Update all tests for new JWT API

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-10 19:07:00 +08:00
parent 209cb24ab4
commit a99091ea7a
20 changed files with 766 additions and 244 deletions

View File

@@ -22,10 +22,11 @@ from sqlalchemy.orm import sessionmaker
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from main import app
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels.user import UserStatus
from utils import Password
from utils.JWT import create_access_token
from utils.JWT import JWT
import utils.JWT as JWT
# ==================== 事件循环配置 ====================
@@ -184,12 +185,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="testuser@test.local",
password=Password.hash("testpass123"),
nickname="测试用户",
status=True,
status=UserStatus.ACTIVE,
storage=0,
score=0,
group_id=default_group.id,
avatar="default",
theme="system",
)
test_session.add(test_user)
@@ -198,12 +198,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="admin@disknext.local",
password=Password.hash("adminpass123"),
nickname="管理员",
status=True,
status=UserStatus.ACTIVE,
storage=0,
score=0,
group_id=admin_group.id,
avatar="default",
theme="system",
)
test_session.add(admin_user)
@@ -212,12 +211,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="banneduser@test.local",
password=Password.hash("banned123"),
nickname="封禁用户",
status=False, # 封禁状态
status=UserStatus.ADMIN_BANNED,
storage=0,
score=0,
group_id=default_group.id,
avatar="default",
theme="system",
)
test_session.add(banned_user)
@@ -256,6 +254,10 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
# 8. 设置JWT密钥从数据库加载
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
# 刷新 group options
await test_session.refresh(default_group_options)
await test_session.refresh(admin_group_options)
return test_session
@@ -290,34 +292,68 @@ def banned_user_info() -> dict[str, str]:
# ==================== JWT Token ====================
@pytest.fixture
def test_user_token(test_user_info: dict[str, str]) -> str:
def _build_group_claims(group: Group, group_options: GroupOptions | None) -> GroupClaims:
"""从 Group 对象构建 GroupClaims"""
group.options = group_options
return GroupClaims.from_group(group)
@pytest_asyncio.fixture
async def test_user_token(initialized_db: AsyncSession) -> str:
"""生成测试用户的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": test_user_info["email"]},
user = await User.get(initialized_db, User.email == "testuser@test.local")
group = await Group.get(initialized_db, Group.id == user.group_id)
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
group_claims = _build_group_claims(group, group_options)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
expires_delta=timedelta(hours=1),
)
return token
return result.access_token
@pytest.fixture
def admin_user_token(admin_user_info: dict[str, str]) -> str:
@pytest_asyncio.fixture
async def admin_user_token(initialized_db: AsyncSession) -> str:
"""生成管理员的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": admin_user_info["email"]},
user = await User.get(initialized_db, User.email == "admin@disknext.local")
group = await Group.get(initialized_db, Group.id == user.group_id)
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
group_claims = _build_group_claims(group, group_options)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
expires_delta=timedelta(hours=1),
)
return token
return result.access_token
@pytest.fixture
def expired_token() -> str:
"""生成过期的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": "testuser@test.local"},
expires_delta=timedelta(seconds=-1), # 已过期
group_claims = GroupClaims(
id=uuid4(),
name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
return token
result = create_access_token(
sub=uuid4(),
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(seconds=-1),
)
return result.access_token
# ==================== 认证头 ====================

View File

@@ -1,11 +1,15 @@
"""
认证中间件集成测试
"""
from datetime import timedelta
from uuid import uuid4
import pytest
from httpx import AsyncClient
from datetime import timedelta
from utils.JWT import JWT
from sqlmodels.group import GroupClaims
from utils.JWT import create_access_token, create_refresh_token
import utils.JWT as JWT
# ==================== AuthRequired 测试 ====================
@@ -66,11 +70,14 @@ async def test_auth_required_valid_token(
@pytest.mark.asyncio
async def test_auth_required_token_without_sub(async_client: AsyncClient):
"""测试缺少sub字段的token返回 401"""
token, _ = JWT.create_access_token(
data={"other_field": "value"},
expires_delta=timedelta(hours=1)
)
"""测试缺少必要字段的token返回 401"""
import jwt as pyjwt
# 手动构建一个缺少 status 和 group 的 token
payload = {
"other_field": "value",
"exp": int((__import__('datetime').datetime.now(__import__('datetime').timezone.utc) + timedelta(hours=1)).timestamp()),
}
token = pyjwt.encode(payload, JWT.SECRET_KEY, algorithm="HS256")
response = await async_client.get(
"/api/user/me",
@@ -81,16 +88,29 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
@pytest.mark.asyncio
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
"""测试用户不存在的token返回 401"""
token, _ = JWT.create_access_token(
data={"sub": "nonexistent_user@test.local"},
expires_delta=timedelta(hours=1)
"""测试用户不存在的token返回 403 或 401取决于 Redis 可用性)"""
group_claims = GroupClaims(
id=uuid4(),
name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
result = create_access_token(
sub=uuid4(), # 不存在的用户 UUID
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(hours=1),
)
response = await async_client.get(
"/api/user/me",
headers={"Authorization": f"Bearer {token}"}
headers={"Authorization": f"Bearer {result.access_token}"}
)
# auth_required 会查库,用户不存在时返回 401
assert response.status_code == 401
@@ -234,23 +254,36 @@ async def test_auth_on_storage_endpoint(
@pytest.mark.asyncio
async def test_refresh_token_format(test_user_info: dict[str, str]):
"""测试刷新token格式正确"""
refresh_token, _ = JWT.create_refresh_token(
data={"sub": test_user_info["email"]},
expires_delta=timedelta(days=7)
result = create_refresh_token(
sub=uuid4(),
jti=uuid4(),
expires_delta=timedelta(days=7),
)
assert isinstance(refresh_token, str)
assert len(refresh_token) > 0
assert isinstance(result.refresh_token, str)
assert len(result.refresh_token) > 0
@pytest.mark.asyncio
async def test_access_token_format(test_user_info: dict[str, str]):
"""测试访问token格式正确"""
access_token, expires = JWT.create_access_token(
data={"sub": test_user_info["email"]},
expires_delta=timedelta(hours=1)
group_claims = GroupClaims(
id=uuid4(),
name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
result = create_access_token(
sub=uuid4(),
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(hours=1),
)
assert isinstance(access_token, str)
assert len(access_token) > 0
assert expires is not None
assert isinstance(result.access_token, str)
assert len(result.access_token) > 0
assert result.access_expires is not None