feat: 更新验证码请求模型,添加 Google reCAPTCHA 和 Cloudflare Turnstile 验证功能

refactor: 修改用户状态字段类型,优化用户模型
fix: 修复启动服务的错误提示信息
refactor: 统一认证依赖,替换为 AuthRequired
docs: 添加用户会话刷新接口
This commit is contained in:
2025-12-25 10:26:45 +08:00
parent 16cec42181
commit 44a8959aa5
21 changed files with 138 additions and 83 deletions

View File

@@ -30,5 +30,5 @@ app.include_router(router)
# 防止直接运行 main.py
if __name__ == "__main__":
from loguru import logger
logger.error("请用 fastapi ['dev', 'main'] 命令启动服务")
logger.error("请用 fastapi ['dev', 'run'] 命令启动服务")
exit(1)

View File

@@ -259,7 +259,6 @@ async def init_default_user() -> None:
admin_user = User(
username="admin",
nickname="admin",
status=True,
group_id=admin_group.id,
password=hashed_admin_password,
)

View File

@@ -5,6 +5,12 @@ from sqlmodel import Field
from .base import SQLModelBase
class ResponseBase(SQLModelBase):
"""通用响应模型"""
instance_id: uuid.UUID = Field(default_factory=uuid.uuid4)
"""实例ID用于标识请求的唯一性"""
class MCPMethod(StrEnum):
"""MCP 方法枚举"""
@@ -30,10 +36,4 @@ class MCPResponseBase(MCPBase):
"""MCP 响应模型基础类"""
result: str
"""方法返回结果"""
class ResponseBase(SQLModelBase):
"""通用响应模型"""
instance_id: uuid.UUID = Field(default_factory=uuid.uuid4)
"""实例ID用于标识请求的唯一性"""
"""方法返回结果"""

View File

@@ -69,10 +69,10 @@ class ObjectBase(SQLModelBase):
"""对象名称(文件名或目录名)"""
type: ObjectType
"""对象类型file 或 folder"""
"""对象类型"""
size: int = 0
"""文件大小(字节),目录为 0"""
size: int | None = None
"""文件大小(字节),目录为 None"""
# ==================== DTO 模型 ====================
@@ -93,7 +93,7 @@ class DirectoryCreateRequest(SQLModelBase):
class ObjectMoveRequest(SQLModelBase):
"""移动对象请求 DTO"""
src_ids: list[UUID]
src_ids: UUID | list[UUID]
"""源对象UUID列表"""
dst_id: UUID
@@ -103,7 +103,7 @@ class ObjectMoveRequest(SQLModelBase):
class ObjectDeleteRequest(SQLModelBase):
"""删除对象请求 DTO"""
ids: list[UUID]
ids: UUID | list[UUID]
"""待删除对象UUID列表"""

View File

@@ -1,11 +1,16 @@
from typing import Literal
from enum import StrEnum
from sqlmodel import Field, UniqueConstraint
from sqlmodel import UniqueConstraint
from .base import SQLModelBase
from .mixin import TableBaseMixin
from .user import UserResponse
class CaptchaType(StrEnum):
"""验证码类型枚举"""
DEFAULT = "default"
GCAPTCHA = "gcaptcha"
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
# ==================== DTO 模型 ====================
@@ -24,7 +29,7 @@ class SiteConfigResponse(SQLModelBase):
site_notice: str | None = None
"""网站公告"""
user: dict[str, str | int | bool] = {}
user: UserResponse
"""用户信息"""
logo_light: str | None = None
@@ -33,7 +38,7 @@ class SiteConfigResponse(SQLModelBase):
logo_dark: str | None = None
"""网站Logo URL深色模式"""
captcha_type: Literal["none", "default", "gcaptcha", "cloudflare turnstile"] = "none"
captcha_type: CaptchaType | None = None
"""验证码类型"""
captcha_key: str | None = None
@@ -104,6 +109,11 @@ class Setting(SQLModelBase, TableBaseMixin):
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)
type: SettingsType = Field(max_length=255, description="设置类型/分组")
name: str = Field(max_length=255, description="设置项名称")
value: str | None = Field(default=None, description="设置值")
type: SettingsType
"""设置类型/分组"""
name: str
"""设置项名称"""
value: str | None
"""设置值"""

View File

