Files
disknext/tests/unit/utils/test_jwt.py
于小丘 a99091ea7a feat: embed permission claims in JWT and add captcha verification
- Add GroupClaims model for JWT permission snapshots
- Add JWTPayload model for typed JWT decoding
- Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB)
- Add UserBanStore for instant ban enforcement via Redis + memory fallback
- Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE
- Shorten access_token expiry from 3h to 1h
- Add CaptchaScene enum and verify_captcha_if_needed service
- Add require_captcha dependency injection factory
- Add CLA document and new default settings
- Update all tests for new JWT API

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 19:07:48 +08:00

197 lines
5.8 KiB
Python

"""
JWT 工具的单元测试
"""
from datetime import timedelta, datetime, timezone
from uuid import uuid4, UUID
import jwt as pyjwt
import pytest
from sqlmodels.group import GroupClaims
from utils.JWT import create_access_token, create_refresh_token, build_token_payload
# 测试用的 GroupClaims
def _make_group_claims(admin: bool = False) -> GroupClaims:
return GroupClaims(
id=uuid4(),
name="测试组",
max_storage=1073741824,
share_enabled=True,
web_dav_enabled=False,
admin=admin,
speed_limit=0,
)
# 设置测试用的密钥
@pytest.fixture(autouse=True)
def setup_secret_key():
"""为测试设置密钥"""
import utils.JWT as jwt_module
jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests"
yield
def test_create_access_token():
"""测试访问令牌创建"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
assert isinstance(result.access_token, str)
assert isinstance(result.access_expires, datetime)
# 解码验证
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == str(sub)
assert decoded["jti"] == str(jti)
assert decoded["status"] == "active"
assert decoded["group"]["admin"] is False
assert "exp" in decoded
def test_create_access_token_custom_expiry():
"""测试自定义过期时间"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
custom_expiry = timedelta(minutes=30)
result = create_access_token(sub=sub, jti=jti, status="active", group=group, expires_delta=custom_expiry)
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30分钟后
exp_timestamp = decoded["exp"]
now_timestamp = datetime.now(timezone.utc).timestamp()
# 允许1秒误差
assert abs(exp_timestamp - now_timestamp - 1800) < 1
def test_create_access_token_default_expiry():
"""测试访问令牌默认1小时过期"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是1小时后
exp_timestamp = decoded["exp"]
now_timestamp = datetime.now(timezone.utc).timestamp()
# 允许1秒误差
assert abs(exp_timestamp - now_timestamp - 3600) < 1
def test_create_refresh_token():
"""测试刷新令牌创建"""
sub = uuid4()
jti = uuid4()
result = create_refresh_token(sub=sub, jti=jti)
assert isinstance(result.refresh_token, str)
assert isinstance(result.refresh_expires, datetime)
# 解码验证
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == str(sub)
assert decoded["token_type"] == "refresh"
assert "exp" in decoded
def test_create_refresh_token_default_expiry():
"""测试刷新令牌默认30天过期"""
sub = uuid4()
jti = uuid4()
result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30天后
exp_timestamp = decoded["exp"]
now_timestamp = datetime.now(timezone.utc).timestamp()
# 30天 = 30 * 24 * 3600 = 2592000 秒
# 允许1秒误差
assert abs(exp_timestamp - now_timestamp - 2592000) < 1
def test_access_token_contains_group_claims():
"""测试访问令牌包含完整的 group claims"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims(admin=True)
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["group"]["admin"] is True
assert decoded["group"]["name"] == "测试组"
assert decoded["group"]["max_storage"] == 1073741824
assert decoded["group"]["share_enabled"] is True
def test_access_token_does_not_have_token_type():
"""测试访问令牌不包含 token_type"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert "token_type" not in decoded
def test_refresh_token_has_token_type():
"""测试刷新令牌包含 token_type"""
sub = uuid4()
jti = uuid4()
result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["token_type"] == "refresh"
def test_token_expired():
"""测试令牌过期"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
# 创建一个立即过期的令牌
result = create_access_token(
sub=sub, jti=jti, status="active", group=group,
expires_delta=timedelta(seconds=-1),
)
# 尝试解码应该抛出过期异常
with pytest.raises(pyjwt.ExpiredSignatureError):
pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
def test_token_invalid_signature():
"""测试无效签名"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
# 使用错误的密钥解码
with pytest.raises(pyjwt.InvalidSignatureError):
pyjwt.decode(result.access_token, "wrong_secret_key", algorithms=["HS256"])