feat: embed permission claims in JWT and add captcha verification
- Add GroupClaims model for JWT permission snapshots - Add JWTPayload model for typed JWT decoding - Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB) - Add UserBanStore for instant ban enforcement via Redis + memory fallback - Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE - Shorten access_token expiry from 3h to 1h - Add CaptchaScene enum and verify_captcha_if_needed service - Add require_captcha dependency injection factory - Add CLA document and new default settings - Update all tests for new JWT API Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
92
docs/CLA.md
Normal file
92
docs/CLA.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# DiskNext Contributor License Agreement
|
||||
|
||||
Thank you for your interest in contributing to the DiskNext project ("We", "Us", or "Our"). This Contributor License Agreement ("Agreement") is for our mutual protection. It clarifies the intellectual property rights You grant to Us for Your Contributions.
|
||||
|
||||
By signing this Agreement, You accept its terms and conditions.
|
||||
|
||||
## 1. The Purpose of This Agreement
|
||||
|
||||
The DiskNext project is developed with a dual-licensing strategy. We maintain a free, open-source community edition alongside a commercial Pro edition. This model allows Us to support a vibrant community while also funding the project's sustainable development.
|
||||
|
||||
To make this model work, We require broad rights to use the code You contribute. This Agreement ensures that We can include Your Contributions in all editions of DiskNext under their respective licenses. By signing this Agreement, You grant Us the rights needed to manage the project effectively, including the right to incorporate Your Contribution into Our commercial products and to transfer the project to another entity.
|
||||
|
||||
## 2. Definitions
|
||||
|
||||
**"You"** means the individual copyright owner who Submits a Contribution to Us.
|
||||
|
||||
**"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that you intentionally Submit to Us for inclusion in the Material.
|
||||
|
||||
**"Material"** means the software and documentation We make available to third parties. Your Contribution may be included in the Material.
|
||||
|
||||
**"Submit"** means any form of communication sent to Us (e.g., via a pull request, issue tracker, or email) that is managed by Us for the purpose of discussing and improving the Material, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
|
||||
**"Copyright"** means all rights protecting works of authorship, including copyright, moral rights, and neighboring rights, for the full term of their existence.
|
||||
|
||||
## 3. Copyright License Grant
|
||||
|
||||
Subject to the terms and conditions of this Agreement, You hereby grant to Us a worldwide, royalty-free, **non-exclusive**, perpetual, and irrevocable license under the Copyright covering your Contribution. This license includes the right to sublicense and to assign Your Contribution.
|
||||
|
||||
This license allows Us to use, reproduce, prepare derivative works of, publicly display, publicly perform, distribute, and publish your Contribution and such derivative works in any form. This includes, without limitation, the right to sell and distribute the Contribution as part of a commercial product under a proprietary license.
|
||||
|
||||
You retain full ownership of the Copyright in Your Contribution. Nothing in this Agreement shall be construed to restrict or transfer Your rights to use Your own Contribution for any purpose.
|
||||
|
||||
## 4. Patent License Grant
|
||||
|
||||
You hereby grant to Us and to recipients of the Material a worldwide, royalty-free, non-exclusive, perpetual, and irrevocable patent license to make, have made, use, sell, offer for sale, import, and otherwise transfer Your Contribution. This license applies to all patents owned or controlled by You, now or in the future, that would be infringed by Your Contribution alone or in combination with the Material.
|
||||
|
||||
## 5. Your Representations
|
||||
|
||||
You represent and warrant that:
|
||||
|
||||
1. The Contribution is Your original work.
|
||||
2. You are legally entitled to grant the licenses in this Agreement.
|
||||
3. If Your employer has rights to intellectual property that You create, You have either (i) received permission from Your employer to make the Contribution on behalf of that employer, or (ii) Your employer has waived such rights for the Contribution.
|
||||
4. To the best of Your knowledge, the Contribution does not violate any third-party rights, including copyright, patent, trademark, or trade secret.
|
||||
|
||||
You agree to notify Us of any facts or circumstances of which you become aware that would make these representations inaccurate in any respect.
|
||||
|
||||
## 6. Our Licensing Rights
|
||||
|
||||
You acknowledge that We may license the Material, including Your Contribution, under different license terms. We intend to distribute a community edition of DiskNext under a free and open-source license. We also reserve the right to distribute a Pro edition and other commercial versions of the Material, including Your Contribution, under a proprietary license at Our sole discretion.
|
||||
|
||||
## 7. Disclaimer of Warranty
|
||||
|
||||
THE CONTRIBUTION IS PROVIDED "AS IS" AND WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
|
||||
## 8. Limitation of Liability
|
||||
|
||||
TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL YOU OR WE BE LIABLE FOR ANY LOSS OF PROFITS, LOSS OF ANTICIPATED SAVINGS, LOSS OF DATA, OR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THIS AGREEMENT, REGARDLESS OF THE LEGAL THEORY UPON WHICH THE CLAIM IS BASED.
|
||||
|
||||
## 9. Term
|
||||
|
||||
This Agreement is effective on the date You accept it and shall continue for the full term of the copyrights and patents licensed herein. This Agreement is irrevocable.
|
||||
|
||||
## 10. Miscellaneous
|
||||
|
||||
**10.1 Governing Law:** This Agreement shall be governed by the laws of the People's Republic of China, excluding its conflict of law provisions.
|
||||
|
||||
**10.2 Entire Agreement:** This Agreement sets out the entire agreement between You and Us for Your Contributions and supersedes all prior communications and understandings.
|
||||
|
||||
**10.3 Assignment:** We may assign Our rights and obligations under this Agreement at Our sole discretion. This Agreement will be binding upon and will inure to the benefit of the parties, their successors, and permitted assigns.
|
||||
|
||||
**10.4 Severability:** If any provision of this Agreement is found to be void or unenforceable, it will be replaced with a provision that comes closest to the meaning of the original and is enforceable.
|
||||
|
||||
---
|
||||
|
||||
## To Accept This Agreement
|
||||
|
||||
Please provide the following information to signify your acceptance.
|
||||
|
||||
### Contributor ("You"):
|
||||
|
||||
- **Date:**
|
||||
- **Full Name:**
|
||||
- **Address:**
|
||||
- **Email:**
|
||||
- **GitHub Username (if applicable):**
|
||||
|
||||
### For DiskNext ("Us"):
|
||||
|
||||
- **Date:**
|
||||
- **[NAME]**
|
||||
- **Owner of DiskNext Org**
|
||||
@@ -4,49 +4,79 @@ from uuid import UUID
|
||||
from fastapi import Depends
|
||||
import jwt
|
||||
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.user import JWTPayload, User, UserStatus
|
||||
from utils import JWT
|
||||
from .dependencies import SessionDep
|
||||
from utils import http_exceptions
|
||||
from service.redis import RedisManager
|
||||
from service.redis.user_ban_store import UserBanStore
|
||||
|
||||
async def auth_required(
|
||||
|
||||
async def jwt_required(
|
||||
session: SessionDep,
|
||||
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
||||
) -> User:
|
||||
) -> JWTPayload:
|
||||
"""
|
||||
AuthRequired 需要登录
|
||||
验证 JWT 并返回 claims。
|
||||
|
||||
封禁检查策略:
|
||||
1. JWT 内嵌 status 检查(签发时快照)
|
||||
2. Redis 黑名单检查(即时封禁,如果 Redis 可用)
|
||||
3. Redis 不可用时查库检查 status(降级方案)
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if user_id is None:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
user_id = UUID(user_id)
|
||||
|
||||
# 从数据库获取用户信息(预加载 group 关系)
|
||||
user = await User.get(session, User.id == user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
return user
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
claims = JWTPayload(
|
||||
sub=payload["sub"],
|
||||
jti=payload["jti"],
|
||||
status=payload["status"],
|
||||
group=payload["group"],
|
||||
)
|
||||
except (jwt.InvalidTokenError, KeyError, ValueError):
|
||||
http_exceptions.raise_unauthorized("凭据过期或无效")
|
||||
|
||||
# 1. JWT 内嵌 status 检查
|
||||
if claims.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 2. 即时封禁检查
|
||||
user_id_str = str(claims.sub)
|
||||
if RedisManager.is_available():
|
||||
# Redis 可用:查黑名单
|
||||
if await UserBanStore.is_banned(user_id_str):
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
else:
|
||||
# Redis 不可用:查库(仅 status 字段,不加载关系)
|
||||
user = await User.get(session, User.id == claims.sub)
|
||||
if not user or user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
async def admin_required(
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> User:
|
||||
claims: Annotated[JWTPayload, Depends(jwt_required)],
|
||||
) -> JWTPayload:
|
||||
"""
|
||||
验证是否为管理员。
|
||||
验证管理员权限(仅读取 JWT claims,不查库)。
|
||||
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(admin_required)])
|
||||
"""
|
||||
if user.group.admin:
|
||||
if not claims.group.admin:
|
||||
http_exceptions.raise_forbidden("Admin Required")
|
||||
return claims
|
||||
|
||||
|
||||
async def auth_required(
|
||||
session: SessionDep,
|
||||
claims: Annotated[JWTPayload, Depends(jwt_required)],
|
||||
) -> User:
|
||||
"""验证 JWT + 从数据库加载完整 User(含 group 关系)"""
|
||||
user = await User.get(session, User.id == claims.sub, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
return user
|
||||
raise http_exceptions.raise_forbidden("Admin Required")
|
||||
|
||||
|
||||
def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
|
||||
|
||||
@@ -6,12 +6,14 @@ FastAPI 依赖注入
|
||||
- TimeFilterRequestDep: 时间筛选查询依赖(用于 count 等统计接口)
|
||||
- TableViewRequestDep: 分页排序查询依赖(包含时间筛选 + 分页排序)
|
||||
- UserFilterParamsDep: 用户筛选参数依赖(用于管理员用户列表)
|
||||
- require_captcha: 验证码校验依赖注入工厂
|
||||
"""
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from fastapi import Depends, Form, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
@@ -94,3 +96,30 @@ async def _get_user_filter_params(
|
||||
|
||||
UserFilterParamsDep: TypeAlias = Annotated[UserFilterParams, Depends(_get_user_filter_params)]
|
||||
"""获取用户筛选参数的依赖(用于管理员用户列表)"""
|
||||
|
||||
|
||||
# --- 验证码校验依赖 ---
|
||||
|
||||
def require_captcha(scene: 'CaptchaScene') -> Callable[..., Awaitable[None]]:
|
||||
"""
|
||||
验证码校验依赖注入工厂。
|
||||
|
||||
根据场景查询数据库设置,判断是否需要验证码。
|
||||
需要则校验前端提交的 captcha_code,失败则抛出异常。
|
||||
|
||||
使用方式::
|
||||
|
||||
@router.post('/session', dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))])
|
||||
async def login(...): ...
|
||||
|
||||
:param scene: 验证码使用场景(LOGIN / REGISTER / FORGET)
|
||||
"""
|
||||
from service.captcha import CaptchaScene, verify_captcha_if_needed
|
||||
|
||||
async def _verify_captcha(
|
||||
session: SessionDep,
|
||||
captcha_code: Annotated[str | None, Form()] = None,
|
||||
) -> None:
|
||||
await verify_captcha_if_needed(session, scene, captcha_code)
|
||||
|
||||
return _verify_captcha
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
Policy, PolicyType, User, ListResponse,
|
||||
JWTPayload, Policy, PolicyType, User, ListResponse,
|
||||
Object, ObjectType, AdminFileResponse, FileBanRequest, )
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
@@ -164,14 +164,13 @@ async def router_admin_preview_file(
|
||||
path='/ban/{file_id}',
|
||||
summary='封禁/解禁文件',
|
||||
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_ban_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
request: FileBanRequest,
|
||||
admin: Annotated[User, Depends(admin_required)],
|
||||
claims: Annotated[JWTPayload, Depends(admin_required)],
|
||||
) -> None:
|
||||
"""
|
||||
封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
|
||||
@@ -180,14 +179,14 @@ async def router_admin_ban_file(
|
||||
:param session: 数据库会话
|
||||
:param file_id: 文件UUID
|
||||
:param request: 封禁请求
|
||||
:param admin: 当前管理员
|
||||
:param claims: 当前管理员 JWT claims
|
||||
:return: 封禁结果
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
count = await _set_ban_recursive(session, file_obj, request.ban, admin.id, request.reason)
|
||||
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
||||
|
||||
action = "封禁" if request.ban else "解禁"
|
||||
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")
|
||||
|
||||
@@ -6,13 +6,14 @@ from sqlalchemy import func
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
|
||||
from service.redis.user_ban_store import UserBanStore
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, Object, ObjectType, Setting, SettingsType,
|
||||
BatchDeleteRequest,
|
||||
)
|
||||
from sqlmodels.user import (
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse,
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
|
||||
)
|
||||
from utils import Password, http_exceptions
|
||||
|
||||
@@ -159,11 +160,21 @@ async def router_admin_update_user(
|
||||
if len(update_data['two_factor']) != 32:
|
||||
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
|
||||
|
||||
# 记录旧 status 以便检测变更
|
||||
old_status = user.status
|
||||
|
||||
# 更新字段
|
||||
for key, value in update_data.items():
|
||||
setattr(user, key, value)
|
||||
user = await user.save(session)
|
||||
|
||||
# 封禁状态变更 → 更新 BanStore
|
||||
new_status = user.status
|
||||
if old_status == UserStatus.ACTIVE and new_status != UserStatus.ACTIVE:
|
||||
await UserBanStore.ban(str(user_id))
|
||||
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
|
||||
await UserBanStore.unban(str(user_id))
|
||||
|
||||
l.info(f"管理员更新了用户: {request.email}")
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@ from typing import Annotated, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from loguru import logger
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
@@ -11,7 +10,9 @@ from webauthn.helpers import options_to_json_dict
|
||||
import service
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from middleware.dependencies import SessionDep, require_captcha
|
||||
from service.captcha import CaptchaScene
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from .settings import user_settings_router
|
||||
|
||||
@@ -22,48 +23,60 @@ user_router = APIRouter(
|
||||
|
||||
user_router.include_router(user_settings_router)
|
||||
|
||||
class OAuth2PasswordWithExtrasForm:
|
||||
"""
|
||||
扩展 OAuth2 密码表单。
|
||||
|
||||
在标准 username/password 基础上添加 otp_code 字段。
|
||||
captcha_code 由 require_captcha 依赖注入单独处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
username: Annotated[str, Form()],
|
||||
password: Annotated[str, Form()],
|
||||
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
|
||||
):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.otp_code = otp_code
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
summary='用户登录',
|
||||
description='User login endpoint. 当用户启用两步验证时,需要传入 otp 参数。',
|
||||
description='用户登录端点,支持验证码校验和两步验证。',
|
||||
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
|
||||
)
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
用户登录端点。
|
||||
用户登录端点
|
||||
|
||||
根据 OAuth2.1 规范,使用 password grant type 进行登录。
|
||||
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。
|
||||
表单字段:
|
||||
- username: 用户邮箱
|
||||
- password: 用户密码
|
||||
- captcha_code: 验证码 token(可选,由 require_captcha 依赖校验)
|
||||
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
|
||||
|
||||
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
|
||||
错误处理:
|
||||
- 400: 需要验证码但未提供
|
||||
- 401: 邮箱/密码错误,或 2FA 验证码错误
|
||||
- 403: 账户已禁用 / 验证码验证失败
|
||||
- 428: 需要两步验证但未提供 otp_code
|
||||
"""
|
||||
email = form_data.username # OAuth2 表单字段名为 username,实际传入的是 email
|
||||
password = form_data.password
|
||||
|
||||
# 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式)
|
||||
# scopes 格式可以是 ["otp:123456"] 或 ["123456"]
|
||||
otp_code: str | None = None
|
||||
for scope in form_data.scopes:
|
||||
if scope.startswith("otp:"):
|
||||
otp_code = scope[4:]
|
||||
break
|
||||
elif scope.isdigit() and len(scope) == 6:
|
||||
otp_code = scope
|
||||
break
|
||||
|
||||
result = await service.user.login(
|
||||
return await service.user.login(
|
||||
session,
|
||||
sqlmodels.LoginRequest(
|
||||
email=email,
|
||||
password=password,
|
||||
two_fa_code=otp_code,
|
||||
email=form_data.username,
|
||||
password=form_data.password,
|
||||
two_fa_code=form_data.otp_code,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@user_router.post(
|
||||
path='/session/refresh',
|
||||
summary="用刷新令牌刷新会话",
|
||||
@@ -101,17 +114,27 @@ async def router_user_session_refresh(
|
||||
http_exceptions.raise_unauthorized("令牌缺少用户标识")
|
||||
|
||||
user_id = UUID(user_id_str)
|
||||
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id)
|
||||
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id, load=sqlmodels.User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
|
||||
if not user.status:
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 加载 GroupOptions(获取最新权限)
|
||||
group_options = await sqlmodels.GroupOptions.get(
|
||||
session,
|
||||
sqlmodels.GroupOptions.group_id == user.group_id,
|
||||
)
|
||||
user.group.options = group_options
|
||||
group_claims = sqlmodels.GroupClaims.from_group(user.group)
|
||||
|
||||
# 签发新令牌
|
||||
access_token = JWT.create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
refresh_token = JWT.create_refresh_token(
|
||||
sub=user.id,
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
import abc
|
||||
from enum import StrEnum
|
||||
|
||||
import aiohttp
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .gcaptcha import GCaptcha
|
||||
from .turnstile import TurnstileCaptcha
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class CaptchaRequestBase(BaseModel):
|
||||
"""验证码验证请求"""
|
||||
token: str
|
||||
"""验证 token"""
|
||||
|
||||
response: str
|
||||
"""用户的验证码 response token"""
|
||||
|
||||
secret: str
|
||||
"""验证密钥"""
|
||||
"""服务端密钥"""
|
||||
|
||||
|
||||
class CaptchaBase(abc.ABC):
|
||||
@@ -30,10 +32,89 @@ class CaptchaBase(abc.ABC):
|
||||
"""
|
||||
payload = request.model_dump()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.verify_url, data=payload) as response:
|
||||
if response.status != 200:
|
||||
async with aiohttp.ClientSession() as client_session:
|
||||
async with client_session.post(self.verify_url, data=payload) as resp:
|
||||
if resp.status != 200:
|
||||
return False
|
||||
|
||||
result = await response.json()
|
||||
result = await resp.json()
|
||||
return result.get('success', False)
|
||||
|
||||
|
||||
# 子类导入必须在 CaptchaBase 定义之后(gcaptcha.py / turnstile.py 依赖 CaptchaBase)
|
||||
from .gcaptcha import GCaptcha # noqa: E402
|
||||
from .turnstile import TurnstileCaptcha # noqa: E402
|
||||
|
||||
|
||||
class CaptchaScene(StrEnum):
|
||||
"""验证码使用场景,value 对应 Setting 表中的 name"""
|
||||
|
||||
LOGIN = "login_captcha"
|
||||
REGISTER = "reg_captcha"
|
||||
FORGET = "forget_captcha"
|
||||
|
||||
|
||||
async def verify_captcha_if_needed(
|
||||
session: AsyncSession,
|
||||
scene: CaptchaScene,
|
||||
captcha_code: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
通用验证码校验:查询设置判断是否需要,需要则校验。
|
||||
|
||||
:param session: 数据库异步会话
|
||||
:param scene: 验证码使用场景
|
||||
:param captcha_code: 用户提交的验证码 response token
|
||||
:raises HTTPException 400: 需要验证码但未提供
|
||||
:raises HTTPException 403: 验证码验证失败
|
||||
:raises HTTPException 500: 验证码密钥未配置
|
||||
"""
|
||||
from sqlmodels import Setting, SettingsType
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
# 1. 查询该场景是否需要验证码
|
||||
scene_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.LOGIN) & (Setting.name == scene.value),
|
||||
)
|
||||
if not scene_setting or scene_setting.value != "1":
|
||||
return
|
||||
|
||||
# 2. 需要但未提供
|
||||
if not captcha_code:
|
||||
http_exceptions.raise_bad_request(detail="请完成验证码验证")
|
||||
|
||||
# 3. 查询验证码类型和密钥
|
||||
captcha_settings: list[Setting] = await Setting.get(
|
||||
session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all",
|
||||
)
|
||||
s: dict[str, str | None] = {item.name: item.value for item in captcha_settings}
|
||||
captcha_type = CaptchaType(s.get("captcha_type") or "default")
|
||||
|
||||
# 4. DEFAULT 图片验证码尚未实现,跳过
|
||||
if captcha_type == CaptchaType.DEFAULT:
|
||||
l.warning("DEFAULT 图片验证码尚未实现,跳过验证")
|
||||
return
|
||||
|
||||
# 5. 选择验证器和密钥
|
||||
if captcha_type == CaptchaType.GCAPTCHA:
|
||||
secret = s.get("captcha_ReCaptchaSecret")
|
||||
verifier: CaptchaBase = GCaptcha()
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
secret = s.get("captcha_CloudflareSecret")
|
||||
verifier = TurnstileCaptcha()
|
||||
else:
|
||||
l.error(f"未知的验证码类型: {captcha_type}")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
if not secret:
|
||||
l.error(f"验证码密钥未配置: captcha_type={captcha_type}")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 6. 调用第三方 API 校验
|
||||
is_valid = await verifier.verify_captcha(
|
||||
CaptchaRequestBase(response=captcha_code, secret=secret)
|
||||
)
|
||||
if not is_valid:
|
||||
http_exceptions.raise_forbidden(detail="验证码验证失败")
|
||||
|
||||
72
service/redis/user_ban_store.py
Normal file
72
service/redis/user_ban_store.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
用户封禁状态存储
|
||||
|
||||
用于 JWT 模式下的即时封禁生效。
|
||||
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from cachetools import TTLCache
|
||||
from loguru import logger as l
|
||||
|
||||
from . import RedisManager
|
||||
|
||||
# access_token 有效期(秒)
|
||||
_BAN_TTL: int = 3600
|
||||
|
||||
|
||||
class UserBanStore:
|
||||
"""
|
||||
用户封禁状态存储
|
||||
|
||||
管理员封禁用户时调用 ban(),jwt_required 每次请求调用 is_banned() 检查。
|
||||
TTL 与 access_token 有效期一致(1h),过期后旧 token 自然失效,无需继续记录。
|
||||
"""
|
||||
|
||||
_memory_cache: ClassVar[TTLCache[str, bool]] = TTLCache(maxsize=10000, ttl=_BAN_TTL)
|
||||
"""内存缓存降级方案"""
|
||||
|
||||
@classmethod
|
||||
async def ban(cls, user_id: str) -> None:
|
||||
"""
|
||||
标记用户为已封禁。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
await client.set(key, "1", ex=_BAN_TTL)
|
||||
else:
|
||||
cls._memory_cache[user_id] = True
|
||||
l.info(f"用户 {user_id} 已加入封禁黑名单")
|
||||
|
||||
@classmethod
|
||||
async def unban(cls, user_id: str) -> None:
|
||||
"""
|
||||
移除用户封禁标记(解封时调用)。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
await client.delete(key)
|
||||
else:
|
||||
cls._memory_cache.pop(user_id, None)
|
||||
l.info(f"用户 {user_id} 已从封禁黑名单移除")
|
||||
|
||||
@classmethod
|
||||
async def is_banned(cls, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否在封禁黑名单中。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
:return: True 表示已封禁
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
return await client.exists(key) > 0
|
||||
else:
|
||||
return user_id in cls._memory_cache
|
||||
@@ -4,6 +4,8 @@ from loguru import logger
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import LoginRequest, TokenResponse, User
|
||||
from sqlmodels.group import GroupClaims, GroupOptions
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import http_exceptions
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
@@ -22,15 +24,13 @@ async def login(
|
||||
|
||||
:return: TokenResponse 对象或状态码或 None
|
||||
"""
|
||||
# TODO: 验证码校验
|
||||
# captcha_setting = await Setting.get(
|
||||
# session,
|
||||
# (Setting.type == "auth") & (Setting.name == "login_captcha")
|
||||
# )
|
||||
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
||||
|
||||
# 获取用户信息
|
||||
current_user: User = await User.get(session, User.email == login_request.email, fetch_mode="first") #type: ignore
|
||||
# 获取用户信息(预加载 group 关系)
|
||||
current_user: User = await User.get(
|
||||
session,
|
||||
User.email == login_request.email,
|
||||
fetch_mode="first",
|
||||
load=User.group,
|
||||
) #type: ignore
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
@@ -42,8 +42,8 @@ async def login(
|
||||
logger.debug(f"Password verification failed for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
|
||||
# 验证用户是否可登录
|
||||
if not current_user.status:
|
||||
# 验证用户是否可登录(修复:显式枚举比较,StrEnum 永远 truthy)
|
||||
if current_user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("Your account is disabled")
|
||||
|
||||
# 检查两步验证
|
||||
@@ -58,10 +58,22 @@ async def login(
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
||||
|
||||
# 加载 GroupOptions
|
||||
group_options: GroupOptions | None = await GroupOptions.get(
|
||||
session,
|
||||
GroupOptions.group_id == current_user.group_id,
|
||||
)
|
||||
|
||||
# 构建权限快照
|
||||
current_user.group.options = group_options
|
||||
group_claims = GroupClaims.from_group(current_user.group)
|
||||
|
||||
# 创建令牌
|
||||
access_token = create_access_token(
|
||||
sub=current_user.id,
|
||||
jti=uuid4()
|
||||
jti=uuid4(),
|
||||
status=current_user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
sub=current_user.id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .user import (
|
||||
BatchDeleteRequest,
|
||||
JWTPayload,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
RegisterRequest,
|
||||
@@ -37,7 +38,7 @@ from .node import (
|
||||
NodeType,
|
||||
)
|
||||
from .group import (
|
||||
Group, GroupBase, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
|
||||
Group, GroupBase, GroupClaims, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
|
||||
# 管理员DTO
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse,
|
||||
)
|
||||
|
||||
@@ -188,6 +188,28 @@ class GroupListResponse(SQLModelBase):
|
||||
"""总数"""
|
||||
|
||||
|
||||
class GroupClaims(GroupCoreBase, GroupAllOptionsBase):
|
||||
"""
|
||||
JWT 中的用户组权限快照。
|
||||
|
||||
复用 GroupCoreBase(id, name, max_storage, share_enabled, web_dav_enabled, admin, speed_limit)
|
||||
和 GroupAllOptionsBase(share_download, share_free, ... 共 11 个功能开关)。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_group(cls, group: "Group") -> "GroupClaims":
|
||||
"""
|
||||
从 Group ORM 对象(需预加载 options 关系)构建权限快照。
|
||||
|
||||
:param group: 已加载 options 的 Group 对象
|
||||
"""
|
||||
opts = group.options
|
||||
return cls(
|
||||
**GroupCoreBase.model_validate(group, from_attributes=True).model_dump(),
|
||||
**(GroupAllOptionsBase.model_validate(opts, from_attributes=True).model_dump() if opts else {}),
|
||||
)
|
||||
|
||||
|
||||
class GroupResponse(GroupBase, GroupOptionsBase):
|
||||
"""用户组响应 DTO"""
|
||||
|
||||
|
||||
@@ -29,6 +29,10 @@ default_settings: list[Setting] = [
|
||||
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
|
||||
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
|
||||
Setting(name="site_notice", value="", type=SettingsType.BASIC),
|
||||
Setting(name="footer_code", value="", type=SettingsType.BASIC),
|
||||
Setting(name="tos_url", value="", type=SettingsType.BASIC),
|
||||
Setting(name="privacy_url", value="", type=SettingsType.BASIC),
|
||||
Setting(name="fromName", value="DiskNext", type=SettingsType.MAIL),
|
||||
Setting(name="mail_keepalive", value="30", type=SettingsType.MAIL),
|
||||
Setting(name="fromAdress", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
|
||||
|
||||
@@ -99,7 +99,7 @@ class LoginRequest(SQLModelBase):
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
|
||||
two_fa_code: int | None = Field(min_length=6, max_length=6)
|
||||
two_fa_code: int | None = Field(default=None, min_length=6, max_length=6)
|
||||
"""两步验证代码"""
|
||||
|
||||
|
||||
@@ -151,6 +151,22 @@ class WebAuthnInfo(SQLModelBase):
|
||||
transports: list[str]
|
||||
"""支持的传输方式"""
|
||||
|
||||
class JWTPayload(SQLModelBase):
|
||||
"""JWT 访问令牌解析后的 claims"""
|
||||
|
||||
sub: UUID
|
||||
"""用户 ID"""
|
||||
|
||||
jti: UUID
|
||||
"""令牌唯一标识符"""
|
||||
|
||||
status: UserStatus
|
||||
"""用户状态"""
|
||||
|
||||
group: "GroupClaims"
|
||||
"""用户组权限快照"""
|
||||
|
||||
|
||||
class AccessTokenBase(BaseModel):
|
||||
"""访问令牌响应 DTO"""
|
||||
|
||||
@@ -370,10 +386,11 @@ class UserAdminDetailResponse(UserPublic):
|
||||
|
||||
|
||||
# 前向引用导入
|
||||
from .group import GroupResponse # noqa: E402
|
||||
from .group import GroupClaims, GroupResponse # noqa: E402
|
||||
from .user_authn import AuthnResponse # noqa: E402
|
||||
|
||||
# 更新前向引用
|
||||
JWTPayload.model_rebuild()
|
||||
UserResponse.model_rebuild()
|
||||
UserSettingResponse.model_rebuild()
|
||||
|
||||
|
||||
@@ -24,12 +24,12 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')
|
||||
|
||||
from main import app
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
from sqlmodels.group import Group, GroupClaims, GroupOptions
|
||||
from sqlmodels.migration import migration
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.user import User
|
||||
from utils.JWT.JWT import create_access_token
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from utils.JWT import create_access_token
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=100,
|
||||
group_id=group.id,
|
||||
@@ -211,14 +211,24 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
await root_folder.save(db_session)
|
||||
|
||||
# 构建权限快照
|
||||
group.options = group_options
|
||||
group_claims = GroupClaims.from_group(group)
|
||||
|
||||
# 生成访问令牌
|
||||
access_token, _ = create_access_token({"sub": str(user.id)})
|
||||
from uuid import uuid4
|
||||
access_token_obj = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"token": access_token_obj.access_token,
|
||||
"group_id": group.id,
|
||||
"policy_id": policy.id,
|
||||
}
|
||||
@@ -270,7 +280,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
email="admin@disknext.local",
|
||||
nickname="管理员",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=9999,
|
||||
group_id=admin_group.id,
|
||||
@@ -288,14 +298,24 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
await root_folder.save(db_session)
|
||||
|
||||
# 构建权限快照
|
||||
admin_group.options = admin_group_options
|
||||
admin_group_claims = GroupClaims.from_group(admin_group)
|
||||
|
||||
# 生成访问令牌
|
||||
access_token, _ = create_access_token({"sub": str(admin.id)})
|
||||
from uuid import uuid4
|
||||
access_token_obj = create_access_token(
|
||||
sub=admin.id,
|
||||
jti=uuid4(),
|
||||
status=admin.status.value,
|
||||
group=admin_group_claims,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": admin.id,
|
||||
"email": admin.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"token": access_token_obj.access_token,
|
||||
"group_id": admin_group.id,
|
||||
"policy_id": policy.id,
|
||||
}
|
||||
|
||||
@@ -22,10 +22,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from main import app
|
||||
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
from utils.JWT import JWT
|
||||
import utils.JWT as JWT
|
||||
|
||||
|
||||
# ==================== 事件循环配置 ====================
|
||||
@@ -184,12 +185,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
email="testuser@test.local",
|
||||
password=Password.hash("testpass123"),
|
||||
nickname="测试用户",
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=default_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(test_user)
|
||||
|
||||
@@ -198,12 +198,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
email="admin@disknext.local",
|
||||
password=Password.hash("adminpass123"),
|
||||
nickname="管理员",
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=admin_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(admin_user)
|
||||
|
||||
@@ -212,12 +211,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("banned123"),
|
||||
nickname="封禁用户",
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=default_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(banned_user)
|
||||
|
||||
@@ -256,6 +254,10 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 8. 设置JWT密钥(从数据库加载)
|
||||
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
|
||||
|
||||
# 刷新 group options
|
||||
await test_session.refresh(default_group_options)
|
||||
await test_session.refresh(admin_group_options)
|
||||
|
||||
return test_session
|
||||
|
||||
|
||||
@@ -290,34 +292,68 @@ def banned_user_info() -> dict[str, str]:
|
||||
|
||||
# ==================== JWT Token ====================
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_token(test_user_info: dict[str, str]) -> str:
|
||||
def _build_group_claims(group: Group, group_options: GroupOptions | None) -> GroupClaims:
|
||||
"""从 Group 对象构建 GroupClaims"""
|
||||
group.options = group_options
|
||||
return GroupClaims.from_group(group)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_token(initialized_db: AsyncSession) -> str:
|
||||
"""生成测试用户的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": test_user_info["email"]},
|
||||
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
group = await Group.get(initialized_db, Group.id == user.group_id)
|
||||
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
|
||||
group_claims = _build_group_claims(group, group_options)
|
||||
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
return result.access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_token(admin_user_info: dict[str, str]) -> str:
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_user_token(initialized_db: AsyncSession) -> str:
|
||||
"""生成管理员的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": admin_user_info["email"]},
|
||||
user = await User.get(initialized_db, User.email == "admin@disknext.local")
|
||||
group = await Group.get(initialized_db, Group.id == user.group_id)
|
||||
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
|
||||
group_claims = _build_group_claims(group, group_options)
|
||||
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
return result.access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expired_token() -> str:
|
||||
"""生成过期的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "testuser@test.local"},
|
||||
expires_delta=timedelta(seconds=-1), # 已过期
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
return token
|
||||
result = create_access_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
return result.access_token
|
||||
|
||||
|
||||
# ==================== 认证头 ====================
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""
|
||||
认证中间件集成测试
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from datetime import timedelta
|
||||
|
||||
from utils.JWT import JWT
|
||||
from sqlmodels.group import GroupClaims
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
import utils.JWT as JWT
|
||||
|
||||
|
||||
# ==================== AuthRequired 测试 ====================
|
||||
@@ -66,11 +70,14 @@ async def test_auth_required_valid_token(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
"""测试缺少sub字段的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"other_field": "value"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
)
|
||||
"""测试缺少必要字段的token返回 401"""
|
||||
import jwt as pyjwt
|
||||
# 手动构建一个缺少 status 和 group 的 token
|
||||
payload = {
|
||||
"other_field": "value",
|
||||
"exp": int((__import__('datetime').datetime.now(__import__('datetime').timezone.utc) + timedelta(hours=1)).timestamp()),
|
||||
}
|
||||
token = pyjwt.encode(payload, JWT.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/user/me",
|
||||
@@ -81,16 +88,29 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
|
||||
"""测试用户不存在的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "nonexistent_user@test.local"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
"""测试用户不存在的token返回 403 或 401(取决于 Redis 可用性)"""
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
result = create_access_token(
|
||||
sub=uuid4(), # 不存在的用户 UUID
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/user/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
headers={"Authorization": f"Bearer {result.access_token}"}
|
||||
)
|
||||
# auth_required 会查库,用户不存在时返回 401
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@@ -234,23 +254,36 @@ async def test_auth_on_storage_endpoint(
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_format(test_user_info: dict[str, str]):
|
||||
"""测试刷新token格式正确"""
|
||||
refresh_token, _ = JWT.create_refresh_token(
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(days=7)
|
||||
result = create_refresh_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
expires_delta=timedelta(days=7),
|
||||
)
|
||||
|
||||
assert isinstance(refresh_token, str)
|
||||
assert len(refresh_token) > 0
|
||||
assert isinstance(result.refresh_token, str)
|
||||
assert len(result.refresh_token) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_token_format(test_user_info: dict[str, str]):
|
||||
"""测试访问token格式正确"""
|
||||
access_token, expires = JWT.create_access_token(
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1)
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
result = create_access_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert isinstance(access_token, str)
|
||||
assert len(access_token) > 0
|
||||
assert expires is not None
|
||||
assert isinstance(result.access_token, str)
|
||||
assert len(result.access_token) > 0
|
||||
assert result.access_expires is not None
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User, ThemeType, UserPublic
|
||||
from sqlmodels.user import User, ThemeType, UserPublic, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ async def test_user_create(db_session: AsyncSession):
|
||||
assert user.id is not None
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.nickname == "测试用户"
|
||||
assert user.status is True
|
||||
assert user.status == UserStatus.ACTIVE
|
||||
assert user.storage == 0
|
||||
assert user.score == 0
|
||||
|
||||
@@ -131,7 +131,7 @@ async def test_user_status_default(db_session: AsyncSession):
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.status is True
|
||||
assert user.status == UserStatus.ACTIVE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -4,7 +4,7 @@ Login 服务的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
from service.user.login import login
|
||||
from utils.password.pwd import Password
|
||||
@@ -22,7 +22,7 @@ async def setup_user(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="loginuser@test.local",
|
||||
password=Password.hash(plain_password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -43,7 +43,7 @@ async def setup_banned_user(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED, # 封禁状态
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -63,7 +63,7 @@ async def setup_2fa_user(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="2fauser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
two_factor=secret,
|
||||
group_id=group.id
|
||||
)
|
||||
|
||||
@@ -1,49 +1,86 @@
|
||||
"""
|
||||
JWT 工具的单元测试
|
||||
"""
|
||||
import time
|
||||
from datetime import timedelta, datetime, timezone
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
|
||||
from utils.JWT.JWT import create_access_token, create_refresh_token, SECRET_KEY
|
||||
from sqlmodels.group import GroupClaims
|
||||
from utils.JWT import create_access_token, create_refresh_token, build_token_payload
|
||||
|
||||
|
||||
# 测试用的 GroupClaims
|
||||
def _make_group_claims(admin: bool = False) -> GroupClaims:
|
||||
return GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=1073741824,
|
||||
share_enabled=True,
|
||||
web_dav_enabled=False,
|
||||
admin=admin,
|
||||
speed_limit=0,
|
||||
)
|
||||
|
||||
|
||||
# 设置测试用的密钥
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_secret_key():
|
||||
"""为测试设置密钥"""
|
||||
import utils.JWT.JWT as jwt_module
|
||||
import utils.JWT as jwt_module
|
||||
jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests"
|
||||
yield
|
||||
# 测试后恢复(虽然在单元测试中不太重要)
|
||||
|
||||
|
||||
def test_create_access_token():
|
||||
"""测试访问令牌创建"""
|
||||
data = {"sub": "testuser", "role": "user"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, expire_time = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(expire_time, datetime)
|
||||
assert isinstance(result.access_token, str)
|
||||
assert isinstance(result.access_expires, datetime)
|
||||
|
||||
# 解码验证
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == "testuser"
|
||||
assert decoded["role"] == "user"
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == str(sub)
|
||||
assert decoded["jti"] == str(jti)
|
||||
assert decoded["status"] == "active"
|
||||
assert decoded["group"]["admin"] is False
|
||||
assert "exp" in decoded
|
||||
|
||||
|
||||
def test_create_access_token_custom_expiry():
|
||||
"""测试自定义过期时间"""
|
||||
data = {"sub": "testuser"}
|
||||
custom_expiry = timedelta(hours=1)
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
custom_expiry = timedelta(minutes=30)
|
||||
|
||||
token, expire_time = create_access_token(data, expires_delta=custom_expiry)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group, expires_delta=custom_expiry)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是30分钟后
|
||||
exp_timestamp = decoded["exp"]
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# 允许1秒误差
|
||||
assert abs(exp_timestamp - now_timestamp - 1800) < 1
|
||||
|
||||
|
||||
def test_create_access_token_default_expiry():
|
||||
"""测试访问令牌默认1小时过期"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是1小时后
|
||||
exp_timestamp = decoded["exp"]
|
||||
@@ -55,27 +92,29 @@ def test_create_access_token_custom_expiry():
|
||||
|
||||
def test_create_refresh_token():
|
||||
"""测试刷新令牌创建"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, expire_time = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(expire_time, datetime)
|
||||
assert isinstance(result.refresh_token, str)
|
||||
assert isinstance(result.refresh_expires, datetime)
|
||||
|
||||
# 解码验证
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == "testuser"
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == str(sub)
|
||||
assert decoded["token_type"] == "refresh"
|
||||
assert "exp" in decoded
|
||||
|
||||
|
||||
def test_create_refresh_token_default_expiry():
|
||||
"""测试刷新令牌默认30天过期"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, expire_time = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是30天后
|
||||
exp_timestamp = decoded["exp"]
|
||||
@@ -86,78 +125,72 @@ def test_create_refresh_token_default_expiry():
|
||||
assert abs(exp_timestamp - now_timestamp - 2592000) < 1
|
||||
|
||||
|
||||
def test_token_decode():
|
||||
"""测试令牌解码"""
|
||||
data = {"sub": "user123", "email": "user@example.com"}
|
||||
def test_access_token_contains_group_claims():
|
||||
"""测试访问令牌包含完整的 group claims"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims(admin=True)
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
# 解码
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["sub"] == "user123"
|
||||
assert decoded["email"] == "user@example.com"
|
||||
|
||||
|
||||
def test_token_expired():
|
||||
"""测试令牌过期"""
|
||||
data = {"sub": "testuser"}
|
||||
|
||||
# 创建一个立即过期的令牌
|
||||
token, _ = create_access_token(data, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
# 尝试解码应该抛出过期异常
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
|
||||
def test_token_invalid_signature():
|
||||
"""测试无效签名"""
|
||||
data = {"sub": "testuser"}
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
|
||||
# 使用错误的密钥解码
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
pyjwt.decode(token, "wrong_secret_key", algorithms=["HS256"])
|
||||
assert decoded["group"]["admin"] is True
|
||||
assert decoded["group"]["name"] == "测试组"
|
||||
assert decoded["group"]["max_storage"] == 1073741824
|
||||
assert decoded["group"]["share_enabled"] is True
|
||||
|
||||
|
||||
def test_access_token_does_not_have_token_type():
|
||||
"""测试访问令牌不包含 token_type"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert "token_type" not in decoded
|
||||
|
||||
|
||||
def test_refresh_token_has_token_type():
|
||||
"""测试刷新令牌包含 token_type"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, _ = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["token_type"] == "refresh"
|
||||
|
||||
|
||||
def test_token_payload_preserved():
|
||||
"""测试自定义负载保留"""
|
||||
data = {
|
||||
"sub": "user123",
|
||||
"name": "Test User",
|
||||
"roles": ["admin", "user"],
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
def test_token_expired():
|
||||
"""测试令牌过期"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
# 创建一个立即过期的令牌
|
||||
result = create_access_token(
|
||||
sub=sub, jti=jti, status="active", group=group,
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
# 尝试解码应该抛出过期异常
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["sub"] == "user123"
|
||||
assert decoded["name"] == "Test User"
|
||||
assert decoded["roles"] == ["admin", "user"]
|
||||
assert decoded["metadata"] == {"key": "value"}
|
||||
|
||||
def test_token_invalid_signature():
|
||||
"""测试无效签名"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
# 使用错误的密钥解码
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
pyjwt.decode(result.access_token, "wrong_secret_key", algorithms=["HS256"])
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
@@ -6,6 +7,9 @@ from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from sqlmodels import AccessTokenBase, RefreshTokenBase, TokenResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodels.group import GroupClaims
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
scheme_name='获取 JWT Bearer 令牌',
|
||||
description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
|
||||
@@ -59,7 +63,7 @@ def build_token_payload(
|
||||
elif is_refresh:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=3)
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
to_encode.update({
|
||||
"iat": int(datetime.now(timezone.utc).timestamp()),
|
||||
"exp": int(expire.timestamp())
|
||||
@@ -71,33 +75,36 @@ def build_token_payload(
|
||||
def create_access_token(
|
||||
sub: UUID,
|
||||
jti: UUID,
|
||||
*,
|
||||
status: str,
|
||||
group: "GroupClaims",
|
||||
expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256",
|
||||
**kwargs
|
||||
) -> AccessTokenBase:
|
||||
"""
|
||||
生成访问令牌,默认有效期 3 小时。
|
||||
生成访问令牌,默认有效期 1 小时。
|
||||
|
||||
:param sub: 令牌的主题,通常是用户 ID。
|
||||
:param jti: 令牌的唯一标识符,通常是一个 UUID。
|
||||
:param expires_delta: 过期时间, 缺省时为 3 小时。
|
||||
:param status: 用户状态字符串。
|
||||
:param group: 用户组权限快照。
|
||||
:param expires_delta: 过期时间, 缺省时为 1 小时。
|
||||
:param algorithm: JWT 密钥强度,缺省时为 HS256
|
||||
:param kwargs: 需要放进 JWT Payload 的字段。
|
||||
|
||||
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
|
||||
"""
|
||||
|
||||
data = {"sub": str(sub), "jti": str(jti)}
|
||||
|
||||
# 将额外的字段添加到 Payload 中
|
||||
for key, value in kwargs.items():
|
||||
data[key] = value
|
||||
data = {
|
||||
"sub": str(sub),
|
||||
"jti": str(jti),
|
||||
"status": status,
|
||||
"group": group.model_dump(mode="json"),
|
||||
}
|
||||
|
||||
access_token, expire_at = build_token_payload(
|
||||
data,
|
||||
False,
|
||||
algorithm,
|
||||
expires_delta
|
||||
expires_delta,
|
||||
)
|
||||
return AccessTokenBase(
|
||||
access_token=access_token,
|
||||
|
||||
Reference in New Issue
Block a user