Refactor auth and unify error handling in routers

Renamed AuthRequired/AdminRequired to auth_required/admin_required and updated all references. Replaced direct HTTPException usage with utils.http_exceptions for consistent error handling. Updated router endpoints to use new auth dependency and standardized not implemented responses. Cleaned up unused theme fields in SiteConfigResponse and improved site config endpoint. Minor type and import cleanups across routers and middleware.
This commit is contained in:
2025-12-25 19:08:46 +08:00
parent 5835b4c626
commit abd85e2290
24 changed files with 347 additions and 391 deletions

25
.run/开发模式.run.xml Normal file
View File

@@ -0,0 +1,25 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="开发模式" type="Python.FastAPI">
<option name="appName" value="app" />
<option name="file" value="C:\Users\Administrator\Documents\Code\Server\main.py" />
<module name="Server" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="DEBUG" value="true" />
</envs>
<option name="SDK_HOME" value="$PROJECT_DIR$/.venv/Scripts/python.exe" />
<option name="SDK_NAME" value="uv (Server)" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="RUN_TOOL" value="" />
<option name="launchJavascriptDebuger" value="false" />
<method v="2">
<option name="LaunchBrowser.Before.Run" url="http://127.0.0.1:8000/docs" />
</method>
</configuration>
</component>

View File

