feat: embed permission claims in JWT and add captcha verification

- Add GroupClaims model for JWT permission snapshots
- Add JWTPayload model for typed JWT decoding
- Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB)
- Add UserBanStore for instant ban enforcement via Redis + memory fallback
- Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE
- Shorten access_token expiry from 3h to 1h
- Add CaptchaScene enum and verify_captcha_if_needed service
- Add require_captcha dependency injection factory
- Add CLA document and new default settings
- Update all tests for new JWT API

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-10 19:07:00 +08:00
parent 209cb24ab4
commit a99091ea7a
20 changed files with 766 additions and 244 deletions

92
docs/CLA.md Normal file
View File

@@ -0,0 +1,92 @@
# DiskNext Contributor License Agreement
Thank you for your interest in contributing to the DiskNext project ("We", "Us", or "Our"). This Contributor License Agreement ("Agreement") is for our mutual protection. It clarifies the intellectual property rights You grant to Us for Your Contributions.
By signing this Agreement, You accept its terms and conditions.
## 1. The Purpose of This Agreement
The DiskNext project is developed with a dual-licensing strategy. We maintain a free, open-source community edition alongside a commercial Pro edition. This model allows Us to support a vibrant community while also funding the project's sustainable development.
To make this model work, We require broad rights to use the code You contribute. This Agreement ensures that We can include Your Contributions in all editions of DiskNext under their respective licenses. By signing this Agreement, You grant Us the rights needed to manage the project effectively, including the right to incorporate Your Contribution into Our commercial products and to transfer the project to another entity.
## 2. Definitions
**"You"** means the individual copyright owner who Submits a Contribution to Us.
**"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that you intentionally Submit to Us for inclusion in the Material.
**"Material"** means the software and documentation We make available to third parties. Your Contribution may be included in the Material.
**"Submit"** means any form of communication sent to Us (e.g., via a pull request, issue tracker, or email) that is managed by Us for the purpose of discussing and improving the Material, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
**"Copyright"** means all rights protecting works of authorship, including copyright, moral rights, and neighboring rights, for the full term of their existence.
## 3. Copyright License Grant
Subject to the terms and conditions of this Agreement, You hereby grant to Us a worldwide, royalty-free, **non-exclusive**, perpetual, and irrevocable license under the Copyright covering your Contribution. This license includes the right to sublicense and to assign Your Contribution.
This license allows Us to use, reproduce, prepare derivative works of, publicly display, publicly perform, distribute, and publish your Contribution and such derivative works in any form. This includes, without limitation, the right to sell and distribute the Contribution as part of a commercial product under a proprietary license.
You retain full ownership of the Copyright in Your Contribution. Nothing in this Agreement shall be construed to restrict or transfer Your rights to use Your own Contribution for any purpose.
## 4. Patent License Grant
You hereby grant to Us and to recipients of the Material a worldwide, royalty-free, non-exclusive, perpetual, and irrevocable patent license to make, have made, use, sell, offer for sale, import, and otherwise transfer Your Contribution. This license applies to all patents owned or controlled by You, now or in the future, that would be infringed by Your Contribution alone or in combination with the Material.
## 5. Your Representations
You represent and warrant that:
1. The Contribution is Your original work.
2. You are legally entitled to grant the licenses in this Agreement.
3. If Your employer has rights to intellectual property that You create, You have either (i) received permission from Your employer to make the Contribution on behalf of that employer, or (ii) Your employer has waived such rights for the Contribution.
4. To the best of Your knowledge, the Contribution does not violate any third-party rights, including copyright, patent, trademark, or trade secret.
You agree to notify Us of any facts or circumstances of which you become aware that would make these representations inaccurate in any respect.
## 6. Our Licensing Rights
You acknowledge that We may license the Material, including Your Contribution, under different license terms. We intend to distribute a community edition of DiskNext under a free and open-source license. We also reserve the right to distribute a Pro edition and other commercial versions of the Material, including Your Contribution, under a proprietary license at Our sole discretion.
## 7. Disclaimer of Warranty
THE CONTRIBUTION IS PROVIDED "AS IS" AND WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
## 8. Limitation of Liability
TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL YOU OR WE BE LIABLE FOR ANY LOSS OF PROFITS, LOSS OF ANTICIPATED SAVINGS, LOSS OF DATA, OR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THIS AGREEMENT, REGARDLESS OF THE LEGAL THEORY UPON WHICH THE CLAIM IS BASED.
## 9. Term
This Agreement is effective on the date You accept it and shall continue for the full term of the copyrights and patents licensed herein. This Agreement is irrevocable.
## 10. Miscellaneous
**10.1 Governing Law:** This Agreement shall be governed by the laws of the People's Republic of China, excluding its conflict of law provisions.
**10.2 Entire Agreement:** This Agreement sets out the entire agreement between You and Us for Your Contributions and supersedes all prior communications and understandings.
**10.3 Assignment:** We may assign Our rights and obligations under this Agreement at Our sole discretion. This Agreement will be binding upon and will inure to the benefit of the parties, their successors, and permitted assigns.
**10.4 Severability:** If any provision of this Agreement is found to be void or unenforceable, it will be replaced with a provision that comes closest to the meaning of the original and is enforceable.
---
## To Accept This Agreement
Please provide the following information to signify your acceptance.
### Contributor ("You"):
- **Date:**
- **Full Name:**
- **Address:**
- **Email:**
- **GitHub Username (if applicable):**
### For DiskNext ("Us"):
- **Date:**
- **[NAME]**
- **Owner of DiskNext Org**

View File

