From 1ecc0fdc1ca918b70dc39c6fb1e25daaa248b578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Sun, 15 Feb 2026 17:07:20 +0800 Subject: [PATCH] feat: implement source link endpoints and enforce policy rules - Add POST/GET source link endpoints for file sharing via permanent URLs - Enforce max_size check in PATCH /file/content to prevent size limit bypass - Support is_private (proxy) vs public (302 redirect) storage modes - Replace all ResponseBase(data=...) with proper DTOs or 204 responses - Add 18 integration tests for source link and policy rule enforcement Co-Authored-By: Claude Opus 4.6 --- routers/api/v1/admin/__init__.py | 29 +- routers/api/v1/admin/group/__init__.py | 6 +- routers/api/v1/admin/policy/__init__.py | 163 ++++-- routers/api/v1/admin/share/__init__.py | 90 ++- routers/api/v1/admin/task/__init__.py | 72 ++- routers/api/v1/file/__init__.py | 270 +++++++-- routers/api/v1/object/__init__.py | 5 +- routers/api/v1/site/__init__.py | 2 + routers/api/v1/slave/__init__.py | 8 +- sqlmodels/migration.py | 2 + .../integration/api/test_file_source_link.py | 554 ++++++++++++++++++ 11 files changed, 1051 insertions(+), 150 deletions(-) create mode 100644 tests/integration/api/test_file_source_link.py diff --git a/routers/api/v1/admin/__init__.py b/routers/api/v1/admin/__init__.py index d0043e4..e9fdd5d 100644 --- a/routers/api/v1/admin/__init__.py +++ b/routers/api/v1/admin/__init__.py @@ -283,16 +283,17 @@ async def router_admin_get_settings( path='/test', summary='测试 Aria2 连接', description='Test Aria2 RPC connection', - dependencies=[Depends(admin_required)] + dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_aira2_test( request: Aria2TestRequest, -) -> ResponseBase: +) -> None: """ 测试 Aria2 RPC 连接。 :param request: 测试请求 - :return: 测试结果 + :raises HTTPException: 连接失败时抛出 400 """ import aiohttp @@ -307,22 +308,18 @@ async def router_admin_aira2_test( async with aiohttp.ClientSession() as client: async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp: if resp.status != 200: - return ResponseBase( - code=400, - msg=f"连接失败,HTTP {resp.status}" + raise HTTPException( + status_code=400, + detail=f"连接失败,HTTP {resp.status}", ) result = await resp.json() if "error" in result: - return ResponseBase( - code=400, - msg=f"Aria2 错误: {result['error']['message']}" + raise HTTPException( + status_code=400, + detail=f"Aria2 错误: {result['error']['message']}", ) - - version = result.get("result", {}).get("version", "unknown") - return ResponseBase(data={ - "connected": True, - "version": version, - }) + except HTTPException: + raise except Exception as e: - return ResponseBase(code=400, msg=f"连接失败: {str(e)}") \ No newline at end of file + raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}") \ No newline at end of file diff --git a/routers/api/v1/admin/group/__init__.py b/routers/api/v1/admin/group/__init__.py index dbfba35..fd18dbc 100644 --- a/routers/api/v1/admin/group/__init__.py +++ b/routers/api/v1/admin/group/__init__.py @@ -55,7 +55,7 @@ async def router_admin_get_groups( async def router_admin_get_group( session: SessionDep, group_id: UUID, -) -> ResponseBase: +) -> GroupDetailResponse: """ 根据用户组ID获取用户组详细信息。 @@ -71,9 +71,7 @@ async def router_admin_get_group( # 直接访问已加载的关系,无需额外查询 policies = group.policies user_count = await User.count(session, User.group_id == group_id) - response = GroupDetailResponse.from_group(group, user_count, policies) - - return ResponseBase(data=response.model_dump()) + return GroupDetailResponse.from_group(group, user_count, policies) @admin_group_router.get( diff --git a/routers/api/v1/admin/policy/__init__.py b/routers/api/v1/admin/policy/__init__.py index 81772f1..3eaf0ad 100644 --- a/routers/api/v1/admin/policy/__init__.py +++ b/routers/api/v1/admin/policy/__init__.py @@ -1,3 +1,4 @@ +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, HTTPException @@ -8,7 +9,8 @@ from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep from sqlmodels import ( Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase, - ListResponse, Object, ) + ListResponse, Object, +) from sqlmodel_ext import SQLModelBase from service.storage import DirectoryCreationError, LocalStorageService @@ -17,6 +19,78 @@ admin_policy_router = APIRouter( tags=['admin', 'admin_policy'] ) + +class PathTestResponse(SQLModelBase): + """路径测试响应""" + + path: str + """解析后的路径""" + + is_exists: bool + """路径是否存在""" + + is_writable: bool + """路径是否可写""" + + +class PolicyGroupInfo(SQLModelBase): + """策略关联的用户组信息""" + + id: str + """用户组UUID""" + + name: str + """用户组名称""" + + +class PolicyDetailResponse(SQLModelBase): + """存储策略详情响应""" + + id: str + """策略UUID""" + + name: str + """策略名称""" + + type: str + """策略类型""" + + server: str | None + """服务器地址""" + + bucket_name: str | None + """存储桶名称""" + + is_private: bool + """是否私有""" + + base_url: str | None + """基础URL""" + + max_size: int + """最大文件尺寸""" + + auto_rename: bool + """是否自动重命名""" + + dir_name_rule: str | None + """目录命名规则""" + + file_name_rule: str | None + """文件命名规则""" + + is_origin_link_enable: bool + """是否启用外链""" + + options: dict[str, Any] | None + """策略选项""" + + groups: list[PolicyGroupInfo] + """关联的用户组""" + + object_count: int + """使用此策略的对象数量""" + class PolicyTestPathRequest(SQLModelBase): """测试本地路径请求 DTO""" @@ -70,7 +144,7 @@ async def router_policy_list( ) async def router_policy_test_path( request: PolicyTestPathRequest, -) -> ResponseBase: +) -> PathTestResponse: """ 测试本地存储路径是否可用。 @@ -97,22 +171,23 @@ async def router_policy_test_path( except Exception: pass - return ResponseBase(data={ - "path": str(path), - "exists": is_exists, - "writable": is_writable, - }) + return PathTestResponse( + path=str(path), + is_exists=is_exists, + is_writable=is_writable, + ) @admin_policy_router.post( path='/test/slave', summary='测试从机通信', description='Test slave node communication', - dependencies=[Depends(admin_required)] + dependencies=[Depends(admin_required)], + status_code=204, ) async def router_policy_test_slave( request: PolicyTestSlaveRequest, -) -> ResponseBase: +) -> None: """ 测试从机RPC通信。 @@ -129,25 +204,28 @@ async def router_policy_test_slave( timeout=aiohttp.ClientTimeout(total=10) ) as resp: if resp.status == 200: - return ResponseBase(data={"connected": True}) + return else: - return ResponseBase( - code=400, - msg=f"从机响应错误,HTTP {resp.status}" + raise HTTPException( + status_code=400, + detail=f"从机响应错误,HTTP {resp.status}", ) + except HTTPException: + raise except Exception as e: - return ResponseBase(code=400, msg=f"连接失败: {str(e)}") + raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}") @admin_policy_router.post( path='/', summary='创建存储策略', description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。', - dependencies=[Depends(admin_required)] + dependencies=[Depends(admin_required)], + status_code=204, ) async def router_policy_add_policy( session: SessionDep, request: PolicyCreateRequest, -) -> ResponseBase: +) -> None: """ 创建存储策略端点 @@ -199,14 +277,7 @@ async def router_policy_add_policy( raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}") # 保存到数据库 - policy = await policy.save(session) - - return ResponseBase(data={ - "id": str(policy.id), - "name": policy.name, - "type": policy.type.value, - "server": policy.server, - }) + await policy.save(session) @admin_policy_router.post( path='/cors', @@ -274,7 +345,7 @@ async def router_policy_onddrive_oauth( async def router_policy_get_policy( session: SessionDep, policy_id: UUID, -) -> ResponseBase: +) -> PolicyDetailResponse: """ 获取存储策略详情。 @@ -292,35 +363,36 @@ async def router_policy_get_policy( # 统计使用此策略的对象数量 object_count = await Object.count(session, Object.policy_id == policy_id) - return ResponseBase(data={ - "id": str(policy.id), - "name": policy.name, - "type": policy.type.value, - "server": policy.server, - "bucket_name": policy.bucket_name, - "is_private": policy.is_private, - "base_url": policy.base_url, - "max_size": policy.max_size, - "auto_rename": policy.auto_rename, - "dir_name_rule": policy.dir_name_rule, - "file_name_rule": policy.file_name_rule, - "is_origin_link_enable": policy.is_origin_link_enable, - "options": policy.options.model_dump() if policy.options else None, - "groups": [{"id": str(g.id), "name": g.name} for g in groups], - "object_count": object_count, - }) + return PolicyDetailResponse( + id=str(policy.id), + name=policy.name, + type=policy.type.value, + server=policy.server, + bucket_name=policy.bucket_name, + is_private=policy.is_private, + base_url=policy.base_url, + max_size=policy.max_size, + auto_rename=policy.auto_rename, + dir_name_rule=policy.dir_name_rule, + file_name_rule=policy.file_name_rule, + is_origin_link_enable=policy.is_origin_link_enable, + options=policy.options.model_dump() if policy.options else None, + groups=[PolicyGroupInfo(id=str(g.id), name=g.name) for g in groups], + object_count=object_count, + ) @admin_policy_router.delete( path='/{policy_id}', summary='删除存储策略', description='Delete storage policy by ID', - dependencies=[Depends(admin_required)] + dependencies=[Depends(admin_required)], + status_code=204, ) async def router_policy_delete_policy( session: SessionDep, policy_id: UUID, -) -> ResponseBase: +) -> None: """ 删除存储策略。 @@ -345,5 +417,4 @@ async def router_policy_delete_policy( policy_name = policy.name await Policy.delete(session, policy) - l.info(f"管理员删除了存储策略: {policy_name}") - return ResponseBase(data={"deleted": True}) \ No newline at end of file + l.info(f"管理员删除了存储策略: {policy_name}") \ No newline at end of file diff --git a/routers/api/v1/admin/share/__init__.py b/routers/api/v1/admin/share/__init__.py index daeecfb..9fa17dd 100644 --- a/routers/api/v1/admin/share/__init__.py +++ b/routers/api/v1/admin/share/__init__.py @@ -1,3 +1,4 @@ +from datetime import datetime from uuid import UUID from fastapi import APIRouter, Depends, HTTPException @@ -6,8 +7,53 @@ from loguru import logger as l from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep from sqlmodels import ( - ResponseBase, ListResponse, - Share, AdminShareListItem, ) + ListResponse, + Share, AdminShareListItem, +) +from sqlmodel_ext import SQLModelBase + + +class ShareDetailResponse(SQLModelBase): + """分享详情响应""" + + id: UUID + """分享UUID""" + + code: str + """分享码""" + + views: int + """浏览次数""" + + downloads: int + """下载次数""" + + remain_downloads: int | None + """剩余下载次数""" + + expires: datetime | None + """过期时间""" + + preview_enabled: bool + """是否启用预览""" + + score: int + """评分""" + + has_password: bool + """是否有密码""" + + user_id: str + """用户UUID""" + + username: str | None + """用户名""" + + object: dict | None + """关联对象信息""" + + created_at: str + """创建时间""" admin_share_router = APIRouter( prefix='/share', @@ -54,7 +100,7 @@ async def router_admin_get_share_list( async def router_admin_get_share( session: SessionDep, share_id: UUID, -) -> ResponseBase: +) -> ShareDetailResponse: """ 获取分享详情。 @@ -69,38 +115,39 @@ async def router_admin_get_share( obj = await share.awaitable_attrs.object user = await share.awaitable_attrs.user - return ResponseBase(data={ - "id": share.id, - "code": share.code, - "views": share.views, - "downloads": share.downloads, - "remain_downloads": share.remain_downloads, - "expires": share.expires.isoformat() if share.expires else None, - "preview_enabled": share.preview_enabled, - "score": share.score, - "has_password": bool(share.password), - "user_id": str(share.user_id), - "username": user.email if user else None, - "object": { + return ShareDetailResponse( + id=share.id, + code=share.code, + views=share.views, + downloads=share.downloads, + remain_downloads=share.remain_downloads, + expires=share.expires, + preview_enabled=share.preview_enabled, + score=share.score, + has_password=bool(share.password), + user_id=str(share.user_id), + username=user.email if user else None, + object={ "id": str(obj.id), "name": obj.name, "type": obj.type.value, "size": obj.size, } if obj else None, - "created_at": share.created_at.isoformat(), - }) + created_at=share.created_at.isoformat(), + ) @admin_share_router.delete( path='/{share_id}', summary='删除分享', description='Delete share by ID', - dependencies=[Depends(admin_required)] + dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_delete_share( session: SessionDep, share_id: UUID, -) -> ResponseBase: +) -> None: """ 删除分享。 @@ -114,5 +161,4 @@ async def router_admin_delete_share( await Share.delete(session, share) - l.info(f"管理员删除了分享: {share.code}") - return ResponseBase(data={"deleted": True}) \ No newline at end of file + l.info(f"管理员删除了分享: {share.code}") \ No newline at end of file diff --git a/routers/api/v1/admin/task/__init__.py b/routers/api/v1/admin/task/__init__.py index f32246f..0f035e6 100644 --- a/routers/api/v1/admin/task/__init__.py +++ b/routers/api/v1/admin/task/__init__.py @@ -1,3 +1,4 @@ +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, HTTPException @@ -6,9 +7,44 @@ from loguru import logger as l from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep from sqlmodels import ( - ResponseBase, ListResponse, + ListResponse, Task, TaskSummary, ) +from sqlmodel_ext import SQLModelBase + + +class TaskDetailResponse(SQLModelBase): + """任务详情响应""" + + id: int + """任务ID""" + + status: int + """任务状态""" + + type: int + """任务类型""" + + progress: int + """任务进度""" + + error: str | None + """错误信息""" + + user_id: str + """用户UUID""" + + username: str | None + """用户名""" + + props: dict[str, Any] | None + """任务属性""" + + created_at: str + """创建时间""" + + updated_at: str + """更新时间""" admin_task_router = APIRouter( prefix='/task', @@ -67,7 +103,7 @@ async def router_admin_get_task_list( async def router_admin_get_task( session: SessionDep, task_id: int, -) -> ResponseBase: +) -> TaskDetailResponse: """ 获取任务详情。 @@ -82,30 +118,31 @@ async def router_admin_get_task( user = await task.awaitable_attrs.user props = await task.awaitable_attrs.props - return ResponseBase(data={ - "id": task.id, - "status": task.status, - "type": task.type, - "progress": task.progress, - "error": task.error, - "user_id": str(task.user_id), - "username": user.email if user else None, - "props": props.model_dump() if props else None, - "created_at": task.created_at.isoformat(), - "updated_at": task.updated_at.isoformat(), - }) + return TaskDetailResponse( + id=task.id, + status=task.status, + type=task.type, + progress=task.progress, + error=task.error, + user_id=str(task.user_id), + username=user.email if user else None, + props=props.model_dump() if props else None, + created_at=task.created_at.isoformat(), + updated_at=task.updated_at.isoformat(), + ) @admin_task_router.delete( path='/{task_id}', summary='删除任务', description='Delete task by ID', - dependencies=[Depends(admin_required)] + dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_delete_task( session: SessionDep, task_id: int, -) -> ResponseBase: +) -> None: """ 删除任务。 @@ -119,5 +156,4 @@ async def router_admin_delete_task( await Task.delete(session, task) - l.info(f"管理员删除了任务: {task_id}") - return ResponseBase(data={"deleted": True}) \ No newline at end of file + l.info(f"管理员删除了任务: {task_id}") \ No newline at end of file diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py index e675d85..ddb1c05 100644 --- a/routers/api/v1/file/__init__.py +++ b/routers/api/v1/file/__init__.py @@ -15,7 +15,7 @@ from uuid import UUID import whatthepatch from fastapi import APIRouter, Depends, File, HTTPException, UploadFile -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, RedirectResponse from loguru import logger as l from sqlmodel_ext import SQLModelBase from whatthepatch.exceptions import HunkApplyException @@ -37,6 +37,7 @@ from sqlmodels import ( ResponseBase, Setting, SettingsType, + SourceLink, UploadChunkResponse, UploadSession, UploadSessionResponse, @@ -94,6 +95,41 @@ class PatchContentResponse(ResponseBase): new_size: int """新文件字节大小""" + +class SourceLinkResponse(ResponseBase): + """外链响应""" + + url: str + """外链地址(永久有效,/source/ 端点自动 302 适配存储策略)""" + + downloads: int + """历史下载次数""" + + +def _check_policy_size_limit(policy: Policy, file_size: int) -> None: + """ + 检查文件大小是否超过策略限制 + + :param policy: 存储策略 + :param file_size: 文件大小(字节) + :raises HTTPException: 413 Payload Too Large + """ + if policy.max_size > 0 and file_size > policy.max_size: + raise HTTPException( + status_code=413, + detail=f"文件大小超过限制 ({policy.max_size} bytes)", + ) + + +async def _get_site_url(session: SessionDep) -> str: + """获取站点 URL""" + site_url_setting = await Setting.get( + session, + (Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"), + ) + return site_url_setting.value if site_url_setting else "http://localhost" + + # ==================== 主路由 ==================== router = APIRouter(prefix="/file", tags=["file"]) @@ -149,11 +185,7 @@ async def create_upload_session( raise HTTPException(status_code=404, detail="存储策略不存在") # 验证文件大小限制 - if policy.max_size > 0 and request.file_size > policy.max_size: - raise HTTPException( - status_code=413, - detail=f"文件大小超过限制 ({policy.max_size} bytes)" - ) + _check_policy_size_limit(policy, request.file_size) # 检查存储配额(auth_required 已预加载 user.group) max_storage = user.group.max_storage @@ -344,12 +376,13 @@ async def upload_chunk( path='/{session_id}', summary='删除上传会话', description='取消上传并删除会话及已上传的临时文件。', + status_code=204, ) async def delete_upload_session( session: SessionDep, user: Annotated[User, Depends(auth_required)], session_id: UUID, -) -> ResponseBase: +) -> None: """删除上传会话端点""" upload_session = await UploadSession.get(session, UploadSession.id == session_id) if not upload_session or upload_session.owner_id != user.id: @@ -366,18 +399,17 @@ async def delete_upload_session( l.info(f"删除上传会话: {session_id}") - return ResponseBase(data={"deleted": True}) - @_upload_router.delete( path='/', summary='清除所有上传会话', description='清除当前用户的所有上传会话。', + status_code=204, ) async def clear_upload_sessions( session: SessionDep, user: Annotated[User, Depends(auth_required)], -) -> ResponseBase: +) -> None: """清除所有上传会话端点""" # 获取所有会话 sessions = await UploadSession.get( @@ -399,8 +431,6 @@ async def clear_upload_sessions( l.info(f"清除用户 {user.id} 的所有上传会话,共 {deleted_count} 个") - return ResponseBase(data={"deleted": deleted_count}) - @_upload_router.get( path='/archive/{session_id}/archive.zip', @@ -527,12 +557,13 @@ router.include_router(viewers_router) path='/create', summary='创建空白文件', description='在指定目录下创建空白文件。', + status_code=204, ) async def create_empty_file( session: SessionDep, user: Annotated[User, Depends(auth_required)], request: CreateFileRequest, -) -> ResponseBase: +) -> None: """创建空白文件端点""" # 存储 user.id,避免后续 save() 导致 user 过期后无法访问 user_id = user.id @@ -608,12 +639,6 @@ async def create_empty_file( l.info(f"创建空白文件: {file_object.name}, id={file_object.id}") - return ResponseBase(data={ - "id": str(file_object.id), - "name": file_object.name, - "size": file_object.size, - }) - # ==================== WOPI 会话 ==================== @@ -724,28 +749,145 @@ async def create_wopi_session( # ==================== 文件外链(保留原有端点结构) ==================== +async def _validate_source_link( + session: SessionDep, + file_id: UUID, +) -> tuple[Object, SourceLink, PhysicalFile, Policy]: + """ + 验证外链访问的完整链路 + + :returns: (file_obj, link, physical_file, policy) + :raises HTTPException: 验证失败 + """ + file_obj = await Object.get( + session, + (Object.id == file_id) & (Object.deleted_at == None), + ) + if not file_obj: + http_exceptions.raise_not_found("文件不存在") + + if not file_obj.is_file: + http_exceptions.raise_bad_request("对象不是文件") + + if file_obj.is_banned: + http_exceptions.raise_banned() + + policy = await Policy.get(session, Policy.id == file_obj.policy_id) + if not policy: + http_exceptions.raise_internal_error("存储策略不存在") + + if not policy.is_origin_link_enable: + http_exceptions.raise_forbidden("当前存储策略未启用外链功能") + + # SourceLink 必须存在(只有主动创建过外链的文件才能通过外链访问) + link: SourceLink | None = await SourceLink.get( + session, + SourceLink.object_id == file_id, + ) + if not link: + http_exceptions.raise_not_found("外链不存在") + + physical_file = await file_obj.awaitable_attrs.physical_file + if not physical_file or not physical_file.storage_path: + http_exceptions.raise_internal_error("文件存储路径丢失") + + return file_obj, link, physical_file, policy + + @router.get( - path='/get/{id}/{name}', + path='/get/{file_id}/{name}', summary='文件外链(直接输出文件数据)', - description='通过外链直接获取文件内容。', + description='通过外链直接获取文件内容,公开访问无需认证。', ) async def file_get( session: SessionDep, - id: str, + file_id: UUID, name: str, ) -> FileResponse: - """文件外链端点(直接输出)""" - raise HTTPException(status_code=501, detail="外链功能暂未实现") + """ + 文件外链端点(直接输出) + + 公开访问,无需认证。通过 UUID 定位文件,URL 中的 name 仅用于 Content-Disposition。 + + 错误处理: + - 403: 存储策略未启用外链 / 文件被封禁 + - 404: 文件不存在 / 外链不存在 / 物理文件不存在 + """ + file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) + + if policy.type != PolicyType.LOCAL: + http_exceptions.raise_not_implemented("S3 存储暂未实现") + + storage_service = LocalStorageService(policy) + if not await storage_service.file_exists(physical_file.storage_path): + http_exceptions.raise_not_found("物理文件不存在") + + # 缓存物理路径(save 后对象属性会过期) + file_path = physical_file.storage_path + + # 递增下载次数 + link.downloads += 1 + await link.save(session) + + return FileResponse( + path=file_path, + filename=name, + media_type="application/octet-stream", + ) @router.get( - path='/source/{id}/{name}', - summary='文件外链(301跳转)', - description='通过外链获取文件重定向地址。', + path='/source/{file_id}/{name}', + summary='文件外链(302重定向或直接输出)', + description='通过外链获取文件,公有存储 302 重定向,私有存储直接输出。', + response_model=None, ) -async def file_source_redirect(id: str, name: str) -> ResponseBase: - """文件外链端点(301跳转)""" - raise HTTPException(status_code=501, detail="外链功能暂未实现") +async def file_source_redirect( + session: SessionDep, + file_id: UUID, + name: str, +) -> FileResponse | RedirectResponse: + """ + 文件外链端点(重定向/直接输出) + + 公开访问,无需认证。根据 policy.is_private 决定服务方式: + - is_private=False 且 base_url 非空:302 临时重定向 + - is_private=True 或 base_url 为空:直接返回文件内容 + + 错误处理: + - 403: 存储策略未启用外链 / 文件被封禁 + - 404: 文件不存在 / 外链不存在 / 物理文件不存在 + """ + file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) + + if policy.type != PolicyType.LOCAL: + http_exceptions.raise_not_implemented("S3 存储暂未实现") + + storage_service = LocalStorageService(policy) + if not await storage_service.file_exists(physical_file.storage_path): + http_exceptions.raise_not_found("物理文件不存在") + + # 缓存所有需要的值(save 后对象属性会过期) + file_path = physical_file.storage_path + is_private = policy.is_private + base_url = policy.base_url + + # 递增下载次数 + link.downloads += 1 + await link.save(session) + + # 公有存储:302 重定向到 base_url + if not is_private and base_url: + relative_path = storage_service.get_relative_path(file_path) + redirect_url = f"{base_url}/{relative_path}" + return RedirectResponse(url=redirect_url, status_code=302) + + # 私有存储或 base_url 为空:通过应用代理文件 + return FileResponse( + path=file_path, + filename=name, + media_type="application/octet-stream", + ) @router.put( @@ -903,6 +1045,9 @@ async def patch_file_content( new_bytes = new_text.encode('utf-8') + # 验证文件大小限制 + _check_policy_size_limit(policy, len(new_bytes)) + # 写入文件 await storage_service.write_file(storage_path, new_bytes) @@ -939,14 +1084,65 @@ async def file_thumb(id: str) -> ResponseBase: @router.post( - path='/source/{id}', - summary='取得文件外链', - description='获取文件的外链地址。', - dependencies=[Depends(auth_required)] + path='/source/{file_id}', + summary='创建/获取文件外链', + description='为指定文件创建或获取已有的外链地址。', ) -async def file_source(id: str) -> ResponseBase: - """获取文件外链""" - raise HTTPException(status_code=501, detail="外链功能暂未实现") +async def file_source( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], + file_id: UUID, +) -> SourceLinkResponse: + """ + 创建/获取文件外链端点 + + 检查 policy 是否启用外链,查找或创建 SourceLink,返回外链 URL。 + + 认证:JWT token 必填 + + 错误处理: + - 403: 存储策略未启用外链 + - 404: 文件不存在 + """ + file_obj = await Object.get( + session, + (Object.id == file_id) & (Object.deleted_at == None), + ) + if not file_obj or file_obj.owner_id != user.id: + http_exceptions.raise_not_found("文件不存在") + + if not file_obj.is_file: + http_exceptions.raise_bad_request("对象不是文件") + + if file_obj.is_banned: + http_exceptions.raise_banned() + + policy = await Policy.get(session, Policy.id == file_obj.policy_id) + if not policy: + http_exceptions.raise_internal_error("存储策略不存在") + + if not policy.is_origin_link_enable: + http_exceptions.raise_forbidden("当前存储策略未启用外链功能") + + # 缓存文件名(save 后对象属性会过期) + file_name = file_obj.name + + # 查找已有 SourceLink + link: SourceLink | None = await SourceLink.get( + session, + (SourceLink.object_id == file_id) & (SourceLink.name == file_name), + ) + if not link: + link = SourceLink( + name=file_name, + object_id=file_id, + ) + link = await link.save(session) + + site_url = await _get_site_url(session) + url = f"{site_url}/api/v1/file/source/{file_id}/{file_name}" + + return SourceLinkResponse(url=url, downloads=link.downloads) @router.post( diff --git a/routers/api/v1/object/__init__.py b/routers/api/v1/object/__init__.py index 560bc0e..7e7fcc4 100644 --- a/routers/api/v1/object/__init__.py +++ b/routers/api/v1/object/__init__.py @@ -26,7 +26,6 @@ from sqlmodels import ( PhysicalFile, Policy, PolicyType, - ResponseBase, User, ) from service.storage import ( @@ -439,9 +438,9 @@ async def router_object_rename( if '/' in new_name or '\\' in new_name: raise HTTPException(status_code=400, detail="名称不能包含斜杠") - # 如果名称没有变化,直接返回成功 + # 如果名称没有变化,直接返回 if obj.name == new_name: - return ResponseBase(data={"success": True}) + return # noqa: already 204 # 检查同目录下是否存在同名对象(仅检查未删除的) existing = await Object.get( diff --git a/routers/api/v1/site/__init__.py b/routers/api/v1/site/__init__.py index 3e5206d..484ec44 100644 --- a/routers/api/v1/site/__init__.py +++ b/routers/api/v1/site/__init__.py @@ -110,6 +110,8 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse: return SiteConfigResponse( title=s.get("siteName") or "DiskNext", + logo_light=s.get("logo_light") or None, + logo_dark=s.get("logo_dark") or None, register_enabled=s.get("register_enabled") == "1", login_captcha=s.get("login_captcha") == "1", reg_captcha=s.get("reg_captcha") == "1", diff --git a/routers/api/v1/slave/__init__.py b/routers/api/v1/slave/__init__.py index 37f5e0a..da0c3b8 100644 --- a/routers/api/v1/slave/__init__.py +++ b/routers/api/v1/slave/__init__.py @@ -20,15 +20,15 @@ slave_aria2_router = APIRouter( summary='测试用路由', description='Test route for checking connectivity.', ) -def router_slave_ping() -> ResponseBase: +def router_slave_ping() -> str: """ Test route for checking connectivity. - + Returns: - ResponseBase: A response model indicating success. + str: 后端版本号 """ from utils.conf.appmeta import BackendVersion - return ResponseBase(data=BackendVersion) + return BackendVersion @slave_router.post( path='/post', diff --git a/sqlmodels/migration.py b/sqlmodels/migration.py index 59aa43c..2f11347 100644 --- a/sqlmodels/migration.py +++ b/sqlmodels/migration.py @@ -109,6 +109,8 @@ default_settings: list[Setting] = [ Setting(name="pwa_display", value="standalone", type=SettingsType.PWA), Setting(name="pwa_theme_color", value="#000000", type=SettingsType.PWA), Setting(name="pwa_background_color", value="#ffffff", type=SettingsType.PWA), + Setting(name="logo_light", value="", type=SettingsType.BASIC), + Setting(name="logo_dark", value="", type=SettingsType.BASIC), # ==================== 认证方式配置 ==================== Setting(name="auth_email_password_enabled", value="1", type=SettingsType.AUTH), Setting(name="auth_phone_sms_enabled", value="0", type=SettingsType.AUTH), diff --git a/tests/integration/api/test_file_source_link.py b/tests/integration/api/test_file_source_link.py new file mode 100644 index 0000000..fefbf80 --- /dev/null +++ b/tests/integration/api/test_file_source_link.py @@ -0,0 +1,554 @@ +""" +文件外链与 Policy 规则集成测试 + +测试端点: +- POST /file/source/{file_id} 创建/获取文件外链 +- GET /file/get/{file_id}/{name} 外链直接输出 +- GET /file/source/{file_id}/{name} 外链重定向/输出 + +测试 Policy 规则: +- max_size 在 PATCH /file/content 中的检查 +- is_origin_link_enable 控制外链创建与访问 +- is_private + base_url 控制 302 重定向 vs 应用代理 +""" +import hashlib +from pathlib import Path +from uuid import UUID, uuid4 + +import pytest +import pytest_asyncio +from httpx import AsyncClient +from sqlalchemy import event +from sqlalchemy.engine import Engine +from sqlmodel.ext.asyncio.session import AsyncSession + +from sqlmodels import Object, ObjectType, PhysicalFile, Policy, PolicyType, SourceLink, User + + +@pytest.fixture(autouse=True) +def _register_sqlite_greatest(): + """注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法""" + + def _on_connect(dbapi_connection, connection_record): + if hasattr(dbapi_connection, 'create_function'): + dbapi_connection.create_function("greatest", 2, max) + + event.listen(Engine, "connect", _on_connect) + yield + event.remove(Engine, "connect", _on_connect) + + +# ==================== Fixtures ==================== + +@pytest_asyncio.fixture +async def source_policy( + initialized_db: AsyncSession, + tmp_path: Path, +) -> Policy: + """创建启用外链的本地存储策略""" + policy = Policy( + id=uuid4(), + name="测试外链存储", + type=PolicyType.LOCAL, + server=str(tmp_path), + is_origin_link_enable=True, + is_private=True, + max_size=0, + ) + initialized_db.add(policy) + await initialized_db.commit() + await initialized_db.refresh(policy) + return policy + + +@pytest_asyncio.fixture +async def source_file( + initialized_db: AsyncSession, + tmp_path: Path, + source_policy: Policy, +) -> dict[str, str | int]: + """创建一个文本测试文件,关联到启用外链的存储策略""" + user = await User.get(initialized_db, User.email == "testuser@test.local") + root = await Object.get_root(initialized_db, user.id) + + content = "A" * 50 + content_bytes = content.encode('utf-8') + content_hash = hashlib.sha256(content_bytes).hexdigest() + + file_path = tmp_path / "source_test.txt" + file_path.write_bytes(content_bytes) + + physical_file = PhysicalFile( + id=uuid4(), + storage_path=str(file_path), + size=len(content_bytes), + policy_id=source_policy.id, + reference_count=1, + ) + initialized_db.add(physical_file) + + file_obj = Object( + id=uuid4(), + name="source_test.txt", + type=ObjectType.FILE, + size=len(content_bytes), + physical_file_id=physical_file.id, + parent_id=root.id, + owner_id=user.id, + policy_id=source_policy.id, + ) + initialized_db.add(file_obj) + await initialized_db.commit() + + return { + "id": str(file_obj.id), + "name": "source_test.txt", + "content": content, + "hash": content_hash, + "size": len(content_bytes), + "path": str(file_path), + } + + +@pytest_asyncio.fixture +async def source_file_with_link( + initialized_db: AsyncSession, + source_file: dict[str, str | int], +) -> dict[str, str | int]: + """创建已有 SourceLink 的测试文件""" + link = SourceLink( + name=source_file["name"], + object_id=UUID(source_file["id"]), + downloads=5, + ) + initialized_db.add(link) + await initialized_db.commit() + await initialized_db.refresh(link) + + return {**source_file, "link_id": link.id, "link_downloads": 5} + + +# ==================== POST /file/source/{file_id} ==================== + +class TestCreateSourceLink: + """POST /file/source/{file_id} 端点测试""" + + @pytest.mark.asyncio + async def test_create_source_link_success( + self, + async_client: AsyncClient, + auth_headers: dict[str, str], + source_file: dict[str, str | int], + ) -> None: + """成功创建外链""" + response = await async_client.post( + f"/api/v1/file/source/{source_file['id']}", + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert "/api/v1/file/source/" in data["url"] + assert source_file["name"] in data["url"] + assert data["downloads"] == 0 + + @pytest.mark.asyncio + async def test_create_source_link_idempotent( + self, + async_client: AsyncClient, + auth_headers: dict[str, str], + source_file_with_link: dict[str, str | int], + ) -> None: + """已有外链时返回现有外链(幂等)""" + response = await async_client.post( + f"/api/v1/file/source/{source_file_with_link['id']}", + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["downloads"] == source_file_with_link["link_downloads"] + + @pytest.mark.asyncio + async def test_create_source_link_disabled_returns_403( + self, + async_client: AsyncClient, + auth_headers: dict[str, str], + source_file: dict[str, str | int], + source_policy: Policy, + initialized_db: AsyncSession, + ) -> None: + """存储策略未启用外链时返回 403""" + source_policy.is_origin_link_enable = False + initialized_db.add(source_policy) + await initialized_db.commit() + + response = await async_client.post( + f"/api/v1/file/source/{source_file['id']}", + headers=auth_headers, + ) + + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_create_source_link_file_not_found( + self, + async_client: AsyncClient, + auth_headers: dict[str, str], + ) -> None: + """文件不存在返回 404""" + response = await async_client.post( + f"/api/v1/file/source/{uuid4()}", + headers=auth_headers, + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_create_source_link_unauthenticated( + self, + async_client: AsyncClient, + source_file: dict[str, str | int], + ) -> None: + """未认证返回 401""" + response = await async_client.post( + f"/api/v1/file/source/{source_file['id']}", + ) + + assert response.status_code == 401 + + +# ==================== GET /file/get/{file_id}/{name} ==================== + +class TestFileGetDirect: + """GET /file/get/{file_id}/{name} 端点测试""" + + @pytest.mark.asyncio + async def test_get_direct_success( + self, + async_client: AsyncClient, + source_file_with_link: dict[str, str | int], + ) -> None: + """成功通过外链直接获取文件(无需认证)""" + response = await async_client.get( + f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}", + ) + + assert response.status_code == 200 + assert source_file_with_link["content"] in response.text + + @pytest.mark.asyncio + async def test_get_direct_increments_download_count( + self, + async_client: AsyncClient, + auth_headers: dict[str, str], + source_file_with_link: dict[str, str | int], + initialized_db: AsyncSession, + ) -> None: + """下载后递增计数""" + link_before = await SourceLink.get( + initialized_db, + SourceLink.object_id == UUID(source_file_with_link["id"]), + ) + downloads_before = link_before.downloads + + await async_client.get( + f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}", + ) + + await initialized_db.refresh(link_before) + assert link_before.downloads == downloads_before + 1 + + @pytest.mark.asyncio + async def test_get_direct_no_link_returns_404( + self, + async_client: AsyncClient, + source_file: dict[str, str | int], + ) -> None: + """未创建外链的文件返回 404""" + response = await async_client.get( + f"/api/v1/file/get/{source_file['id']}/{source_file['name']}", + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_get_direct_nonexistent_file_returns_404( + self, + async_client: AsyncClient, + ) -> None: + """文件不存在返回 404""" + response = await async_client.get( + f"/api/v1/file/get/{uuid4()}/fake.txt", + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_get_direct_disabled_policy_returns_403( + self, + async_client: AsyncClient, + source_file_with_link: dict[str, str | int], + source_policy: Policy, + initialized_db: AsyncSession, + ) -> None: + """存储策略禁用外链时返回 403""" + source_policy.is_origin_link_enable = False + initialized_db.add(source_policy) + await initialized_db.commit() + + response = await async_client.get( + f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}", + ) + + assert response.status_code == 403 + + +# ==================== GET /file/source/{file_id}/{name} ==================== + +class TestFileSourceRedirect: + """GET /file/source/{file_id}/{name} 端点测试""" + + @pytest.mark.asyncio + async def test_source_private_returns_file_content( + self, + async_client: AsyncClient, + source_file_with_link: dict[str, str | int], + ) -> None: + """is_private=True 时直接返回文件内容""" + response = await async_client.get( + f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}", + follow_redirects=False, + ) + + assert response.status_code == 200 + assert source_file_with_link["content"] in response.text + + @pytest.mark.asyncio + async def test_source_public_redirects_302( + self, + async_client: AsyncClient, + source_file_with_link: dict[str, str | int], + source_policy: Policy, + initialized_db: AsyncSession, + ) -> None: + """is_private=False + base_url 时 302 重定向""" + source_policy.is_private = False + source_policy.base_url = "http://cdn.example.com/storage" + initialized_db.add(source_policy) + await initialized_db.commit() + + response = await async_client.get( + f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}", + follow_redirects=False, + ) + + assert response.status_code == 302 + location = response.headers["location"] + assert "cdn.example.com/storage" in location + + @pytest.mark.asyncio + async def test_source_public_no_base_url_fallback( + self, + async_client: AsyncClient, + source_file_with_link: dict[str, str | int], + source_policy: Policy, + initialized_db: AsyncSession, + ) -> None: + """is_private=False 但 base_url 为空时降级为直接输出""" + source_policy.is_private = False + source_policy.base_url = None + initialized_db.add(source_policy) + await initialized_db.commit() + + response = await async_client.get( + f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}", + follow_redirects=False, + ) + + assert response.status_code == 200 + assert source_file_with_link["content"] in response.text + + @pytest.mark.asyncio + async def test_source_increments_download_count( + self, + async_client: AsyncClient, + source_file_with_link: dict[str, str | int], + initialized_db: AsyncSession, + ) -> None: + """访问外链递增下载计数""" + link_before = await SourceLink.get( + initialized_db, + SourceLink.object_id == UUID(source_file_with_link["id"]), + ) + downloads_before = link_before.downloads + + await async_client.get( + f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}", + ) + + await initialized_db.refresh(link_before) + assert link_before.downloads == downloads_before + 1 + + @pytest.mark.asyncio + async def test_source_no_link_returns_404( + self, + async_client: AsyncClient, + source_file: dict[str, str | int], + ) -> None: + """未创建外链的文件返回 404""" + response = await async_client.get( + f"/api/v1/file/source/{source_file['id']}/{source_file['name']}", + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_source_disabled_policy_returns_403( + self, + async_client: AsyncClient, + source_file_with_link: dict[str, str | int], + source_policy: Policy, + initialized_db: AsyncSession, + ) -> None: + """存储策略禁用外链时返回 403""" + source_policy.is_origin_link_enable = False + initialized_db.add(source_policy) + await initialized_db.commit() + + response = await async_client.get( + f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}", + ) + + assert response.status_code == 403 + + +# ==================== max_size 在 PATCH 中的检查 ==================== + +class TestPatchMaxSizePolicy: + """PATCH /file/content/{file_id} 的 max_size 策略检查""" + + @pytest_asyncio.fixture + async def size_limited_policy( + self, + initialized_db: AsyncSession, + tmp_path: Path, + ) -> Policy: + """创建有大小限制的存储策略(100 bytes)""" + policy = Policy( + id=uuid4(), + name="限制大小存储", + type=PolicyType.LOCAL, + server=str(tmp_path), + max_size=100, + ) + initialized_db.add(policy) + await initialized_db.commit() + await initialized_db.refresh(policy) + return policy + + @pytest_asyncio.fixture + async def small_file( + self, + initialized_db: AsyncSession, + tmp_path: Path, + size_limited_policy: Policy, + ) -> dict[str, str | int]: + """创建一个 50 字节的文本文件(策略限制 100 字节)""" + user = await User.get(initialized_db, User.email == "testuser@test.local") + root = await Object.get_root(initialized_db, user.id) + + content = "A" * 50 + content_bytes = content.encode('utf-8') + content_hash = hashlib.sha256(content_bytes).hexdigest() + + file_path = tmp_path / "small.txt" + file_path.write_bytes(content_bytes) + + physical_file = PhysicalFile( + id=uuid4(), + storage_path=str(file_path), + size=len(content_bytes), + policy_id=size_limited_policy.id, + reference_count=1, + ) + initialized_db.add(physical_file) + + file_obj = Object( + id=uuid4(), + name="small.txt", + type=ObjectType.FILE, + size=len(content_bytes), + physical_file_id=physical_file.id, + parent_id=root.id, + owner_id=user.id, + policy_id=size_limited_policy.id, + ) + initialized_db.add(file_obj) + await initialized_db.commit() + + return { + "id": str(file_obj.id), + "content": content, + "hash": content_hash, + "size": len(content_bytes), + "path": str(file_path), + } + + @pytest.mark.asyncio + async def test_patch_exceeds_max_size_returns_413( + self, + async_client: AsyncClient, + auth_headers: dict[str, str], + small_file: dict[str, str | int], + ) -> None: + """PATCH 后文件超过 max_size 返回 413""" + big_content = "B" * 200 + patch_text = ( + "--- a\n" + "+++ b\n" + "@@ -1 +1 @@\n" + f"-{'A' * 50}\n" + f"+{big_content}\n" + ) + + response = await async_client.patch( + f"/api/v1/file/content/{small_file['id']}", + headers=auth_headers, + json={ + "patch": patch_text, + "base_hash": small_file["hash"], + }, + ) + + assert response.status_code == 413 + + @pytest.mark.asyncio + async def test_patch_within_max_size_succeeds( + self, + async_client: AsyncClient, + auth_headers: dict[str, str], + small_file: dict[str, str | int], + ) -> None: + """PATCH 后文件未超过 max_size 正常保存""" + new_content = "C" * 80 # 80 bytes < 100 bytes limit + patch_text = ( + "--- a\n" + "+++ b\n" + "@@ -1 +1 @@\n" + f"-{'A' * 50}\n" + f"+{new_content}\n" + ) + + response = await async_client.patch( + f"/api/v1/file/content/{small_file['id']}", + headers=auth_headers, + json={ + "patch": patch_text, + "base_hash": small_file["hash"], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["new_size"] == 80