diff --git a/.run/开发模式.run.xml b/.run/开发模式.run.xml new file mode 100644 index 0000000..e3531e7 --- /dev/null +++ b/.run/开发模式.run.xml @@ -0,0 +1,25 @@ + + + + \ No newline at end of file diff --git a/middleware/auth.py b/middleware/auth.py index 22d8513..fb51b16 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -1,20 +1,14 @@ from typing import Annotated -from fastapi import Depends, HTTPException -from jwt import InvalidTokenError +from fastapi import Depends import jwt from models.user import User from utils.JWT import JWT from .dependencies import SessionDep +from utils import http_exceptions -credentials_exception = HTTPException( - status_code=401, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, -) - -async def AuthRequired( +async def auth_required( session: SessionDep, token: Annotated[str, Depends(JWT.oauth2_scheme)], ) -> User: @@ -26,28 +20,28 @@ async def AuthRequired( username = payload.get("sub") if username is None: - raise credentials_exception + http_exceptions.raise_unauthorized("账号或密码错误") # 从数据库获取用户信息 user = await User.get(session, User.username == username) if not user: - raise credentials_exception + http_exceptions.raise_unauthorized("账号或密码错误") return user - except InvalidTokenError: - raise credentials_exception + except jwt.InvalidTokenError: + http_exceptions.raise_unauthorized("账号或密码错误") -async def AdminRequired( - user: Annotated[User, Depends(AuthRequired)], +async def admin_required( + user: Annotated[User, Depends(auth_required)], ) -> User: """ 验证是否为管理员。 使用方法: - >>> APIRouter(dependencies=[Depends(AdminRequired)]) + >>> APIRouter(dependencies=[Depends(admin_required)]) """ group = await user.awaitable_attrs.group if group.admin: return user - raise HTTPException(status_code=403, detail="Admin Required") \ No newline at end of file + raise http_exceptions.raise_forbidden("Admin Required") \ No newline at end of file diff --git a/middleware/dependencies.py b/middleware/dependencies.py index dd55103..3c97006 100644 --- a/middleware/dependencies.py +++ b/middleware/dependencies.py @@ -1,4 +1,4 @@ -from typing import Annotated, AsyncGenerator +from typing import Annotated from fastapi import Depends from sqlmodel.ext.asyncio.session import AsyncSession diff --git a/models/setting.py b/models/setting.py index 5819a57..51d8070 100644 --- a/models/setting.py +++ b/models/setting.py @@ -20,11 +20,11 @@ class SiteConfigResponse(SQLModelBase): title: str = "DiskNext" """网站标题""" - themes: dict[str, str] = {} - """网站主题配置""" + # themes: dict[str, str] = {} + # """网站主题配置""" - default_theme: dict[str, str] = {} - """默认主题RGB色号""" + # default_theme: dict[str, str] = {} + # """默认主题RGB色号""" site_notice: str | None = None """网站公告""" diff --git a/routers/api/v1/__init__.py b/routers/api/v1/__init__.py index ec2d6b9..3641cff 100644 --- a/routers/api/v1/__init__.py +++ b/routers/api/v1/__init__.py @@ -24,16 +24,9 @@ from .webdav import webdav_router router = APIRouter(prefix="/v1") -router.include_router(admin_router) -router.include_router(admin_aria2_router) -router.include_router(admin_file_router) -router.include_router(admin_group_router) -router.include_router(admin_policy_router) -router.include_router(admin_share_router) -router.include_router(admin_task_router) -router.include_router(admin_user_router) -router.include_router(admin_vas_router) +# [TODO] 如果是主机,导入下面的路由 +router.include_router(admin_router) router.include_router(callback_router) router.include_router(directory_router) router.include_router(download_router) @@ -41,7 +34,9 @@ router.include_router(file_router) router.include_router(object_router) router.include_router(share_router) router.include_router(site_router) -router.include_router(slave_router) router.include_router(user_router) router.include_router(vas_router) router.include_router(webdav_router) + +# [TODO] 如果是从机,导入下面的路由 +router.include_router(slave_router) \ No newline at end of file diff --git a/routers/api/v1/admin/__init__.py b/routers/api/v1/admin/__init__.py index 7e5860a..3203819 100644 --- a/routers/api/v1/admin/__init__.py +++ b/routers/api/v1/admin/__init__.py @@ -8,7 +8,7 @@ from loguru import logger as l from sqlalchemy import func, and_ from sqlmodel import Field -from middleware.auth import AdminRequired +from middleware.auth import admin_required from middleware.dependencies import SessionDep from models import ( Policy, PolicyOptions, PolicyType, User, ResponseBase, @@ -156,7 +156,7 @@ admin_vas_router = APIRouter( path='/summary', summary='获取站点概况', description='Get site summary information', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) def router_admin_get_summary() -> ResponseBase: """ @@ -165,13 +165,13 @@ def router_admin_get_summary() -> ResponseBase: Returns: ResponseBase: 包含站点概况信息的响应模型。 """ - pass + http_exceptions.raise_not_implemented() @admin_router.get( path='/news', summary='获取社区新闻', description='Get community news', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) def router_admin_get_news() -> ResponseBase: """ @@ -180,13 +180,13 @@ def router_admin_get_news() -> ResponseBase: Returns: ResponseBase: 包含社区新闻信息的响应模型。 """ - pass + http_exceptions.raise_not_implemented() @admin_router.patch( path='/settings', summary='更新设置', description='Update settings', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_update_settings( session: SessionDep, @@ -225,7 +225,7 @@ async def router_admin_update_settings( path='/settings', summary='获取设置', description='Get settings', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_get_settings(session: SessionDep) -> ResponseBase: """ @@ -249,7 +249,7 @@ async def router_admin_get_settings(session: SessionDep) -> ResponseBase: path='/', summary='获取用户组列表', description='Get user group list', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_get_groups( session: SessionDep, @@ -314,7 +314,7 @@ async def router_admin_get_groups( path='/{group_id}', summary='获取用户组信息', description='Get user group information by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_get_group( session: SessionDep, @@ -366,7 +366,7 @@ async def router_admin_get_group( path='/list/{group_id}', summary='获取用户组成员列表', description='Get user group member list by group ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_get_group_members( session: SessionDep, @@ -410,7 +410,7 @@ async def router_admin_get_group_members( path='/', summary='创建用户组', description='Create a new user group', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_create_group( session: SessionDep, @@ -469,7 +469,7 @@ async def router_admin_create_group( path='/{group_id}', summary='更新用户组信息', description='Update user group information by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_update_group( session: SessionDep, @@ -539,7 +539,7 @@ async def router_admin_update_group( path='/{group_id}', summary='删除用户组', description='Delete user group by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_delete_group( session: SessionDep, @@ -576,7 +576,7 @@ async def router_admin_delete_group( path='/info/{user_id}', summary='获取用户信息', description='Get user information by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase: """ @@ -596,7 +596,7 @@ async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBa path='/list', summary='获取用户列表', description='Get user list', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_get_users( session: SessionDep, @@ -630,7 +630,7 @@ async def router_admin_get_users( path='/create', summary='创建用户', description='Create a new user', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_create_user( session: SessionDep, @@ -655,7 +655,7 @@ async def router_admin_create_user( path='/{user_id}', summary='更新用户信息', description='Update user information by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_update_user( session: SessionDep, @@ -700,7 +700,7 @@ async def router_admin_update_user( path='/{user_id}', summary='删除用户', description='Delete user by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_delete_user( session: SessionDep, @@ -730,7 +730,7 @@ async def router_admin_delete_user( path='/calibrate/{user_id}', summary='校准用户存储容量', description='Calibrate the user storage.', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_calibrate_storage( session: SessionDep, @@ -784,7 +784,7 @@ async def router_admin_calibrate_storage( path='/list', summary='获取文件列表', description='Get file list', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_get_file_list( session: SessionDep, @@ -858,7 +858,7 @@ async def router_admin_get_file_list( path='/preview/{file_id}', summary='预览文件', description='Preview file by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_preview_file( session: SessionDep, @@ -904,13 +904,13 @@ async def router_admin_preview_file( path='/ban/{file_id}', summary='封禁/解禁文件', description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_ban_file( session: SessionDep, file_id: UUID, request: FileBanRequest, - admin: Annotated[User, Depends(AdminRequired)], + admin: Annotated[User, Depends(admin_required)], ) -> ResponseBase: """ 封禁或解禁文件。封禁后用户无法访问该文件。 @@ -949,7 +949,7 @@ async def router_admin_ban_file( path='/{file_id}', summary='删除文件', description='Delete file by ID', - dependencies=[Depends(AdminRequired)], + dependencies=[Depends(admin_required)], ) async def router_admin_delete_file( session: SessionDep, @@ -1002,7 +1002,7 @@ async def router_admin_delete_file( path='/test', summary='测试 Aria2 连接', description='Test Aria2 RPC connection', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_aira2_test( request: Aria2TestRequest, @@ -1050,7 +1050,7 @@ async def router_admin_aira2_test( path='/list', summary='列出存储策略', description='List all storage policies', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_list( session: SessionDep, @@ -1097,7 +1097,7 @@ async def router_policy_list( path='/test/path', summary='测试本地路径可用性', description='Test local path availability', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_test_path( request: PolicyTestPathRequest, @@ -1139,7 +1139,7 @@ async def router_policy_test_path( path='/test/slave', summary='测试从机通信', description='Test slave node communication', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_test_slave( request: PolicyTestSlaveRequest, @@ -1173,7 +1173,7 @@ async def router_policy_test_slave( path='/', summary='创建存储策略', description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_add_policy( session: SessionDep, @@ -1243,7 +1243,7 @@ async def router_policy_add_policy( path='/cors', summary='创建跨域策略', description='Create CORS policy for S3 storage', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_add_cors() -> ResponseBase: """ @@ -1259,7 +1259,7 @@ async def router_policy_add_cors() -> ResponseBase: path='/scf', summary='创建COS回调函数', description='Create COS callback function', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_add_scf() -> ResponseBase: """ @@ -1275,7 +1275,7 @@ async def router_policy_add_scf() -> ResponseBase: path='/{policy_id}/oauth', summary='获取 OneDrive OAuth URL', description='Get OneDrive OAuth URL', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_onddrive_oauth( session: SessionDep, @@ -1300,7 +1300,7 @@ async def router_policy_onddrive_oauth( path='/{policy_id}', summary='获取存储策略', description='Get storage policy by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_get_policy( session: SessionDep, @@ -1346,7 +1346,7 @@ async def router_policy_get_policy( path='/{policy_id}', summary='删除存储策略', description='Delete storage policy by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_policy_delete_policy( session: SessionDep, @@ -1386,7 +1386,7 @@ async def router_policy_delete_policy( path='/list', summary='获取分享列表', description='Get share list', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_get_share_list( session: SessionDep, @@ -1443,7 +1443,7 @@ async def router_admin_get_share_list( path='/{share_id}', summary='获取分享详情', description='Get share detail by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_get_share( session: SessionDep, @@ -1489,7 +1489,7 @@ async def router_admin_get_share( path='/{share_id}', summary='删除分享', description='Delete share by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_delete_share( session: SessionDep, @@ -1518,7 +1518,7 @@ async def router_admin_delete_share( path='/list', summary='获取任务列表', description='Get task list', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_get_task_list( session: SessionDep, @@ -1580,7 +1580,7 @@ async def router_admin_get_task_list( path='/{task_id}', summary='获取任务详情', description='Get task detail by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_get_task( session: SessionDep, @@ -1618,7 +1618,7 @@ async def router_admin_get_task( path='/{task_id}', summary='删除任务', description='Delete task by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_delete_task( session: SessionDep, @@ -1647,7 +1647,7 @@ async def router_admin_delete_task( path='/list', summary='获取增值服务列表', description='Get VAS list (orders and storage packs)', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_get_vas_list( session: SessionDep, @@ -1673,7 +1673,7 @@ async def router_admin_get_vas_list( path='/{vas_id}', summary='获取增值服务详情', description='Get VAS detail by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_get_vas( session: SessionDep, @@ -1694,7 +1694,7 @@ async def router_admin_get_vas( path='/{vas_id}', summary='删除增值服务', description='Delete VAS by ID', - dependencies=[Depends(AdminRequired)] + dependencies=[Depends(admin_required)] ) async def router_admin_delete_vas( session: SessionDep, diff --git a/routers/api/v1/callback/__init__.py b/routers/api/v1/callback/__init__.py index 681f54f..38aaac7 100644 --- a/routers/api/v1/callback/__init__.py +++ b/routers/api/v1/callback/__init__.py @@ -1,8 +1,9 @@ -from fastapi import APIRouter, Depends, Query -from fastapi.responses import PlainTextResponse, RedirectResponse -from middleware.auth import AuthRequired +from fastapi import APIRouter, Query +from fastapi.responses import PlainTextResponse + from models import ResponseBase import service.oauth +from utils import http_exceptions callback_router = APIRouter( prefix='/callback', @@ -40,7 +41,7 @@ def router_callback_qq() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the QQ OAuth callback. """ - pass + http_exceptions.raise_not_implemented() @oauth_router.get( path='/github', @@ -86,7 +87,7 @@ def router_callback_alipay() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the Alipay payment callback. """ - pass + http_exceptions.raise_not_implemented() @pay_router.post( path='/wechat', @@ -100,7 +101,7 @@ def router_callback_wechat() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the WeChat Pay payment callback. """ - pass + http_exceptions.raise_not_implemented() @pay_router.post( path='/stripe', @@ -114,7 +115,7 @@ def router_callback_stripe() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the Stripe payment callback. """ - pass + http_exceptions.raise_not_implemented() @pay_router.get( path='/easypay', @@ -128,7 +129,7 @@ def router_callback_easypay() -> PlainTextResponse: Returns: PlainTextResponse: A response containing the payment status for the EasyPay payment callback. """ - pass + http_exceptions.raise_not_implemented() # return PlainTextResponse("success", status_code=200) @pay_router.get( @@ -147,7 +148,7 @@ def router_callback_custom(order_no: str, id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the custom payment callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.post( path='/remote/{session_id}/{key}', @@ -165,7 +166,7 @@ def router_callback_remote(session_id: str, key: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the remote upload callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.post( path='/qiniu/{session_id}', @@ -182,7 +183,7 @@ def router_callback_qiniu(session_id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the Qiniu Cloud upload callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.post( path='/tencent/{session_id}', @@ -199,7 +200,7 @@ def router_callback_tencent(session_id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the Tencent Cloud upload callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.post( path='/aliyun/{session_id}', @@ -216,7 +217,7 @@ def router_callback_aliyun(session_id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the Aliyun upload callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.post( path='/upyun/{session_id}', @@ -233,7 +234,7 @@ def router_callback_upyun(session_id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the Upyun upload callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.post( path='/aws/{session_id}', @@ -250,7 +251,7 @@ def router_callback_aws(session_id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the AWS S3 upload callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.post( path='/onedrive/finish/{session_id}', @@ -267,7 +268,7 @@ def router_callback_onedrive_finish(session_id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the OneDrive upload completion callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.get( path='/ondrive/auth', @@ -281,7 +282,7 @@ def router_callback_onedrive_auth() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the OneDrive authorization callback. """ - pass + http_exceptions.raise_not_implemented() @upload_router.get( path='/google/auth', @@ -295,4 +296,4 @@ def router_callback_google_auth() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the Google OAuth completion callback. """ - pass \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/routers/api/v1/directory/__init__.py b/routers/api/v1/directory/__init__.py index db34710..d7b3e5c 100644 --- a/routers/api/v1/directory/__init__.py +++ b/routers/api/v1/directory/__init__.py @@ -2,7 +2,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException -from middleware.auth import AuthRequired +from middleware.auth import auth_required from middleware.dependencies import SessionDep from models import ( DirectoryCreateRequest, @@ -26,7 +26,7 @@ directory_router = APIRouter( ) async def router_directory_get( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], path: str ) -> DirectoryResponse: """ @@ -94,7 +94,7 @@ async def router_directory_get( ) async def router_directory_create( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], request: DirectoryCreateRequest ) -> ResponseBase: """ diff --git a/routers/api/v1/download/__init__.py b/routers/api/v1/download/__init__.py index d47ffdd..f6bb6db 100644 --- a/routers/api/v1/download/__init__.py +++ b/routers/api/v1/download/__init__.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Depends -from middleware.auth import AuthRequired + +from middleware.auth import auth_required from models import ResponseBase +from utils import http_exceptions download_router = APIRouter( prefix="/download", @@ -18,7 +20,7 @@ download_router.include_router(aria2_router) path='/url', summary='创建URL下载任务', description='Create a URL download task endpoint.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_aria2_url() -> ResponseBase: """ @@ -27,13 +29,13 @@ def router_aria2_url() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the URL download task. """ - pass + http_exceptions.raise_not_implemented() @aria2_router.post( path='/torrent/{id}', summary='创建种子下载任务', description='Create a torrent download task endpoint.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_aria2_torrent(id: str) -> ResponseBase: """ @@ -45,13 +47,13 @@ def router_aria2_torrent(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the torrent download task. """ - pass + http_exceptions.raise_not_implemented() @aria2_router.put( path='/select/{gid}', summary='重新选择要下载的文件', description='Re-select files to download endpoint.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_aria2_select(gid: str) -> ResponseBase: """ @@ -63,13 +65,13 @@ def router_aria2_select(gid: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the re-selection of files. """ - pass + http_exceptions.raise_not_implemented() @aria2_router.delete( path='/task/{gid}', summary='取消或删除下载任务', description='Delete a download task endpoint.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_aria2_delete(gid: str) -> ResponseBase: """ @@ -81,13 +83,13 @@ def router_aria2_delete(gid: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the deletion of the download task. """ - pass + http_exceptions.raise_not_implemented() @aria2_router.get( '/downloading', summary='获取正在下载中的任务', description='Get currently downloading tasks endpoint.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_aria2_downloading() -> ResponseBase: """ @@ -96,13 +98,13 @@ def router_aria2_downloading() -> ResponseBase: Returns: ResponseBase: A model containing the response data for currently downloading tasks. """ - pass + http_exceptions.raise_not_implemented() @aria2_router.get( path='/finished', summary='获取已完成的任务', description='Get finished tasks endpoint.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_aria2_finished() -> ResponseBase: """ @@ -111,4 +113,4 @@ def router_aria2_finished() -> ResponseBase: Returns: ResponseBase: A model containing the response data for finished tasks. """ - pass \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py index 0c2dd95..4344a76 100644 --- a/routers/api/v1/file/__init__.py +++ b/routers/api/v1/file/__init__.py @@ -17,7 +17,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi.responses import FileResponse from loguru import logger as l -from middleware.auth import AuthRequired +from middleware.auth import auth_required from middleware.dependencies import SessionDep from models import ( CreateFileRequest, @@ -35,6 +35,7 @@ from models import ( ) from service.storage import LocalStorageService from utils.JWT import SECRET_KEY +from utils import http_exceptions # ==================== 下载令牌管理 ==================== @@ -88,7 +89,7 @@ _upload_router = APIRouter(prefix="/upload") ) async def create_upload_session( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], request: CreateUploadSessionRequest, ) -> UploadSessionResponse: """ @@ -187,7 +188,7 @@ async def create_upload_session( ) async def upload_chunk( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], session_id: UUID, chunk_index: int, file: UploadFile = File(...), @@ -291,7 +292,7 @@ async def upload_chunk( ) async def delete_upload_session( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], session_id: UUID, ) -> ResponseBase: """删除上传会话端点""" @@ -320,7 +321,7 @@ async def delete_upload_session( ) async def clear_upload_sessions( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], ) -> ResponseBase: """清除所有上传会话端点""" # 获取所有会话 @@ -368,7 +369,7 @@ _download_router = APIRouter(prefix="/download") ) async def create_download_token( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], file_id: UUID, ) -> ResponseBase: """ @@ -456,7 +457,7 @@ router.include_router(_download_router) ) async def create_empty_file( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], request: CreateFileRequest, ) -> ResponseBase: """创建空白文件端点""" @@ -564,7 +565,7 @@ async def file_source_redirect(id: str, name: str) -> ResponseBase: path='/update/{id}', summary='更新文件', description='更新文件内容。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_update(id: str) -> ResponseBase: """更新文件内容""" @@ -575,7 +576,7 @@ async def file_update(id: str) -> ResponseBase: path='/preview/{id}', summary='预览文件', description='获取文件预览。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_preview(id: str) -> ResponseBase: """预览文件""" @@ -586,7 +587,7 @@ async def file_preview(id: str) -> ResponseBase: path='/content/{id}', summary='获取文本文件内容', description='获取文本文件内容。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_content(id: str) -> ResponseBase: """获取文本文件内容""" @@ -597,7 +598,7 @@ async def file_content(id: str) -> ResponseBase: path='/doc/{id}', summary='获取Office文档预览地址', description='获取Office文档在线预览地址。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_doc(id: str) -> ResponseBase: """获取Office文档预览地址""" @@ -608,7 +609,7 @@ async def file_doc(id: str) -> ResponseBase: path='/thumb/{id}', summary='获取文件缩略图', description='获取文件缩略图。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_thumb(id: str) -> ResponseBase: """获取文件缩略图""" @@ -619,7 +620,7 @@ async def file_thumb(id: str) -> ResponseBase: path='/source/{id}', summary='取得文件外链', description='获取文件的外链地址。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_source(id: str) -> ResponseBase: """获取文件外链""" @@ -630,7 +631,7 @@ async def file_source(id: str) -> ResponseBase: path='/archive', summary='打包要下载的文件', description='将多个文件打包下载。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_archive() -> ResponseBase: """打包文件""" @@ -641,7 +642,7 @@ async def file_archive() -> ResponseBase: path='/compress', summary='创建文件压缩任务', description='创建文件压缩任务。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_compress() -> ResponseBase: """创建压缩任务""" @@ -652,7 +653,7 @@ async def file_compress() -> ResponseBase: path='/decompress', summary='创建文件解压任务', description='创建文件解压任务。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_decompress() -> ResponseBase: """创建解压任务""" @@ -663,7 +664,7 @@ async def file_decompress() -> ResponseBase: path='/relocate', summary='创建文件转移任务', description='创建文件转移任务。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_relocate() -> ResponseBase: """创建转移任务""" @@ -674,7 +675,7 @@ async def file_relocate() -> ResponseBase: path='/search/{type}/{keyword}', summary='搜索文件', description='按关键字搜索文件。', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) async def file_search(type: str, keyword: str) -> ResponseBase: """搜索文件""" diff --git a/routers/api/v1/object/__init__.py b/routers/api/v1/object/__init__.py index 567b441..e1114af 100644 --- a/routers/api/v1/object/__init__.py +++ b/routers/api/v1/object/__init__.py @@ -12,7 +12,7 @@ from fastapi import APIRouter, Depends, HTTPException from loguru import logger as l from sqlmodel.ext.asyncio.session import AsyncSession -from middleware.auth import AuthRequired +from middleware.auth import auth_required from middleware.dependencies import SessionDep from models import ( Object, @@ -171,7 +171,7 @@ async def _copy_object_recursive( ) async def router_object_delete( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], request: ObjectDeleteRequest, ) -> ResponseBase: """ @@ -224,7 +224,7 @@ async def router_object_delete( ) async def router_object_move( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], request: ObjectMoveRequest, ) -> ResponseBase: """ @@ -302,7 +302,7 @@ async def router_object_move( ) async def router_object_copy( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], request: ObjectCopyRequest, ) -> ResponseBase: """ @@ -394,7 +394,7 @@ async def router_object_copy( ) async def router_object_rename( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], request: ObjectRenameRequest, ) -> ResponseBase: """ @@ -465,7 +465,7 @@ async def router_object_rename( ) async def router_object_property( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], id: UUID, ) -> ObjectPropertyResponse: """ @@ -501,7 +501,7 @@ async def router_object_property( ) async def router_object_property_detail( session: SessionDep, - user: Annotated[User, Depends(AuthRequired)], + user: Annotated[User, Depends(auth_required)], id: UUID, ) -> ObjectPropertyDetailResponse: """ diff --git a/routers/api/v1/share/__init__.py b/routers/api/v1/share/__init__.py index 1f00e9c..79590df 100644 --- a/routers/api/v1/share/__init__.py +++ b/routers/api/v1/share/__init__.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Depends -from middleware.auth import AuthRequired + +from middleware.auth import auth_required from models import ResponseBase +from utils import http_exceptions share_router = APIRouter( prefix='/share', @@ -23,7 +25,7 @@ def router_share_get(info: str, id: str) -> ResponseBase: Returns: dict: A dictionary containing shared content information. """ - pass + http_exceptions.raise_not_implemented() @share_router.put( path='/download/{id}', @@ -40,7 +42,7 @@ def router_share_download(id: str) -> ResponseBase: Returns: dict: A dictionary containing download session information. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='preview/{id}', @@ -57,7 +59,7 @@ def router_share_preview(id: str) -> ResponseBase: Returns: dict: A dictionary containing preview information. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/doc/{id}', @@ -74,7 +76,7 @@ def router_share_doc(id: str) -> ResponseBase: Returns: dict: A dictionary containing the document preview URL. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/content/{id}', @@ -91,7 +93,7 @@ def router_share_content(id: str) -> ResponseBase: Returns: str: The content of the text file. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/list/{id}/{path:path}', @@ -109,7 +111,7 @@ def router_share_list(id: str, path: str = '') -> ResponseBase: Returns: dict: A dictionary containing directory listing information. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/search/{id}/{type}/{keywords}', @@ -128,7 +130,7 @@ def router_share_search(id: str, type: str, keywords: str) -> ResponseBase: Returns: dict: A dictionary containing search results. """ - pass + http_exceptions.raise_not_implemented() @share_router.post( path='/archive/{id}', @@ -145,7 +147,7 @@ def router_share_archive(id: str) -> ResponseBase: Returns: dict: A dictionary containing archive download information. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/readme/{id}', @@ -162,7 +164,7 @@ def router_share_readme(id: str) -> ResponseBase: Returns: str: The content of the README file. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/thumb/{id}/{file}', @@ -180,7 +182,7 @@ def router_share_thumb(id: str, file: str) -> ResponseBase: Returns: str: A Base64 encoded string of the thumbnail image. """ - pass + http_exceptions.raise_not_implemented() @share_router.post( path='/report/{id}', @@ -197,7 +199,7 @@ def router_share_report(id: str) -> ResponseBase: Returns: dict: A dictionary containing report submission information. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/search', @@ -215,7 +217,7 @@ def router_share_search_public(keywords: str, type: str = 'all') -> ResponseBase Returns: dict: A dictionary containing search results for public shares. """ - pass + http_exceptions.raise_not_implemented() ##################### # 需要登录的接口 @@ -225,7 +227,7 @@ def router_share_search_public(keywords: str, type: str = 'all') -> ResponseBase path='/', summary='创建新分享', description='Create a new share endpoint.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_share_create() -> ResponseBase: """ @@ -234,13 +236,13 @@ def router_share_create() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the new share creation. """ - pass + http_exceptions.raise_not_implemented() @share_router.get( path='/', summary='列出我的分享', description='Get a list of shares.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_share_list() -> ResponseBase: """ @@ -249,13 +251,13 @@ def router_share_list() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the list of shares. """ - pass + http_exceptions.raise_not_implemented() @share_router.post( path='/save/{id}', summary='转存他人分享', description='Save another user\'s share by ID.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_share_save(id: str) -> ResponseBase: """ @@ -267,13 +269,13 @@ def router_share_save(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the saved share. """ - pass + http_exceptions.raise_not_implemented() @share_router.patch( path='/{id}', summary='更新分享信息', description='Update share information by ID.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_share_update(id: str) -> ResponseBase: """ @@ -285,13 +287,13 @@ def router_share_update(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the updated share. """ - pass + http_exceptions.raise_not_implemented() @share_router.delete( path='/{id}', summary='删除分享', description='Delete a share by ID.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_share_delete(id: str) -> ResponseBase: """ @@ -303,4 +305,4 @@ def router_share_delete(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the deleted share. """ - pass \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/routers/api/v1/site/__init__.py b/routers/api/v1/site/__init__.py index ace0b1a..c4f1f2e 100644 --- a/routers/api/v1/site/__init__.py +++ b/routers/api/v1/site/__init__.py @@ -1,49 +1,29 @@ from fastapi import APIRouter from sqlalchemy import and_ -import json from middleware.dependencies import SessionDep -from models import ResponseBase -from models.setting import Setting +from models import ResponseBase, Setting, SettingsType, SiteConfigResponse +from utils import http_exceptions site_router = APIRouter( prefix="/site", tags=["site"], ) - -async def _get_setting(session: SessionDep, type_: str, name: str) -> str | None: - """获取设置值""" - setting = await Setting.get(session, and_(Setting.type == type_, Setting.name == name)) - return setting.value if setting else None - - -async def _get_setting_bool(session: SessionDep, type_: str, name: str) -> bool: - """获取布尔类型设置值""" - value = await _get_setting(session, type_, name) - return value == "1" if value else False - -async def _get_setting_json(session: SessionDep, type_: str, name: str) -> dict | list | None: - """获取 JSON 类型设置值""" - value = await _get_setting(session, type_, name) - return json.loads(value) if value else None - - @site_router.get( path="/ping", summary="测试用路由", description="A simple endpoint to check if the site is up and running.", response_model=ResponseBase, ) -def router_site_ping(): +def router_site_ping() -> ResponseBase: """ Ping the site to check if it is up and running. Returns: str: A message indicating the site is running. """ - from utils.conf.appmeta import BackendVersion - return ResponseBase(data=BackendVersion) + return ResponseBase() @site_router.get( @@ -59,7 +39,7 @@ def router_site_captcha(): Returns: str: A Base64 encoded string of the captcha image. """ - pass + http_exceptions.raise_not_implemented() @site_router.get( @@ -68,38 +48,13 @@ def router_site_captcha(): description='Get the configuration file.', response_model=ResponseBase, ) -async def router_site_config(session: SessionDep): +async def router_site_config(session: SessionDep) -> SiteConfigResponse: """ Get the configuration file. Returns: dict: The site configuration. """ - return ResponseBase( - data={ - "title": await _get_setting(session, "basic", "siteName"), - "loginCaptcha": await _get_setting_bool(session, "login", "login_captcha"), - "regCaptcha": await _get_setting_bool(session, "login", "reg_captcha"), - "forgetCaptcha": await _get_setting_bool(session, "login", "forget_captcha"), - "emailActive": await _get_setting_bool(session, "login", "email_active"), - "QQLogin": None, - "themes": await _get_setting_json(session, "basic", "themes"), - "defaultTheme": await _get_setting(session, "basic", "defaultTheme"), - "score_enabled": None, - "share_score_rate": None, - "home_view_method": await _get_setting(session, "view", "home_view_method"), - "share_view_method": await _get_setting(session, "view", "share_view_method"), - "authn": await _get_setting_bool(session, "authn", "authn_enabled"), - "user": {}, - "captcha_type": None, - "captcha_ReCaptchaKey": await _get_setting(session, "captcha", "captcha_ReCaptchaKey"), - "captcha_CloudflareKey": await _get_setting(session, "captcha", "captcha_CloudflareKey"), - "captcha_tcaptcha_appid": None, - "site_notice": None, - "registerEnabled": await _get_setting_bool(session, "register", "register_enabled"), - "app_promotion": None, - "wopi_exts": None, - "app_feedback": None, - "app_forum": None, - } + return SiteConfigResponse( + title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")), ) \ No newline at end of file diff --git a/routers/api/v1/slave/__init__.py b/routers/api/v1/slave/__init__.py index f62c684..ad751cb 100644 --- a/routers/api/v1/slave/__init__.py +++ b/routers/api/v1/slave/__init__.py @@ -1,7 +1,9 @@ from fastapi import APIRouter, Depends from fastapi.responses import FileResponse -from middleware.auth import AuthRequired + +from middleware.auth import auth_required from models import ResponseBase +from utils import http_exceptions slave_router = APIRouter( prefix="/slave", @@ -32,7 +34,7 @@ def router_slave_ping() -> ResponseBase: path='/post', summary='上传', description='Upload data to the server.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_post(data: str) -> ResponseBase: """ @@ -44,7 +46,7 @@ def router_slave_post(data: str) -> ResponseBase: Returns: ResponseBase: A response model indicating success. """ - pass + http_exceptions.raise_not_implemented() @slave_router.get( path='/get/{speed}/{path}/{name}', @@ -62,13 +64,13 @@ def router_slave_download(speed: int, path: str, name: str) -> ResponseBase: Returns: ResponseBase: A response model containing download information. """ - pass + http_exceptions.raise_not_implemented() @slave_router.get( path='/download/{sign}', summary='根据签名下载文件', description='Download a file based on its signature.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_download_by_sign(sign: str) -> FileResponse: """ @@ -80,13 +82,13 @@ def router_slave_download_by_sign(sign: str) -> FileResponse: Returns: FileResponse: A response containing the file to be downloaded. """ - pass + http_exceptions.raise_not_implemented() @slave_router.get( path='/source/{speed}/{path}/{name}', summary='获取文件外链', description='Get the external link for a file based on its signature.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_source(speed: int, path: str, name: str) -> ResponseBase: """ @@ -100,13 +102,13 @@ def router_slave_source(speed: int, path: str, name: str) -> ResponseBase: Returns: ResponseBase: A response model containing the external link for the file. """ - pass + http_exceptions.raise_not_implemented() @slave_router.get( path='/source/{sign}', summary='根据签名获取文件', description='Get a file based on its signature.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_source_by_sign(sign: str) -> FileResponse: """ @@ -118,13 +120,13 @@ def router_slave_source_by_sign(sign: str) -> FileResponse: Returns: FileResponse: A response containing the file to be retrieved. """ - pass + http_exceptions.raise_not_implemented() @slave_router.get( path='/thumb/{id}', summary='获取缩略图', description='Get a thumbnail image based on its ID.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_thumb(id: str) -> ResponseBase: """ @@ -136,13 +138,13 @@ def router_slave_thumb(id: str) -> ResponseBase: Returns: ResponseBase: A response model containing the Base64 encoded thumbnail image. """ - pass + http_exceptions.raise_not_implemented() @slave_router.delete( path='/delete', summary='删除文件', description='Delete a file from the server.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_delete(path: str) -> ResponseBase: """ @@ -154,25 +156,25 @@ def router_slave_delete(path: str) -> ResponseBase: Returns: ResponseBase: A response model indicating success or failure of the deletion. """ - pass + http_exceptions.raise_not_implemented() @slave_aria2_router.post( path='/test', summary='测试从机连接Aria2服务', description='Test the connection to the Aria2 service from the slave.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_aria2_test() -> ResponseBase: """ Test the connection to the Aria2 service from the slave. """ - pass + http_exceptions.raise_not_implemented() @slave_aria2_router.get( path='/get/{gid}', summary='获取Aria2任务信息', description='Get information about an Aria2 task by its GID.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_aria2_get(gid: str = None) -> ResponseBase: """ @@ -184,13 +186,13 @@ def router_slave_aria2_get(gid: str = None) -> ResponseBase: Returns: ResponseBase: A response model containing the task information. """ - pass + http_exceptions.raise_not_implemented() @slave_aria2_router.post( path='/add', summary='添加Aria2任务', description='Add a new Aria2 task.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> ResponseBase: """ @@ -204,13 +206,13 @@ def router_slave_aria2_add(gid: str, url: str, options: dict = None) -> Response Returns: ResponseBase: A response model indicating success or failure of the task addition. """ - pass + http_exceptions.raise_not_implemented() @slave_aria2_router.delete( path='/remove/{gid}', summary='删除Aria2任务', description='Remove an Aria2 task by its GID.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_slave_aria2_remove(gid: str) -> ResponseBase: """ @@ -222,4 +224,4 @@ def router_slave_aria2_remove(gid: str) -> ResponseBase: Returns: ResponseBase: A response model indicating success or failure of the task removal. """ - pass \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/routers/api/v1/tag/__init__.py b/routers/api/v1/tag/__init__.py index 23c02b4..edc1c7a 100644 --- a/routers/api/v1/tag/__init__.py +++ b/routers/api/v1/tag/__init__.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Depends -from middleware.auth import AuthRequired +from middleware.auth import auth_required + from models import ResponseBase +from utils import http_exceptions tag_router = APIRouter( prefix='/tag', @@ -11,7 +13,7 @@ tag_router = APIRouter( path='/filter', summary='创建文件分类标签', description='Create a file classification tag.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_tag_create_filter() -> ResponseBase: """ @@ -20,13 +22,13 @@ def router_tag_create_filter() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the created tag. """ - pass + http_exceptions.raise_not_implemented() @tag_router.post( path='/link', summary='创建目录快捷方式标签', description='Create a directory shortcut tag.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_tag_create_link() -> ResponseBase: """ @@ -35,13 +37,13 @@ def router_tag_create_link() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the created tag. """ - pass + http_exceptions.raise_not_implemented() @tag_router.delete( path='/{id}', summary='删除标签', description='Delete a tag by its ID.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_tag_delete(id: str) -> ResponseBase: """ @@ -53,4 +55,4 @@ def router_tag_delete(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the deletion operation. """ - pass \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py index 793c1d1..0892ae6 100644 --- a/routers/api/v1/user/__init__.py +++ b/routers/api/v1/user/__init__.py @@ -7,13 +7,14 @@ from sqlalchemy import and_ from webauthn import generate_registration_options from webauthn.helpers import options_to_json_dict from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired +from loguru import logger import models import service -from middleware.auth import AuthRequired +from middleware.auth import auth_required from middleware.dependencies import SessionDep from utils.JWT.JWT import SECRET_KEY -from utils import Password +from utils import Password, http_exceptions user_router = APIRouter( prefix="/user", @@ -23,7 +24,7 @@ user_router = APIRouter( user_settings_router = APIRouter( prefix='/user/settings', tags=["user", "user_settings"], - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) @user_router.post( @@ -42,11 +43,6 @@ async def router_user_session( 当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。 OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码 - - :raises HTTPException 401: 用户名或密码错误 - :raises HTTPException 403: 用户账号被封禁或未完成注册 - :raises HTTPException 428: 需要两步验证但未提供验证码 - :raises HTTPException 400: 两步验证码无效 """ username = form_data.username password = form_data.password @@ -62,7 +58,7 @@ async def router_user_session( otp_code = scope break - result = await service.user.Login( + result = await service.user.login( session, models.LoginRequest( username=username, @@ -71,22 +67,7 @@ async def router_user_session( ), ) - if isinstance(result, models.TokenResponse): - return result - elif result is None: - raise HTTPException(status_code=401, detail="Invalid username or password") - elif result is False: - raise HTTPException(status_code=403, detail="User account is banned or not fully registered") - elif result == "2fa_required": - raise HTTPException( - status_code=428, - detail="Two-factor authentication required", - headers={"X-2FA-Required": "true"}, - ) - elif result == "2fa_invalid": - raise HTTPException(status_code=400, detail="Invalid two-factor authentication code") - else: - raise HTTPException(status_code=500, detail="Internal server error during login") + return result @user_router.post( path='/session/refresh', @@ -97,7 +78,7 @@ async def router_user_session_refresh( session: SessionDep, request, # RefreshTokenRequest ) -> models.TokenResponse: - ... + http_exceptions.raise_not_implemented() @user_router.post( path='/', @@ -137,12 +118,14 @@ async def router_user_register( and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group") ) if default_group_setting is None or not default_group_setting.value: - raise HTTPException(status_code=500, detail="默认用户组设置不存在") + logger.error("默认用户组不存在") + http_exceptions.raise_internal_error() default_group_id = UUID(default_group_setting.value) default_group = await models.Group.get(session, models.Group.id == default_group_id) if not default_group: - raise HTTPException(status_code=500, detail="默认用户组不存在") + logger.error("默认用户组不存在") + http_exceptions.raise_internal_error() # 3. 创建用户 hashed_password = Password.hash(request.password) @@ -158,7 +141,8 @@ async def router_user_register( # 4. 创建以用户名命名的根目录 default_policy = await models.Policy.get(session, models.Policy.name == "本地存储") if not default_policy: - raise HTTPException(status_code=500, detail="默认存储策略不存在") + logger.error("默认存储策略不存在") + http_exceptions.raise_internal_error() await models.Object( name=new_user_username, @@ -190,7 +174,7 @@ def router_user_email_code( Returns: dict: A dictionary containing information about the password reset email. """ - pass + http_exceptions.raise_not_implemented() @user_router.get( path='/qq', @@ -204,7 +188,7 @@ def router_user_qq() -> models.ResponseBase: Returns: dict: A dictionary containing QQ login initialization information. """ - pass + http_exceptions.raise_not_implemented() @user_router.get( path='authn/{username}', @@ -213,7 +197,7 @@ def router_user_qq() -> models.ResponseBase: ) async def router_user_authn(username: str) -> models.ResponseBase: - pass + http_exceptions.raise_not_implemented() @user_router.post( path='authn/finish/{username}', @@ -230,7 +214,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase: Returns: dict: A dictionary containing WebAuthn login information. """ - pass + http_exceptions.raise_not_implemented() @user_router.get( path='/profile/{id}', @@ -247,7 +231,7 @@ def router_user_profile(id: str) -> models.ResponseBase: Returns: dict: A dictionary containing user profile information. """ - pass + http_exceptions.raise_not_implemented() @user_router.get( path='/avatar/{id}/{size}', @@ -265,7 +249,7 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase: Returns: str: A Base64 encoded string of the user avatar image. """ - pass + http_exceptions.raise_not_implemented() ##################### # 需要登录的接口 @@ -275,12 +259,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase: path='/me', summary='获取用户信息', description='Get user information.', - dependencies=[Depends(dependency=AuthRequired)], + dependencies=[Depends(dependency=auth_required)], response_model=models.ResponseBase, ) async def router_user_me( session: SessionDep, - user: Annotated[models.User, Depends(AuthRequired)], + user: Annotated[models.User, Depends(auth_required)], ) -> models.ResponseBase: """ 获取用户信息. @@ -319,11 +303,11 @@ async def router_user_me( path='/storage', summary='存储信息', description='Get user storage information.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) async def router_user_storage( session: SessionDep, - user: Annotated[models.user.User, Depends(AuthRequired)], + user: Annotated[models.user.User, Depends(auth_required)], ) -> models.ResponseBase: """ 获取用户存储空间信息。 @@ -353,11 +337,11 @@ async def router_user_storage( path='/authn/start', summary='WebAuthn登录初始化', description='Initialize WebAuthn login for a user.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) async def router_user_authn_start( session: SessionDep, - user: Annotated[models.user.User, Depends(AuthRequired)], + user: Annotated[models.user.User, Depends(auth_required)], ) -> models.ResponseBase: """ Initialize WebAuthn login for a user. @@ -395,7 +379,7 @@ async def router_user_authn_start( path='/authn/finish', summary='WebAuthn登录', description='Finish WebAuthn login for a user.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_user_authn_finish() -> models.ResponseBase: """ @@ -404,7 +388,7 @@ def router_user_authn_finish() -> models.ResponseBase: Returns: dict: A dictionary containing WebAuthn login information. """ - pass + http_exceptions.raise_not_implemented() @user_settings_router.get( path='/policies', @@ -418,13 +402,13 @@ def router_user_settings_policies() -> models.ResponseBase: Returns: dict: A dictionary containing available storage policies for the user. """ - pass + http_exceptions.raise_not_implemented() @user_settings_router.get( path='/nodes', summary='获取用户可选节点', description='Get user selectable nodes.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_user_settings_nodes() -> models.ResponseBase: """ @@ -433,13 +417,13 @@ def router_user_settings_nodes() -> models.ResponseBase: Returns: dict: A dictionary containing available nodes for the user. """ - pass + http_exceptions.raise_not_implemented() @user_settings_router.get( path='/tasks', summary='任务队列', description='Get user task queue.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_user_settings_tasks() -> models.ResponseBase: """ @@ -448,13 +432,13 @@ def router_user_settings_tasks() -> models.ResponseBase: Returns: dict: A dictionary containing the user's task queue information. """ - pass + http_exceptions.raise_not_implemented() @user_settings_router.get( path='/', summary='获取当前用户设定', description='Get current user settings.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_user_settings() -> models.ResponseBase: """ @@ -469,7 +453,7 @@ def router_user_settings() -> models.ResponseBase: path='/avatar', summary='从文件上传头像', description='Upload user avatar from file.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_user_settings_avatar() -> models.ResponseBase: """ @@ -478,13 +462,13 @@ def router_user_settings_avatar() -> models.ResponseBase: Returns: dict: A dictionary containing the result of the avatar upload. """ - pass + http_exceptions.raise_not_implemented() @user_settings_router.put( path='/avatar', summary='设定为Gravatar头像', description='Set user avatar to Gravatar.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_user_settings_avatar_gravatar() -> models.ResponseBase: """ @@ -493,13 +477,13 @@ def router_user_settings_avatar_gravatar() -> models.ResponseBase: Returns: dict: A dictionary containing the result of setting the Gravatar avatar. """ - pass + http_exceptions.raise_not_implemented() @user_settings_router.patch( path='/{option}', summary='更新用户设定', description='Update user settings.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_user_settings_patch(option: str) -> models.ResponseBase: """ @@ -511,16 +495,16 @@ def router_user_settings_patch(option: str) -> models.ResponseBase: Returns: dict: A dictionary containing the result of the settings update. """ - pass + http_exceptions.raise_not_implemented() @user_settings_router.get( path='/2fa', summary='获取两步验证初始化信息', description='Get two-factor authentication initialization information.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) async def router_user_settings_2fa( - user: Annotated[models.user.User, Depends(AuthRequired)], + user: Annotated[models.user.User, Depends(auth_required)], ) -> models.ResponseBase: """ Get two-factor authentication initialization information. @@ -537,11 +521,11 @@ async def router_user_settings_2fa( path='/2fa', summary='启用两步验证', description='Enable two-factor authentication.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) async def router_user_settings_2fa_enable( session: SessionDep, - user: Annotated[models.user.User, Depends(AuthRequired)], + user: Annotated[models.user.User, Depends(auth_required)], setup_token: str, code: str, ) -> models.ResponseBase: diff --git a/routers/api/v1/vas/__init__.py b/routers/api/v1/vas/__init__.py index 3fd42f1..c2c26ef 100644 --- a/routers/api/v1/vas/__init__.py +++ b/routers/api/v1/vas/__init__.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Depends -from middleware.auth import AuthRequired + +from middleware.auth import auth_required from models import ResponseBase +from utils import http_exceptions vas_router = APIRouter( prefix="/vas", @@ -11,7 +13,7 @@ vas_router = APIRouter( path='/pack', summary='获取容量包及配额信息', description='Get information about storage packs and quotas.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_vas_pack() -> ResponseBase: """ @@ -20,13 +22,13 @@ def router_vas_pack() -> ResponseBase: Returns: ResponseBase: A model containing the response data for storage packs and quotas. """ - pass + http_exceptions.raise_not_implemented() @vas_router.get( path='/product', summary='获取商品信息,同时返回支付信息', description='Get product information along with payment details.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_vas_product() -> ResponseBase: """ @@ -35,13 +37,13 @@ def router_vas_product() -> ResponseBase: Returns: ResponseBase: A model containing the response data for products and payment information. """ - pass + http_exceptions.raise_not_implemented() @vas_router.post( path='/order', summary='新建支付订单', description='Create an order for a product.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_vas_order() -> ResponseBase: """ @@ -50,13 +52,13 @@ def router_vas_order() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the created order. """ - pass + http_exceptions.raise_not_implemented() @vas_router.get( path='/order/{id}', summary='查询订单状态', description='Get information about a specific payment order by ID.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_vas_order_get(id: str) -> ResponseBase: """ @@ -68,13 +70,13 @@ def router_vas_order_get(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the specified order. """ - pass + http_exceptions.raise_not_implemented() @vas_router.get( path='/redeem', summary='获取兑换码信息', description='Get information about a specific redemption code.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_vas_redeem(code: str) -> ResponseBase: """ @@ -86,13 +88,13 @@ def router_vas_redeem(code: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the specified redemption code. """ - pass + http_exceptions.raise_not_implemented() @vas_router.post( path='/redeem', summary='执行兑换', description='Redeem a redemption code for a product or service.', - dependencies=[Depends(AuthRequired)] + dependencies=[Depends(auth_required)] ) def router_vas_redeem_post() -> ResponseBase: """ @@ -101,4 +103,4 @@ def router_vas_redeem_post() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the redeemed code. """ - pass \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/routers/api/v1/webdav/__init__.py b/routers/api/v1/webdav/__init__.py index 6448c26..d9047a9 100644 --- a/routers/api/v1/webdav/__init__.py +++ b/routers/api/v1/webdav/__init__.py @@ -1,6 +1,8 @@ -from fastapi import APIRouter, Depends, Request -from middleware.auth import AuthRequired +from fastapi import APIRouter, Depends + +from middleware.auth import auth_required from models import ResponseBase +from utils import http_exceptions # WebDAV 管理路由 webdav_router = APIRouter( @@ -12,7 +14,7 @@ webdav_router = APIRouter( path='/accounts', summary='获取账号信息', description='Get account information for WebDAV.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_webdav_accounts() -> ResponseBase: """ @@ -21,13 +23,13 @@ def router_webdav_accounts() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the account information. """ - pass + http_exceptions.raise_not_implemented() @webdav_router.post( path='/accounts', summary='新建账号', description='Create a new WebDAV account.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_webdav_create_account() -> ResponseBase: """ @@ -36,13 +38,13 @@ def router_webdav_create_account() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the created account. """ - pass + http_exceptions.raise_not_implemented() @webdav_router.delete( path='/accounts/{id}', summary='删除账号', description='Delete a WebDAV account by its ID.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_webdav_delete_account(id: str) -> ResponseBase: """ @@ -54,13 +56,13 @@ def router_webdav_delete_account(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the deletion operation. """ - pass + http_exceptions.raise_not_implemented() @webdav_router.post( path='/mount', summary='新建目录挂载', description='Create a new WebDAV mount point.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_webdav_create_mount() -> ResponseBase: """ @@ -69,13 +71,13 @@ def router_webdav_create_mount() -> ResponseBase: Returns: ResponseBase: A model containing the response data for the created mount point. """ - pass + http_exceptions.raise_not_implemented() @webdav_router.delete( path='/mount/{id}', summary='删除目录挂载', description='Delete a WebDAV mount point by its ID.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_webdav_delete_mount(id: str) -> ResponseBase: """ @@ -87,13 +89,13 @@ def router_webdav_delete_mount(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the deletion operation. """ - pass + http_exceptions.raise_not_implemented() @webdav_router.patch( path='accounts/{id}', summary='更新账号信息', description='Update WebDAV account information by ID.', - dependencies=[Depends(AuthRequired)], + dependencies=[Depends(auth_required)], ) def router_webdav_update_account(id: str) -> ResponseBase: """ @@ -105,4 +107,4 @@ def router_webdav_update_account(id: str) -> ResponseBase: Returns: ResponseBase: A model containing the response data for the updated account. """ - pass \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/service/user/__init__.py b/service/user/__init__.py index 751474d..ee9b18b 100644 --- a/service/user/__init__.py +++ b/service/user/__init__.py @@ -1 +1 @@ -from .login import Login \ No newline at end of file +from .login import login \ No newline at end of file diff --git a/service/user/login.py b/service/user/login.py index 6ca2527..3f9e7e5 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -1,25 +1,19 @@ -from typing import Literal - -from loguru import logger as log -from sqlmodel.ext.asyncio.session import AsyncSession +from loguru import logger +from middleware.dependencies import SessionDep from models import LoginRequest, TokenResponse, User +from utils import http_exceptions from utils.JWT.JWT import create_access_token, create_refresh_token from utils.password.pwd import Password, PasswordStatus -async def Login( - session: AsyncSession, +async def login( + session: SessionDep, login_request: LoginRequest, -) -> TokenResponse | bool | Literal["2fa_required", "2fa_invalid"] | None: +) -> TokenResponse: """ 根据账号密码进行登录。 - 如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。 - 如果登录异常,返回 `False`(未完成注册或账号被封禁)。 - 如果登录失败,返回 `None`。 - 如果需要两步验证但未提供验证码,返回 `"2fa_required"`。 - 如果两步验证码无效,返回 `"2fa_invalid"`。 :param session: 数据库会话 :param login_request: 登录请求 @@ -38,30 +32,29 @@ async def Login( # 验证用户是否存在 if not current_user: - log.debug(f"Cannot find user with username: {login_request.username}") - return None + logger.debug(f"Cannot find user with username: {login_request.username}") + http_exceptions.raise_unauthorized("Invalid username or password") # 验证密码是否正确 if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID: - log.debug(f"Password verification failed for user: {login_request.username}") - return None + logger.debug(f"Password verification failed for user: {login_request.username}") + http_exceptions.raise_unauthorized("Invalid username or password") # 验证用户是否可登录 if not current_user.status: - # 未完成注册 or 账号已被封禁 - return False + http_exceptions.raise_forbidden("Your account is disabled") # 检查两步验证 if current_user.two_factor: # 用户已启用两步验证 if not login_request.two_fa_code: - log.debug(f"2FA required for user: {login_request.username}") - return "2fa_required" + logger.debug(f"2FA required for user: {login_request.username}") + http_exceptions.raise_precondition_required("2FA required") # 验证 OTP 码 if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID: - log.debug(f"Invalid 2FA code for user: {login_request.username}") - return "2fa_invalid" + logger.debug(f"Invalid 2FA code for user: {login_request.username}") + http_exceptions.raise_unauthorized("Invalid 2FA code") # 创建令牌 access_token, access_expire = create_access_token(data={'sub': current_user.username}) diff --git a/tests/unit/service/test_login.py b/tests/unit/service/test_login.py index f72b90f..e57f34b 100644 --- a/tests/unit/service/test_login.py +++ b/tests/unit/service/test_login.py @@ -6,7 +6,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from models.user import User, LoginRequest, TokenResponse from models.group import Group -from service.user.login import Login +from service.user.login import login from utils.password.pwd import Password @@ -86,7 +86,7 @@ async def test_login_success(db_session: AsyncSession, setup_user): password=user_data["password"] ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert isinstance(result, TokenResponse) assert result.access_token is not None @@ -103,7 +103,7 @@ async def test_login_user_not_found(db_session: AsyncSession): password="any_password" ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert result is None @@ -116,7 +116,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user): password="wrong_password" ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert result is None @@ -129,7 +129,7 @@ async def test_login_user_banned(db_session: AsyncSession, setup_banned_user): password="password" ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert result is False @@ -145,7 +145,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user): # 未提供 two_fa_code ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert result == "2fa_required" @@ -161,7 +161,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user): two_fa_code="000000" # 错误的验证码 ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert result == "2fa_invalid" @@ -184,7 +184,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user): two_fa_code=valid_code ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert isinstance(result, TokenResponse) assert result.access_token is not None @@ -202,7 +202,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user): password=user_data["password"] ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) assert isinstance(result, TokenResponse) @@ -227,7 +227,7 @@ async def test_login_case_sensitive_username(db_session: AsyncSession, setup_use password=user_data["password"] ) - result = await Login(db_session, login_request) + result = await login(db_session, login_request) # 应该失败,因为用户名大小写不匹配 assert result is None diff --git a/utils/__init__.py b/utils/__init__.py index 48cbfe3..c07d153 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1 +1,2 @@ -from .password.pwd import Password, PasswordStatus \ No newline at end of file +from .password.pwd import Password, PasswordStatus +from .http import http_exceptions \ No newline at end of file diff --git a/utils/http/http_exceptions.py b/utils/http/http_exceptions.py index a5d837c..708ad50 100644 --- a/utils/http/http_exceptions.py +++ b/utils/http/http_exceptions.py @@ -1,20 +1,6 @@ from typing import Any, NoReturn -from fastapi import HTTPException - -from starlette.status import ( - HTTP_400_BAD_REQUEST, - HTTP_401_UNAUTHORIZED, - HTTP_402_PAYMENT_REQUIRED, - HTTP_403_FORBIDDEN, - HTTP_404_NOT_FOUND, - HTTP_409_CONFLICT, - HTTP_429_TOO_MANY_REQUESTS, - HTTP_500_INTERNAL_SERVER_ERROR, - HTTP_501_NOT_IMPLEMENTED, - HTTP_503_SERVICE_UNAVAILABLE, - HTTP_504_GATEWAY_TIMEOUT, -) +from fastapi import HTTPException, status # --- 400 --- @@ -24,50 +10,54 @@ def ensure_request_param(to_check: Any, detail: str) -> None: This function returns None if the check passes. """ if not to_check: - raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) def raise_bad_request(detail: str = '') -> NoReturn: """Raises an HTTP 400 Bad Request exception.""" - raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) def raise_unauthorized(detail: str) -> NoReturn: """Raises an HTTP 401 Unauthorized exception.""" - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=detail) def raise_insufficient_quota(detail: str = "积分不足,请充值") -> NoReturn: """Raises an HTTP 402 Payment Required exception.""" - raise HTTPException(status_code=HTTP_402_PAYMENT_REQUIRED, detail=detail) + raise HTTPException(status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=detail) def raise_forbidden(detail: str) -> NoReturn: """Raises an HTTP 403 Forbidden exception.""" - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) def raise_not_found(detail: str) -> NoReturn: """Raises an HTTP 404 Not Found exception.""" - raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail) def raise_conflict(detail: str) -> NoReturn: """Raises an HTTP 409 Conflict exception.""" - raise HTTPException(status_code=HTTP_409_CONFLICT, detail=detail) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=detail) + +def raise_precondition_required(detail: str) -> NoReturn: + """Raises an HTTP 428 Precondition required exception.""" + raise HTTPException(status_code=status.HTTP_428_PRECONDITION_REQUIRED, detail=detail) def raise_too_many_requests(detail: str) -> NoReturn: """Raises an HTTP 429 Too Many Requests exception.""" - raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=detail) + raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail) # --- 500 --- def raise_internal_error(detail: str = "服务器出现故障,请稍后再试或联系管理员") -> NoReturn: """Raises an HTTP 500 Internal Server Error exception.""" - raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail) def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn: """Raises an HTTP 501 Not Implemented exception.""" - raise HTTPException(status_code=HTTP_501_NOT_IMPLEMENTED, detail=detail) + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail) def raise_service_unavailable(detail: str) -> NoReturn: """Raises an HTTP 503 Service Unavailable exception.""" - raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail=detail) + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail) def raise_gateway_timeout(detail: str) -> NoReturn: """Raises an HTTP 504 Gateway Timeout exception.""" - raise HTTPException(status_code=HTTP_504_GATEWAY_TIMEOUT, detail=detail) + raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=detail) diff --git a/utils/password/pwd.py b/utils/password/pwd.py index 99cdda3..282bb30 100644 --- a/utils/password/pwd.py +++ b/utils/password/pwd.py @@ -1,4 +1,5 @@ import secrets + from loguru import logger from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError @@ -104,10 +105,11 @@ class Password: @staticmethod async def generate_totp( - username: str + *args, **kwargs ) -> TwoFactorResponse: """ 生成 TOTP 密钥和对应的 URI,用于两步验证。 + 所有的参数将会给到 `pyotp.totp.TOTP` :return: 包含 TOTP 密钥和 URI 的元组 """ @@ -121,8 +123,7 @@ class Password: salt="2fa-setup-salt" ) - otp_uri = pyotp.totp.TOTP(secret).provisioning_uri( - name=username, + otp_uri = pyotp.totp.TOTP(secret, *args, **kwargs).provisioning_uri( issuer_name=appmeta.APP_NAME ) @@ -134,17 +135,21 @@ class Password: @staticmethod def verify_totp( secret: str, - code: str + code: int, + *args, **kwargs ) -> PasswordStatus: """ 验证 TOTP 验证码。 :param secret: TOTP 密钥(Base32 编码) :param code: 用户输入的 6 位验证码 + :param args: 传入 `totp.verify` 的参数 + :param kwargs: 传入 `totp.verify` 的参数 + :return: 验证是否成功 """ totp = pyotp.TOTP(secret) - if totp.verify(code): + if totp.verify(otp=str(code), *args, **kwargs): return PasswordStatus.VALID else: return PasswordStatus.INVALID \ No newline at end of file