feat: embed permission claims in JWT and add captcha verification

- Add GroupClaims model for JWT permission snapshots
- Add JWTPayload model for typed JWT decoding
- Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB)
- Add UserBanStore for instant ban enforcement via Redis + memory fallback
- Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE
- Shorten access_token expiry from 3h to 1h
- Add CaptchaScene enum and verify_captcha_if_needed service
- Add require_captcha dependency injection factory
- Add CLA document and new default settings
- Update all tests for new JWT API

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-10 19:07:00 +08:00
parent 209cb24ab4
commit a99091ea7a
20 changed files with 766 additions and 244 deletions

View File

@@ -24,12 +24,12 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')
from main import app
from sqlmodels.database import get_session
from sqlmodels.group import Group, GroupOptions
from sqlmodels.group import Group, GroupClaims, GroupOptions
from sqlmodels.migration import migration
from sqlmodels.object import Object, ObjectType
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.user import User
from utils.JWT.JWT import create_access_token
from sqlmodels.user import User, UserStatus
from utils.JWT import create_access_token
from utils.password.pwd import Password
@@ -193,7 +193,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
email="testuser@test.local",
nickname="测试用户",
password=Password.hash(password),
status=True,
status=UserStatus.ACTIVE,
storage=0,
score=100,
group_id=group.id,
@@ -211,14 +211,24 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
)
await root_folder.save(db_session)
# 构建权限快照
group.options = group_options
group_claims = GroupClaims.from_group(group)
# 生成访问令牌
access_token, _ = create_access_token({"sub": str(user.id)})
from uuid import uuid4
access_token_obj = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
)
return {
"id": user.id,
"email": user.email,
"password": password,
"token": access_token,
"token": access_token_obj.access_token,
"group_id": group.id,
"policy_id": policy.id,
}
@@ -270,7 +280,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
email="admin@disknext.local",
nickname="管理员",
password=Password.hash(password),
status=True,
status=UserStatus.ACTIVE,
storage=0,
score=9999,
group_id=admin_group.id,
@@ -288,14 +298,24 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
)
await root_folder.save(db_session)
# 构建权限快照
admin_group.options = admin_group_options
admin_group_claims = GroupClaims.from_group(admin_group)
# 生成访问令牌
access_token, _ = create_access_token({"sub": str(admin.id)})
from uuid import uuid4
access_token_obj = create_access_token(
sub=admin.id,
jti=uuid4(),
status=admin.status.value,
group=admin_group_claims,
)
return {
"id": admin.id,
"email": admin.email,
"password": password,
"token": access_token,
"token": access_token_obj.access_token,
"group_id": admin_group.id,
"policy_id": policy.id,
}

View File