@@ -1,20 +1,14 @@
from typing import Annotated
from fastapi import Depends, HTTPException
from jwt import InvalidTokenError
from fastapi import Depends
import jwt
from models.user import User
from utils.JWT import JWT
from .dependencies import SessionDep
from utils import http_exceptions
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
async def AuthRequired(
async def auth_required(
session: SessionDep,
token: Annotated[str, Depends(JWT.oauth2_scheme)],
) -> User:
@@ -26,28 +20,28 @@ async def AuthRequired(
username = payload.get("sub")
if username is None:
raise credentials_exception
http_exceptions.raise_unauthorized("账号或密码错误")
# 从数据库获取用户信息
user = await User.get(session, User.username == username)
if not user:
raise credentials_exception
http_exceptions.raise_unauthorized("账号或密码错误")
return user
except InvalidTokenError:
raise credentials_exception
except jwt.InvalidTokenError:
http_exceptions.raise_unauthorized("账号或密码错误")
async def AdminRequired(
user: Annotated[User, Depends(AuthRequired)],
async def admin_required(
user: Annotated[User, Depends(auth_required)],
) -> User:
"""
验证是否为管理员。
使用方法:
>>> APIRouter(dependencies=[Depends(AdminRequired)])
>>> APIRouter(dependencies=[Depends(admin_required)])
"""
group = await user.awaitable_attrs.group
if group.admin:
return user
raise HTTPException(status_code=403, detail="Admin Required")
raise http_exceptions.raise_forbidden("Admin Required")

View File

@@ -1,4 +1,4 @@
from typing import Annotated, AsyncGenerator
from typing import Annotated
from fastapi import Depends
from sqlmodel.ext.asyncio.session import AsyncSession

View File

@@ -20,11 +20,11 @@ class SiteConfigResponse(SQLModelBase):
title: str = "DiskNext"
"""网站标题"""
themes: dict[str, str] = {}
"""网站主题配置"""
# themes: dict[str, str] = {}
# """网站主题配置"""
default_theme: dict[str, str] = {}
"""默认主题RGB色号"""
# default_theme: dict[str, str] = {}
# """默认主题RGB色号"""
site_notice: str | None = None
"""网站公告"""

View File

@@ -24,16 +24,9 @@ from .webdav import webdav_router
router = APIRouter(prefix="/v1")
router.include_router(admin_router)
router.include_router(admin_aria2_router)
router.include_router(admin_file_router)
router.include_router(admin_group_router)
router.include_router(admin_policy_router)
router.include_router(admin_share_router)
router.include_router(admin_task_router)
router.include_router(admin_user_router)
router.include_router(admin_vas_router)
# [TODO] 如果是主机,导入下面的路由
router.include_router(admin_router)
router.include_router(callback_router)
router.include_router(directory_router)
router.include_router(download_router)
@@ -41,7 +34,9 @@ router.include_router(file_router)
router.include_router(object_router)
router.include_router(share_router)
router.include_router(site_router)
router.include_router(slave_router)
router.include_router(user_router)
router.include_router(vas_router)
router.include_router(webdav_router)
# [TODO] 如果是从机,导入下面的路由
router.include_router(slave_router)

View File

@@ -8,7 +8,7 @@ from loguru import logger as l
from sqlalchemy import func, and_
from sqlmodel import Field
from middleware.auth import AdminRequired
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from models import (
Policy, PolicyOptions, PolicyType, User, ResponseBase,
@@ -156,7 +156,7 @@ admin_vas_router = APIRouter(
path='/summary',
summary='获取站点概况',
description='Get site summary information',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
def router_admin_get_summary() -> ResponseBase:
"""
@@ -165,13 +165,13 @@ def router_admin_get_summary() -> ResponseBase:
Returns:
ResponseBase: 包含站点概况信息的响应模型。
"""
pass
http_exceptions.raise_not_implemented()
@admin_router.get(
path='/news',
summary='获取社区新闻',
description='Get community news',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
def router_admin_get_news() -> ResponseBase:
"""
@@ -180,13 +180,13 @@ def router_admin_get_news() -> ResponseBase:
Returns:
ResponseBase: 包含社区新闻信息的响应模型。
"""
pass
http_exceptions.raise_not_implemented()
@admin_router.patch(
path='/settings',
summary='更新设置',
description='Update settings',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_update_settings(
session: SessionDep,
@@ -225,7 +225,7 @@ async def router_admin_update_settings(
path='/settings',
summary='获取设置',
description='Get settings',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_get_settings(session: SessionDep) -> ResponseBase:
"""
@@ -249,7 +249,7 @@ async def router_admin_get_settings(session: SessionDep) -> ResponseBase:
path='/',
summary='获取用户组列表',
description='Get user group list',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_get_groups(
session: SessionDep,
@@ -314,7 +314,7 @@ async def router_admin_get_groups(
path='/{group_id}',
summary='获取用户组信息',
description='Get user group information by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_get_group(
session: SessionDep,
@@ -366,7 +366,7 @@ async def router_admin_get_group(
path='/list/{group_id}',
summary='获取用户组成员列表',
description='Get user group member list by group ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_get_group_members(
session: SessionDep,
@@ -410,7 +410,7 @@ async def router_admin_get_group_members(
path='/',
summary='创建用户组',
description='Create a new user group',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_create_group(
session: SessionDep,
@@ -469,7 +469,7 @@ async def router_admin_create_group(
path='/{group_id}',
summary='更新用户组信息',
description='Update user group information by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_update_group(
session: SessionDep,
@@ -539,7 +539,7 @@ async def router_admin_update_group(
path='/{group_id}',
summary='删除用户组',
description='Delete user group by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_delete_group(
session: SessionDep,
@@ -576,7 +576,7 @@ async def router_admin_delete_group(
path='/info/{user_id}',
summary='获取用户信息',
description='Get user information by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase:
"""
@@ -596,7 +596,7 @@ async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBa
path='/list',
summary='获取用户列表',
description='Get user list',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_get_users(
session: SessionDep,
@@ -630,7 +630,7 @@ async def router_admin_get_users(
path='/create',
summary='创建用户',
description='Create a new user',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_create_user(
session: SessionDep,
@@ -655,7 +655,7 @@ async def router_admin_create_user(
path='/{user_id}',
summary='更新用户信息',
description='Update user information by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_update_user(
session: SessionDep,
@@ -700,7 +700,7 @@ async def router_admin_update_user(
path='/{user_id}',
summary='删除用户',
description='Delete user by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_delete_user(
session: SessionDep,
@@ -730,7 +730,7 @@ async def router_admin_delete_user(
path='/calibrate/{user_id}',
summary='校准用户存储容量',
description='Calibrate the user storage.',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_calibrate_storage(
session: SessionDep,
@@ -784,7 +784,7 @@ async def router_admin_calibrate_storage(
path='/list',
summary='获取文件列表',
description='Get file list',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_get_file_list(
session: SessionDep,
@@ -858,7 +858,7 @@ async def router_admin_get_file_list(
path='/preview/{file_id}',
summary='预览文件',
description='Preview file by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_preview_file(
session: SessionDep,
@@ -904,13 +904,13 @@ async def router_admin_preview_file(
path='/ban/{file_id}',
summary='封禁/解禁文件',
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_ban_file(
session: SessionDep,
file_id: UUID,
request: FileBanRequest,
admin: Annotated[User, Depends(AdminRequired)],
admin: Annotated[User, Depends(admin_required)],
) -> ResponseBase:
"""
封禁或解禁文件。封禁后用户无法访问该文件。
@@ -949,7 +949,7 @@ async def router_admin_ban_file(
path='/{file_id}',
summary='删除文件',
description='Delete file by ID',
dependencies=[Depends(AdminRequired)],
dependencies=[Depends(admin_required)],
)
async def router_admin_delete_file(
session: SessionDep,
@@ -1002,7 +1002,7 @@ async def router_admin_delete_file(
path='/test',
summary='测试 Aria2 连接',
description='Test Aria2 RPC connection',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_aira2_test(
request: Aria2TestRequest,
@@ -1050,7 +1050,7 @@ async def router_admin_aira2_test(
path='/list',
summary='列出存储策略',
description='List all storage policies',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_list(
session: SessionDep,
@@ -1097,7 +1097,7 @@ async def router_policy_list(
path='/test/path',
summary='测试本地路径可用性',
description='Test local path availability',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_test_path(
request: PolicyTestPathRequest,
@@ -1139,7 +1139,7 @@ async def router_policy_test_path(
path='/test/slave',
summary='测试从机通信',
description='Test slave node communication',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_test_slave(
request: PolicyTestSlaveRequest,
@@ -1173,7 +1173,7 @@ async def router_policy_test_slave(
path='/',
summary='创建存储策略',
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_add_policy(
session: SessionDep,
@@ -1243,7 +1243,7 @@ async def router_policy_add_policy(
path='/cors',
summary='创建跨域策略',
description='Create CORS policy for S3 storage',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_add_cors() -> ResponseBase:
"""
@@ -1259,7 +1259,7 @@ async def router_policy_add_cors() -> ResponseBase:
path='/scf',
summary='创建COS回调函数',
description='Create COS callback function',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_add_scf() -> ResponseBase:
"""
@@ -1275,7 +1275,7 @@ async def router_policy_add_scf() -> ResponseBase:
path='/{policy_id}/oauth',
summary='获取 OneDrive OAuth URL',
description='Get OneDrive OAuth URL',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_onddrive_oauth(
session: SessionDep,
@@ -1300,7 +1300,7 @@ async def router_policy_onddrive_oauth(
path='/{policy_id}',
summary='获取存储策略',
description='Get storage policy by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_get_policy(
session: SessionDep,
@@ -1346,7 +1346,7 @@ async def router_policy_get_policy(
path='/{policy_id}',
summary='删除存储策略',
description='Delete storage policy by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_policy_delete_policy(
session: SessionDep,
@@ -1386,7 +1386,7 @@ async def router_policy_delete_policy(
path='/list',
summary='获取分享列表',
description='Get share list',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_get_share_list(
session: SessionDep,
@@ -1443,7 +1443,7 @@ async def router_admin_get_share_list(
path='/{share_id}',
summary='获取分享详情',
description='Get share detail by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_get_share(
session: SessionDep,
@@ -1489,7 +1489,7 @@ async def router_admin_get_share(
path='/{share_id}',
summary='删除分享',
description='Delete share by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_delete_share(
session: SessionDep,
@@ -1518,7 +1518,7 @@ async def router_admin_delete_share(
path='/list',
summary='获取任务列表',
description='Get task list',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_get_task_list(
session: SessionDep,
@@ -1580,7 +1580,7 @@ async def router_admin_get_task_list(
path='/{task_id}',
summary='获取任务详情',
description='Get task detail by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_get_task(
session: SessionDep,
@@ -1618,7 +1618,7 @@ async def router_admin_get_task(
path='/{task_id}',
summary='删除任务',
description='Delete task by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_delete_task(
session: SessionDep,
@@ -1647,7 +1647,7 @@ async def router_admin_delete_task(
path='/list',
summary='获取增值服务列表',
description='Get VAS list (orders and storage packs)',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas_list(
session: SessionDep,
@@ -1673,7 +1673,7 @@ async def router_admin_get_vas_list(
path='/{vas_id}',
summary='获取增值服务详情',
description='Get VAS detail by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas(
session: SessionDep,
@@ -1694,7 +1694,7 @@ async def router_admin_get_vas(
path='/{vas_id}',
summary='删除增值服务',
description='Delete VAS by ID',
dependencies=[Depends(AdminRequired)]
dependencies=[Depends(admin_required)]
)
async def router_admin_delete_vas(
session: SessionDep,

View File

@@ -1,8 +1,9 @@
from fastapi import APIRouter, Depends, Query
from fastapi.responses import PlainTextResponse, RedirectResponse
from middleware.auth import AuthRequired
from fastapi import APIRouter, Query
from fastapi.responses import PlainTextResponse
from models import ResponseBase
import service.oauth
from utils import http_exceptions
callback_router = APIRouter(
prefix='/callback',
@@ -40,7 +41,7 @@ def router_callback_qq() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the QQ OAuth callback.
"""
pass
http_exceptions.raise_not_implemented()
@oauth_router.get(
path='/github',
@@ -86,7 +87,7 @@ def router_callback_alipay() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the Alipay payment callback.
"""
pass
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/wechat',
@@ -100,7 +101,7 @@ def router_callback_wechat() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
"""
pass
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/stripe',
@@ -114,7 +115,7 @@ def router_callback_stripe() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the Stripe payment callback.
"""
pass
http_exceptions.raise_not_implemented()
@pay_router.get(
path='/easypay',
@@ -128,7 +129,7 @@ def router_callback_easypay() -> PlainTextResponse:
Returns:
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
"""
pass
http_exceptions.raise_not_implemented()
# return PlainTextResponse("success", status_code=200)
@pay_router.get(
@@ -147,7 +148,7 @@ def router_callback_custom(order_no: str, id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the custom payment callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/remote/{session_id}/{key}',
@@ -165,7 +166,7 @@ def router_callback_remote(session_id: str, key: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the remote upload callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/qiniu/{session_id}',
@@ -182,7 +183,7 @@ def router_callback_qiniu(session_id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the Qiniu Cloud upload callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/tencent/{session_id}',
@@ -199,7 +200,7 @@ def router_callback_tencent(session_id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the Tencent Cloud upload callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/aliyun/{session_id}',
@@ -216,7 +217,7 @@ def router_callback_aliyun(session_id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the Aliyun upload callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/upyun/{session_id}',
@@ -233,7 +234,7 @@ def router_callback_upyun(session_id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the Upyun upload callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/aws/{session_id}',
@@ -250,7 +251,7 @@ def router_callback_aws(session_id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the AWS S3 upload callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/onedrive/finish/{session_id}',
@@ -267,7 +268,7 @@ def router_callback_onedrive_finish(session_id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the OneDrive upload completion callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.get(
path='/ondrive/auth',
@@ -281,7 +282,7 @@ def router_callback_onedrive_auth() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the OneDrive authorization callback.
"""
pass
http_exceptions.raise_not_implemented()
@upload_router.get(
path='/google/auth',
@@ -295,4 +296,4 @@ def router_callback_google_auth() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the Google OAuth completion callback.
"""
pass
http_exceptions.raise_not_implemented()

View File

@@ -2,7 +2,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from models import (
DirectoryCreateRequest,
@@ -26,7 +26,7 @@ directory_router = APIRouter(
)
async def router_directory_get(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
path: str
) -> DirectoryResponse:
"""
@@ -94,7 +94,7 @@ async def router_directory_get(
)
async def router_directory_create(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
request: DirectoryCreateRequest
) -> ResponseBase:
"""

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from models import ResponseBase
from utils import http_exceptions
download_router = APIRouter(
prefix="/download",
@@ -18,7 +20,7 @@ download_router.include_router(aria2_router)
path='/url',
summary='创建URL下载任务',
description='Create a URL download task endpoint.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_aria2_url() -> ResponseBase:
"""
@@ -27,13 +29,13 @@ def router_aria2_url() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the URL download task.
"""
pass
http_exceptions.raise_not_implemented()
@aria2_router.post(
path='/torrent/{id}',
summary='创建种子下载任务',
description='Create a torrent download task endpoint.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_aria2_torrent(id: str) -> ResponseBase:
"""
@@ -45,13 +47,13 @@ def router_aria2_torrent(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the torrent download task.
"""
pass
http_exceptions.raise_not_implemented()
@aria2_router.put(
path='/select/{gid}',
summary='重新选择要下载的文件',
description='Re-select files to download endpoint.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_aria2_select(gid: str) -> ResponseBase:
"""
@@ -63,13 +65,13 @@ def router_aria2_select(gid: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the re-selection of files.
"""
pass
http_exceptions.raise_not_implemented()
@aria2_router.delete(
path='/task/{gid}',
summary='取消或删除下载任务',
description='Delete a download task endpoint.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_aria2_delete(gid: str) -> ResponseBase:
"""
@@ -81,13 +83,13 @@ def router_aria2_delete(gid: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the deletion of the download task.
"""
pass
http_exceptions.raise_not_implemented()
@aria2_router.get(
'/downloading',
summary='获取正在下载中的任务',
description='Get currently downloading tasks endpoint.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_aria2_downloading() -> ResponseBase:
"""
@@ -96,13 +98,13 @@ def router_aria2_downloading() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for currently downloading tasks.
"""
pass
http_exceptions.raise_not_implemented()
@aria2_router.get(
path='/finished',
summary='获取已完成的任务',
description='Get finished tasks endpoint.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_aria2_finished() -> ResponseBase:
"""
@@ -111,4 +113,4 @@ def router_aria2_finished() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for finished tasks.
"""
pass
http_exceptions.raise_not_implemented()

View File

@@ -17,7 +17,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse
from loguru import logger as l
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from models import (
CreateFileRequest,
@@ -35,6 +35,7 @@ from models import (
)
from service.storage import LocalStorageService
from utils.JWT import SECRET_KEY
from utils import http_exceptions
# ==================== 下载令牌管理 ====================
@@ -88,7 +89,7 @@ _upload_router = APIRouter(prefix="/upload")
)
async def create_upload_session(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
request: CreateUploadSessionRequest,
) -> UploadSessionResponse:
"""
@@ -187,7 +188,7 @@ async def create_upload_session(
)
async def upload_chunk(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
session_id: UUID,
chunk_index: int,
file: UploadFile = File(...),
@@ -291,7 +292,7 @@ async def upload_chunk(
)
async def delete_upload_session(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
session_id: UUID,
) -> ResponseBase:
"""删除上传会话端点"""
@@ -320,7 +321,7 @@ async def delete_upload_session(
)
async def clear_upload_sessions(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
) -> ResponseBase:
"""清除所有上传会话端点"""
# 获取所有会话
@@ -368,7 +369,7 @@ _download_router = APIRouter(prefix="/download")
)
async def create_download_token(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
file_id: UUID,
) -> ResponseBase:
"""
@@ -456,7 +457,7 @@ router.include_router(_download_router)
)
async def create_empty_file(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
request: CreateFileRequest,
) -> ResponseBase:
"""创建空白文件端点"""
@@ -564,7 +565,7 @@ async def file_source_redirect(id: str, name: str) -> ResponseBase:
path='/update/{id}',
summary='更新文件',
description='更新文件内容。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_update(id: str) -> ResponseBase:
"""更新文件内容"""
@@ -575,7 +576,7 @@ async def file_update(id: str) -> ResponseBase:
path='/preview/{id}',
summary='预览文件',
description='获取文件预览。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_preview(id: str) -> ResponseBase:
"""预览文件"""
@@ -586,7 +587,7 @@ async def file_preview(id: str) -> ResponseBase:
path='/content/{id}',
summary='获取文本文件内容',
description='获取文本文件内容。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_content(id: str) -> ResponseBase:
"""获取文本文件内容"""
@@ -597,7 +598,7 @@ async def file_content(id: str) -> ResponseBase:
path='/doc/{id}',
summary='获取Office文档预览地址',
description='获取Office文档在线预览地址。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_doc(id: str) -> ResponseBase:
"""获取Office文档预览地址"""
@@ -608,7 +609,7 @@ async def file_doc(id: str) -> ResponseBase:
path='/thumb/{id}',
summary='获取文件缩略图',
description='获取文件缩略图。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_thumb(id: str) -> ResponseBase:
"""获取文件缩略图"""
@@ -619,7 +620,7 @@ async def file_thumb(id: str) -> ResponseBase:
path='/source/{id}',
summary='取得文件外链',
description='获取文件的外链地址。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_source(id: str) -> ResponseBase:
"""获取文件外链"""
@@ -630,7 +631,7 @@ async def file_source(id: str) -> ResponseBase:
path='/archive',
summary='打包要下载的文件',
description='将多个文件打包下载。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_archive() -> ResponseBase:
"""打包文件"""
@@ -641,7 +642,7 @@ async def file_archive() -> ResponseBase:
path='/compress',
summary='创建文件压缩任务',
description='创建文件压缩任务。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_compress() -> ResponseBase:
"""创建压缩任务"""
@@ -652,7 +653,7 @@ async def file_compress() -> ResponseBase:
path='/decompress',
summary='创建文件解压任务',
description='创建文件解压任务。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_decompress() -> ResponseBase:
"""创建解压任务"""
@@ -663,7 +664,7 @@ async def file_decompress() -> ResponseBase:
path='/relocate',
summary='创建文件转移任务',
description='创建文件转移任务。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_relocate() -> ResponseBase:
"""创建转移任务"""
@@ -674,7 +675,7 @@ async def file_relocate() -> ResponseBase:
path='/search/{type}/{keyword}',
summary='搜索文件',
description='按关键字搜索文件。',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
async def file_search(type: str, keyword: str) -> ResponseBase:
"""搜索文件"""

View File

@@ -12,7 +12,7 @@ from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from models import (
Object,
@@ -171,7 +171,7 @@ async def _copy_object_recursive(
)
async def router_object_delete(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
request: ObjectDeleteRequest,
) -> ResponseBase:
"""
@@ -224,7 +224,7 @@ async def router_object_delete(
)
async def router_object_move(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
request: ObjectMoveRequest,
) -> ResponseBase:
"""
@@ -302,7 +302,7 @@ async def router_object_move(
)
async def router_object_copy(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
request: ObjectCopyRequest,
) -> ResponseBase:
"""
@@ -394,7 +394,7 @@ async def router_object_copy(
)
async def router_object_rename(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
request: ObjectRenameRequest,
) -> ResponseBase:
"""
@@ -465,7 +465,7 @@ async def router_object_rename(
)
async def router_object_property(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> ObjectPropertyResponse:
"""
@@ -501,7 +501,7 @@ async def router_object_property(
)
async def router_object_property_detail(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> ObjectPropertyDetailResponse:
"""

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from models import ResponseBase
from utils import http_exceptions
share_router = APIRouter(
prefix='/share',
@@ -23,7 +25,7 @@ def router_share_get(info: str, id: str) -> ResponseBase:
Returns:
dict: A dictionary containing shared content information.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.put(
path='/download/{id}',
@@ -40,7 +42,7 @@ def router_share_download(id: str) -> ResponseBase:
Returns:
dict: A dictionary containing download session information.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='preview/{id}',
@@ -57,7 +59,7 @@ def router_share_preview(id: str) -> ResponseBase:
Returns:
dict: A dictionary containing preview information.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/doc/{id}',
@@ -74,7 +76,7 @@ def router_share_doc(id: str) -> ResponseBase:
Returns:
dict: A dictionary containing the document preview URL.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/content/{id}',
@@ -91,7 +93,7 @@ def router_share_content(id: str) -> ResponseBase:
Returns:
str: The content of the text file.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/list/{id}/{path:path}',
@@ -109,7 +111,7 @@ def router_share_list(id: str, path: str = '') -> ResponseBase:
Returns:
dict: A dictionary containing directory listing information.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/search/{id}/{type}/{keywords}',
@@ -128,7 +130,7 @@ def router_share_search(id: str, type: str, keywords: str) -> ResponseBase:
Returns:
dict: A dictionary containing search results.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.post(
path='/archive/{id}',
@@ -145,7 +147,7 @@ def router_share_archive(id: str) -> ResponseBase:
Returns:
dict: A dictionary containing archive download information.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/readme/{id}',
@@ -162,7 +164,7 @@ def router_share_readme(id: str) -> ResponseBase:
Returns:
str: The content of the README file.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/thumb/{id}/{file}',
@@ -180,7 +182,7 @@ def router_share_thumb(id: str, file: str) -> ResponseBase:
Returns:
str: A Base64 encoded string of the thumbnail image.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.post(
path='/report/{id}',
@@ -197,7 +199,7 @@ def router_share_report(id: str) -> ResponseBase:
Returns:
dict: A dictionary containing report submission information.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/search',
@@ -215,7 +217,7 @@ def router_share_search_public(keywords: str, type: str = 'all') -> ResponseBase
Returns:
dict: A dictionary containing search results for public shares.
"""
pass
http_exceptions.raise_not_implemented()
#####################
# 需要登录的接口
@@ -225,7 +227,7 @@ def router_share_search_public(keywords: str, type: str = 'all') -> ResponseBase
path='/',
summary='创建新分享',
description='Create a new share endpoint.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_share_create() -> ResponseBase:
"""
@@ -234,13 +236,13 @@ def router_share_create() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the new share creation.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.get(
path='/',
summary='列出我的分享',
description='Get a list of shares.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_share_list() -> ResponseBase:
"""
@@ -249,13 +251,13 @@ def router_share_list() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the list of shares.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.post(
path='/save/{id}',
summary='转存他人分享',
description='Save another user\'s share by ID.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_share_save(id: str) -> ResponseBase:
"""
@@ -267,13 +269,13 @@ def router_share_save(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the saved share.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.patch(
path='/{id}',
summary='更新分享信息',
description='Update share information by ID.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_share_update(id: str) -> ResponseBase:
"""
@@ -285,13 +287,13 @@ def router_share_update(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the updated share.
"""
pass
http_exceptions.raise_not_implemented()
@share_router.delete(
path='/{id}',
summary='删除分享',
description='Delete a share by ID.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_share_delete(id: str) -> ResponseBase:
"""
@@ -303,4 +305,4 @@ def router_share_delete(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the deleted share.
"""
pass
http_exceptions.raise_not_implemented()

View File

@@ -1,49 +1,29 @@
from fastapi import APIRouter
from sqlalchemy import and_
import json
from middleware.dependencies import SessionDep
from models import ResponseBase
from models.setting import Setting
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
from utils import http_exceptions
site_router = APIRouter(
prefix="/site",
tags=["site"],
)
async def _get_setting(session: SessionDep, type_: str, name: str) -> str | None:
"""获取设置值"""
setting = await Setting.get(session, and_(Setting.type == type_, Setting.name == name))
return setting.value if setting else None
async def _get_setting_bool(session: SessionDep, type_: str, name: str) -> bool:
"""获取布尔类型设置值"""
value = await _get_setting(session, type_, name)
return value == "1" if value else False
async def _get_setting_json(session: SessionDep, type_: str, name: str) -> dict | list | None:
"""获取 JSON 类型设置值"""
value = await _get_setting(session, type_, name)
return json.loads(value) if value else None
@site_router.get(
path="/ping",
summary="测试用路由",
description="A simple endpoint to check if the site is up and running.",
response_model=ResponseBase,
)
def router_site_ping():
def router_site_ping() -> ResponseBase:
"""
Ping the site to check if it is up and running.
Returns:
str: A message indicating the site is running.
"""
from utils.conf.appmeta import BackendVersion
return ResponseBase(data=BackendVersion)
return ResponseBase()
@site_router.get(
@@ -59,7 +39,7 @@ def router_site_captcha():
Returns:
str: A Base64 encoded string of the captcha image.
"""
pass
http_exceptions.raise_not_implemented()
@site_router.get(
@@ -68,38 +48,13 @@ def router_site_captcha():
description='Get the configuration file.',
response_model=ResponseBase,
)
async def router_site_config(session: SessionDep):
async def router_site_config(session: SessionDep) -> SiteConfigResponse:
"""
Get the configuration file.
Returns:
dict: The site configuration.
"""
return ResponseBase(
data={
"title": await _get_setting(session, "basic", "siteName"),
"loginCaptcha": await _get_setting_bool(session, "login", "login_captcha"),
"regCaptcha": await _get_setting_bool(session, "login", "reg_captcha"),
"forgetCaptcha": await _get_setting_bool(session, "login", "forget_captcha"),
"emailActive": await _get_setting_bool(session, "login", "email_active"),
"QQLogin": None,
"themes": await _get_setting_json(session, "basic", "themes"),
"defaultTheme": await _get_setting(session, "basic", "defaultTheme"),
"score_enabled": None,
"share_score_rate": None,
"home_view_method": await _get_setting(session, "view", "home_view_method"),
"share_view_method": await _get_setting(session, "view", "share_view_method"),
"authn": await _get_setting_bool(session, "authn", "authn_enabled"),
"user": {},
"captcha_type": None,
"captcha_ReCaptchaKey": await _get_setting(session, "captcha", "captcha_ReCaptchaKey"),
"captcha_CloudflareKey": await _get_setting(session, "captcha", "captcha_CloudflareKey"),
"captcha_tcaptcha_appid": None,
"site_notice": None,
"registerEnabled": await _get_setting_bool(session, "register", "register_enabled"),
"app_promotion": None,
"wopi_exts": None,
"app_feedback": None,
"app_forum": None,
}
return SiteConfigResponse(
title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")),
)

View File

@@ -1,7 +1,9 @@
from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from models import ResponseBase
from utils import http_exceptions
slave_router = APIRouter(
prefix="/slave",
@@ -32,7 +34,7 @@ def router_slave_ping() -> ResponseBase:
path='/post',
summary='上传',
description='Upload data to the server.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_post(data: str) -> ResponseBase:
"""
@@ -44,7 +46,7 @@ def router_slave_post(data: str) -> ResponseBase:
Returns:
ResponseBase: A response model indicating success.
"""
pass
http_exceptions.raise_not_implemented()
@slave_router.get(
path='/get/{speed}/{path}/{name}',
@@ -62,13 +64,13 @@ def router_slave_download(speed: int, path: str, name: str) -> ResponseBase:
Returns:
ResponseBase: A response model containing download information.
"""
pass
http_exceptions.raise_not_implemented()
@slave_router.get(
path='/download/{sign}',
summary='根据签名下载文件',
description='Download a file based on its signature.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_download_by_sign(sign: str) -> FileResponse:
"""
@@ -80,13 +82,13 @@ def router_slave_download_by_sign(sign: str) -> FileResponse:
Returns:
FileResponse: A response containing the file to be downloaded.
"""
pass
http_exceptions.raise_not_implemented()
@slave_router.get(
path='/source/{speed}/{path}/{name}',
summary='获取文件外链',
description='Get the external link for a file based on its signature.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_source(speed: int, path: str, name: str) -> ResponseBase:
"""
@@ -100,13 +102,13 @@ def router_slave_source(speed: int, path: str, name: str) -> ResponseBase:
Returns:
ResponseBase: A response model containing the external link for the file.
"""
pass
http_exceptions.raise_not_implemented()
@slave_router.get(
path='/source/{sign}',
summary='根据签名获取文件',
description='Get a file based on its signature.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_source_by_sign(sign: str) -> FileResponse:
"""
@@ -118,13 +120,13 @@ def router_slave_source_by_sign(sign: str) -> FileResponse:
Returns:
FileResponse: A response containing the file to be retrieved.
"""
pass
http_exceptions.raise_not_implemented()
@slave_router.get(
path='/thumb/{id}',
summary='获取缩略图',
description='Get a thumbnail image based on its ID.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_thumb(id: str) -> ResponseBase:
"""
@@ -136,13 +138,13 @@ def router_slave_thumb(id: str) -> ResponseBase:
Returns:
ResponseBase: A response model containing the Base64 encoded thumbnail image.
"""
pass
http_exceptions.raise_not_implemented()
@slave_router.delete(
path='/delete',
summary='删除文件',
description='Delete a file from the server.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_delete(path: str) -> ResponseBase:
"""
@@ -154,25 +156,25 @@ def router_slave_delete(path: str) -> ResponseBase:
Returns:
ResponseBase: A response model indicating success or failure of the deletion.
"""
pass
http_exceptions.raise_not_implemented()
@slave_aria2_router.post(
path='/test',
summary='测试从机连接Aria2服务',
description='Test the connection to the Aria2 service from the slave.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_aria2_test() -> ResponseBase:
"""
Test the connection to the Aria2 service from the slave.
"""
pass
http_exceptions.raise_not_implemented()
@slave_aria2_router.get(
path='/get/{gid}',
summary='获取Aria2任务信息',
description='Get information about an Aria2 task by its GID.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_aria2_get(gid: str = None) -> ResponseBase:
"""
@@ -184,13 +186,13 @@ def router_slave_aria2_get(gid: str = None) -> ResponseBase:
Returns:
ResponseBase: A response model containing the task information.
"""
pass
http_exceptions.raise_not_implemented()
@slave_aria2_router.post(
path='/add',
summary='添加Aria2任务',
description='Add a new Aria2 task.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> ResponseBase:
"""
@@ -204,13 +206,13 @@ def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> Response
Returns:
ResponseBase: A response model indicating success or failure of the task addition.
"""
pass
http_exceptions.raise_not_implemented()
@slave_aria2_router.delete(
path='/remove/{gid}',
summary='删除Aria2任务',
description='Remove an Aria2 task by its GID.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_slave_aria2_remove(gid: str) -> ResponseBase:
"""
@@ -222,4 +224,4 @@ def router_slave_aria2_remove(gid: str) -> ResponseBase:
Returns:
ResponseBase: A response model indicating success or failure of the task removal.
"""
pass
http_exceptions.raise_not_implemented()

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from models import ResponseBase
from utils import http_exceptions
tag_router = APIRouter(
prefix='/tag',
@@ -11,7 +13,7 @@ tag_router = APIRouter(
path='/filter',
summary='创建文件分类标签',
description='Create a file classification tag.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_tag_create_filter() -> ResponseBase:
"""
@@ -20,13 +22,13 @@ def router_tag_create_filter() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the created tag.
"""
pass
http_exceptions.raise_not_implemented()
@tag_router.post(
path='/link',
summary='创建目录快捷方式标签',
description='Create a directory shortcut tag.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_tag_create_link() -> ResponseBase:
"""
@@ -35,13 +37,13 @@ def router_tag_create_link() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the created tag.
"""
pass
http_exceptions.raise_not_implemented()
@tag_router.delete(
path='/{id}',
summary='删除标签',
description='Delete a tag by its ID.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_tag_delete(id: str) -> ResponseBase:
"""
@@ -53,4 +55,4 @@ def router_tag_delete(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the deletion operation.
"""
pass
http_exceptions.raise_not_implemented()

View File

@@ -7,13 +7,14 @@ from sqlalchemy import and_
from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from loguru import logger
import models
import service
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from utils.JWT.JWT import SECRET_KEY
from utils import Password
from utils import Password, http_exceptions
user_router = APIRouter(
prefix="/user",
@@ -23,7 +24,7 @@ user_router = APIRouter(
user_settings_router = APIRouter(
prefix='/user/settings',
tags=["user", "user_settings"],
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
@user_router.post(
@@ -42,11 +43,6 @@ async def router_user_session(
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
:raises HTTPException 401: 用户名或密码错误
:raises HTTPException 403: 用户账号被封禁或未完成注册
:raises HTTPException 428: 需要两步验证但未提供验证码
:raises HTTPException 400: 两步验证码无效
"""
username = form_data.username
password = form_data.password
@@ -62,7 +58,7 @@ async def router_user_session(
otp_code = scope
break
result = await service.user.Login(
result = await service.user.login(
session,
models.LoginRequest(
username=username,
@@ -71,22 +67,7 @@ async def router_user_session(
),
)
if isinstance(result, models.TokenResponse):
return result
elif result is None:
raise HTTPException(status_code=401, detail="Invalid username or password")
elif result is False:
raise HTTPException(status_code=403, detail="User account is banned or not fully registered")
elif result == "2fa_required":
raise HTTPException(
status_code=428,
detail="Two-factor authentication required",
headers={"X-2FA-Required": "true"},
)
elif result == "2fa_invalid":
raise HTTPException(status_code=400, detail="Invalid two-factor authentication code")
else:
raise HTTPException(status_code=500, detail="Internal server error during login")
return result
@user_router.post(
path='/session/refresh',
@@ -97,7 +78,7 @@ async def router_user_session_refresh(
session: SessionDep,
request, # RefreshTokenRequest
) -> models.TokenResponse:
...
http_exceptions.raise_not_implemented()
@user_router.post(
path='/',
@@ -137,12 +118,14 @@ async def router_user_register(
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group")
)
if default_group_setting is None or not default_group_setting.value:
raise HTTPException(status_code=500, detail="默认用户组设置不存在")
logger.error("默认用户组不存在")
http_exceptions.raise_internal_error()
default_group_id = UUID(default_group_setting.value)
default_group = await models.Group.get(session, models.Group.id == default_group_id)
if not default_group:
raise HTTPException(status_code=500, detail="默认用户组不存在")
logger.error("默认用户组不存在")
http_exceptions.raise_internal_error()
# 3. 创建用户
hashed_password = Password.hash(request.password)
@@ -158,7 +141,8 @@ async def router_user_register(
# 4. 创建以用户名命名的根目录
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
if not default_policy:
raise HTTPException(status_code=500, detail="默认存储策略不存在")
logger.error("默认存储策略不存在")
http_exceptions.raise_internal_error()
await models.Object(
name=new_user_username,
@@ -190,7 +174,7 @@ def router_user_email_code(
Returns:
dict: A dictionary containing information about the password reset email.
"""
pass
http_exceptions.raise_not_implemented()
@user_router.get(
path='/qq',
@@ -204,7 +188,7 @@ def router_user_qq() -> models.ResponseBase:
Returns:
dict: A dictionary containing QQ login initialization information.
"""
pass
http_exceptions.raise_not_implemented()
@user_router.get(
path='authn/{username}',
@@ -213,7 +197,7 @@ def router_user_qq() -> models.ResponseBase:
)
async def router_user_authn(username: str) -> models.ResponseBase:
pass
http_exceptions.raise_not_implemented()
@user_router.post(
path='authn/finish/{username}',
@@ -230,7 +214,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase:
Returns:
dict: A dictionary containing WebAuthn login information.
"""
pass
http_exceptions.raise_not_implemented()
@user_router.get(
path='/profile/{id}',
@@ -247,7 +231,7 @@ def router_user_profile(id: str) -> models.ResponseBase:
Returns:
dict: A dictionary containing user profile information.
"""
pass
http_exceptions.raise_not_implemented()
@user_router.get(
path='/avatar/{id}/{size}',
@@ -265,7 +249,7 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
Returns:
str: A Base64 encoded string of the user avatar image.
"""
pass
http_exceptions.raise_not_implemented()
#####################
# 需要登录的接口
@@ -275,12 +259,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
path='/me',
summary='获取用户信息',
description='Get user information.',
dependencies=[Depends(dependency=AuthRequired)],
dependencies=[Depends(dependency=auth_required)],
response_model=models.ResponseBase,
)
async def router_user_me(
session: SessionDep,
user: Annotated[models.User, Depends(AuthRequired)],
user: Annotated[models.User, Depends(auth_required)],
) -> models.ResponseBase:
"""
获取用户信息.
@@ -319,11 +303,11 @@ async def router_user_me(
path='/storage',
summary='存储信息',
description='Get user storage information.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
async def router_user_storage(
session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)],
user: Annotated[models.user.User, Depends(auth_required)],
) -> models.ResponseBase:
"""
获取用户存储空间信息。
@@ -353,11 +337,11 @@ async def router_user_storage(
path='/authn/start',
summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
async def router_user_authn_start(
session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)],
user: Annotated[models.user.User, Depends(auth_required)],
) -> models.ResponseBase:
"""
Initialize WebAuthn login for a user.
@@ -395,7 +379,7 @@ async def router_user_authn_start(
path='/authn/finish',
summary='WebAuthn登录',
description='Finish WebAuthn login for a user.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_user_authn_finish() -> models.ResponseBase:
"""
@@ -404,7 +388,7 @@ def router_user_authn_finish() -> models.ResponseBase:
Returns:
dict: A dictionary containing WebAuthn login information.
"""
pass
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/policies',
@@ -418,13 +402,13 @@ def router_user_settings_policies() -> models.ResponseBase:
Returns:
dict: A dictionary containing available storage policies for the user.
"""
pass
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/nodes',
summary='获取用户可选节点',
description='Get user selectable nodes.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_user_settings_nodes() -> models.ResponseBase:
"""
@@ -433,13 +417,13 @@ def router_user_settings_nodes() -> models.ResponseBase:
Returns:
dict: A dictionary containing available nodes for the user.
"""
pass
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/tasks',
summary='任务队列',
description='Get user task queue.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_user_settings_tasks() -> models.ResponseBase:
"""
@@ -448,13 +432,13 @@ def router_user_settings_tasks() -> models.ResponseBase:
Returns:
dict: A dictionary containing the user's task queue information.
"""
pass
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/',
summary='获取当前用户设定',
description='Get current user settings.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_user_settings() -> models.ResponseBase:
"""
@@ -469,7 +453,7 @@ def router_user_settings() -> models.ResponseBase:
path='/avatar',
summary='从文件上传头像',
description='Upload user avatar from file.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar() -> models.ResponseBase:
"""
@@ -478,13 +462,13 @@ def router_user_settings_avatar() -> models.ResponseBase:
Returns:
dict: A dictionary containing the result of the avatar upload.
"""
pass
http_exceptions.raise_not_implemented()
@user_settings_router.put(
path='/avatar',
summary='设定为Gravatar头像',
description='Set user avatar to Gravatar.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar_gravatar() -> models.ResponseBase:
"""
@@ -493,13 +477,13 @@ def router_user_settings_avatar_gravatar() -> models.ResponseBase:
Returns:
dict: A dictionary containing the result of setting the Gravatar avatar.
"""
pass
http_exceptions.raise_not_implemented()
@user_settings_router.patch(
path='/{option}',
summary='更新用户设定',
description='Update user settings.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_user_settings_patch(option: str) -> models.ResponseBase:
"""
@@ -511,16 +495,16 @@ def router_user_settings_patch(option: str) -> models.ResponseBase:
Returns:
dict: A dictionary containing the result of the settings update.
"""
pass
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/2fa',
summary='获取两步验证初始化信息',
description='Get two-factor authentication initialization information.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa(
user: Annotated[models.user.User, Depends(AuthRequired)],
user: Annotated[models.user.User, Depends(auth_required)],
) -> models.ResponseBase:
"""
Get two-factor authentication initialization information.
@@ -537,11 +521,11 @@ async def router_user_settings_2fa(
path='/2fa',
summary='启用两步验证',
description='Enable two-factor authentication.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa_enable(
session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)],
user: Annotated[models.user.User, Depends(auth_required)],
setup_token: str,
code: str,
) -> models.ResponseBase:

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends
from middleware.auth import AuthRequired
from middleware.auth import auth_required
from models import ResponseBase
from utils import http_exceptions
vas_router = APIRouter(
prefix="/vas",
@@ -11,7 +13,7 @@ vas_router = APIRouter(
path='/pack',
summary='获取容量包及配额信息',
description='Get information about storage packs and quotas.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_vas_pack() -> ResponseBase:
"""
@@ -20,13 +22,13 @@ def router_vas_pack() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for storage packs and quotas.
"""
pass
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/product',
summary='获取商品信息,同时返回支付信息',
description='Get product information along with payment details.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_vas_product() -> ResponseBase:
"""
@@ -35,13 +37,13 @@ def router_vas_product() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for products and payment information.
"""
pass
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/order',
summary='新建支付订单',
description='Create an order for a product.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_vas_order() -> ResponseBase:
"""
@@ -50,13 +52,13 @@ def router_vas_order() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the created order.
"""
pass
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/order/{id}',
summary='查询订单状态',
description='Get information about a specific payment order by ID.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_vas_order_get(id: str) -> ResponseBase:
"""
@@ -68,13 +70,13 @@ def router_vas_order_get(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the specified order.
"""
pass
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/redeem',
summary='获取兑换码信息',
description='Get information about a specific redemption code.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_vas_redeem(code: str) -> ResponseBase:
"""
@@ -86,13 +88,13 @@ def router_vas_redeem(code: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the specified redemption code.
"""
pass
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/redeem',
summary='执行兑换',
description='Redeem a redemption code for a product or service.',
dependencies=[Depends(AuthRequired)]
dependencies=[Depends(auth_required)]
)
def router_vas_redeem_post() -> ResponseBase:
"""
@@ -101,4 +103,4 @@ def router_vas_redeem_post() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the redeemed code.
"""
pass
http_exceptions.raise_not_implemented()

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends, Request
from middleware.auth import AuthRequired
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from models import ResponseBase
from utils import http_exceptions
# WebDAV 管理路由
webdav_router = APIRouter(
@@ -12,7 +14,7 @@ webdav_router = APIRouter(
path='/accounts',
summary='获取账号信息',
description='Get account information for WebDAV.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_webdav_accounts() -> ResponseBase:
"""
@@ -21,13 +23,13 @@ def router_webdav_accounts() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the account information.
"""
pass
http_exceptions.raise_not_implemented()
@webdav_router.post(
path='/accounts',
summary='新建账号',
description='Create a new WebDAV account.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_webdav_create_account() -> ResponseBase:
"""
@@ -36,13 +38,13 @@ def router_webdav_create_account() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the created account.
"""
pass
http_exceptions.raise_not_implemented()
@webdav_router.delete(
path='/accounts/{id}',
summary='删除账号',
description='Delete a WebDAV account by its ID.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_webdav_delete_account(id: str) -> ResponseBase:
"""
@@ -54,13 +56,13 @@ def router_webdav_delete_account(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the deletion operation.
"""
pass
http_exceptions.raise_not_implemented()
@webdav_router.post(
path='/mount',
summary='新建目录挂载',
description='Create a new WebDAV mount point.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_webdav_create_mount() -> ResponseBase:
"""
@@ -69,13 +71,13 @@ def router_webdav_create_mount() -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the created mount point.
"""
pass
http_exceptions.raise_not_implemented()
@webdav_router.delete(
path='/mount/{id}',
summary='删除目录挂载',
description='Delete a WebDAV mount point by its ID.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_webdav_delete_mount(id: str) -> ResponseBase:
"""
@@ -87,13 +89,13 @@ def router_webdav_delete_mount(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the deletion operation.
"""
pass
http_exceptions.raise_not_implemented()
@webdav_router.patch(
path='accounts/{id}',
summary='更新账号信息',
description='Update WebDAV account information by ID.',
dependencies=[Depends(AuthRequired)],
dependencies=[Depends(auth_required)],
)
def router_webdav_update_account(id: str) -> ResponseBase:
"""
@@ -105,4 +107,4 @@ def router_webdav_update_account(id: str) -> ResponseBase:
Returns:
ResponseBase: A model containing the response data for the updated account.
"""
pass
http_exceptions.raise_not_implemented()

View File

@@ -1 +1 @@
from .login import Login
from .login import login

View File

@@ -1,25 +1,19 @@
from typing import Literal
from loguru import logger as log
from sqlmodel.ext.asyncio.session import AsyncSession
from loguru import logger
from middleware.dependencies import SessionDep
from models import LoginRequest, TokenResponse, User
from utils import http_exceptions
from utils.JWT.JWT import create_access_token, create_refresh_token
from utils.password.pwd import Password, PasswordStatus
async def Login(
session: AsyncSession,
async def login(
session: SessionDep,
login_request: LoginRequest,
) -> TokenResponse | bool | Literal["2fa_required", "2fa_invalid"] | None:
) -> TokenResponse:
"""
根据账号密码进行登录。
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
如果登录异常,返回 `False`(未完成注册或账号被封禁)。
如果登录失败,返回 `None`。
如果需要两步验证但未提供验证码,返回 `"2fa_required"`。
如果两步验证码无效,返回 `"2fa_invalid"`。
:param session: 数据库会话
:param login_request: 登录请求
@@ -38,30 +32,29 @@ async def Login(
# 验证用户是否存在
if not current_user:
log.debug(f"Cannot find user with username: {login_request.username}")
return None
logger.debug(f"Cannot find user with username: {login_request.username}")
http_exceptions.raise_unauthorized("Invalid username or password")
# 验证密码是否正确
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
log.debug(f"Password verification failed for user: {login_request.username}")
return None
logger.debug(f"Password verification failed for user: {login_request.username}")
http_exceptions.raise_unauthorized("Invalid username or password")
# 验证用户是否可登录
if not current_user.status:
# 未完成注册 or 账号已被封禁
return False
http_exceptions.raise_forbidden("Your account is disabled")
# 检查两步验证
if current_user.two_factor:
# 用户已启用两步验证
if not login_request.two_fa_code:
log.debug(f"2FA required for user: {login_request.username}")
return "2fa_required"
logger.debug(f"2FA required for user: {login_request.username}")
http_exceptions.raise_precondition_required("2FA required")
# 验证 OTP 码
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
log.debug(f"Invalid 2FA code for user: {login_request.username}")
return "2fa_invalid"
logger.debug(f"Invalid 2FA code for user: {login_request.username}")
http_exceptions.raise_unauthorized("Invalid 2FA code")
# 创建令牌
access_token, access_expire = create_access_token(data={'sub': current_user.username})

View File

@@ -6,7 +6,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User, LoginRequest, TokenResponse
from models.group import Group
from service.user.login import Login
from service.user.login import login
from utils.password.pwd import Password
@@ -86,7 +86,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
password=user_data["password"]
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert isinstance(result, TokenResponse)
assert result.access_token is not None
@@ -103,7 +103,7 @@ async def test_login_user_not_found(db_session: AsyncSession):
password="any_password"
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert result is None
@@ -116,7 +116,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user):
password="wrong_password"
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert result is None
@@ -129,7 +129,7 @@ async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
password="password"
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert result is False
@@ -145,7 +145,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
# 未提供 two_fa_code
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert result == "2fa_required"
@@ -161,7 +161,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
two_fa_code="000000" # 错误的验证码
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert result == "2fa_invalid"
@@ -184,7 +184,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
two_fa_code=valid_code
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert isinstance(result, TokenResponse)
assert result.access_token is not None
@@ -202,7 +202,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
password=user_data["password"]
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
assert isinstance(result, TokenResponse)
@@ -227,7 +227,7 @@ async def test_login_case_sensitive_username(db_session: AsyncSession, setup_use
password=user_data["password"]
)
result = await Login(db_session, login_request)
result = await login(db_session, login_request)
# 应该失败,因为用户名大小写不匹配
assert result is None

View File

@@ -1 +1,2 @@
from .password.pwd import Password, PasswordStatus
from .http import http_exceptions

View File

@@ -1,20 +1,6 @@
from typing import Any, NoReturn
from fastapi import HTTPException
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_402_PAYMENT_REQUIRED,
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
HTTP_409_CONFLICT,
HTTP_429_TOO_MANY_REQUESTS,
HTTP_500_INTERNAL_SERVER_ERROR,
HTTP_501_NOT_IMPLEMENTED,
HTTP_503_SERVICE_UNAVAILABLE,
HTTP_504_GATEWAY_TIMEOUT,
)
from fastapi import HTTPException, status
# --- 400 ---
@@ -24,50 +10,54 @@ def ensure_request_param(to_check: Any, detail: str) -> None:
This function returns None if the check passes.
"""
if not to_check:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)
def raise_bad_request(detail: str = '') -> NoReturn:
"""Raises an HTTP 400 Bad Request exception."""
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)
def raise_unauthorized(detail: str) -> NoReturn:
"""Raises an HTTP 401 Unauthorized exception."""
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=detail)
def raise_insufficient_quota(detail: str = "积分不足,请充值") -> NoReturn:
"""Raises an HTTP 402 Payment Required exception."""
raise HTTPException(status_code=HTTP_402_PAYMENT_REQUIRED, detail=detail)
raise HTTPException(status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=detail)
def raise_forbidden(detail: str) -> NoReturn:
"""Raises an HTTP 403 Forbidden exception."""
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def raise_not_found(detail: str) -> NoReturn:
"""Raises an HTTP 404 Not Found exception."""
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail)
def raise_conflict(detail: str) -> NoReturn:
"""Raises an HTTP 409 Conflict exception."""
raise HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=detail)
def raise_precondition_required(detail: str) -> NoReturn:
"""Raises an HTTP 428 Precondition required exception."""
raise HTTPException(status_code=status.HTTP_428_PRECONDITION_REQUIRED, detail=detail)
def raise_too_many_requests(detail: str) -> NoReturn:
"""Raises an HTTP 429 Too Many Requests exception."""
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=detail)
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail)
# --- 500 ---
def raise_internal_error(detail: str = "服务器出现故障,请稍后再试或联系管理员") -> NoReturn:
"""Raises an HTTP 500 Internal Server Error exception."""
raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
"""Raises an HTTP 501 Not Implemented exception."""
raise HTTPException(status_code=HTTP_501_NOT_IMPLEMENTED, detail=detail)
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail)
def raise_service_unavailable(detail: str) -> NoReturn:
"""Raises an HTTP 503 Service Unavailable exception."""
raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
def raise_gateway_timeout(detail: str) -> NoReturn:
"""Raises an HTTP 504 Gateway Timeout exception."""
raise HTTPException(status_code=HTTP_504_GATEWAY_TIMEOUT, detail=detail)
raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=detail)

View File

@@ -1,4 +1,5 @@
import secrets
from loguru import logger
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
@@ -104,10 +105,11 @@ class Password:
@staticmethod
async def generate_totp(
username: str
*args, **kwargs
) -> TwoFactorResponse:
"""
生成 TOTP 密钥和对应的 URI用于两步验证。
所有的参数将会给到 `pyotp.totp.TOTP`
:return: 包含 TOTP 密钥和 URI 的元组
"""
@@ -121,8 +123,7 @@ class Password:
salt="2fa-setup-salt"
)
otp_uri = pyotp.totp.TOTP(secret).provisioning_uri(
name=username,
otp_uri = pyotp.totp.TOTP(secret, *args, **kwargs).provisioning_uri(
issuer_name=appmeta.APP_NAME
)
@@ -134,17 +135,21 @@ class Password:
@staticmethod
def verify_totp(
secret: str,
code: str
code: int,
*args, **kwargs
) -> PasswordStatus:
"""
验证 TOTP 验证码。
:param secret: TOTP 密钥Base32 编码)
:param code: 用户输入的 6 位验证码
:param args: 传入 `totp.verify` 的参数
:param kwargs: 传入 `totp.verify` 的参数
:return: 验证是否成功
"""
totp = pyotp.TOTP(secret)
if totp.verify(code):
if totp.verify(otp=str(code), *args, **kwargs):
return PasswordStatus.VALID
else:
return PasswordStatus.INVALID