@@ -4,49 +4,79 @@ from uuid import UUID
from fastapi import Depends from fastapi import Depends
import jwt import jwt
from sqlmodels.user import User from sqlmodels.user import JWTPayload, User, UserStatus
from utils import JWT from utils import JWT
from .dependencies import SessionDep from .dependencies import SessionDep
from utils import http_exceptions from utils import http_exceptions
from service.redis import RedisManager
from service.redis.user_ban_store import UserBanStore
async def auth_required(
async def jwt_required(
session: SessionDep, session: SessionDep,
token: Annotated[str, Depends(JWT.oauth2_scheme)], token: Annotated[str, Depends(JWT.oauth2_scheme)],
) -> User: ) -> JWTPayload:
""" """
AuthRequired 需要登录 验证 JWT 并返回 claims。
封禁检查策略:
1. JWT 内嵌 status 检查(签发时快照)
2. Redis 黑名单检查(即时封禁,如果 Redis 可用)
3. Redis 不可用时查库检查 status降级方案
""" """
try: try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"]) payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("sub") claims = JWTPayload(
sub=payload["sub"],
if user_id is None: jti=payload["jti"],
http_exceptions.raise_unauthorized("账号或密码错误") status=payload["status"],
group=payload["group"],
user_id = UUID(user_id) )
except (jwt.InvalidTokenError, KeyError, ValueError):
# 从数据库获取用户信息(预加载 group 关系)
user = await User.get(session, User.id == user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("账号或密码错误")
return user
except jwt.InvalidTokenError:
http_exceptions.raise_unauthorized("凭据过期或无效") http_exceptions.raise_unauthorized("凭据过期或无效")
# 1. JWT 内嵌 status 检查
if claims.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
# 2. 即时封禁检查
user_id_str = str(claims.sub)
if RedisManager.is_available():
# Redis 可用:查黑名单
if await UserBanStore.is_banned(user_id_str):
http_exceptions.raise_forbidden("账户已被禁用")
else:
# Redis 不可用:查库(仅 status 字段,不加载关系)
user = await User.get(session, User.id == claims.sub)
if not user or user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
return claims
async def admin_required( async def admin_required(
user: Annotated[User, Depends(auth_required)], claims: Annotated[JWTPayload, Depends(jwt_required)],
) -> User: ) -> JWTPayload:
""" """
验证是否为管理员。 验证管理员权限(仅读取 JWT claims不查库
使用方法: 使用方法:
>>> APIRouter(dependencies=[Depends(admin_required)]) >>> APIRouter(dependencies=[Depends(admin_required)])
""" """
if user.group.admin: if not claims.group.admin:
return user http_exceptions.raise_forbidden("Admin Required")
raise http_exceptions.raise_forbidden("Admin Required") return claims
async def auth_required(
session: SessionDep,
claims: Annotated[JWTPayload, Depends(jwt_required)],
) -> User:
"""验证 JWT + 从数据库加载完整 User含 group 关系)"""
user = await User.get(session, User.id == claims.sub, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
return user
def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None: def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:

View File

@@ -6,12 +6,14 @@ FastAPI 依赖注入
- TimeFilterRequestDep: 时间筛选查询依赖(用于 count 等统计接口) - TimeFilterRequestDep: 时间筛选查询依赖(用于 count 等统计接口)
- TableViewRequestDep: 分页排序查询依赖(包含时间筛选 + 分页排序) - TableViewRequestDep: 分页排序查询依赖(包含时间筛选 + 分页排序)
- UserFilterParamsDep: 用户筛选参数依赖(用于管理员用户列表) - UserFilterParamsDep: 用户筛选参数依赖(用于管理员用户列表)
- require_captcha: 验证码校验依赖注入工厂
""" """
from collections.abc import Awaitable, Callable
from datetime import datetime from datetime import datetime
from typing import Annotated, Literal, TypeAlias from typing import Annotated, Literal, TypeAlias
from uuid import UUID from uuid import UUID
from fastapi import Depends, Query from fastapi import Depends, Form, Query
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.database_connection import DatabaseManager from sqlmodels.database_connection import DatabaseManager
@@ -94,3 +96,30 @@ async def _get_user_filter_params(
UserFilterParamsDep: TypeAlias = Annotated[UserFilterParams, Depends(_get_user_filter_params)] UserFilterParamsDep: TypeAlias = Annotated[UserFilterParams, Depends(_get_user_filter_params)]
"""获取用户筛选参数的依赖(用于管理员用户列表)""" """获取用户筛选参数的依赖(用于管理员用户列表)"""
# --- 验证码校验依赖 ---
def require_captcha(scene: 'CaptchaScene') -> Callable[..., Awaitable[None]]:
"""
验证码校验依赖注入工厂。
根据场景查询数据库设置,判断是否需要验证码。
需要则校验前端提交的 captcha_code失败则抛出异常。
使用方式::
@router.post('/session', dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))])
async def login(...): ...
:param scene: 验证码使用场景LOGIN / REGISTER / FORGET
"""
from service.captcha import CaptchaScene, verify_captcha_if_needed
async def _verify_captcha(
session: SessionDep,
captcha_code: Annotated[str | None, Form()] = None,
) -> None:
await verify_captcha_if_needed(session, scene, captcha_code)
return _verify_captcha

View File

@@ -10,7 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import ( from sqlmodels import (
Policy, PolicyType, User, ListResponse, JWTPayload, Policy, PolicyType, User, ListResponse,
Object, ObjectType, AdminFileResponse, FileBanRequest, ) Object, ObjectType, AdminFileResponse, FileBanRequest, )
from service.storage import LocalStorageService from service.storage import LocalStorageService
@@ -164,14 +164,13 @@ async def router_admin_preview_file(
path='/ban/{file_id}', path='/ban/{file_id}',
summary='封禁/解禁文件', summary='封禁/解禁文件',
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.', description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
dependencies=[Depends(admin_required)],
status_code=204, status_code=204,
) )
async def router_admin_ban_file( async def router_admin_ban_file(
session: SessionDep, session: SessionDep,
file_id: UUID, file_id: UUID,
request: FileBanRequest, request: FileBanRequest,
admin: Annotated[User, Depends(admin_required)], claims: Annotated[JWTPayload, Depends(admin_required)],
) -> None: ) -> None:
""" """
封禁或解禁文件/文件夹。封禁后用户无法访问该文件。 封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
@@ -180,14 +179,14 @@ async def router_admin_ban_file(
:param session: 数据库会话 :param session: 数据库会话
:param file_id: 文件UUID :param file_id: 文件UUID
:param request: 封禁请求 :param request: 封禁请求
:param admin: 当前管理员 :param claims: 当前管理员 JWT claims
:return: 封禁结果 :return: 封禁结果
""" """
file_obj = await Object.get(session, Object.id == file_id) file_obj = await Object.get(session, Object.id == file_id)
if not file_obj: if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在") raise HTTPException(status_code=404, detail="文件不存在")
count = await _set_ban_recursive(session, file_obj, request.ban, admin.id, request.reason) count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
action = "封禁" if request.ban else "解禁" action = "封禁" if request.ban else "解禁"
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象") l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")

View File

@@ -6,13 +6,14 @@ from sqlalchemy import func
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
from service.redis.user_ban_store import UserBanStore
from sqlmodels import ( from sqlmodels import (
User, ResponseBase, UserPublic, ListResponse, User, ResponseBase, UserPublic, ListResponse,
Group, Object, ObjectType, Setting, SettingsType, Group, Object, ObjectType, Setting, SettingsType,
BatchDeleteRequest, BatchDeleteRequest,
) )
from sqlmodels.user import ( from sqlmodels.user import (
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
) )
from utils import Password, http_exceptions from utils import Password, http_exceptions
@@ -159,11 +160,21 @@ async def router_admin_update_user(
if len(update_data['two_factor']) != 32: if len(update_data['two_factor']) != 32:
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串") raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
# 记录旧 status 以便检测变更
old_status = user.status
# 更新字段 # 更新字段
for key, value in update_data.items(): for key, value in update_data.items():
setattr(user, key, value) setattr(user, key, value)
user = await user.save(session) user = await user.save(session)
# 封禁状态变更 → 更新 BanStore
new_status = user.status
if old_status == UserStatus.ACTIVE and new_status != UserStatus.ACTIVE:
await UserBanStore.ban(str(user_id))
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
await UserBanStore.unban(str(user_id))
l.info(f"管理员更新了用户: {request.email}") l.info(f"管理员更新了用户: {request.email}")

View File

@@ -2,8 +2,7 @@ from typing import Annotated, Literal
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import jwt import jwt
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, Form, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from loguru import logger from loguru import logger
from webauthn import generate_registration_options from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict from webauthn.helpers import options_to_json_dict
@@ -11,7 +10,9 @@ from webauthn.helpers import options_to_json_dict
import service import service
import sqlmodels import sqlmodels
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep, require_captcha
from service.captcha import CaptchaScene
from sqlmodels.user import UserStatus
from utils import JWT, Password, http_exceptions from utils import JWT, Password, http_exceptions
from .settings import user_settings_router from .settings import user_settings_router
@@ -22,48 +23,60 @@ user_router = APIRouter(
user_router.include_router(user_settings_router) user_router.include_router(user_settings_router)
class OAuth2PasswordWithExtrasForm:
"""
扩展 OAuth2 密码表单。
在标准 username/password 基础上添加 otp_code 字段。
captcha_code 由 require_captcha 依赖注入单独处理。
"""
def __init__(
self,
*,
username: Annotated[str, Form()],
password: Annotated[str, Form()],
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
):
self.username = username
self.password = password
self.otp_code = otp_code
@user_router.post( @user_router.post(
path='/session', path='/session',
summary='用户登录', summary='用户登录',
description='User login endpoint. 当用户启用两步验证时,需要传入 otp 参数', description='用户登录端点,支持验证码校验和两步验证',
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
) )
async def router_user_session( async def router_user_session(
session: SessionDep, session: SessionDep,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
) -> sqlmodels.TokenResponse: ) -> sqlmodels.TokenResponse:
""" """
用户登录端点 用户登录端点
根据 OAuth2.1 规范,使用 password grant type 进行登录。 表单字段:
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。 - username: 用户邮箱
- password: 用户密码
- captcha_code: 验证码 token可选由 require_captcha 依赖校验)
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码 错误处理:
- 400: 需要验证码但未提供
- 401: 邮箱/密码错误,或 2FA 验证码错误
- 403: 账户已禁用 / 验证码验证失败
- 428: 需要两步验证但未提供 otp_code
""" """
email = form_data.username # OAuth2 表单字段名为 username实际传入的是 email return await service.user.login(
password = form_data.password
# 从 scopes 中提取 OTP 验证码OAuth2.1 扩展方式)
# scopes 格式可以是 ["otp:123456"] 或 ["123456"]
otp_code: str | None = None
for scope in form_data.scopes:
if scope.startswith("otp:"):
otp_code = scope[4:]
break
elif scope.isdigit() and len(scope) == 6:
otp_code = scope
break
result = await service.user.login(
session, session,
sqlmodels.LoginRequest( sqlmodels.LoginRequest(
email=email, email=form_data.username,
password=password, password=form_data.password,
two_fa_code=otp_code, two_fa_code=form_data.otp_code,
), ),
) )
return result
@user_router.post( @user_router.post(
path='/session/refresh', path='/session/refresh',
summary="用刷新令牌刷新会话", summary="用刷新令牌刷新会话",
@@ -101,17 +114,27 @@ async def router_user_session_refresh(
http_exceptions.raise_unauthorized("令牌缺少用户标识") http_exceptions.raise_unauthorized("令牌缺少用户标识")
user_id = UUID(user_id_str) user_id = UUID(user_id_str)
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id) user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id, load=sqlmodels.User.group)
if not user: if not user:
http_exceptions.raise_unauthorized("用户不存在") http_exceptions.raise_unauthorized("用户不存在")
if not user.status: if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用") http_exceptions.raise_forbidden("账户已被禁用")
# 加载 GroupOptions获取最新权限
group_options = await sqlmodels.GroupOptions.get(
session,
sqlmodels.GroupOptions.group_id == user.group_id,
)
user.group.options = group_options
group_claims = sqlmodels.GroupClaims.from_group(user.group)
# 签发新令牌 # 签发新令牌
access_token = JWT.create_access_token( access_token = JWT.create_access_token(
sub=user.id, sub=user.id,
jti=uuid4(), jti=uuid4(),
status=user.status.value,
group=group_claims,
) )
refresh_token = JWT.create_refresh_token( refresh_token = JWT.create_refresh_token(
sub=user.id, sub=user.id,

View File

@@ -1,18 +1,20 @@
import abc import abc
from enum import StrEnum
import aiohttp import aiohttp
from loguru import logger as l
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from .gcaptcha import GCaptcha
from .turnstile import TurnstileCaptcha
class CaptchaRequestBase(BaseModel): class CaptchaRequestBase(BaseModel):
"""验证码验证请求""" """验证码验证请求"""
token: str
"""验证 token""" response: str
"""用户的验证码 response token"""
secret: str secret: str
"""验证密钥""" """服务端密钥"""
class CaptchaBase(abc.ABC): class CaptchaBase(abc.ABC):
@@ -30,10 +32,89 @@ class CaptchaBase(abc.ABC):
""" """
payload = request.model_dump() payload = request.model_dump()
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as client_session:
async with session.post(self.verify_url, data=payload) as response: async with client_session.post(self.verify_url, data=payload) as resp:
if response.status != 200: if resp.status != 200:
return False return False
result = await response.json() result = await resp.json()
return result.get('success', False) return result.get('success', False)
# 子类导入必须在 CaptchaBase 定义之后gcaptcha.py / turnstile.py 依赖 CaptchaBase
from .gcaptcha import GCaptcha # noqa: E402
from .turnstile import TurnstileCaptcha # noqa: E402
class CaptchaScene(StrEnum):
"""验证码使用场景value 对应 Setting 表中的 name"""
LOGIN = "login_captcha"
REGISTER = "reg_captcha"
FORGET = "forget_captcha"
async def verify_captcha_if_needed(
session: AsyncSession,
scene: CaptchaScene,
captcha_code: str | None,
) -> None:
"""
通用验证码校验:查询设置判断是否需要,需要则校验。
:param session: 数据库异步会话
:param scene: 验证码使用场景
:param captcha_code: 用户提交的验证码 response token
:raises HTTPException 400: 需要验证码但未提供
:raises HTTPException 403: 验证码验证失败
:raises HTTPException 500: 验证码密钥未配置
"""
from sqlmodels import Setting, SettingsType
from sqlmodels.setting import CaptchaType
from utils import http_exceptions
# 1. 查询该场景是否需要验证码
scene_setting = await Setting.get(
session,
(Setting.type == SettingsType.LOGIN) & (Setting.name == scene.value),
)
if not scene_setting or scene_setting.value != "1":
return
# 2. 需要但未提供
if not captcha_code:
http_exceptions.raise_bad_request(detail="请完成验证码验证")
# 3. 查询验证码类型和密钥
captcha_settings: list[Setting] = await Setting.get(
session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all",
)
s: dict[str, str | None] = {item.name: item.value for item in captcha_settings}
captcha_type = CaptchaType(s.get("captcha_type") or "default")
# 4. DEFAULT 图片验证码尚未实现,跳过
if captcha_type == CaptchaType.DEFAULT:
l.warning("DEFAULT 图片验证码尚未实现,跳过验证")
return
# 5. 选择验证器和密钥
if captcha_type == CaptchaType.GCAPTCHA:
secret = s.get("captcha_ReCaptchaSecret")
verifier: CaptchaBase = GCaptcha()
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
secret = s.get("captcha_CloudflareSecret")
verifier = TurnstileCaptcha()
else:
l.error(f"未知的验证码类型: {captcha_type}")
http_exceptions.raise_internal_error()
if not secret:
l.error(f"验证码密钥未配置: captcha_type={captcha_type}")
http_exceptions.raise_internal_error()
# 6. 调用第三方 API 校验
is_valid = await verifier.verify_captcha(
CaptchaRequestBase(response=captcha_code, secret=secret)
)
if not is_valid:
http_exceptions.raise_forbidden(detail="验证码验证失败")

View File

@@ -0,0 +1,72 @@
"""
用户封禁状态存储
用于 JWT 模式下的即时封禁生效。
支持 Redis首选和内存缓存降级两种存储后端。
"""
from typing import ClassVar
from cachetools import TTLCache
from loguru import logger as l
from . import RedisManager
# access_token 有效期(秒)
_BAN_TTL: int = 3600
class UserBanStore:
"""
用户封禁状态存储
管理员封禁用户时调用 ban()jwt_required 每次请求调用 is_banned() 检查。
TTL 与 access_token 有效期一致1h过期后旧 token 自然失效,无需继续记录。
"""
_memory_cache: ClassVar[TTLCache[str, bool]] = TTLCache(maxsize=10000, ttl=_BAN_TTL)
"""内存缓存降级方案"""
@classmethod
async def ban(cls, user_id: str) -> None:
"""
标记用户为已封禁。
:param user_id: 用户 UUID 字符串
"""
client = RedisManager.get_client()
if client is not None:
key = f"user_ban:{user_id}"
await client.set(key, "1", ex=_BAN_TTL)
else:
cls._memory_cache[user_id] = True
l.info(f"用户 {user_id} 已加入封禁黑名单")
@classmethod
async def unban(cls, user_id: str) -> None:
"""
移除用户封禁标记(解封时调用)。
:param user_id: 用户 UUID 字符串
"""
client = RedisManager.get_client()
if client is not None:
key = f"user_ban:{user_id}"
await client.delete(key)
else:
cls._memory_cache.pop(user_id, None)
l.info(f"用户 {user_id} 已从封禁黑名单移除")
@classmethod
async def is_banned(cls, user_id: str) -> bool:
"""
检查用户是否在封禁黑名单中。
:param user_id: 用户 UUID 字符串
:return: True 表示已封禁
"""
client = RedisManager.get_client()
if client is not None:
key = f"user_ban:{user_id}"
return await client.exists(key) > 0
else:
return user_id in cls._memory_cache

View File

@@ -4,6 +4,8 @@ from loguru import logger
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from sqlmodels import LoginRequest, TokenResponse, User from sqlmodels import LoginRequest, TokenResponse, User
from sqlmodels.group import GroupClaims, GroupOptions
from sqlmodels.user import UserStatus
from utils import http_exceptions from utils import http_exceptions
from utils.JWT import create_access_token, create_refresh_token from utils.JWT import create_access_token, create_refresh_token
from utils.password.pwd import Password, PasswordStatus from utils.password.pwd import Password, PasswordStatus
@@ -22,15 +24,13 @@ async def login(
:return: TokenResponse 对象或状态码或 None :return: TokenResponse 对象或状态码或 None
""" """
# TODO: 验证码校验 # 获取用户信息(预加载 group 关系)
# captcha_setting = await Setting.get( current_user: User = await User.get(
# session, session,
# (Setting.type == "auth") & (Setting.name == "login_captcha") User.email == login_request.email,
# ) fetch_mode="first",
# is_captcha_required = captcha_setting and captcha_setting.value == "1" load=User.group,
) #type: ignore
# 获取用户信息
current_user: User = await User.get(session, User.email == login_request.email, fetch_mode="first") #type: ignore
# 验证用户是否存在 # 验证用户是否存在
if not current_user: if not current_user:
@@ -42,8 +42,8 @@ async def login(
logger.debug(f"Password verification failed for user: {login_request.email}") logger.debug(f"Password verification failed for user: {login_request.email}")
http_exceptions.raise_unauthorized("Invalid email or password") http_exceptions.raise_unauthorized("Invalid email or password")
# 验证用户是否可登录 # 验证用户是否可登录修复显式枚举比较StrEnum 永远 truthy
if not current_user.status: if current_user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("Your account is disabled") http_exceptions.raise_forbidden("Your account is disabled")
# 检查两步验证 # 检查两步验证
@@ -58,10 +58,22 @@ async def login(
logger.debug(f"Invalid 2FA code for user: {login_request.email}") logger.debug(f"Invalid 2FA code for user: {login_request.email}")
http_exceptions.raise_unauthorized("Invalid 2FA code") http_exceptions.raise_unauthorized("Invalid 2FA code")
# 加载 GroupOptions
group_options: GroupOptions | None = await GroupOptions.get(
session,
GroupOptions.group_id == current_user.group_id,
)
# 构建权限快照
current_user.group.options = group_options
group_claims = GroupClaims.from_group(current_user.group)
# 创建令牌 # 创建令牌
access_token = create_access_token( access_token = create_access_token(
sub=current_user.id, sub=current_user.id,
jti=uuid4() jti=uuid4(),
status=current_user.status.value,
group=group_claims,
) )
refresh_token = create_refresh_token( refresh_token = create_refresh_token(
sub=current_user.id, sub=current_user.id,

View File

@@ -1,5 +1,6 @@
from .user import ( from .user import (
BatchDeleteRequest, BatchDeleteRequest,
JWTPayload,
LoginRequest, LoginRequest,
RefreshTokenRequest, RefreshTokenRequest,
RegisterRequest, RegisterRequest,
@@ -37,7 +38,7 @@ from .node import (
NodeType, NodeType,
) )
from .group import ( from .group import (
Group, GroupBase, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse, Group, GroupBase, GroupClaims, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
# 管理员DTO # 管理员DTO
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse, GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse,
) )

View File

@@ -188,6 +188,28 @@ class GroupListResponse(SQLModelBase):
"""总数""" """总数"""
class GroupClaims(GroupCoreBase, GroupAllOptionsBase):
"""
JWT 中的用户组权限快照。
复用 GroupCoreBaseid, name, max_storage, share_enabled, web_dav_enabled, admin, speed_limit
和 GroupAllOptionsBaseshare_download, share_free, ... 共 11 个功能开关)。
"""
@classmethod
def from_group(cls, group: "Group") -> "GroupClaims":
"""
从 Group ORM 对象(需预加载 options 关系)构建权限快照。
:param group: 已加载 options 的 Group 对象
"""
opts = group.options
return cls(
**GroupCoreBase.model_validate(group, from_attributes=True).model_dump(),
**(GroupAllOptionsBase.model_validate(opts, from_attributes=True).model_dump() if opts else {}),
)
class GroupResponse(GroupBase, GroupOptionsBase): class GroupResponse(GroupBase, GroupOptionsBase):
"""用户组响应 DTO""" """用户组响应 DTO"""

View File

@@ -29,6 +29,10 @@ default_settings: list[Setting] = [
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC), Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC), Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC), Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
Setting(name="site_notice", value="", type=SettingsType.BASIC),
Setting(name="footer_code", value="", type=SettingsType.BASIC),
Setting(name="tos_url", value="", type=SettingsType.BASIC),
Setting(name="privacy_url", value="", type=SettingsType.BASIC),
Setting(name="fromName", value="DiskNext", type=SettingsType.MAIL), Setting(name="fromName", value="DiskNext", type=SettingsType.MAIL),
Setting(name="mail_keepalive", value="30", type=SettingsType.MAIL), Setting(name="mail_keepalive", value="30", type=SettingsType.MAIL),
Setting(name="fromAdress", value="no-reply@yxqi.cn", type=SettingsType.MAIL), Setting(name="fromAdress", value="no-reply@yxqi.cn", type=SettingsType.MAIL),

View File

@@ -99,7 +99,7 @@ class LoginRequest(SQLModelBase):
captcha: str | None = None captcha: str | None = None
"""验证码""" """验证码"""
two_fa_code: int | None = Field(min_length=6, max_length=6) two_fa_code: int | None = Field(default=None, min_length=6, max_length=6)
"""两步验证代码""" """两步验证代码"""
@@ -151,6 +151,22 @@ class WebAuthnInfo(SQLModelBase):
transports: list[str] transports: list[str]
"""支持的传输方式""" """支持的传输方式"""
class JWTPayload(SQLModelBase):
"""JWT 访问令牌解析后的 claims"""
sub: UUID
"""用户 ID"""
jti: UUID
"""令牌唯一标识符"""
status: UserStatus
"""用户状态"""
group: "GroupClaims"
"""用户组权限快照"""
class AccessTokenBase(BaseModel): class AccessTokenBase(BaseModel):
"""访问令牌响应 DTO""" """访问令牌响应 DTO"""
@@ -370,10 +386,11 @@ class UserAdminDetailResponse(UserPublic):
# 前向引用导入 # 前向引用导入
from .group import GroupResponse # noqa: E402 from .group import GroupClaims, GroupResponse # noqa: E402
from .user_authn import AuthnResponse # noqa: E402 from .user_authn import AuthnResponse # noqa: E402
# 更新前向引用 # 更新前向引用
JWTPayload.model_rebuild()
UserResponse.model_rebuild() UserResponse.model_rebuild()
UserSettingResponse.model_rebuild() UserSettingResponse.model_rebuild()

View File

@@ -24,12 +24,12 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')
from main import app from main import app
from sqlmodels.database import get_session from sqlmodels.database import get_session
from sqlmodels.group import Group, GroupOptions from sqlmodels.group import Group, GroupClaims, GroupOptions
from sqlmodels.migration import migration from sqlmodels.migration import migration
from sqlmodels.object import Object, ObjectType from sqlmodels.object import Object, ObjectType
from sqlmodels.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
from sqlmodels.user import User from sqlmodels.user import User, UserStatus
from utils.JWT.JWT import create_access_token from utils.JWT import create_access_token
from utils.password.pwd import Password from utils.password.pwd import Password
@@ -193,7 +193,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
email="testuser@test.local", email="testuser@test.local",
nickname="测试用户", nickname="测试用户",
password=Password.hash(password), password=Password.hash(password),
status=True, status=UserStatus.ACTIVE,
storage=0, storage=0,
score=100, score=100,
group_id=group.id, group_id=group.id,
@@ -211,14 +211,24 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
) )
await root_folder.save(db_session) await root_folder.save(db_session)
# 构建权限快照
group.options = group_options
group_claims = GroupClaims.from_group(group)
# 生成访问令牌 # 生成访问令牌
access_token, _ = create_access_token({"sub": str(user.id)}) from uuid import uuid4
access_token_obj = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
)
return { return {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"password": password, "password": password,
"token": access_token, "token": access_token_obj.access_token,
"group_id": group.id, "group_id": group.id,
"policy_id": policy.id, "policy_id": policy.id,
} }
@@ -270,7 +280,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
email="admin@disknext.local", email="admin@disknext.local",
nickname="管理员", nickname="管理员",
password=Password.hash(password), password=Password.hash(password),
status=True, status=UserStatus.ACTIVE,
storage=0, storage=0,
score=9999, score=9999,
group_id=admin_group.id, group_id=admin_group.id,
@@ -288,14 +298,24 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
) )
await root_folder.save(db_session) await root_folder.save(db_session)
# 构建权限快照
admin_group.options = admin_group_options
admin_group_claims = GroupClaims.from_group(admin_group)
# 生成访问令牌 # 生成访问令牌
access_token, _ = create_access_token({"sub": str(admin.id)}) from uuid import uuid4
access_token_obj = create_access_token(
sub=admin.id,
jti=uuid4(),
status=admin.status.value,
group=admin_group_claims,
)
return { return {
"id": admin.id, "id": admin.id,
"email": admin.email, "email": admin.email,
"password": password, "password": password,
"token": access_token, "token": access_token_obj.access_token,
"group_id": admin_group.id, "group_id": admin_group.id,
"policy_id": policy.id, "policy_id": policy.id,
} }

View File

@@ -22,10 +22,11 @@ from sqlalchemy.orm import sessionmaker
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from main import app from main import app
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels.user import UserStatus
from utils import Password from utils import Password
from utils.JWT import create_access_token from utils.JWT import create_access_token
from utils.JWT import JWT import utils.JWT as JWT
# ==================== 事件循环配置 ==================== # ==================== 事件循环配置 ====================
@@ -184,12 +185,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="testuser@test.local", email="testuser@test.local",
password=Password.hash("testpass123"), password=Password.hash("testpass123"),
nickname="测试用户", nickname="测试用户",
status=True, status=UserStatus.ACTIVE,
storage=0, storage=0,
score=0, score=0,
group_id=default_group.id, group_id=default_group.id,
avatar="default", avatar="default",
theme="system",
) )
test_session.add(test_user) test_session.add(test_user)
@@ -198,12 +198,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="admin@disknext.local", email="admin@disknext.local",
password=Password.hash("adminpass123"), password=Password.hash("adminpass123"),
nickname="管理员", nickname="管理员",
status=True, status=UserStatus.ACTIVE,
storage=0, storage=0,
score=0, score=0,
group_id=admin_group.id, group_id=admin_group.id,
avatar="default", avatar="default",
theme="system",
) )
test_session.add(admin_user) test_session.add(admin_user)
@@ -212,12 +211,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
email="banneduser@test.local", email="banneduser@test.local",
password=Password.hash("banned123"), password=Password.hash("banned123"),
nickname="封禁用户", nickname="封禁用户",
status=False, # 封禁状态 status=UserStatus.ADMIN_BANNED,
storage=0, storage=0,
score=0, score=0,
group_id=default_group.id, group_id=default_group.id,
avatar="default", avatar="default",
theme="system",
) )
test_session.add(banned_user) test_session.add(banned_user)
@@ -256,6 +254,10 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
# 8. 设置JWT密钥从数据库加载 # 8. 设置JWT密钥从数据库加载
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation" JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
# 刷新 group options
await test_session.refresh(default_group_options)
await test_session.refresh(admin_group_options)
return test_session return test_session
@@ -290,34 +292,68 @@ def banned_user_info() -> dict[str, str]:
# ==================== JWT Token ==================== # ==================== JWT Token ====================
@pytest.fixture def _build_group_claims(group: Group, group_options: GroupOptions | None) -> GroupClaims:
def test_user_token(test_user_info: dict[str, str]) -> str: """从 Group 对象构建 GroupClaims"""
group.options = group_options
return GroupClaims.from_group(group)
@pytest_asyncio.fixture
async def test_user_token(initialized_db: AsyncSession) -> str:
"""生成测试用户的JWT token""" """生成测试用户的JWT token"""
token, _ = JWT.create_access_token( user = await User.get(initialized_db, User.email == "testuser@test.local")
data={"sub": test_user_info["email"]}, group = await Group.get(initialized_db, Group.id == user.group_id)
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
group_claims = _build_group_claims(group, group_options)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
expires_delta=timedelta(hours=1), expires_delta=timedelta(hours=1),
) )
return token return result.access_token
@pytest.fixture @pytest_asyncio.fixture
def admin_user_token(admin_user_info: dict[str, str]) -> str: async def admin_user_token(initialized_db: AsyncSession) -> str:
"""生成管理员的JWT token""" """生成管理员的JWT token"""
token, _ = JWT.create_access_token( user = await User.get(initialized_db, User.email == "admin@disknext.local")
data={"sub": admin_user_info["email"]}, group = await Group.get(initialized_db, Group.id == user.group_id)
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
group_claims = _build_group_claims(group, group_options)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
expires_delta=timedelta(hours=1), expires_delta=timedelta(hours=1),
) )
return token return result.access_token
@pytest.fixture @pytest.fixture
def expired_token() -> str: def expired_token() -> str:
"""生成过期的JWT token""" """生成过期的JWT token"""
token, _ = JWT.create_access_token( group_claims = GroupClaims(
data={"sub": "testuser@test.local"}, id=uuid4(),
expires_delta=timedelta(seconds=-1), # 已过期 name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
) )
return token result = create_access_token(
sub=uuid4(),
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(seconds=-1),
)
return result.access_token
# ==================== 认证头 ==================== # ==================== 认证头 ====================

