feat: embed permission claims in JWT and add captcha verification
- Add GroupClaims model for JWT permission snapshots - Add JWTPayload model for typed JWT decoding - Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB) - Add UserBanStore for instant ban enforcement via Redis + memory fallback - Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE - Shorten access_token expiry from 3h to 1h - Add CaptchaScene enum and verify_captcha_if_needed service - Add require_captcha dependency injection factory - Add CLA document and new default settings - Update all tests for new JWT API Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -24,12 +24,12 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')
|
||||
|
||||
from main import app
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
from sqlmodels.group import Group, GroupClaims, GroupOptions
|
||||
from sqlmodels.migration import migration
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.user import User
|
||||
from utils.JWT.JWT import create_access_token
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from utils.JWT import create_access_token
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=100,
|
||||
group_id=group.id,
|
||||
@@ -211,14 +211,24 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
await root_folder.save(db_session)
|
||||
|
||||
# 构建权限快照
|
||||
group.options = group_options
|
||||
group_claims = GroupClaims.from_group(group)
|
||||
|
||||
# 生成访问令牌
|
||||
access_token, _ = create_access_token({"sub": str(user.id)})
|
||||
from uuid import uuid4
|
||||
access_token_obj = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"token": access_token_obj.access_token,
|
||||
"group_id": group.id,
|
||||
"policy_id": policy.id,
|
||||
}
|
||||
@@ -270,7 +280,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
email="admin@disknext.local",
|
||||
nickname="管理员",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=9999,
|
||||
group_id=admin_group.id,
|
||||
@@ -288,14 +298,24 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
await root_folder.save(db_session)
|
||||
|
||||
# 构建权限快照
|
||||
admin_group.options = admin_group_options
|
||||
admin_group_claims = GroupClaims.from_group(admin_group)
|
||||
|
||||
# 生成访问令牌
|
||||
access_token, _ = create_access_token({"sub": str(admin.id)})
|
||||
from uuid import uuid4
|
||||
access_token_obj = create_access_token(
|
||||
sub=admin.id,
|
||||
jti=uuid4(),
|
||||
status=admin.status.value,
|
||||
group=admin_group_claims,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": admin.id,
|
||||
"email": admin.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"token": access_token_obj.access_token,
|
||||
"group_id": admin_group.id,
|
||||
"policy_id": policy.id,
|
||||
}
|
||||
|
||||
@@ -22,10 +22,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from main import app
|
||||
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
from utils.JWT import JWT
|
||||
import utils.JWT as JWT
|
||||
|
||||
|
||||
# ==================== 事件循环配置 ====================
|
||||
@@ -184,12 +185,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
email="testuser@test.local",
|
||||
password=Password.hash("testpass123"),
|
||||
nickname="测试用户",
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=default_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(test_user)
|
||||
|
||||
@@ -198,12 +198,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
email="admin@disknext.local",
|
||||
password=Password.hash("adminpass123"),
|
||||
nickname="管理员",
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=admin_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(admin_user)
|
||||
|
||||
@@ -212,12 +211,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("banned123"),
|
||||
nickname="封禁用户",
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=default_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(banned_user)
|
||||
|
||||
@@ -256,6 +254,10 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 8. 设置JWT密钥(从数据库加载)
|
||||
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
|
||||
|
||||
# 刷新 group options
|
||||
await test_session.refresh(default_group_options)
|
||||
await test_session.refresh(admin_group_options)
|
||||
|
||||
return test_session
|
||||
|
||||
|
||||
@@ -290,34 +292,68 @@ def banned_user_info() -> dict[str, str]:
|
||||
|
||||
# ==================== JWT Token ====================
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_token(test_user_info: dict[str, str]) -> str:
|
||||
def _build_group_claims(group: Group, group_options: GroupOptions | None) -> GroupClaims:
|
||||
"""从 Group 对象构建 GroupClaims"""
|
||||
group.options = group_options
|
||||
return GroupClaims.from_group(group)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_token(initialized_db: AsyncSession) -> str:
|
||||
"""生成测试用户的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": test_user_info["email"]},
|
||||
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
group = await Group.get(initialized_db, Group.id == user.group_id)
|
||||
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
|
||||
group_claims = _build_group_claims(group, group_options)
|
||||
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
return result.access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_token(admin_user_info: dict[str, str]) -> str:
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_user_token(initialized_db: AsyncSession) -> str:
|
||||
"""生成管理员的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": admin_user_info["email"]},
|
||||
user = await User.get(initialized_db, User.email == "admin@disknext.local")
|
||||
group = await Group.get(initialized_db, Group.id == user.group_id)
|
||||
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
|
||||
group_claims = _build_group_claims(group, group_options)
|
||||
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
return result.access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expired_token() -> str:
|
||||
"""生成过期的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "testuser@test.local"},
|
||||
expires_delta=timedelta(seconds=-1), # 已过期
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
return token
|
||||
result = create_access_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
return result.access_token
|
||||
|
||||
|
||||
# ==================== 认证头 ====================
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""
|
||||
认证中间件集成测试
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from datetime import timedelta
|
||||
|
||||
from utils.JWT import JWT
|
||||
from sqlmodels.group import GroupClaims
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
import utils.JWT as JWT
|
||||
|
||||
|
||||
# ==================== AuthRequired 测试 ====================
|
||||
@@ -66,11 +70,14 @@ async def test_auth_required_valid_token(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
"""测试缺少sub字段的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"other_field": "value"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
)
|
||||
"""测试缺少必要字段的token返回 401"""
|
||||
import jwt as pyjwt
|
||||
# 手动构建一个缺少 status 和 group 的 token
|
||||
payload = {
|
||||
"other_field": "value",
|
||||
"exp": int((__import__('datetime').datetime.now(__import__('datetime').timezone.utc) + timedelta(hours=1)).timestamp()),
|
||||
}
|
||||
token = pyjwt.encode(payload, JWT.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/user/me",
|
||||
@@ -81,16 +88,29 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
|
||||
"""测试用户不存在的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "nonexistent_user@test.local"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
"""测试用户不存在的token返回 403 或 401(取决于 Redis 可用性)"""
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
result = create_access_token(
|
||||
sub=uuid4(), # 不存在的用户 UUID
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/user/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
headers={"Authorization": f"Bearer {result.access_token}"}
|
||||
)
|
||||
# auth_required 会查库,用户不存在时返回 401
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@@ -234,23 +254,36 @@ async def test_auth_on_storage_endpoint(
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_format(test_user_info: dict[str, str]):
|
||||
"""测试刷新token格式正确"""
|
||||
refresh_token, _ = JWT.create_refresh_token(
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(days=7)
|
||||
result = create_refresh_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
expires_delta=timedelta(days=7),
|
||||
)
|
||||
|
||||
assert isinstance(refresh_token, str)
|
||||
assert len(refresh_token) > 0
|
||||
assert isinstance(result.refresh_token, str)
|
||||
assert len(result.refresh_token) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_token_format(test_user_info: dict[str, str]):
|
||||
"""测试访问token格式正确"""
|
||||
access_token, expires = JWT.create_access_token(
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1)
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
result = create_access_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert isinstance(access_token, str)
|
||||
assert len(access_token) > 0
|
||||
assert expires is not None
|
||||
assert isinstance(result.access_token, str)
|
||||
assert len(result.access_token) > 0
|
||||
assert result.access_expires is not None
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User, ThemeType, UserPublic
|
||||
from sqlmodels.user import User, ThemeType, UserPublic, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ async def test_user_create(db_session: AsyncSession):
|
||||
assert user.id is not None
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.nickname == "测试用户"
|
||||
assert user.status is True
|
||||
assert user.status == UserStatus.ACTIVE
|
||||
assert user.storage == 0
|
||||
assert user.score == 0
|
||||
|
||||
@@ -131,7 +131,7 @@ async def test_user_status_default(db_session: AsyncSession):
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.status is True
|
||||
assert user.status == UserStatus.ACTIVE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -4,7 +4,7 @@ Login 服务的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
from service.user.login import login
|
||||
from utils.password.pwd import Password
|
||||
@@ -22,7 +22,7 @@ async def setup_user(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="loginuser@test.local",
|
||||
password=Password.hash(plain_password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -43,7 +43,7 @@ async def setup_banned_user(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED, # 封禁状态
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -63,7 +63,7 @@ async def setup_2fa_user(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="2fauser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
two_factor=secret,
|
||||
group_id=group.id
|
||||
)
|
||||
|
||||
@@ -1,49 +1,86 @@
|
||||
"""
|
||||
JWT 工具的单元测试
|
||||
"""
|
||||
import time
|
||||
from datetime import timedelta, datetime, timezone
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
|
||||
from utils.JWT.JWT import create_access_token, create_refresh_token, SECRET_KEY
|
||||
from sqlmodels.group import GroupClaims
|
||||
from utils.JWT import create_access_token, create_refresh_token, build_token_payload
|
||||
|
||||
|
||||
# 测试用的 GroupClaims
|
||||
def _make_group_claims(admin: bool = False) -> GroupClaims:
|
||||
return GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=1073741824,
|
||||
share_enabled=True,
|
||||
web_dav_enabled=False,
|
||||
admin=admin,
|
||||
speed_limit=0,
|
||||
)
|
||||
|
||||
|
||||
# 设置测试用的密钥
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_secret_key():
|
||||
"""为测试设置密钥"""
|
||||
import utils.JWT.JWT as jwt_module
|
||||
import utils.JWT as jwt_module
|
||||
jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests"
|
||||
yield
|
||||
# 测试后恢复(虽然在单元测试中不太重要)
|
||||
|
||||
|
||||
def test_create_access_token():
|
||||
"""测试访问令牌创建"""
|
||||
data = {"sub": "testuser", "role": "user"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, expire_time = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(expire_time, datetime)
|
||||
assert isinstance(result.access_token, str)
|
||||
assert isinstance(result.access_expires, datetime)
|
||||
|
||||
# 解码验证
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == "testuser"
|
||||
assert decoded["role"] == "user"
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == str(sub)
|
||||
assert decoded["jti"] == str(jti)
|
||||
assert decoded["status"] == "active"
|
||||
assert decoded["group"]["admin"] is False
|
||||
assert "exp" in decoded
|
||||
|
||||
|
||||
def test_create_access_token_custom_expiry():
|
||||
"""测试自定义过期时间"""
|
||||
data = {"sub": "testuser"}
|
||||
custom_expiry = timedelta(hours=1)
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
custom_expiry = timedelta(minutes=30)
|
||||
|
||||
token, expire_time = create_access_token(data, expires_delta=custom_expiry)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group, expires_delta=custom_expiry)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是30分钟后
|
||||
exp_timestamp = decoded["exp"]
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# 允许1秒误差
|
||||
assert abs(exp_timestamp - now_timestamp - 1800) < 1
|
||||
|
||||
|
||||
def test_create_access_token_default_expiry():
|
||||
"""测试访问令牌默认1小时过期"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是1小时后
|
||||
exp_timestamp = decoded["exp"]
|
||||
@@ -55,27 +92,29 @@ def test_create_access_token_custom_expiry():
|
||||
|
||||
def test_create_refresh_token():
|
||||
"""测试刷新令牌创建"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, expire_time = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(expire_time, datetime)
|
||||
assert isinstance(result.refresh_token, str)
|
||||
assert isinstance(result.refresh_expires, datetime)
|
||||
|
||||
# 解码验证
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == "testuser"
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == str(sub)
|
||||
assert decoded["token_type"] == "refresh"
|
||||
assert "exp" in decoded
|
||||
|
||||
|
||||
def test_create_refresh_token_default_expiry():
|
||||
"""测试刷新令牌默认30天过期"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, expire_time = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是30天后
|
||||
exp_timestamp = decoded["exp"]
|
||||
@@ -86,78 +125,72 @@ def test_create_refresh_token_default_expiry():
|
||||
assert abs(exp_timestamp - now_timestamp - 2592000) < 1
|
||||
|
||||
|
||||
def test_token_decode():
|
||||
"""测试令牌解码"""
|
||||
data = {"sub": "user123", "email": "user@example.com"}
|
||||
def test_access_token_contains_group_claims():
|
||||
"""测试访问令牌包含完整的 group claims"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims(admin=True)
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
# 解码
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["sub"] == "user123"
|
||||
assert decoded["email"] == "user@example.com"
|
||||
|
||||
|
||||
def test_token_expired():
|
||||
"""测试令牌过期"""
|
||||
data = {"sub": "testuser"}
|
||||
|
||||
# 创建一个立即过期的令牌
|
||||
token, _ = create_access_token(data, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
# 尝试解码应该抛出过期异常
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
|
||||
def test_token_invalid_signature():
|
||||
"""测试无效签名"""
|
||||
data = {"sub": "testuser"}
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
|
||||
# 使用错误的密钥解码
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
pyjwt.decode(token, "wrong_secret_key", algorithms=["HS256"])
|
||||
assert decoded["group"]["admin"] is True
|
||||
assert decoded["group"]["name"] == "测试组"
|
||||
assert decoded["group"]["max_storage"] == 1073741824
|
||||
assert decoded["group"]["share_enabled"] is True
|
||||
|
||||
|
||||
def test_access_token_does_not_have_token_type():
|
||||
"""测试访问令牌不包含 token_type"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert "token_type" not in decoded
|
||||
|
||||
|
||||
def test_refresh_token_has_token_type():
|
||||
"""测试刷新令牌包含 token_type"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, _ = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["token_type"] == "refresh"
|
||||
|
||||
|
||||
def test_token_payload_preserved():
|
||||
"""测试自定义负载保留"""
|
||||
data = {
|
||||
"sub": "user123",
|
||||
"name": "Test User",
|
||||
"roles": ["admin", "user"],
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
def test_token_expired():
|
||||
"""测试令牌过期"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
# 创建一个立即过期的令牌
|
||||
result = create_access_token(
|
||||
sub=sub, jti=jti, status="active", group=group,
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
# 尝试解码应该抛出过期异常
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["sub"] == "user123"
|
||||
assert decoded["name"] == "Test User"
|
||||
assert decoded["roles"] == ["admin", "user"]
|
||||
assert decoded["metadata"] == {"key": "value"}
|
||||
|
||||
def test_token_invalid_signature():
|
||||
"""测试无效签名"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
# 使用错误的密钥解码
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
pyjwt.decode(result.access_token, "wrong_secret_key", algorithms=["HS256"])
|
||||
|
||||
Reference in New Issue
Block a user