feat: embed permission claims in JWT and add captcha verification

- Add GroupClaims model for JWT permission snapshots
- Add JWTPayload model for typed JWT decoding
- Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB)
- Add UserBanStore for instant ban enforcement via Redis + memory fallback
- Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE
- Shorten access_token expiry from 3h to 1h
- Add CaptchaScene enum and verify_captcha_if_needed service
- Add require_captcha dependency injection factory
- Add CLA document and new default settings
- Update all tests for new JWT API

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-10 19:07:00 +08:00
parent 209cb24ab4
commit a99091ea7a
20 changed files with 766 additions and 244 deletions

View File

@@ -5,7 +5,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User, ThemeType, UserPublic
from sqlmodels.user import User, ThemeType, UserPublic, UserStatus
from sqlmodels.group import Group
@@ -28,7 +28,7 @@ async def test_user_create(db_session: AsyncSession):
assert user.id is not None
assert user.email == "testuser@test.local"
assert user.nickname == "测试用户"
assert user.status is True
assert user.status == UserStatus.ACTIVE
assert user.storage == 0
assert user.score == 0
@@ -131,7 +131,7 @@ async def test_user_status_default(db_session: AsyncSession):
)
user = await user.save(db_session)
assert user.status is True
assert user.status == UserStatus.ACTIVE
@pytest.mark.asyncio

View File

@@ -4,7 +4,7 @@ Login 服务的单元测试
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User, LoginRequest, TokenResponse
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
from sqlmodels.group import Group
from service.user.login import login
from utils.password.pwd import Password
@@ -22,7 +22,7 @@ async def setup_user(db_session: AsyncSession):
user = User(
email="loginuser@test.local",
password=Password.hash(plain_password),
status=True,
status=UserStatus.ACTIVE,
group_id=group.id
)
user = await user.save(db_session)
@@ -43,7 +43,7 @@ async def setup_banned_user(db_session: AsyncSession):
user = User(
email="banneduser@test.local",
password=Password.hash("password"),
status=False, # 封禁状态
status=UserStatus.ADMIN_BANNED, # 封禁状态
group_id=group.id
)
user = await user.save(db_session)
@@ -63,7 +63,7 @@ async def setup_2fa_user(db_session: AsyncSession):
user = User(
email="2fauser@test.local",
password=Password.hash("password"),
status=True,
status=UserStatus.ACTIVE,
two_factor=secret,
group_id=group.id
)

View File