View File

@@ -1,11 +1,15 @@
""" """
认证中间件集成测试 认证中间件集成测试
""" """
from datetime import timedelta
from uuid import uuid4
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from datetime import timedelta
from utils.JWT import JWT from sqlmodels.group import GroupClaims
from utils.JWT import create_access_token, create_refresh_token
import utils.JWT as JWT
# ==================== AuthRequired 测试 ==================== # ==================== AuthRequired 测试 ====================
@@ -66,11 +70,14 @@ async def test_auth_required_valid_token(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auth_required_token_without_sub(async_client: AsyncClient): async def test_auth_required_token_without_sub(async_client: AsyncClient):
"""测试缺少sub字段的token返回 401""" """测试缺少必要字段的token返回 401"""
token, _ = JWT.create_access_token( import jwt as pyjwt
data={"other_field": "value"}, # 手动构建一个缺少 status 和 group 的 token
expires_delta=timedelta(hours=1) payload = {
) "other_field": "value",
"exp": int((__import__('datetime').datetime.now(__import__('datetime').timezone.utc) + timedelta(hours=1)).timestamp()),
}
token = pyjwt.encode(payload, JWT.SECRET_KEY, algorithm="HS256")
response = await async_client.get( response = await async_client.get(
"/api/user/me", "/api/user/me",
@@ -81,16 +88,29 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient): async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
"""测试用户不存在的token返回 401""" """测试用户不存在的token返回 403 或 401取决于 Redis 可用性)"""
token, _ = JWT.create_access_token( group_claims = GroupClaims(
data={"sub": "nonexistent_user@test.local"}, id=uuid4(),
expires_delta=timedelta(hours=1) name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
result = create_access_token(
sub=uuid4(), # 不存在的用户 UUID
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(hours=1),
) )
response = await async_client.get( response = await async_client.get(
"/api/user/me", "/api/user/me",
headers={"Authorization": f"Bearer {token}"} headers={"Authorization": f"Bearer {result.access_token}"}
) )
# auth_required 会查库,用户不存在时返回 401
assert response.status_code == 401 assert response.status_code == 401
@@ -234,23 +254,36 @@ async def test_auth_on_storage_endpoint(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_token_format(test_user_info: dict[str, str]): async def test_refresh_token_format(test_user_info: dict[str, str]):
"""测试刷新token格式正确""" """测试刷新token格式正确"""
refresh_token, _ = JWT.create_refresh_token( result = create_refresh_token(
data={"sub": test_user_info["email"]}, sub=uuid4(),
expires_delta=timedelta(days=7) jti=uuid4(),
expires_delta=timedelta(days=7),
) )
assert isinstance(refresh_token, str) assert isinstance(result.refresh_token, str)
assert len(refresh_token) > 0 assert len(result.refresh_token) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_access_token_format(test_user_info: dict[str, str]): async def test_access_token_format(test_user_info: dict[str, str]):
"""测试访问token格式正确""" """测试访问token格式正确"""
access_token, expires = JWT.create_access_token( group_claims = GroupClaims(
data={"sub": test_user_info["email"]}, id=uuid4(),
expires_delta=timedelta(hours=1) name="测试组",
max_storage=0,
share_enabled=False,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
result = create_access_token(
sub=uuid4(),
jti=uuid4(),
status="active",
group=group_claims,
expires_delta=timedelta(hours=1),
) )
assert isinstance(access_token, str) assert isinstance(result.access_token, str)
assert len(access_token) > 0 assert len(result.access_token) > 0
assert expires is not None assert result.access_expires is not None

View File

@@ -5,7 +5,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User, ThemeType, UserPublic from sqlmodels.user import User, ThemeType, UserPublic, UserStatus
from sqlmodels.group import Group from sqlmodels.group import Group
@@ -28,7 +28,7 @@ async def test_user_create(db_session: AsyncSession):
assert user.id is not None assert user.id is not None
assert user.email == "testuser@test.local" assert user.email == "testuser@test.local"
assert user.nickname == "测试用户" assert user.nickname == "测试用户"
assert user.status is True assert user.status == UserStatus.ACTIVE
assert user.storage == 0 assert user.storage == 0
assert user.score == 0 assert user.score == 0
@@ -131,7 +131,7 @@ async def test_user_status_default(db_session: AsyncSession):
) )
user = await user.save(db_session) user = await user.save(db_session)
assert user.status is True assert user.status == UserStatus.ACTIVE
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -4,7 +4,7 @@ Login 服务的单元测试
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User, LoginRequest, TokenResponse from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
from sqlmodels.group import Group from sqlmodels.group import Group
from service.user.login import login from service.user.login import login
from utils.password.pwd import Password from utils.password.pwd import Password
@@ -22,7 +22,7 @@ async def setup_user(db_session: AsyncSession):
user = User( user = User(
email="loginuser@test.local", email="loginuser@test.local",
password=Password.hash(plain_password), password=Password.hash(plain_password),
status=True, status=UserStatus.ACTIVE,
group_id=group.id group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
@@ -43,7 +43,7 @@ async def setup_banned_user(db_session: AsyncSession):
user = User( user = User(
email="banneduser@test.local", email="banneduser@test.local",
password=Password.hash("password"), password=Password.hash("password"),
status=False, # 封禁状态 status=UserStatus.ADMIN_BANNED, # 封禁状态
group_id=group.id group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
@@ -63,7 +63,7 @@ async def setup_2fa_user(db_session: AsyncSession):
user = User( user = User(
email="2fauser@test.local", email="2fauser@test.local",
password=Password.hash("password"), password=Password.hash("password"),
status=True, status=UserStatus.ACTIVE,
two_factor=secret, two_factor=secret,
group_id=group.id group_id=group.id
) )

View File

@@ -1,49 +1,86 @@
""" """
JWT 工具的单元测试 JWT 工具的单元测试
""" """
import time
from datetime import timedelta, datetime, timezone from datetime import timedelta, datetime, timezone
from uuid import uuid4, UUID
import jwt as pyjwt import jwt as pyjwt
import pytest import pytest
from utils.JWT.JWT import create_access_token, create_refresh_token, SECRET_KEY from sqlmodels.group import GroupClaims
from utils.JWT import create_access_token, create_refresh_token, build_token_payload
# 测试用的 GroupClaims
def _make_group_claims(admin: bool = False) -> GroupClaims:
return GroupClaims(
id=uuid4(),
name="测试组",
max_storage=1073741824,
share_enabled=True,
web_dav_enabled=False,
admin=admin,
speed_limit=0,
)
# 设置测试用的密钥 # 设置测试用的密钥
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_secret_key(): def setup_secret_key():
"""为测试设置密钥""" """为测试设置密钥"""
import utils.JWT.JWT as jwt_module import utils.JWT as jwt_module
jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests" jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests"
yield yield
# 测试后恢复(虽然在单元测试中不太重要)
def test_create_access_token(): def test_create_access_token():
"""测试访问令牌创建""" """测试访问令牌创建"""
data = {"sub": "testuser", "role": "user"} sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, expire_time = create_access_token(data) result = create_access_token(sub=sub, jti=jti, status="active", group=group)
assert isinstance(token, str) assert isinstance(result.access_token, str)
assert isinstance(expire_time, datetime) assert isinstance(result.access_expires, datetime)
# 解码验证 # 解码验证
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"]) decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "testuser" assert decoded["sub"] == str(sub)
assert decoded["role"] == "user" assert decoded["jti"] == str(jti)
assert decoded["status"] == "active"
assert decoded["group"]["admin"] is False
assert "exp" in decoded assert "exp" in decoded
def test_create_access_token_custom_expiry(): def test_create_access_token_custom_expiry():
"""测试自定义过期时间""" """测试自定义过期时间"""
data = {"sub": "testuser"} sub = uuid4()
custom_expiry = timedelta(hours=1) jti = uuid4()
group = _make_group_claims()
custom_expiry = timedelta(minutes=30)
token, expire_time = create_access_token(data, expires_delta=custom_expiry) result = create_access_token(sub=sub, jti=jti, status="active", group=group, expires_delta=custom_expiry)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"]) decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30分钟后
exp_timestamp = decoded["exp"]
now_timestamp = datetime.now(timezone.utc).timestamp()
# 允许1秒误差
assert abs(exp_timestamp - now_timestamp - 1800) < 1
def test_create_access_token_default_expiry():
"""测试访问令牌默认1小时过期"""
sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是1小时后 # 验证过期时间大约是1小时后
exp_timestamp = decoded["exp"] exp_timestamp = decoded["exp"]
@@ -55,27 +92,29 @@ def test_create_access_token_custom_expiry():
def test_create_refresh_token(): def test_create_refresh_token():
"""测试刷新令牌创建""" """测试刷新令牌创建"""
data = {"sub": "testuser"} sub = uuid4()
jti = uuid4()
token, expire_time = create_refresh_token(data) result = create_refresh_token(sub=sub, jti=jti)
assert isinstance(token, str) assert isinstance(result.refresh_token, str)
assert isinstance(expire_time, datetime) assert isinstance(result.refresh_expires, datetime)
# 解码验证 # 解码验证
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"]) decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "testuser" assert decoded["sub"] == str(sub)
assert decoded["token_type"] == "refresh" assert decoded["token_type"] == "refresh"
assert "exp" in decoded assert "exp" in decoded
def test_create_refresh_token_default_expiry(): def test_create_refresh_token_default_expiry():
"""测试刷新令牌默认30天过期""" """测试刷新令牌默认30天过期"""
data = {"sub": "testuser"} sub = uuid4()
jti = uuid4()
token, expire_time = create_refresh_token(data) result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"]) decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
# 验证过期时间大约是30天后 # 验证过期时间大约是30天后
exp_timestamp = decoded["exp"] exp_timestamp = decoded["exp"]
@@ -86,78 +125,72 @@ def test_create_refresh_token_default_expiry():
assert abs(exp_timestamp - now_timestamp - 2592000) < 1 assert abs(exp_timestamp - now_timestamp - 2592000) < 1
def test_token_decode(): def test_access_token_contains_group_claims():
"""测试令牌解码""" """测试访问令牌包含完整的 group claims"""
data = {"sub": "user123", "email": "user@example.com"} sub = uuid4()
jti = uuid4()
group = _make_group_claims(admin=True)
token, _ = create_access_token(data) result = create_access_token(sub=sub, jti=jti, status="active", group=group)
# 解码 decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "user123" assert decoded["group"]["admin"] is True
assert decoded["email"] == "user@example.com" assert decoded["group"]["name"] == "测试组"
assert decoded["group"]["max_storage"] == 1073741824
assert decoded["group"]["share_enabled"] is True
def test_token_expired():
"""测试令牌过期"""
data = {"sub": "testuser"}
# 创建一个立即过期的令牌
token, _ = create_access_token(data, expires_delta=timedelta(seconds=-1))
# 尝试解码应该抛出过期异常
with pytest.raises(pyjwt.ExpiredSignatureError):
pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
def test_token_invalid_signature():
"""测试无效签名"""
data = {"sub": "testuser"}
token, _ = create_access_token(data)
# 使用错误的密钥解码
with pytest.raises(pyjwt.InvalidSignatureError):
pyjwt.decode(token, "wrong_secret_key", algorithms=["HS256"])
def test_access_token_does_not_have_token_type(): def test_access_token_does_not_have_token_type():
"""测试访问令牌不包含 token_type""" """测试访问令牌不包含 token_type"""
data = {"sub": "testuser"} sub = uuid4()
jti = uuid4()
group = _make_group_claims()
token, _ = create_access_token(data) result = create_access_token(sub=sub, jti=jti, status="active", group=group)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"]) decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert "token_type" not in decoded assert "token_type" not in decoded
def test_refresh_token_has_token_type(): def test_refresh_token_has_token_type():
"""测试刷新令牌包含 token_type""" """测试刷新令牌包含 token_type"""
data = {"sub": "testuser"} sub = uuid4()
jti = uuid4()
token, _ = create_refresh_token(data) result = create_refresh_token(sub=sub, jti=jti)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"]) decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["token_type"] == "refresh" assert decoded["token_type"] == "refresh"
def test_token_payload_preserved(): def test_token_expired():
"""测试自定义负载保留""" """测试令牌过期"""
data = { sub = uuid4()
"sub": "user123", jti = uuid4()
"name": "Test User", group = _make_group_claims()
"roles": ["admin", "user"],
"metadata": {"key": "value"}
}
token, _ = create_access_token(data) # 创建一个立即过期的令牌
result = create_access_token(
sub=sub, jti=jti, status="active", group=group,
expires_delta=timedelta(seconds=-1),
)
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"]) # 尝试解码应该抛出过期异常
with pytest.raises(pyjwt.ExpiredSignatureError):
pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
assert decoded["sub"] == "user123"
assert decoded["name"] == "Test User" def test_token_invalid_signature():
assert decoded["roles"] == ["admin", "user"] """测试无效签名"""
assert decoded["metadata"] == {"key": "value"} sub = uuid4()
jti = uuid4()
group = _make_group_claims()
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
# 使用错误的密钥解码
with pytest.raises(pyjwt.InvalidSignatureError):
pyjwt.decode(result.access_token, "wrong_secret_key", algorithms=["HS256"])

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import jwt import jwt
@@ -6,6 +7,9 @@ from fastapi.security import OAuth2PasswordBearer
from sqlmodels import AccessTokenBase, RefreshTokenBase, TokenResponse from sqlmodels import AccessTokenBase, RefreshTokenBase, TokenResponse
if TYPE_CHECKING:
from sqlmodels.group import GroupClaims
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌', scheme_name='获取 JWT Bearer 令牌',
description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
@@ -59,7 +63,7 @@ def build_token_payload(
elif is_refresh: elif is_refresh:
expire = datetime.now(timezone.utc) + timedelta(days=30) expire = datetime.now(timezone.utc) + timedelta(days=30)
else: else:
expire = datetime.now(timezone.utc) + timedelta(hours=3) expire = datetime.now(timezone.utc) + timedelta(hours=1)
to_encode.update({ to_encode.update({
"iat": int(datetime.now(timezone.utc).timestamp()), "iat": int(datetime.now(timezone.utc).timestamp()),
"exp": int(expire.timestamp()) "exp": int(expire.timestamp())
@@ -71,33 +75,36 @@ def build_token_payload(
def create_access_token( def create_access_token(
sub: UUID, sub: UUID,
jti: UUID, jti: UUID,
*,
status: str,
group: "GroupClaims",
expires_delta: timedelta | None = None, expires_delta: timedelta | None = None,
algorithm: str = "HS256", algorithm: str = "HS256",
**kwargs
) -> AccessTokenBase: ) -> AccessTokenBase:
""" """
生成访问令牌,默认有效期 3 小时。 生成访问令牌,默认有效期 1 小时。
:param sub: 令牌的主题,通常是用户 ID。 :param sub: 令牌的主题,通常是用户 ID。
:param jti: 令牌的唯一标识符,通常是一个 UUID。 :param jti: 令牌的唯一标识符,通常是一个 UUID。
:param expires_delta: 过期时间, 缺省时为 3 小时 :param status: 用户状态字符串
:param group: 用户组权限快照。
:param expires_delta: 过期时间, 缺省时为 1 小时。
:param algorithm: JWT 密钥强度,缺省时为 HS256 :param algorithm: JWT 密钥强度,缺省时为 HS256
:param kwargs: 需要放进 JWT Payload 的字段。
:return: 包含密钥本身和过期时间的 `AccessTokenBase` :return: 包含密钥本身和过期时间的 `AccessTokenBase`
""" """
data = {
data = {"sub": str(sub), "jti": str(jti)} "sub": str(sub),
"jti": str(jti),
# 将额外的字段添加到 Payload 中 "status": status,
for key, value in kwargs.items(): "group": group.model_dump(mode="json"),
data[key] = value }
access_token, expire_at = build_token_payload( access_token, expire_at = build_token_payload(
data, data,
False, False,
algorithm, algorithm,
expires_delta expires_delta,
) )
return AccessTokenBase( return AccessTokenBase(
access_token=access_token, access_token=access_token,