@@ -6,6 +6,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship
from .base import SQLModelBase
from .model_base import ResponseBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
@@ -110,7 +111,7 @@ class WebAuthnInfo(SQLModelBase):
"""支持的传输方式"""
class TokenResponse(SQLModelBase):
class TokenResponse(ResponseBase):
"""访问令牌响应 DTO"""
access_expires: datetime
@@ -126,7 +127,7 @@ class TokenResponse(SQLModelBase):
"""刷新令牌"""
class UserResponse(UserBase):
class UserResponse(ResponseBase):
"""用户响应 DTO"""
id: UUID
@@ -215,7 +216,7 @@ class UserAdminUpdateRequest(SQLModelBase):
group_id: UUID | None = None
"""用户组UUID"""
status: bool | None = None
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
score: int | None = Field(default=None, ge=0)
@@ -286,8 +287,8 @@ class User(UserBase, UUIDTableBaseMixin):
password: str = Field(max_length=255)
"""用户密码(加密后)"""
status: bool = Field(default=True, sa_column_kwargs={"server_default": "true"})
"""用户状态: True=正常, False=封禁"""
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
"""已用存储空间(字节)"""
@@ -350,7 +351,10 @@ class User(UserBase, UUIDTableBaseMixin):
)
objects: list["Object"] = Relationship(
back_populates="owner",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
sa_relationship_kwargs={
"cascade": "all, delete-orphan",
"foreign_keys": "[Object.owner_id]"
}
)
"""用户的所有对象(文件和目录)"""
orders: list["Order"] = Relationship(

View File

@@ -25,7 +25,7 @@ from models.setting import SettingsUpdateRequest, SettingsGetResponse
from models.object import AdminFileResponse, AdminFileListResponse, FileBanRequest
from models.policy import GroupPolicyLink
from service.storage import DirectoryCreationError, LocalStorageService
from service.password import Password
from utils import Password
class PolicyTestPathRequest(SQLModelBase):

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, Query
from fastapi.responses import PlainTextResponse, RedirectResponse
from middleware.auth import SignRequired
from middleware.auth import AuthRequired
from models import ResponseBase
import service.oauth

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends
from middleware.auth import SignRequired
from middleware.auth import AuthRequired
from models import ResponseBase
download_router = APIRouter(
@@ -18,7 +18,7 @@ download_router.include_router(aria2_router)
path='/url',
summary='创建URL下载任务',
description='Create a URL download task endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_aria2_url() -> ResponseBase:
"""
@@ -33,7 +33,7 @@ def router_aria2_url() -> ResponseBase:
path='/torrent/{id}',
summary='创建种子下载任务',
description='Create a torrent download task endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_aria2_torrent(id: str) -> ResponseBase:
"""
@@ -51,7 +51,7 @@ def router_aria2_torrent(id: str) -> ResponseBase:
path='/select/{gid}',
summary='重新选择要下载的文件',
description='Re-select files to download endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_aria2_select(gid: str) -> ResponseBase:
"""
@@ -69,7 +69,7 @@ def router_aria2_select(gid: str) -> ResponseBase:
path='/task/{gid}',
summary='取消或删除下载任务',
description='Delete a download task endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_aria2_delete(gid: str) -> ResponseBase:
"""
@@ -87,7 +87,7 @@ def router_aria2_delete(gid: str) -> ResponseBase:
'/downloading',
summary='获取正在下载中的任务',
description='Get currently downloading tasks endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_aria2_downloading() -> ResponseBase:
"""
@@ -102,7 +102,7 @@ def router_aria2_downloading() -> ResponseBase:
path='/finished',
summary='获取已完成的任务',
description='Get finished tasks endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_aria2_finished() -> ResponseBase:
"""

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends
from middleware.auth import SignRequired
from middleware.auth import AuthRequired
from models import ResponseBase
share_router = APIRouter(
@@ -225,7 +225,7 @@ def router_share_search_public(keywords: str, type: str = 'all') -> ResponseBase
path='/',
summary='创建新分享',
description='Create a new share endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_share_create() -> ResponseBase:
"""
@@ -240,7 +240,7 @@ def router_share_create() -> ResponseBase:
path='/',
summary='列出我的分享',
description='Get a list of shares.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_share_list() -> ResponseBase:
"""
@@ -255,7 +255,7 @@ def router_share_list() -> ResponseBase:
path='/save/{id}',
summary='转存他人分享',
description='Save another user\'s share by ID.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_share_save(id: str) -> ResponseBase:
"""
@@ -273,7 +273,7 @@ def router_share_save(id: str) -> ResponseBase:
path='/{id}',
summary='更新分享信息',
description='Update share information by ID.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_share_update(id: str) -> ResponseBase:
"""
@@ -291,7 +291,7 @@ def router_share_update(id: str) -> ResponseBase:
path='/{id}',
summary='删除分享',
description='Delete a share by ID.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_share_delete(id: str) -> ResponseBase:
"""

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse
from middleware.auth import SignRequired
from middleware.auth import AuthRequired
from models import ResponseBase
slave_router = APIRouter(
@@ -32,7 +32,7 @@ def router_slave_ping() -> ResponseBase:
path='/post',
summary='上传',
description='Upload data to the server.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_post(data: str) -> ResponseBase:
"""
@@ -68,7 +68,7 @@ def router_slave_download(speed: int, path: str, name: str) -> ResponseBase:
path='/download/{sign}',
summary='根据签名下载文件',
description='Download a file based on its signature.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_download_by_sign(sign: str) -> FileResponse:
"""
@@ -86,7 +86,7 @@ def router_slave_download_by_sign(sign: str) -> FileResponse:
path='/source/{speed}/{path}/{name}',
summary='获取文件外链',
description='Get the external link for a file based on its signature.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_source(speed: int, path: str, name: str) -> ResponseBase:
"""
@@ -106,7 +106,7 @@ def router_slave_source(speed: int, path: str, name: str) -> ResponseBase:
path='/source/{sign}',
summary='根据签名获取文件',
description='Get a file based on its signature.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_source_by_sign(sign: str) -> FileResponse:
"""
@@ -124,7 +124,7 @@ def router_slave_source_by_sign(sign: str) -> FileResponse:
path='/thumb/{id}',
summary='获取缩略图',
description='Get a thumbnail image based on its ID.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_thumb(id: str) -> ResponseBase:
"""
@@ -142,7 +142,7 @@ def router_slave_thumb(id: str) -> ResponseBase:
path='/delete',
summary='删除文件',
description='Delete a file from the server.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_delete(path: str) -> ResponseBase:
"""
@@ -160,7 +160,7 @@ def router_slave_delete(path: str) -> ResponseBase:
path='/test',
summary='测试从机连接Aria2服务',
description='Test the connection to the Aria2 service from the slave.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_aria2_test() -> ResponseBase:
"""
@@ -172,7 +172,7 @@ def router_slave_aria2_test() -> ResponseBase:
path='/get/{gid}',
summary='获取Aria2任务信息',
description='Get information about an Aria2 task by its GID.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_aria2_get(gid: str = None) -> ResponseBase:
"""
@@ -190,7 +190,7 @@ def router_slave_aria2_get(gid: str = None) -> ResponseBase:
path='/add',
summary='添加Aria2任务',
description='Add a new Aria2 task.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> ResponseBase:
"""
@@ -210,7 +210,7 @@ def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> Response
path='/remove/{gid}',
summary='删除Aria2任务',
description='Remove an Aria2 task by its GID.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_slave_aria2_remove(gid: str) -> ResponseBase:
"""

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends
from middleware.auth import SignRequired
from middleware.auth import AuthRequired
from models import ResponseBase
tag_router = APIRouter(
@@ -11,7 +11,7 @@ tag_router = APIRouter(
path='/filter',
summary='创建文件分类标签',
description='Create a file classification tag.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_tag_create_filter() -> ResponseBase:
"""
@@ -26,7 +26,7 @@ def router_tag_create_filter() -> ResponseBase:
path='/link',
summary='创建目录快捷方式标签',
description='Create a directory shortcut tag.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_tag_create_link() -> ResponseBase:
"""
@@ -41,7 +41,7 @@ def router_tag_create_link() -> ResponseBase:
path='/{id}',
summary='删除标签',
description='Delete a tag by its ID.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_tag_delete(id: str) -> ResponseBase:
"""

View File

@@ -88,6 +88,17 @@ async def router_user_session(
else:
raise HTTPException(status_code=500, detail="Internal server error during login")
@user_router.post(
path='/session/refresh',
summary="用刷新令牌刷新会话",
description="Refresh the user session using a refresh token."
)
async def router_user_session_refresh(
session: SessionDep,
request, # RefreshTokenRequest
) -> models.TokenResponse:
...
@user_router.post(
path='/',
summary='用户注册',

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends
from middleware.auth import SignRequired
from middleware.auth import AuthRequired
from models import ResponseBase
vas_router = APIRouter(
@@ -11,7 +11,7 @@ vas_router = APIRouter(
path='/pack',
summary='获取容量包及配额信息',
description='Get information about storage packs and quotas.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_vas_pack() -> ResponseBase:
"""
@@ -26,7 +26,7 @@ def router_vas_pack() -> ResponseBase:
path='/product',
summary='获取商品信息,同时返回支付信息',
description='Get product information along with payment details.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_vas_product() -> ResponseBase:
"""
@@ -41,7 +41,7 @@ def router_vas_product() -> ResponseBase:
path='/order',
summary='新建支付订单',
description='Create an order for a product.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_vas_order() -> ResponseBase:
"""
@@ -56,7 +56,7 @@ def router_vas_order() -> ResponseBase:
path='/order/{id}',
summary='查询订单状态',
description='Get information about a specific payment order by ID.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_vas_order_get(id: str) -> ResponseBase:
"""
@@ -74,7 +74,7 @@ def router_vas_order_get(id: str) -> ResponseBase:
path='/redeem',
summary='获取兑换码信息',
description='Get information about a specific redemption code.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_vas_redeem(code: str) -> ResponseBase:
"""
@@ -92,7 +92,7 @@ def router_vas_redeem(code: str) -> ResponseBase:
path='/redeem',
summary='执行兑换',
description='Redeem a redemption code for a product or service.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_vas_redeem_post() -> ResponseBase:
"""

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, Request
from middleware.auth import SignRequired
from middleware.auth import AuthRequired
from models import ResponseBase
# WebDAV 管理路由
@@ -12,7 +12,7 @@ webdav_router = APIRouter(
path='/accounts',
summary='获取账号信息',
description='Get account information for WebDAV.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_webdav_accounts() -> ResponseBase:
"""
@@ -27,7 +27,7 @@ def router_webdav_accounts() -> ResponseBase:
path='/accounts',
summary='新建账号',
description='Create a new WebDAV account.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_webdav_create_account() -> ResponseBase:
"""
@@ -42,7 +42,7 @@ def router_webdav_create_account() -> ResponseBase:
path='/accounts/{id}',
summary='删除账号',
description='Delete a WebDAV account by its ID.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_webdav_delete_account(id: str) -> ResponseBase:
"""
@@ -60,7 +60,7 @@ def router_webdav_delete_account(id: str) -> ResponseBase:
path='/mount',
summary='新建目录挂载',
description='Create a new WebDAV mount point.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_webdav_create_mount() -> ResponseBase:
"""
@@ -75,7 +75,7 @@ def router_webdav_create_mount() -> ResponseBase:
path='/mount/{id}',
summary='删除目录挂载',
description='Delete a WebDAV mount point by its ID.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_webdav_delete_mount(id: str) -> ResponseBase:
"""
@@ -93,7 +93,7 @@ def router_webdav_delete_mount(id: str) -> ResponseBase:
path='accounts/{id}',
summary='更新账号信息',
description='Update WebDAV account information by ID.',
dependencies=[Depends(SignRequired)],
dependencies=[Depends(AuthRequired)],
)
def router_webdav_update_account(id: str) -> ResponseBase:
"""

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class CaptchaRequestBase(BaseModel):
token: str
secret: str

View File

@@ -1,6 +1,8 @@
import aiohttp
async def verify_captcha(token: str, secret_key: str) -> bool:
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool:
"""
验证 Google reCAPTCHA v2/v3 的 token 是否有效。
@@ -13,10 +15,7 @@ async def verify_captcha(token: str, secret_key: str) -> bool:
:rtype: bool
"""
verify_url = "https://www.google.com/recaptcha/api/siteverify"
payload = {
'secret': secret_key,
'response': token
}
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:

View File

@@ -0,0 +1,26 @@
import aiohttp
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool:
"""
验证 Turnstile 的 token 是否有效。
:param token: 用户提交的 Turnstile token
:type token: str
:param secret_key: Turnstile 的密钥
:type secret_key: str
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)

View File

@@ -54,7 +54,6 @@ async def get_access_token(code: str) -> GithubAccessToken:
'code': code
},
headers={'accept': 'application/json'},
proxy='socks5://127.0.0.1:7890'
) as access_resp:
access_data = await access_resp.json()
return GithubAccessToken(
@@ -73,7 +72,6 @@ async def get_user_info(access_token: str | GithubAccessToken) -> GithubUserInfo
headers={
'accept': 'application/json',
'Authorization': f'token {access_token}'},
proxy='socks5://127.0.0.1:7890'
) as resp:
user_data = await resp.json()
return GithubUserInfoResponse(**user_data)

View File

@@ -7,6 +7,7 @@ oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌',
description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
tokenUrl="/api/v1/user/session",
refreshUrl="/api/v1/user/session/refresh",
)
SECRET_KEY = ''

View File

@@ -1,10 +1,12 @@
from fastapi import FastAPI
from typing import Callable
from contextlib import asynccontextmanager
__on_startup: list[callable] = []
__on_shutdown: list[callable] = []
from fastapi import FastAPI
def add_startup(func: callable):
__on_startup: list[Callable] = []
__on_shutdown: list[Callable] = []
def add_startup(func: Callable):
"""
注册一个函数,在应用启动时调用。
@@ -12,7 +14,7 @@ def add_startup(func: callable):
"""
__on_startup.append(func)
def add_shutdown(func: callable):
def add_shutdown(func: Callable):
"""
注册一个函数,在应用关闭时调用。