Refactor auth and unify error handling in routers
Renamed AuthRequired/AdminRequired to auth_required/admin_required and updated all references. Replaced direct HTTPException usage with utils.http_exceptions for consistent error handling. Updated router endpoints to use new auth dependency and standardized not implemented responses. Cleaned up unused theme fields in SiteConfigResponse and improved site config endpoint. Minor type and import cleanups across routers and middleware.
This commit is contained in:
25
.run/开发模式.run.xml
Normal file
25
.run/开发模式.run.xml
Normal file
@@ -0,0 +1,25 @@
|
||||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="开发模式" type="Python.FastAPI">
|
||||
<option name="appName" value="app" />
|
||||
<option name="file" value="C:\Users\Administrator\Documents\Code\Server\main.py" />
|
||||
<module name="Server" />
|
||||
<option name="ENV_FILES" value="" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="DEBUG" value="true" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="$PROJECT_DIR$/.venv/Scripts/python.exe" />
|
||||
<option name="SDK_NAME" value="uv (Server)" />
|
||||
<option name="WORKING_DIRECTORY" value="" />
|
||||
<option name="IS_MODULE_SDK" value="false" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="RUN_TOOL" value="" />
|
||||
<option name="launchJavascriptDebuger" value="false" />
|
||||
<method v="2">
|
||||
<option name="LaunchBrowser.Before.Run" url="http://127.0.0.1:8000/docs" />
|
||||
</method>
|
||||
</configuration>
|
||||
</component>
|
||||
@@ -1,20 +1,14 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from jwt import InvalidTokenError
|
||||
from fastapi import Depends
|
||||
import jwt
|
||||
|
||||
from models.user import User
|
||||
from utils.JWT import JWT
|
||||
from .dependencies import SessionDep
|
||||
from utils import http_exceptions
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def AuthRequired(
|
||||
async def auth_required(
|
||||
session: SessionDep,
|
||||
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
||||
) -> User:
|
||||
@@ -26,28 +20,28 @@ async def AuthRequired(
|
||||
username = payload.get("sub")
|
||||
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
# 从数据库获取用户信息
|
||||
user = await User.get(session, User.username == username)
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
return user
|
||||
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
except jwt.InvalidTokenError:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
async def AdminRequired(
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
async def admin_required(
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> User:
|
||||
"""
|
||||
验证是否为管理员。
|
||||
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(AdminRequired)])
|
||||
>>> APIRouter(dependencies=[Depends(admin_required)])
|
||||
"""
|
||||
group = await user.awaitable_attrs.group
|
||||
if group.admin:
|
||||
return user
|
||||
raise HTTPException(status_code=403, detail="Admin Required")
|
||||
raise http_exceptions.raise_forbidden("Admin Required")
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated, AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -20,11 +20,11 @@ class SiteConfigResponse(SQLModelBase):
|
||||
title: str = "DiskNext"
|
||||
"""网站标题"""
|
||||
|
||||
themes: dict[str, str] = {}
|
||||
"""网站主题配置"""
|
||||
# themes: dict[str, str] = {}
|
||||
# """网站主题配置"""
|
||||
|
||||
default_theme: dict[str, str] = {}
|
||||
"""默认主题RGB色号"""
|
||||
# default_theme: dict[str, str] = {}
|
||||
# """默认主题RGB色号"""
|
||||
|
||||
site_notice: str | None = None
|
||||
"""网站公告"""
|
||||
|
||||
@@ -24,16 +24,9 @@ from .webdav import webdav_router
|
||||
|
||||
router = APIRouter(prefix="/v1")
|
||||
|
||||
router.include_router(admin_router)
|
||||
router.include_router(admin_aria2_router)
|
||||
router.include_router(admin_file_router)
|
||||
router.include_router(admin_group_router)
|
||||
router.include_router(admin_policy_router)
|
||||
router.include_router(admin_share_router)
|
||||
router.include_router(admin_task_router)
|
||||
router.include_router(admin_user_router)
|
||||
router.include_router(admin_vas_router)
|
||||
# [TODO] 如果是主机,导入下面的路由
|
||||
|
||||
router.include_router(admin_router)
|
||||
router.include_router(callback_router)
|
||||
router.include_router(directory_router)
|
||||
router.include_router(download_router)
|
||||
@@ -41,7 +34,9 @@ router.include_router(file_router)
|
||||
router.include_router(object_router)
|
||||
router.include_router(share_router)
|
||||
router.include_router(site_router)
|
||||
router.include_router(slave_router)
|
||||
router.include_router(user_router)
|
||||
router.include_router(vas_router)
|
||||
router.include_router(webdav_router)
|
||||
|
||||
# [TODO] 如果是从机,导入下面的路由
|
||||
router.include_router(slave_router)
|
||||
@@ -8,7 +8,7 @@ from loguru import logger as l
|
||||
from sqlalchemy import func, and_
|
||||
from sqlmodel import Field
|
||||
|
||||
from middleware.auth import AdminRequired
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
Policy, PolicyOptions, PolicyType, User, ResponseBase,
|
||||
@@ -156,7 +156,7 @@ admin_vas_router = APIRouter(
|
||||
path='/summary',
|
||||
summary='获取站点概况',
|
||||
description='Get site summary information',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
def router_admin_get_summary() -> ResponseBase:
|
||||
"""
|
||||
@@ -165,13 +165,13 @@ def router_admin_get_summary() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: 包含站点概况信息的响应模型。
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@admin_router.get(
|
||||
path='/news',
|
||||
summary='获取社区新闻',
|
||||
description='Get community news',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
def router_admin_get_news() -> ResponseBase:
|
||||
"""
|
||||
@@ -180,13 +180,13 @@ def router_admin_get_news() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: 包含社区新闻信息的响应模型。
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@admin_router.patch(
|
||||
path='/settings',
|
||||
summary='更新设置',
|
||||
description='Update settings',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_update_settings(
|
||||
session: SessionDep,
|
||||
@@ -225,7 +225,7 @@ async def router_admin_update_settings(
|
||||
path='/settings',
|
||||
summary='获取设置',
|
||||
description='Get settings',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_settings(session: SessionDep) -> ResponseBase:
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ async def router_admin_get_settings(session: SessionDep) -> ResponseBase:
|
||||
path='/',
|
||||
summary='获取用户组列表',
|
||||
description='Get user group list',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_groups(
|
||||
session: SessionDep,
|
||||
@@ -314,7 +314,7 @@ async def router_admin_get_groups(
|
||||
path='/{group_id}',
|
||||
summary='获取用户组信息',
|
||||
description='Get user group information by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_group(
|
||||
session: SessionDep,
|
||||
@@ -366,7 +366,7 @@ async def router_admin_get_group(
|
||||
path='/list/{group_id}',
|
||||
summary='获取用户组成员列表',
|
||||
description='Get user group member list by group ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_group_members(
|
||||
session: SessionDep,
|
||||
@@ -410,7 +410,7 @@ async def router_admin_get_group_members(
|
||||
path='/',
|
||||
summary='创建用户组',
|
||||
description='Create a new user group',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_create_group(
|
||||
session: SessionDep,
|
||||
@@ -469,7 +469,7 @@ async def router_admin_create_group(
|
||||
path='/{group_id}',
|
||||
summary='更新用户组信息',
|
||||
description='Update user group information by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_update_group(
|
||||
session: SessionDep,
|
||||
@@ -539,7 +539,7 @@ async def router_admin_update_group(
|
||||
path='/{group_id}',
|
||||
summary='删除用户组',
|
||||
description='Delete user group by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_delete_group(
|
||||
session: SessionDep,
|
||||
@@ -576,7 +576,7 @@ async def router_admin_delete_group(
|
||||
path='/info/{user_id}',
|
||||
summary='获取用户信息',
|
||||
description='Get user information by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase:
|
||||
"""
|
||||
@@ -596,7 +596,7 @@ async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBa
|
||||
path='/list',
|
||||
summary='获取用户列表',
|
||||
description='Get user list',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_users(
|
||||
session: SessionDep,
|
||||
@@ -630,7 +630,7 @@ async def router_admin_get_users(
|
||||
path='/create',
|
||||
summary='创建用户',
|
||||
description='Create a new user',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_create_user(
|
||||
session: SessionDep,
|
||||
@@ -655,7 +655,7 @@ async def router_admin_create_user(
|
||||
path='/{user_id}',
|
||||
summary='更新用户信息',
|
||||
description='Update user information by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_update_user(
|
||||
session: SessionDep,
|
||||
@@ -700,7 +700,7 @@ async def router_admin_update_user(
|
||||
path='/{user_id}',
|
||||
summary='删除用户',
|
||||
description='Delete user by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_delete_user(
|
||||
session: SessionDep,
|
||||
@@ -730,7 +730,7 @@ async def router_admin_delete_user(
|
||||
path='/calibrate/{user_id}',
|
||||
summary='校准用户存储容量',
|
||||
description='Calibrate the user storage.',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_calibrate_storage(
|
||||
session: SessionDep,
|
||||
@@ -784,7 +784,7 @@ async def router_admin_calibrate_storage(
|
||||
path='/list',
|
||||
summary='获取文件列表',
|
||||
description='Get file list',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_file_list(
|
||||
session: SessionDep,
|
||||
@@ -858,7 +858,7 @@ async def router_admin_get_file_list(
|
||||
path='/preview/{file_id}',
|
||||
summary='预览文件',
|
||||
description='Preview file by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_preview_file(
|
||||
session: SessionDep,
|
||||
@@ -904,13 +904,13 @@ async def router_admin_preview_file(
|
||||
path='/ban/{file_id}',
|
||||
summary='封禁/解禁文件',
|
||||
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_ban_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
request: FileBanRequest,
|
||||
admin: Annotated[User, Depends(AdminRequired)],
|
||||
admin: Annotated[User, Depends(admin_required)],
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
封禁或解禁文件。封禁后用户无法访问该文件。
|
||||
@@ -949,7 +949,7 @@ async def router_admin_ban_file(
|
||||
path='/{file_id}',
|
||||
summary='删除文件',
|
||||
description='Delete file by ID',
|
||||
dependencies=[Depends(AdminRequired)],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_delete_file(
|
||||
session: SessionDep,
|
||||
@@ -1002,7 +1002,7 @@ async def router_admin_delete_file(
|
||||
path='/test',
|
||||
summary='测试 Aria2 连接',
|
||||
description='Test Aria2 RPC connection',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_aira2_test(
|
||||
request: Aria2TestRequest,
|
||||
@@ -1050,7 +1050,7 @@ async def router_admin_aira2_test(
|
||||
path='/list',
|
||||
summary='列出存储策略',
|
||||
description='List all storage policies',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_list(
|
||||
session: SessionDep,
|
||||
@@ -1097,7 +1097,7 @@ async def router_policy_list(
|
||||
path='/test/path',
|
||||
summary='测试本地路径可用性',
|
||||
description='Test local path availability',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_test_path(
|
||||
request: PolicyTestPathRequest,
|
||||
@@ -1139,7 +1139,7 @@ async def router_policy_test_path(
|
||||
path='/test/slave',
|
||||
summary='测试从机通信',
|
||||
description='Test slave node communication',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_test_slave(
|
||||
request: PolicyTestSlaveRequest,
|
||||
@@ -1173,7 +1173,7 @@ async def router_policy_test_slave(
|
||||
path='/',
|
||||
summary='创建存储策略',
|
||||
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_add_policy(
|
||||
session: SessionDep,
|
||||
@@ -1243,7 +1243,7 @@ async def router_policy_add_policy(
|
||||
path='/cors',
|
||||
summary='创建跨域策略',
|
||||
description='Create CORS policy for S3 storage',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_add_cors() -> ResponseBase:
|
||||
"""
|
||||
@@ -1259,7 +1259,7 @@ async def router_policy_add_cors() -> ResponseBase:
|
||||
path='/scf',
|
||||
summary='创建COS回调函数',
|
||||
description='Create COS callback function',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_add_scf() -> ResponseBase:
|
||||
"""
|
||||
@@ -1275,7 +1275,7 @@ async def router_policy_add_scf() -> ResponseBase:
|
||||
path='/{policy_id}/oauth',
|
||||
summary='获取 OneDrive OAuth URL',
|
||||
description='Get OneDrive OAuth URL',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_onddrive_oauth(
|
||||
session: SessionDep,
|
||||
@@ -1300,7 +1300,7 @@ async def router_policy_onddrive_oauth(
|
||||
path='/{policy_id}',
|
||||
summary='获取存储策略',
|
||||
description='Get storage policy by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_get_policy(
|
||||
session: SessionDep,
|
||||
@@ -1346,7 +1346,7 @@ async def router_policy_get_policy(
|
||||
path='/{policy_id}',
|
||||
summary='删除存储策略',
|
||||
description='Delete storage policy by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_policy_delete_policy(
|
||||
session: SessionDep,
|
||||
@@ -1386,7 +1386,7 @@ async def router_policy_delete_policy(
|
||||
path='/list',
|
||||
summary='获取分享列表',
|
||||
description='Get share list',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_share_list(
|
||||
session: SessionDep,
|
||||
@@ -1443,7 +1443,7 @@ async def router_admin_get_share_list(
|
||||
path='/{share_id}',
|
||||
summary='获取分享详情',
|
||||
description='Get share detail by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_share(
|
||||
session: SessionDep,
|
||||
@@ -1489,7 +1489,7 @@ async def router_admin_get_share(
|
||||
path='/{share_id}',
|
||||
summary='删除分享',
|
||||
description='Delete share by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_delete_share(
|
||||
session: SessionDep,
|
||||
@@ -1518,7 +1518,7 @@ async def router_admin_delete_share(
|
||||
path='/list',
|
||||
summary='获取任务列表',
|
||||
description='Get task list',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_task_list(
|
||||
session: SessionDep,
|
||||
@@ -1580,7 +1580,7 @@ async def router_admin_get_task_list(
|
||||
path='/{task_id}',
|
||||
summary='获取任务详情',
|
||||
description='Get task detail by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_task(
|
||||
session: SessionDep,
|
||||
@@ -1618,7 +1618,7 @@ async def router_admin_get_task(
|
||||
path='/{task_id}',
|
||||
summary='删除任务',
|
||||
description='Delete task by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_delete_task(
|
||||
session: SessionDep,
|
||||
@@ -1647,7 +1647,7 @@ async def router_admin_delete_task(
|
||||
path='/list',
|
||||
summary='获取增值服务列表',
|
||||
description='Get VAS list (orders and storage packs)',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_vas_list(
|
||||
session: SessionDep,
|
||||
@@ -1673,7 +1673,7 @@ async def router_admin_get_vas_list(
|
||||
path='/{vas_id}',
|
||||
summary='获取增值服务详情',
|
||||
description='Get VAS detail by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_vas(
|
||||
session: SessionDep,
|
||||
@@ -1694,7 +1694,7 @@ async def router_admin_get_vas(
|
||||
path='/{vas_id}',
|
||||
summary='删除增值服务',
|
||||
description='Delete VAS by ID',
|
||||
dependencies=[Depends(AdminRequired)]
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_delete_vas(
|
||||
session: SessionDep,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import PlainTextResponse, RedirectResponse
|
||||
from middleware.auth import AuthRequired
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from models import ResponseBase
|
||||
import service.oauth
|
||||
from utils import http_exceptions
|
||||
|
||||
callback_router = APIRouter(
|
||||
prefix='/callback',
|
||||
@@ -40,7 +41,7 @@ def router_callback_qq() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the QQ OAuth callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@oauth_router.get(
|
||||
path='/github',
|
||||
@@ -86,7 +87,7 @@ def router_callback_alipay() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Alipay payment callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.post(
|
||||
path='/wechat',
|
||||
@@ -100,7 +101,7 @@ def router_callback_wechat() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.post(
|
||||
path='/stripe',
|
||||
@@ -114,7 +115,7 @@ def router_callback_stripe() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Stripe payment callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.get(
|
||||
path='/easypay',
|
||||
@@ -128,7 +129,7 @@ def router_callback_easypay() -> PlainTextResponse:
|
||||
Returns:
|
||||
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
# return PlainTextResponse("success", status_code=200)
|
||||
|
||||
@pay_router.get(
|
||||
@@ -147,7 +148,7 @@ def router_callback_custom(order_no: str, id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the custom payment callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
path='/remote/{session_id}/{key}',
|
||||
@@ -165,7 +166,7 @@ def router_callback_remote(session_id: str, key: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the remote upload callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
path='/qiniu/{session_id}',
|
||||
@@ -182,7 +183,7 @@ def router_callback_qiniu(session_id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Qiniu Cloud upload callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
path='/tencent/{session_id}',
|
||||
@@ -199,7 +200,7 @@ def router_callback_tencent(session_id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Tencent Cloud upload callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
path='/aliyun/{session_id}',
|
||||
@@ -216,7 +217,7 @@ def router_callback_aliyun(session_id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Aliyun upload callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
path='/upyun/{session_id}',
|
||||
@@ -233,7 +234,7 @@ def router_callback_upyun(session_id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Upyun upload callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
path='/aws/{session_id}',
|
||||
@@ -250,7 +251,7 @@ def router_callback_aws(session_id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the AWS S3 upload callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
path='/onedrive/finish/{session_id}',
|
||||
@@ -267,7 +268,7 @@ def router_callback_onedrive_finish(session_id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the OneDrive upload completion callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.get(
|
||||
path='/ondrive/auth',
|
||||
@@ -281,7 +282,7 @@ def router_callback_onedrive_auth() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the OneDrive authorization callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.get(
|
||||
path='/google/auth',
|
||||
@@ -295,4 +296,4 @@ def router_callback_google_auth() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Google OAuth completion callback.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -2,7 +2,7 @@ from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import AuthRequired
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
DirectoryCreateRequest,
|
||||
@@ -26,7 +26,7 @@ directory_router = APIRouter(
|
||||
)
|
||||
async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
path: str
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
@@ -94,7 +94,7 @@ async def router_directory_get(
|
||||
)
|
||||
async def router_directory_create(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: DirectoryCreateRequest
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import AuthRequired
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
download_router = APIRouter(
|
||||
prefix="/download",
|
||||
@@ -18,7 +20,7 @@ download_router.include_router(aria2_router)
|
||||
path='/url',
|
||||
summary='创建URL下载任务',
|
||||
description='Create a URL download task endpoint.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_aria2_url() -> ResponseBase:
|
||||
"""
|
||||
@@ -27,13 +29,13 @@ def router_aria2_url() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the URL download task.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@aria2_router.post(
|
||||
path='/torrent/{id}',
|
||||
summary='创建种子下载任务',
|
||||
description='Create a torrent download task endpoint.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_aria2_torrent(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -45,13 +47,13 @@ def router_aria2_torrent(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the torrent download task.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@aria2_router.put(
|
||||
path='/select/{gid}',
|
||||
summary='重新选择要下载的文件',
|
||||
description='Re-select files to download endpoint.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_aria2_select(gid: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -63,13 +65,13 @@ def router_aria2_select(gid: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the re-selection of files.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@aria2_router.delete(
|
||||
path='/task/{gid}',
|
||||
summary='取消或删除下载任务',
|
||||
description='Delete a download task endpoint.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_aria2_delete(gid: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -81,13 +83,13 @@ def router_aria2_delete(gid: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion of the download task.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@aria2_router.get(
|
||||
'/downloading',
|
||||
summary='获取正在下载中的任务',
|
||||
description='Get currently downloading tasks endpoint.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_aria2_downloading() -> ResponseBase:
|
||||
"""
|
||||
@@ -96,13 +98,13 @@ def router_aria2_downloading() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for currently downloading tasks.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@aria2_router.get(
|
||||
path='/finished',
|
||||
summary='获取已完成的任务',
|
||||
description='Get finished tasks endpoint.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_aria2_finished() -> ResponseBase:
|
||||
"""
|
||||
@@ -111,4 +113,4 @@ def router_aria2_finished() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for finished tasks.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -17,7 +17,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import AuthRequired
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
CreateFileRequest,
|
||||
@@ -35,6 +35,7 @@ from models import (
|
||||
)
|
||||
from service.storage import LocalStorageService
|
||||
from utils.JWT import SECRET_KEY
|
||||
from utils import http_exceptions
|
||||
|
||||
|
||||
# ==================== 下载令牌管理 ====================
|
||||
@@ -88,7 +89,7 @@ _upload_router = APIRouter(prefix="/upload")
|
||||
)
|
||||
async def create_upload_session(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CreateUploadSessionRequest,
|
||||
) -> UploadSessionResponse:
|
||||
"""
|
||||
@@ -187,7 +188,7 @@ async def create_upload_session(
|
||||
)
|
||||
async def upload_chunk(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
session_id: UUID,
|
||||
chunk_index: int,
|
||||
file: UploadFile = File(...),
|
||||
@@ -291,7 +292,7 @@ async def upload_chunk(
|
||||
)
|
||||
async def delete_upload_session(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
session_id: UUID,
|
||||
) -> ResponseBase:
|
||||
"""删除上传会话端点"""
|
||||
@@ -320,7 +321,7 @@ async def delete_upload_session(
|
||||
)
|
||||
async def clear_upload_sessions(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> ResponseBase:
|
||||
"""清除所有上传会话端点"""
|
||||
# 获取所有会话
|
||||
@@ -368,7 +369,7 @@ _download_router = APIRouter(prefix="/download")
|
||||
)
|
||||
async def create_download_token(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
file_id: UUID,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
@@ -456,7 +457,7 @@ router.include_router(_download_router)
|
||||
)
|
||||
async def create_empty_file(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CreateFileRequest,
|
||||
) -> ResponseBase:
|
||||
"""创建空白文件端点"""
|
||||
@@ -564,7 +565,7 @@ async def file_source_redirect(id: str, name: str) -> ResponseBase:
|
||||
path='/update/{id}',
|
||||
summary='更新文件',
|
||||
description='更新文件内容。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_update(id: str) -> ResponseBase:
|
||||
"""更新文件内容"""
|
||||
@@ -575,7 +576,7 @@ async def file_update(id: str) -> ResponseBase:
|
||||
path='/preview/{id}',
|
||||
summary='预览文件',
|
||||
description='获取文件预览。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_preview(id: str) -> ResponseBase:
|
||||
"""预览文件"""
|
||||
@@ -586,7 +587,7 @@ async def file_preview(id: str) -> ResponseBase:
|
||||
path='/content/{id}',
|
||||
summary='获取文本文件内容',
|
||||
description='获取文本文件内容。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_content(id: str) -> ResponseBase:
|
||||
"""获取文本文件内容"""
|
||||
@@ -597,7 +598,7 @@ async def file_content(id: str) -> ResponseBase:
|
||||
path='/doc/{id}',
|
||||
summary='获取Office文档预览地址',
|
||||
description='获取Office文档在线预览地址。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_doc(id: str) -> ResponseBase:
|
||||
"""获取Office文档预览地址"""
|
||||
@@ -608,7 +609,7 @@ async def file_doc(id: str) -> ResponseBase:
|
||||
path='/thumb/{id}',
|
||||
summary='获取文件缩略图',
|
||||
description='获取文件缩略图。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_thumb(id: str) -> ResponseBase:
|
||||
"""获取文件缩略图"""
|
||||
@@ -619,7 +620,7 @@ async def file_thumb(id: str) -> ResponseBase:
|
||||
path='/source/{id}',
|
||||
summary='取得文件外链',
|
||||
description='获取文件的外链地址。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_source(id: str) -> ResponseBase:
|
||||
"""获取文件外链"""
|
||||
@@ -630,7 +631,7 @@ async def file_source(id: str) -> ResponseBase:
|
||||
path='/archive',
|
||||
summary='打包要下载的文件',
|
||||
description='将多个文件打包下载。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_archive() -> ResponseBase:
|
||||
"""打包文件"""
|
||||
@@ -641,7 +642,7 @@ async def file_archive() -> ResponseBase:
|
||||
path='/compress',
|
||||
summary='创建文件压缩任务',
|
||||
description='创建文件压缩任务。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_compress() -> ResponseBase:
|
||||
"""创建压缩任务"""
|
||||
@@ -652,7 +653,7 @@ async def file_compress() -> ResponseBase:
|
||||
path='/decompress',
|
||||
summary='创建文件解压任务',
|
||||
description='创建文件解压任务。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_decompress() -> ResponseBase:
|
||||
"""创建解压任务"""
|
||||
@@ -663,7 +664,7 @@ async def file_decompress() -> ResponseBase:
|
||||
path='/relocate',
|
||||
summary='创建文件转移任务',
|
||||
description='创建文件转移任务。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_relocate() -> ResponseBase:
|
||||
"""创建转移任务"""
|
||||
@@ -674,7 +675,7 @@ async def file_relocate() -> ResponseBase:
|
||||
path='/search/{type}/{keyword}',
|
||||
summary='搜索文件',
|
||||
description='按关键字搜索文件。',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_search(type: str, keyword: str) -> ResponseBase:
|
||||
"""搜索文件"""
|
||||
|
||||
@@ -12,7 +12,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import AuthRequired
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
Object,
|
||||
@@ -171,7 +171,7 @@ async def _copy_object_recursive(
|
||||
)
|
||||
async def router_object_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectDeleteRequest,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
@@ -224,7 +224,7 @@ async def router_object_delete(
|
||||
)
|
||||
async def router_object_move(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectMoveRequest,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
@@ -302,7 +302,7 @@ async def router_object_move(
|
||||
)
|
||||
async def router_object_copy(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectCopyRequest,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
@@ -394,7 +394,7 @@ async def router_object_copy(
|
||||
)
|
||||
async def router_object_rename(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectRenameRequest,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
@@ -465,7 +465,7 @@ async def router_object_rename(
|
||||
)
|
||||
async def router_object_property(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
) -> ObjectPropertyResponse:
|
||||
"""
|
||||
@@ -501,7 +501,7 @@ async def router_object_property(
|
||||
)
|
||||
async def router_object_property_detail(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
) -> ObjectPropertyDetailResponse:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import AuthRequired
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
share_router = APIRouter(
|
||||
prefix='/share',
|
||||
@@ -23,7 +25,7 @@ def router_share_get(info: str, id: str) -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing shared content information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.put(
|
||||
path='/download/{id}',
|
||||
@@ -40,7 +42,7 @@ def router_share_download(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing download session information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='preview/{id}',
|
||||
@@ -57,7 +59,7 @@ def router_share_preview(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing preview information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/doc/{id}',
|
||||
@@ -74,7 +76,7 @@ def router_share_doc(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing the document preview URL.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/content/{id}',
|
||||
@@ -91,7 +93,7 @@ def router_share_content(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
str: The content of the text file.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/list/{id}/{path:path}',
|
||||
@@ -109,7 +111,7 @@ def router_share_list(id: str, path: str = '') -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing directory listing information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/search/{id}/{type}/{keywords}',
|
||||
@@ -128,7 +130,7 @@ def router_share_search(id: str, type: str, keywords: str) -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing search results.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.post(
|
||||
path='/archive/{id}',
|
||||
@@ -145,7 +147,7 @@ def router_share_archive(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing archive download information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/readme/{id}',
|
||||
@@ -162,7 +164,7 @@ def router_share_readme(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
str: The content of the README file.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/thumb/{id}/{file}',
|
||||
@@ -180,7 +182,7 @@ def router_share_thumb(id: str, file: str) -> ResponseBase:
|
||||
Returns:
|
||||
str: A Base64 encoded string of the thumbnail image.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.post(
|
||||
path='/report/{id}',
|
||||
@@ -197,7 +199,7 @@ def router_share_report(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing report submission information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/search',
|
||||
@@ -215,7 +217,7 @@ def router_share_search_public(keywords: str, type: str = 'all') -> ResponseBase
|
||||
Returns:
|
||||
dict: A dictionary containing search results for public shares.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
#####################
|
||||
# 需要登录的接口
|
||||
@@ -225,7 +227,7 @@ def router_share_search_public(keywords: str, type: str = 'all') -> ResponseBase
|
||||
path='/',
|
||||
summary='创建新分享',
|
||||
description='Create a new share endpoint.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_share_create() -> ResponseBase:
|
||||
"""
|
||||
@@ -234,13 +236,13 @@ def router_share_create() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the new share creation.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/',
|
||||
summary='列出我的分享',
|
||||
description='Get a list of shares.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_share_list() -> ResponseBase:
|
||||
"""
|
||||
@@ -249,13 +251,13 @@ def router_share_list() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the list of shares.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.post(
|
||||
path='/save/{id}',
|
||||
summary='转存他人分享',
|
||||
description='Save another user\'s share by ID.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_share_save(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -267,13 +269,13 @@ def router_share_save(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the saved share.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.patch(
|
||||
path='/{id}',
|
||||
summary='更新分享信息',
|
||||
description='Update share information by ID.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_share_update(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -285,13 +287,13 @@ def router_share_update(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the updated share.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.delete(
|
||||
path='/{id}',
|
||||
summary='删除分享',
|
||||
description='Delete a share by ID.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_share_delete(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -303,4 +305,4 @@ def router_share_delete(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deleted share.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -1,49 +1,29 @@
|
||||
from fastapi import APIRouter
|
||||
from sqlalchemy import and_
|
||||
import json
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase
|
||||
from models.setting import Setting
|
||||
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from utils import http_exceptions
|
||||
|
||||
site_router = APIRouter(
|
||||
prefix="/site",
|
||||
tags=["site"],
|
||||
)
|
||||
|
||||
|
||||
async def _get_setting(session: SessionDep, type_: str, name: str) -> str | None:
|
||||
"""获取设置值"""
|
||||
setting = await Setting.get(session, and_(Setting.type == type_, Setting.name == name))
|
||||
return setting.value if setting else None
|
||||
|
||||
|
||||
async def _get_setting_bool(session: SessionDep, type_: str, name: str) -> bool:
|
||||
"""获取布尔类型设置值"""
|
||||
value = await _get_setting(session, type_, name)
|
||||
return value == "1" if value else False
|
||||
|
||||
async def _get_setting_json(session: SessionDep, type_: str, name: str) -> dict | list | None:
|
||||
"""获取 JSON 类型设置值"""
|
||||
value = await _get_setting(session, type_, name)
|
||||
return json.loads(value) if value else None
|
||||
|
||||
|
||||
@site_router.get(
|
||||
path="/ping",
|
||||
summary="测试用路由",
|
||||
description="A simple endpoint to check if the site is up and running.",
|
||||
response_model=ResponseBase,
|
||||
)
|
||||
def router_site_ping():
|
||||
def router_site_ping() -> ResponseBase:
|
||||
"""
|
||||
Ping the site to check if it is up and running.
|
||||
|
||||
Returns:
|
||||
str: A message indicating the site is running.
|
||||
"""
|
||||
from utils.conf.appmeta import BackendVersion
|
||||
return ResponseBase(data=BackendVersion)
|
||||
return ResponseBase()
|
||||
|
||||
|
||||
@site_router.get(
|
||||
@@ -59,7 +39,7 @@ def router_site_captcha():
|
||||
Returns:
|
||||
str: A Base64 encoded string of the captcha image.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@site_router.get(
|
||||
@@ -68,38 +48,13 @@ def router_site_captcha():
|
||||
description='Get the configuration file.',
|
||||
response_model=ResponseBase,
|
||||
)
|
||||
async def router_site_config(session: SessionDep):
|
||||
async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
"""
|
||||
Get the configuration file.
|
||||
|
||||
Returns:
|
||||
dict: The site configuration.
|
||||
"""
|
||||
return ResponseBase(
|
||||
data={
|
||||
"title": await _get_setting(session, "basic", "siteName"),
|
||||
"loginCaptcha": await _get_setting_bool(session, "login", "login_captcha"),
|
||||
"regCaptcha": await _get_setting_bool(session, "login", "reg_captcha"),
|
||||
"forgetCaptcha": await _get_setting_bool(session, "login", "forget_captcha"),
|
||||
"emailActive": await _get_setting_bool(session, "login", "email_active"),
|
||||
"QQLogin": None,
|
||||
"themes": await _get_setting_json(session, "basic", "themes"),
|
||||
"defaultTheme": await _get_setting(session, "basic", "defaultTheme"),
|
||||
"score_enabled": None,
|
||||
"share_score_rate": None,
|
||||
"home_view_method": await _get_setting(session, "view", "home_view_method"),
|
||||
"share_view_method": await _get_setting(session, "view", "share_view_method"),
|
||||
"authn": await _get_setting_bool(session, "authn", "authn_enabled"),
|
||||
"user": {},
|
||||
"captcha_type": None,
|
||||
"captcha_ReCaptchaKey": await _get_setting(session, "captcha", "captcha_ReCaptchaKey"),
|
||||
"captcha_CloudflareKey": await _get_setting(session, "captcha", "captcha_CloudflareKey"),
|
||||
"captcha_tcaptcha_appid": None,
|
||||
"site_notice": None,
|
||||
"registerEnabled": await _get_setting_bool(session, "register", "register_enabled"),
|
||||
"app_promotion": None,
|
||||
"wopi_exts": None,
|
||||
"app_feedback": None,
|
||||
"app_forum": None,
|
||||
}
|
||||
return SiteConfigResponse(
|
||||
title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")),
|
||||
)
|
||||
@@ -1,7 +1,9 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
from middleware.auth import AuthRequired
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
slave_router = APIRouter(
|
||||
prefix="/slave",
|
||||
@@ -32,7 +34,7 @@ def router_slave_ping() -> ResponseBase:
|
||||
path='/post',
|
||||
summary='上传',
|
||||
description='Upload data to the server.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_post(data: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -44,7 +46,7 @@ def router_slave_post(data: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A response model indicating success.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_router.get(
|
||||
path='/get/{speed}/{path}/{name}',
|
||||
@@ -62,13 +64,13 @@ def router_slave_download(speed: int, path: str, name: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A response model containing download information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_router.get(
|
||||
path='/download/{sign}',
|
||||
summary='根据签名下载文件',
|
||||
description='Download a file based on its signature.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_download_by_sign(sign: str) -> FileResponse:
|
||||
"""
|
||||
@@ -80,13 +82,13 @@ def router_slave_download_by_sign(sign: str) -> FileResponse:
|
||||
Returns:
|
||||
FileResponse: A response containing the file to be downloaded.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_router.get(
|
||||
path='/source/{speed}/{path}/{name}',
|
||||
summary='获取文件外链',
|
||||
description='Get the external link for a file based on its signature.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_source(speed: int, path: str, name: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -100,13 +102,13 @@ def router_slave_source(speed: int, path: str, name: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A response model containing the external link for the file.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_router.get(
|
||||
path='/source/{sign}',
|
||||
summary='根据签名获取文件',
|
||||
description='Get a file based on its signature.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_source_by_sign(sign: str) -> FileResponse:
|
||||
"""
|
||||
@@ -118,13 +120,13 @@ def router_slave_source_by_sign(sign: str) -> FileResponse:
|
||||
Returns:
|
||||
FileResponse: A response containing the file to be retrieved.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_router.get(
|
||||
path='/thumb/{id}',
|
||||
summary='获取缩略图',
|
||||
description='Get a thumbnail image based on its ID.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_thumb(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -136,13 +138,13 @@ def router_slave_thumb(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A response model containing the Base64 encoded thumbnail image.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_router.delete(
|
||||
path='/delete',
|
||||
summary='删除文件',
|
||||
description='Delete a file from the server.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_delete(path: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -154,25 +156,25 @@ def router_slave_delete(path: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A response model indicating success or failure of the deletion.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_aria2_router.post(
|
||||
path='/test',
|
||||
summary='测试从机连接Aria2服务',
|
||||
description='Test the connection to the Aria2 service from the slave.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_aria2_test() -> ResponseBase:
|
||||
"""
|
||||
Test the connection to the Aria2 service from the slave.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_aria2_router.get(
|
||||
path='/get/{gid}',
|
||||
summary='获取Aria2任务信息',
|
||||
description='Get information about an Aria2 task by its GID.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_aria2_get(gid: str = None) -> ResponseBase:
|
||||
"""
|
||||
@@ -184,13 +186,13 @@ def router_slave_aria2_get(gid: str = None) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A response model containing the task information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_aria2_router.post(
|
||||
path='/add',
|
||||
summary='添加Aria2任务',
|
||||
description='Add a new Aria2 task.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> ResponseBase:
|
||||
"""
|
||||
@@ -204,13 +206,13 @@ def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> Response
|
||||
Returns:
|
||||
ResponseBase: A response model indicating success or failure of the task addition.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@slave_aria2_router.delete(
|
||||
path='/remove/{gid}',
|
||||
summary='删除Aria2任务',
|
||||
description='Remove an Aria2 task by its GID.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_slave_aria2_remove(gid: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -222,4 +224,4 @@ def router_slave_aria2_remove(gid: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A response model indicating success or failure of the task removal.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import AuthRequired
|
||||
from middleware.auth import auth_required
|
||||
|
||||
from models import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
tag_router = APIRouter(
|
||||
prefix='/tag',
|
||||
@@ -11,7 +13,7 @@ tag_router = APIRouter(
|
||||
path='/filter',
|
||||
summary='创建文件分类标签',
|
||||
description='Create a file classification tag.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_tag_create_filter() -> ResponseBase:
|
||||
"""
|
||||
@@ -20,13 +22,13 @@ def router_tag_create_filter() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created tag.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@tag_router.post(
|
||||
path='/link',
|
||||
summary='创建目录快捷方式标签',
|
||||
description='Create a directory shortcut tag.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_tag_create_link() -> ResponseBase:
|
||||
"""
|
||||
@@ -35,13 +37,13 @@ def router_tag_create_link() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created tag.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@tag_router.delete(
|
||||
path='/{id}',
|
||||
summary='删除标签',
|
||||
description='Delete a tag by its ID.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_tag_delete(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -53,4 +55,4 @@ def router_tag_delete(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -7,13 +7,14 @@ from sqlalchemy import and_
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger
|
||||
|
||||
import models
|
||||
import service
|
||||
from middleware.auth import AuthRequired
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from utils.JWT.JWT import SECRET_KEY
|
||||
from utils import Password
|
||||
from utils import Password, http_exceptions
|
||||
|
||||
user_router = APIRouter(
|
||||
prefix="/user",
|
||||
@@ -23,7 +24,7 @@ user_router = APIRouter(
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/user/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
@@ -42,11 +43,6 @@ async def router_user_session(
|
||||
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。
|
||||
|
||||
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
|
||||
|
||||
:raises HTTPException 401: 用户名或密码错误
|
||||
:raises HTTPException 403: 用户账号被封禁或未完成注册
|
||||
:raises HTTPException 428: 需要两步验证但未提供验证码
|
||||
:raises HTTPException 400: 两步验证码无效
|
||||
"""
|
||||
username = form_data.username
|
||||
password = form_data.password
|
||||
@@ -62,7 +58,7 @@ async def router_user_session(
|
||||
otp_code = scope
|
||||
break
|
||||
|
||||
result = await service.user.Login(
|
||||
result = await service.user.login(
|
||||
session,
|
||||
models.LoginRequest(
|
||||
username=username,
|
||||
@@ -71,22 +67,7 @@ async def router_user_session(
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(result, models.TokenResponse):
|
||||
return result
|
||||
elif result is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
elif result is False:
|
||||
raise HTTPException(status_code=403, detail="User account is banned or not fully registered")
|
||||
elif result == "2fa_required":
|
||||
raise HTTPException(
|
||||
status_code=428,
|
||||
detail="Two-factor authentication required",
|
||||
headers={"X-2FA-Required": "true"},
|
||||
)
|
||||
elif result == "2fa_invalid":
|
||||
raise HTTPException(status_code=400, detail="Invalid two-factor authentication code")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Internal server error during login")
|
||||
|
||||
@user_router.post(
|
||||
path='/session/refresh',
|
||||
@@ -97,7 +78,7 @@ async def router_user_session_refresh(
|
||||
session: SessionDep,
|
||||
request, # RefreshTokenRequest
|
||||
) -> models.TokenResponse:
|
||||
...
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.post(
|
||||
path='/',
|
||||
@@ -137,12 +118,14 @@ async def router_user_register(
|
||||
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group")
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
raise HTTPException(status_code=500, detail="默认用户组设置不存在")
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
default_group = await models.Group.get(session, models.Group.id == default_group_id)
|
||||
if not default_group:
|
||||
raise HTTPException(status_code=500, detail="默认用户组不存在")
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
@@ -158,7 +141,8 @@ async def router_user_register(
|
||||
# 4. 创建以用户名命名的根目录
|
||||
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
raise HTTPException(status_code=500, detail="默认存储策略不存在")
|
||||
logger.error("默认存储策略不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
await models.Object(
|
||||
name=new_user_username,
|
||||
@@ -190,7 +174,7 @@ def router_user_email_code(
|
||||
Returns:
|
||||
dict: A dictionary containing information about the password reset email.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/qq',
|
||||
@@ -204,7 +188,7 @@ def router_user_qq() -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing QQ login initialization information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='authn/{username}',
|
||||
@@ -213,7 +197,7 @@ def router_user_qq() -> models.ResponseBase:
|
||||
)
|
||||
async def router_user_authn(username: str) -> models.ResponseBase:
|
||||
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.post(
|
||||
path='authn/finish/{username}',
|
||||
@@ -230,7 +214,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/profile/{id}',
|
||||
@@ -247,7 +231,7 @@ def router_user_profile(id: str) -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing user profile information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/avatar/{id}/{size}',
|
||||
@@ -265,7 +249,7 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
Returns:
|
||||
str: A Base64 encoded string of the user avatar image.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
#####################
|
||||
# 需要登录的接口
|
||||
@@ -275,12 +259,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
path='/me',
|
||||
summary='获取用户信息',
|
||||
description='Get user information.',
|
||||
dependencies=[Depends(dependency=AuthRequired)],
|
||||
dependencies=[Depends(dependency=auth_required)],
|
||||
response_model=models.ResponseBase,
|
||||
)
|
||||
async def router_user_me(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.User, Depends(AuthRequired)],
|
||||
user: Annotated[models.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
获取用户信息.
|
||||
@@ -319,11 +303,11 @@ async def router_user_me(
|
||||
path='/storage',
|
||||
summary='存储信息',
|
||||
description='Get user storage information.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_storage(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
获取用户存储空间信息。
|
||||
@@ -353,11 +337,11 @@ async def router_user_storage(
|
||||
path='/authn/start',
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_authn_start(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Initialize WebAuthn login for a user.
|
||||
@@ -395,7 +379,7 @@ async def router_user_authn_start(
|
||||
path='/authn/finish',
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_authn_finish() -> models.ResponseBase:
|
||||
"""
|
||||
@@ -404,7 +388,7 @@ def router_user_authn_finish() -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
@@ -418,13 +402,13 @@ def router_user_settings_policies() -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> models.ResponseBase:
|
||||
"""
|
||||
@@ -433,13 +417,13 @@ def router_user_settings_nodes() -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> models.ResponseBase:
|
||||
"""
|
||||
@@ -448,13 +432,13 @@ def router_user_settings_tasks() -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings() -> models.ResponseBase:
|
||||
"""
|
||||
@@ -469,7 +453,7 @@ def router_user_settings() -> models.ResponseBase:
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar() -> models.ResponseBase:
|
||||
"""
|
||||
@@ -478,13 +462,13 @@ def router_user_settings_avatar() -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> models.ResponseBase:
|
||||
"""
|
||||
@@ -493,13 +477,13 @@ def router_user_settings_avatar_gravatar() -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
description='Update user settings.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_patch(option: str) -> models.ResponseBase:
|
||||
"""
|
||||
@@ -511,16 +495,16 @@ def router_user_settings_patch(option: str) -> models.ResponseBase:
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the settings update.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
@@ -537,11 +521,11 @@ async def router_user_settings_2fa(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> models.ResponseBase:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import AuthRequired
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
vas_router = APIRouter(
|
||||
prefix="/vas",
|
||||
@@ -11,7 +13,7 @@ vas_router = APIRouter(
|
||||
path='/pack',
|
||||
summary='获取容量包及配额信息',
|
||||
description='Get information about storage packs and quotas.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_pack() -> ResponseBase:
|
||||
"""
|
||||
@@ -20,13 +22,13 @@ def router_vas_pack() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for storage packs and quotas.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/product',
|
||||
summary='获取商品信息,同时返回支付信息',
|
||||
description='Get product information along with payment details.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_product() -> ResponseBase:
|
||||
"""
|
||||
@@ -35,13 +37,13 @@ def router_vas_product() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for products and payment information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.post(
|
||||
path='/order',
|
||||
summary='新建支付订单',
|
||||
description='Create an order for a product.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_order() -> ResponseBase:
|
||||
"""
|
||||
@@ -50,13 +52,13 @@ def router_vas_order() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created order.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/order/{id}',
|
||||
summary='查询订单状态',
|
||||
description='Get information about a specific payment order by ID.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_order_get(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -68,13 +70,13 @@ def router_vas_order_get(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the specified order.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/redeem',
|
||||
summary='获取兑换码信息',
|
||||
description='Get information about a specific redemption code.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_redeem(code: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -86,13 +88,13 @@ def router_vas_redeem(code: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the specified redemption code.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.post(
|
||||
path='/redeem',
|
||||
summary='执行兑换',
|
||||
description='Redeem a redemption code for a product or service.',
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_redeem_post() -> ResponseBase:
|
||||
"""
|
||||
@@ -101,4 +103,4 @@ def router_vas_redeem_post() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the redeemed code.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from middleware.auth import AuthRequired
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
# WebDAV 管理路由
|
||||
webdav_router = APIRouter(
|
||||
@@ -12,7 +14,7 @@ webdav_router = APIRouter(
|
||||
path='/accounts',
|
||||
summary='获取账号信息',
|
||||
description='Get account information for WebDAV.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_accounts() -> ResponseBase:
|
||||
"""
|
||||
@@ -21,13 +23,13 @@ def router_webdav_accounts() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the account information.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.post(
|
||||
path='/accounts',
|
||||
summary='新建账号',
|
||||
description='Create a new WebDAV account.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_create_account() -> ResponseBase:
|
||||
"""
|
||||
@@ -36,13 +38,13 @@ def router_webdav_create_account() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created account.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/accounts/{id}',
|
||||
summary='删除账号',
|
||||
description='Delete a WebDAV account by its ID.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_account(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -54,13 +56,13 @@ def router_webdav_delete_account(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.post(
|
||||
path='/mount',
|
||||
summary='新建目录挂载',
|
||||
description='Create a new WebDAV mount point.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_create_mount() -> ResponseBase:
|
||||
"""
|
||||
@@ -69,13 +71,13 @@ def router_webdav_create_mount() -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created mount point.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/mount/{id}',
|
||||
summary='删除目录挂载',
|
||||
description='Delete a WebDAV mount point by its ID.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_mount(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -87,13 +89,13 @@ def router_webdav_delete_mount(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.patch(
|
||||
path='accounts/{id}',
|
||||
summary='更新账号信息',
|
||||
description='Update WebDAV account information by ID.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_update_account(id: str) -> ResponseBase:
|
||||
"""
|
||||
@@ -105,4 +107,4 @@ def router_webdav_update_account(id: str) -> ResponseBase:
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the updated account.
|
||||
"""
|
||||
pass
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -1 +1 @@
|
||||
from .login import Login
|
||||
from .login import login
|
||||
@@ -1,25 +1,19 @@
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger as log
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from loguru import logger
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import LoginRequest, TokenResponse, User
|
||||
from utils import http_exceptions
|
||||
from utils.JWT.JWT import create_access_token, create_refresh_token
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
async def Login(
|
||||
session: AsyncSession,
|
||||
async def login(
|
||||
session: SessionDep,
|
||||
login_request: LoginRequest,
|
||||
) -> TokenResponse | bool | Literal["2fa_required", "2fa_invalid"] | None:
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
根据账号密码进行登录。
|
||||
|
||||
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
|
||||
如果登录异常,返回 `False`(未完成注册或账号被封禁)。
|
||||
如果登录失败,返回 `None`。
|
||||
如果需要两步验证但未提供验证码,返回 `"2fa_required"`。
|
||||
如果两步验证码无效,返回 `"2fa_invalid"`。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param login_request: 登录请求
|
||||
@@ -38,30 +32,29 @@ async def Login(
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
log.debug(f"Cannot find user with username: {login_request.username}")
|
||||
return None
|
||||
logger.debug(f"Cannot find user with username: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
|
||||
# 验证密码是否正确
|
||||
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
|
||||
log.debug(f"Password verification failed for user: {login_request.username}")
|
||||
return None
|
||||
logger.debug(f"Password verification failed for user: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
|
||||
# 验证用户是否可登录
|
||||
if not current_user.status:
|
||||
# 未完成注册 or 账号已被封禁
|
||||
return False
|
||||
http_exceptions.raise_forbidden("Your account is disabled")
|
||||
|
||||
# 检查两步验证
|
||||
if current_user.two_factor:
|
||||
# 用户已启用两步验证
|
||||
if not login_request.two_fa_code:
|
||||
log.debug(f"2FA required for user: {login_request.username}")
|
||||
return "2fa_required"
|
||||
logger.debug(f"2FA required for user: {login_request.username}")
|
||||
http_exceptions.raise_precondition_required("2FA required")
|
||||
|
||||
# 验证 OTP 码
|
||||
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
|
||||
log.debug(f"Invalid 2FA code for user: {login_request.username}")
|
||||
return "2fa_invalid"
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
||||
|
||||
# 创建令牌
|
||||
access_token, access_expire = create_access_token(data={'sub': current_user.username})
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User, LoginRequest, TokenResponse
|
||||
from models.group import Group
|
||||
from service.user.login import Login
|
||||
from service.user.login import login
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
assert result.access_token is not None
|
||||
@@ -103,7 +103,7 @@ async def test_login_user_not_found(db_session: AsyncSession):
|
||||
password="any_password"
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -116,7 +116,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
password="wrong_password"
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -129,7 +129,7 @@ async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
|
||||
password="password"
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -145,7 +145,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
|
||||
# 未提供 two_fa_code
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert result == "2fa_required"
|
||||
|
||||
@@ -161,7 +161,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
|
||||
two_fa_code="000000" # 错误的验证码
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert result == "2fa_invalid"
|
||||
|
||||
@@ -184,7 +184,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
|
||||
two_fa_code=valid_code
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
assert result.access_token is not None
|
||||
@@ -202,7 +202,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
|
||||
@@ -227,7 +227,7 @@ async def test_login_case_sensitive_username(db_session: AsyncSession, setup_use
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
result = await Login(db_session, login_request)
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
# 应该失败,因为用户名大小写不匹配
|
||||
assert result is None
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .password.pwd import Password, PasswordStatus
|
||||
from .http import http_exceptions
|
||||
@@ -1,20 +1,6 @@
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from starlette.status import (
|
||||
HTTP_400_BAD_REQUEST,
|
||||
HTTP_401_UNAUTHORIZED,
|
||||
HTTP_402_PAYMENT_REQUIRED,
|
||||
HTTP_403_FORBIDDEN,
|
||||
HTTP_404_NOT_FOUND,
|
||||
HTTP_409_CONFLICT,
|
||||
HTTP_429_TOO_MANY_REQUESTS,
|
||||
HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
HTTP_501_NOT_IMPLEMENTED,
|
||||
HTTP_503_SERVICE_UNAVAILABLE,
|
||||
HTTP_504_GATEWAY_TIMEOUT,
|
||||
)
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# --- 400 ---
|
||||
|
||||
@@ -24,50 +10,54 @@ def ensure_request_param(to_check: Any, detail: str) -> None:
|
||||
This function returns None if the check passes.
|
||||
"""
|
||||
if not to_check:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)
|
||||
|
||||
def raise_bad_request(detail: str = '') -> NoReturn:
|
||||
"""Raises an HTTP 400 Bad Request exception."""
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)
|
||||
|
||||
def raise_unauthorized(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 401 Unauthorized exception."""
|
||||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=detail)
|
||||
|
||||
def raise_insufficient_quota(detail: str = "积分不足,请充值") -> NoReturn:
|
||||
"""Raises an HTTP 402 Payment Required exception."""
|
||||
raise HTTPException(status_code=HTTP_402_PAYMENT_REQUIRED, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=detail)
|
||||
|
||||
def raise_forbidden(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 403 Forbidden exception."""
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
def raise_not_found(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 404 Not Found exception."""
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail)
|
||||
|
||||
def raise_conflict(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 409 Conflict exception."""
|
||||
raise HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=detail)
|
||||
|
||||
def raise_precondition_required(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 428 Precondition required exception."""
|
||||
raise HTTPException(status_code=status.HTTP_428_PRECONDITION_REQUIRED, detail=detail)
|
||||
|
||||
def raise_too_many_requests(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 429 Too Many Requests exception."""
|
||||
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail)
|
||||
|
||||
# --- 500 ---
|
||||
|
||||
def raise_internal_error(detail: str = "服务器出现故障,请稍后再试或联系管理员") -> NoReturn:
|
||||
"""Raises an HTTP 500 Internal Server Error exception."""
|
||||
raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
|
||||
|
||||
def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
|
||||
"""Raises an HTTP 501 Not Implemented exception."""
|
||||
raise HTTPException(status_code=HTTP_501_NOT_IMPLEMENTED, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail)
|
||||
|
||||
def raise_service_unavailable(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 503 Service Unavailable exception."""
|
||||
raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
|
||||
|
||||
def raise_gateway_timeout(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 504 Gateway Timeout exception."""
|
||||
raise HTTPException(status_code=HTTP_504_GATEWAY_TIMEOUT, detail=detail)
|
||||
raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=detail)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import secrets
|
||||
|
||||
from loguru import logger
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
@@ -104,10 +105,11 @@ class Password:
|
||||
|
||||
@staticmethod
|
||||
async def generate_totp(
|
||||
username: str
|
||||
*args, **kwargs
|
||||
) -> TwoFactorResponse:
|
||||
"""
|
||||
生成 TOTP 密钥和对应的 URI,用于两步验证。
|
||||
所有的参数将会给到 `pyotp.totp.TOTP`
|
||||
|
||||
:return: 包含 TOTP 密钥和 URI 的元组
|
||||
"""
|
||||
@@ -121,8 +123,7 @@ class Password:
|
||||
salt="2fa-setup-salt"
|
||||
)
|
||||
|
||||
otp_uri = pyotp.totp.TOTP(secret).provisioning_uri(
|
||||
name=username,
|
||||
otp_uri = pyotp.totp.TOTP(secret, *args, **kwargs).provisioning_uri(
|
||||
issuer_name=appmeta.APP_NAME
|
||||
)
|
||||
|
||||
@@ -134,17 +135,21 @@ class Password:
|
||||
@staticmethod
|
||||
def verify_totp(
|
||||
secret: str,
|
||||
code: str
|
||||
code: int,
|
||||
*args, **kwargs
|
||||
) -> PasswordStatus:
|
||||
"""
|
||||
验证 TOTP 验证码。
|
||||
|
||||
:param secret: TOTP 密钥(Base32 编码)
|
||||
:param code: 用户输入的 6 位验证码
|
||||
:param args: 传入 `totp.verify` 的参数
|
||||
:param kwargs: 传入 `totp.verify` 的参数
|
||||
|
||||
:return: 验证是否成功
|
||||
"""
|
||||
totp = pyotp.TOTP(secret)
|
||||
if totp.verify(code):
|
||||
if totp.verify(otp=str(code), *args, **kwargs):
|
||||
return PasswordStatus.VALID
|
||||
else:
|
||||
return PasswordStatus.INVALID
|
||||
Reference in New Issue
Block a user