@@ -22,10 +22,11 @@ from sqlalchemy.orm import sessionmaker
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from main import app
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels.user import UserStatus
from utils import Password
from utils.JWT import create_access_token
from utils.JWT import JWT
import utils.JWT as JWT
# ==================== 事件循环配置 ====================
@@ -184,12 +185,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="testuser@test.local",
password=Password.hash("testpass123"),
nickname="测试用户",
status=True,
status=UserStatus.ACTIVE,
storage=0,
score=0,
group_id=default_group.id,
avatar="default",
theme="system",
)
test_session.add(test_user)
@@ -198,12 +198,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="admin@disknext.local",
password=Password.hash("adminpass123"),
nickname="管理员",
status=True,
status=UserStatus.ACTIVE,
storage=0,
score=0,
group_id=admin_group.id,
avatar="default",
theme="system",
)
test_session.add(admin_user)
@@ -212,12 +211,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="banneduser@test.local",
password=Password.hash("banned123"),
nickname="封禁用户",
status=False, # 封禁状态
status=UserStatus.ADMIN_BANNED,
storage=0,
score=0,
group_id=default_group.id,
avatar="default",
theme="system",
)
test_session.add(banned_user)
@@ -256,6 +254,10 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
# 8. 设置JWT密钥从数据库加载
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
# 刷新 group options
await test_session.refresh(default_group_options)
await test_session.refresh(admin_group_options)
return test_session
@@ -290,34 +292,68 @@ def banned_user_info() -> dict[str, str]:
# ==================== JWT Token ====================
@pytest.fixture
def test_user_token(test_user_info: dict[str, str]) -> str:
def _build_group_claims(group: Group, group_options: GroupOptions | None) -> GroupClaims:
"""从 Group 对象构建 GroupClaims"""
group.options = group_options
return GroupClaims.from_group(group)
@pytest_asyncio.fixture
async def test_user_token(initialized_db: AsyncSession) -> str:
"""生成测试用户的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": test_user_info["email"]},
user = await User.get(initialized_db, User.email == "testuser@test.local")
group = await Group.get(initialized_db, Group.id == user.group_id)
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
group_claims = _build_group_claims(group, group_options)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
expires_delta=timedelta(hours=1),
)
return token
return result.access_token
@pytest.fixture
def admin_user_token(admin_user_info: dict[str, str]) -> str:
@pytest_asyncio.fixture
async def admin_user_token(initialized_db: AsyncSession) -> str:
"""生成管理员的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": admin_user_info["email"]},
user = await User.get(initialized_db, User.email == "admin@disknext.local")
group = await Group.get(initialized_db, Group.id == user.group_id)
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
group_claims = _build_group_claims(group, group_options)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
expires_delta=timedelta(hours=1),
)
return token
return result.access_token
@pytest.fixture
def expired_token() -> str:
"""生成过期的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": "testuser@test.local"},
expires_delta=timedelta(seconds=-1), # 已过期
group_claims = GroupClaims(
id=uuid4(),
name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
return token
result = create_access_token(
sub=uuid4(),
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(seconds=-1),
)
return result.access_token
# ==================== 认证头 ====================

View File

@@ -1,11 +1,15 @@
"""
认证中间件集成测试
"""
from datetime import timedelta
from uuid import uuid4
import pytest
from httpx import AsyncClient
from datetime import timedelta
from utils.JWT import JWT
from sqlmodels.group import GroupClaims
from utils.JWT import create_access_token, create_refresh_token
import utils.JWT as JWT
# ==================== AuthRequired 测试 ====================
@@ -66,11 +70,14 @@ async def test_auth_required_valid_token(
@pytest.mark.asyncio
async def test_auth_required_token_without_sub(async_client: AsyncClient):
"""测试缺少sub字段的token返回 401"""
token, _ = JWT.create_access_token(
data={"other_field": "value"},
expires_delta=timedelta(hours=1)
)
"""测试缺少必要字段的token返回 401"""
import jwt as pyjwt
# 手动构建一个缺少 status 和 group 的 token
payload = {
"other_field": "value",
"exp": int((__import__('datetime').datetime.now(__import__('datetime').timezone.utc) + timedelta(hours=1)).timestamp()),
}
token = pyjwt.encode(payload, JWT.SECRET_KEY, algorithm="HS256")
response = await async_client.get(
"/api/user/me",
@@ -81,16 +88,29 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
@pytest.mark.asyncio
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
"""测试用户不存在的token返回 401"""
token, _ = JWT.create_access_token(
data={"sub": "nonexistent_user@test.local"},
expires_delta=timedelta(hours=1)
"""测试用户不存在的token返回 403 或 401取决于 Redis 可用性)"""
group_claims = GroupClaims(
id=uuid4(),
name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
result = create_access_token(
sub=uuid4(), # 不存在的用户 UUID
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(hours=1),
)
response = await async_client.get(
"/api/user/me",
headers={"Authorization": f"Bearer {token}"}
headers={"Authorization": f"Bearer {result.access_token}"}
)
# auth_required 会查库,用户不存在时返回 401
assert response.status_code == 401
@@ -234,23 +254,36 @@ async def test_auth_on_storage_endpoint(
@pytest.mark.asyncio
async def test_refresh_token_format(test_user_info: dict[str, str]):
"""测试刷新token格式正确"""
refresh_token, _ = JWT.create_refresh_token(
data={"sub": test_user_info["email"]},
expires_delta=timedelta(days=7)
result = create_refresh_token(
sub=uuid4(),
jti=uuid4(),
expires_delta=timedelta(days=7),
)
assert isinstance(refresh_token, str)
assert len(refresh_token) > 0
assert isinstance(result.refresh_token, str)
assert len(result.refresh_token) > 0
@pytest.mark.asyncio
async def test_access_token_format(test_user_info: dict[str, str]):
"""测试访问token格式正确"""
access_token, expires = JWT.create_access_token(
data={"sub": test_user_info["email"]},
expires_delta=timedelta(hours=1)
group_claims = GroupClaims(
id=uuid4(),
name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
result = create_access_token(
sub=uuid4(),
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(hours=1),
)
assert isinstance(access_token, str)
assert len(access_token) > 0
assert expires is not None
assert isinstance(result.access_token, str)
assert len(result.access_token) > 0
assert result.access_expires is not None

View File

@@ -5,7 +5,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User, ThemeType, UserPublic
from sqlmodels.user import User, ThemeType, UserPublic, UserStatus
from sqlmodels.group import Group
@@ -28,7 +28,7 @@ async def test_user_create(db_session: AsyncSession):
assert user.id is not None
assert user.email == "testuser@test.local"
assert user.nickname == "测试用户"
assert user.status is True
assert user.status == UserStatus.ACTIVE
assert user.storage == 0
assert user.score == 0
@@ -131,7 +131,7 @@ async def test_user_status_default(db_session: AsyncSession):
)
user = await user.save(db_session)
assert user.status is True
assert user.status == UserStatus.ACTIVE
@pytest.mark.asyncio

View File

@@ -4,7 +4,7 @@ Login 服务的单元测试
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User, LoginRequest, TokenResponse
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
from sqlmodels.group import Group
from service.user.login import login
from utils.password.pwd import Password
@@ -22,7 +22,7 @@ async def setup_user(db_session: AsyncSession):
user = User(
email="loginuser@test.local",
password=Password.hash(plain_password),
status=True,
status=UserStatus.ACTIVE,
group_id=group.id
)
user = await user.save(db_session)
@@ -43,7 +43,7 @@ async def setup_banned_user(db_session: AsyncSession):
user = User(
email="banneduser@test.local",
password=Password.hash("password"),
status=False, # 封禁状态
status=UserStatus.ADMIN_BANNED, # 封禁状态
group_id=group.id
)
user = await user.save(db_session)
@@ -63,7 +63,7 @@ async def setup_2fa_user(db_session: AsyncSession):
user = User(
email="2fauser@test.local",
password=Password.hash("password"),
status=True,
status=UserStatus.ACTIVE,
two_factor=secret,
group_id=group.id
)

View File

@@ -1,49 +1,86 @@
"""
JWT 工具的单元测试
"""
import time
from datetime import timedelta, datetime, timezone
from uuid import uuid4, UUID
import jwt as pyjwt
import pytest
from utils.JWT.JWT import create_access_token, create_refresh_token, SECRET_KEY
from sqlmodels.group import GroupClaims
from utils.JWT import create_access_token, create_refresh_token, build_token_payload
# 测试用的 GroupClaims
def _make_group_claims(admin: bool = False) -> GroupClaims:
return GroupClaims(
id=uuid4(),
name="测试组",
max_storage=1073741824,
share_enabled=True,
web_dav_enabled=False,
admin=admin,
speed_limit=0,
)
# 设置测试用的密钥
@pytest.fixture(autouse=True)
def setup_secret_key():
"""为测试设置密钥"""
import utils.JWT.JWT as jwt_module
import utils.JWT as jwt_module
jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests"
yield
# 测试后恢复(虽然在单元测试中不太重要)
def test_create_access_token():
"""测试访问令牌创建"""
data = {"sub": "testuser", "role": "user"}
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, expire_time = create_access_token(data)
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
assert isinstance(token, str)
assert isinstance(expire_time, datetime)
assert isinstance(result.access_token, str)
assert isinstance(result.access_expires, datetime)
# 解码验证
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "testuser"
assert decoded["role"] == "user"
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == str(sub)
assert decoded["jti"] == str(jti)
assert decoded["status"] == "active"
assert decoded["group"]["admin"] is False
assert "exp" in decoded
def test_create_access_token_custom_expiry():
"""测试自定义过期时间"""
data = {"sub": "testuser"}
custom_expiry = timedelta(hours=1)
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
custom_expiry = timedelta(minutes=30)
token, expire_time = create_access_token(data, expires_delta=custom_expiry)
result = create_access_token(sub=sub, jti=jti, status="active", group=group, expires_delta=custom_expiry)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30分钟后
exp_timestamp = decoded["exp"]
now_timestamp = datetime.now(timezone.utc).timestamp()
# 允许1秒误差
assert abs(exp_timestamp - now_timestamp - 1800) < 1
def test_create_access_token_default_expiry():
"""测试访问令牌默认1小时过期"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是1小时后
exp_timestamp = decoded["exp"]
@@ -55,27 +92,29 @@ def test_create_access_token_custom_expiry():
def test_create_refresh_token():
"""测试刷新令牌创建"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
token, expire_time = create_refresh_token(data)
result = create_refresh_token(sub=sub, jti=jti)
assert isinstance(token, str)
assert isinstance(expire_time, datetime)
assert isinstance(result.refresh_token, str)
assert isinstance(result.refresh_expires, datetime)
# 解码验证
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "testuser"
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == str(sub)
assert decoded["token_type"] == "refresh"
assert "exp" in decoded
def test_create_refresh_token_default_expiry():
"""测试刷新令牌默认30天过期"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
token, expire_time = create_refresh_token(data)
result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30天后
exp_timestamp = decoded["exp"]
@@ -86,78 +125,72 @@ def test_create_refresh_token_default_expiry():
assert abs(exp_timestamp - now_timestamp - 2592000) < 1
def test_token_decode():
"""测试令牌解码"""
data = {"sub": "user123", "email": "user@example.com"}
def test_access_token_contains_group_claims():
"""测试访问令牌包含完整的 group claims"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims(admin=True)
token, _ = create_access_token(data)
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
# 解码
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "user123"
assert decoded["email"] == "user@example.com"
def test_token_expired():
"""测试令牌过期"""
data = {"sub": "testuser"}
# 创建一个立即过期的令牌
token, _ = create_access_token(data, expires_delta=timedelta(seconds=-1))
# 尝试解码应该抛出过期异常
with pytest.raises(pyjwt.ExpiredSignatureError):
pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
def test_token_invalid_signature():
"""测试无效签名"""
data = {"sub": "testuser"}
token, _ = create_access_token(data)
# 使用错误的密钥解码
with pytest.raises(pyjwt.InvalidSignatureError):
pyjwt.decode(token, "wrong_secret_key", algorithms=["HS256"])
assert decoded["group"]["admin"] is True
assert decoded["group"]["name"] == "测试组"
assert decoded["group"]["max_storage"] == 1073741824
assert decoded["group"]["share_enabled"] is True
def test_access_token_does_not_have_token_type():
"""测试访问令牌不包含 token_type"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, _ = create_access_token(data)
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert "token_type" not in decoded
def test_refresh_token_has_token_type():
"""测试刷新令牌包含 token_type"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
token, _ = create_refresh_token(data)
result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["token_type"] == "refresh"
def test_token_payload_preserved():
"""测试自定义负载保留"""
data = {
"sub": "user123",
"name": "Test User",
"roles": ["admin", "user"],
"metadata": {"key": "value"}
}
def test_token_expired():
"""测试令牌过期"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, _ = create_access_token(data)
# 创建一个立即过期的令牌
result = create_access_token(
sub=sub, jti=jti, status="active", group=group,
expires_delta=timedelta(seconds=-1),
)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 尝试解码应该抛出过期异常
with pytest.raises(pyjwt.ExpiredSignatureError):
pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "user123"
assert decoded["name"] == "Test User"
assert decoded["roles"] == ["admin", "user"]
assert decoded["metadata"] == {"key": "value"}
def test_token_invalid_signature():
"""测试无效签名"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
# 使用错误的密钥解码
with pytest.raises(pyjwt.InvalidSignatureError):
pyjwt.decode(result.access_token, "wrong_secret_key", algorithms=["HS256"])