@@ -1,49 +1,86 @@
"""
JWT 工具的单元测试
"""
import time
from datetime import timedelta, datetime, timezone
from uuid import uuid4, UUID
import jwt as pyjwt
import pytest
from utils.JWT.JWT import create_access_token, create_refresh_token, SECRET_KEY
from sqlmodels.group import GroupClaims
from utils.JWT import create_access_token, create_refresh_token, build_token_payload
# 测试用的 GroupClaims
def _make_group_claims(admin: bool = False) -> GroupClaims:
return GroupClaims(
id=uuid4(),
name="测试组",
max_storage=1073741824,
share_enabled=True,
web_dav_enabled=False,
admin=admin,
speed_limit=0,
)
# 设置测试用的密钥
@pytest.fixture(autouse=True)
def setup_secret_key():
"""为测试设置密钥"""
import utils.JWT.JWT as jwt_module
import utils.JWT as jwt_module
jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests"
yield
# 测试后恢复(虽然在单元测试中不太重要)
def test_create_access_token():
"""测试访问令牌创建"""
data = {"sub": "testuser", "role": "user"}
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, expire_time = create_access_token(data)
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
assert isinstance(token, str)
assert isinstance(expire_time, datetime)
assert isinstance(result.access_token, str)
assert isinstance(result.access_expires, datetime)
# 解码验证
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "testuser"
assert decoded["role"] == "user"
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == str(sub)
assert decoded["jti"] == str(jti)
assert decoded["status"] == "active"
assert decoded["group"]["admin"] is False
assert "exp" in decoded
def test_create_access_token_custom_expiry():
"""测试自定义过期时间"""
data = {"sub": "testuser"}
custom_expiry = timedelta(hours=1)
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
custom_expiry = timedelta(minutes=30)
token, expire_time = create_access_token(data, expires_delta=custom_expiry)
result = create_access_token(sub=sub, jti=jti, status="active", group=group, expires_delta=custom_expiry)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30分钟后
exp_timestamp = decoded["exp"]
now_timestamp = datetime.now(timezone.utc).timestamp()
# 允许1秒误差
assert abs(exp_timestamp - now_timestamp - 1800) < 1
def test_create_access_token_default_expiry():
"""测试访问令牌默认1小时过期"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是1小时后
exp_timestamp = decoded["exp"]
@@ -55,27 +92,29 @@ def test_create_access_token_custom_expiry():
def test_create_refresh_token():
"""测试刷新令牌创建"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
token, expire_time = create_refresh_token(data)
result = create_refresh_token(sub=sub, jti=jti)
assert isinstance(token, str)
assert isinstance(expire_time, datetime)
assert isinstance(result.refresh_token, str)
assert isinstance(result.refresh_expires, datetime)
# 解码验证
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "testuser"
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == str(sub)
assert decoded["token_type"] == "refresh"
assert "exp" in decoded
def test_create_refresh_token_default_expiry():
"""测试刷新令牌默认30天过期"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
token, expire_time = create_refresh_token(data)
result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30天后
exp_timestamp = decoded["exp"]
@@ -86,78 +125,72 @@ def test_create_refresh_token_default_expiry():
assert abs(exp_timestamp - now_timestamp - 2592000) < 1
def test_token_decode():
"""测试令牌解码"""
data = {"sub": "user123", "email": "user@example.com"}
def test_access_token_contains_group_claims():
"""测试访问令牌包含完整的 group claims"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims(admin=True)
token, _ = create_access_token(data)
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
# 解码
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "user123"
assert decoded["email"] == "user@example.com"
def test_token_expired():
"""测试令牌过期"""
data = {"sub": "testuser"}
# 创建一个立即过期的令牌
token, _ = create_access_token(data, expires_delta=timedelta(seconds=-1))
# 尝试解码应该抛出过期异常
with pytest.raises(pyjwt.ExpiredSignatureError):
pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
def test_token_invalid_signature():
"""测试无效签名"""
data = {"sub": "testuser"}
token, _ = create_access_token(data)
# 使用错误的密钥解码
with pytest.raises(pyjwt.InvalidSignatureError):
pyjwt.decode(token, "wrong_secret_key", algorithms=["HS256"])
assert decoded["group"]["admin"] is True
assert decoded["group"]["name"] == "测试组"
assert decoded["group"]["max_storage"] == 1073741824
assert decoded["group"]["share_enabled"] is True
def test_access_token_does_not_have_token_type():
"""测试访问令牌不包含 token_type"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, _ = create_access_token(data)
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert "token_type" not in decoded
def test_refresh_token_has_token_type():
"""测试刷新令牌包含 token_type"""
data = {"sub": "testuser"}
sub = uuid4()
jti = uuid4()
token, _ = create_refresh_token(data)
result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["token_type"] == "refresh"
def test_token_payload_preserved():
"""测试自定义负载保留"""
data = {
"sub": "user123",
"name": "Test User",
"roles": ["admin", "user"],
"metadata": {"key": "value"}
}
def test_token_expired():
"""测试令牌过期"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, _ = create_access_token(data)
# 创建一个立即过期的令牌
result = create_access_token(
sub=sub, jti=jti, status="active", group=group,
expires_delta=timedelta(seconds=-1),
)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 尝试解码应该抛出过期异常
with pytest.raises(pyjwt.ExpiredSignatureError):
pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "user123"
assert decoded["name"] == "Test User"
assert decoded["roles"] == ["admin", "user"]
assert decoded["metadata"] == {"key": "value"}
def test_token_invalid_signature():
"""测试无效签名"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
# 使用错误的密钥解码
with pytest.raises(pyjwt.InvalidSignatureError):
pyjwt.decode(result.access_token, "wrong_secret_key", algorithms=["HS256"])