feat: implement source link endpoints and enforce policy rules
- Add POST/GET source link endpoints for file sharing via permanent URLs - Enforce max_size check in PATCH /file/content to prevent size limit bypass - Support is_private (proxy) vs public (302 redirect) storage modes - Replace all ResponseBase(data=...) with proper DTOs or 204 responses - Add 18 integration tests for source link and policy rule enforcement Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -283,16 +283,17 @@ async def router_admin_get_settings(
|
|||||||
path='/test',
|
path='/test',
|
||||||
summary='测试 Aria2 连接',
|
summary='测试 Aria2 连接',
|
||||||
description='Test Aria2 RPC connection',
|
description='Test Aria2 RPC connection',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_admin_aira2_test(
|
async def router_admin_aira2_test(
|
||||||
request: Aria2TestRequest,
|
request: Aria2TestRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
测试 Aria2 RPC 连接。
|
测试 Aria2 RPC 连接。
|
||||||
|
|
||||||
:param request: 测试请求
|
:param request: 测试请求
|
||||||
:return: 测试结果
|
:raises HTTPException: 连接失败时抛出 400
|
||||||
"""
|
"""
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@@ -307,22 +308,18 @@ async def router_admin_aira2_test(
|
|||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
return ResponseBase(
|
raise HTTPException(
|
||||||
code=400,
|
status_code=400,
|
||||||
msg=f"连接失败,HTTP {resp.status}"
|
detail=f"连接失败,HTTP {resp.status}",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await resp.json()
|
result = await resp.json()
|
||||||
if "error" in result:
|
if "error" in result:
|
||||||
return ResponseBase(
|
raise HTTPException(
|
||||||
code=400,
|
status_code=400,
|
||||||
msg=f"Aria2 错误: {result['error']['message']}"
|
detail=f"Aria2 错误: {result['error']['message']}",
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
version = result.get("result", {}).get("version", "unknown")
|
raise
|
||||||
return ResponseBase(data={
|
|
||||||
"connected": True,
|
|
||||||
"version": version,
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ResponseBase(code=400, msg=f"连接失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
|
||||||
@@ -55,7 +55,7 @@ async def router_admin_get_groups(
|
|||||||
async def router_admin_get_group(
|
async def router_admin_get_group(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
group_id: UUID,
|
group_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> GroupDetailResponse:
|
||||||
"""
|
"""
|
||||||
根据用户组ID获取用户组详细信息。
|
根据用户组ID获取用户组详细信息。
|
||||||
|
|
||||||
@@ -71,9 +71,7 @@ async def router_admin_get_group(
|
|||||||
# 直接访问已加载的关系,无需额外查询
|
# 直接访问已加载的关系,无需额外查询
|
||||||
policies = group.policies
|
policies = group.policies
|
||||||
user_count = await User.count(session, User.group_id == group_id)
|
user_count = await User.count(session, User.group_id == group_id)
|
||||||
response = GroupDetailResponse.from_group(group, user_count, policies)
|
return GroupDetailResponse.from_group(group, user_count, policies)
|
||||||
|
|
||||||
return ResponseBase(data=response.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
@admin_group_router.get(
|
@admin_group_router.get(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
@@ -8,7 +9,8 @@ from middleware.auth import admin_required
|
|||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||||
ListResponse, Object, )
|
ListResponse, Object,
|
||||||
|
)
|
||||||
from sqlmodel_ext import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
from service.storage import DirectoryCreationError, LocalStorageService
|
from service.storage import DirectoryCreationError, LocalStorageService
|
||||||
|
|
||||||
@@ -17,6 +19,78 @@ admin_policy_router = APIRouter(
|
|||||||
tags=['admin', 'admin_policy']
|
tags=['admin', 'admin_policy']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PathTestResponse(SQLModelBase):
|
||||||
|
"""路径测试响应"""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
"""解析后的路径"""
|
||||||
|
|
||||||
|
is_exists: bool
|
||||||
|
"""路径是否存在"""
|
||||||
|
|
||||||
|
is_writable: bool
|
||||||
|
"""路径是否可写"""
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyGroupInfo(SQLModelBase):
|
||||||
|
"""策略关联的用户组信息"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""用户组UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""用户组名称"""
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyDetailResponse(SQLModelBase):
|
||||||
|
"""存储策略详情响应"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""策略UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""策略名称"""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
"""策略类型"""
|
||||||
|
|
||||||
|
server: str | None
|
||||||
|
"""服务器地址"""
|
||||||
|
|
||||||
|
bucket_name: str | None
|
||||||
|
"""存储桶名称"""
|
||||||
|
|
||||||
|
is_private: bool
|
||||||
|
"""是否私有"""
|
||||||
|
|
||||||
|
base_url: str | None
|
||||||
|
"""基础URL"""
|
||||||
|
|
||||||
|
max_size: int
|
||||||
|
"""最大文件尺寸"""
|
||||||
|
|
||||||
|
auto_rename: bool
|
||||||
|
"""是否自动重命名"""
|
||||||
|
|
||||||
|
dir_name_rule: str | None
|
||||||
|
"""目录命名规则"""
|
||||||
|
|
||||||
|
file_name_rule: str | None
|
||||||
|
"""文件命名规则"""
|
||||||
|
|
||||||
|
is_origin_link_enable: bool
|
||||||
|
"""是否启用外链"""
|
||||||
|
|
||||||
|
options: dict[str, Any] | None
|
||||||
|
"""策略选项"""
|
||||||
|
|
||||||
|
groups: list[PolicyGroupInfo]
|
||||||
|
"""关联的用户组"""
|
||||||
|
|
||||||
|
object_count: int
|
||||||
|
"""使用此策略的对象数量"""
|
||||||
|
|
||||||
class PolicyTestPathRequest(SQLModelBase):
|
class PolicyTestPathRequest(SQLModelBase):
|
||||||
"""测试本地路径请求 DTO"""
|
"""测试本地路径请求 DTO"""
|
||||||
|
|
||||||
@@ -70,7 +144,7 @@ async def router_policy_list(
|
|||||||
)
|
)
|
||||||
async def router_policy_test_path(
|
async def router_policy_test_path(
|
||||||
request: PolicyTestPathRequest,
|
request: PolicyTestPathRequest,
|
||||||
) -> ResponseBase:
|
) -> PathTestResponse:
|
||||||
"""
|
"""
|
||||||
测试本地存储路径是否可用。
|
测试本地存储路径是否可用。
|
||||||
|
|
||||||
@@ -97,22 +171,23 @@ async def router_policy_test_path(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ResponseBase(data={
|
return PathTestResponse(
|
||||||
"path": str(path),
|
path=str(path),
|
||||||
"exists": is_exists,
|
is_exists=is_exists,
|
||||||
"writable": is_writable,
|
is_writable=is_writable,
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.post(
|
@admin_policy_router.post(
|
||||||
path='/test/slave',
|
path='/test/slave',
|
||||||
summary='测试从机通信',
|
summary='测试从机通信',
|
||||||
description='Test slave node communication',
|
description='Test slave node communication',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_policy_test_slave(
|
async def router_policy_test_slave(
|
||||||
request: PolicyTestSlaveRequest,
|
request: PolicyTestSlaveRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
测试从机RPC通信。
|
测试从机RPC通信。
|
||||||
|
|
||||||
@@ -129,25 +204,28 @@ async def router_policy_test_slave(
|
|||||||
timeout=aiohttp.ClientTimeout(total=10)
|
timeout=aiohttp.ClientTimeout(total=10)
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
return ResponseBase(data={"connected": True})
|
return
|
||||||
else:
|
else:
|
||||||
return ResponseBase(
|
raise HTTPException(
|
||||||
code=400,
|
status_code=400,
|
||||||
msg=f"从机响应错误,HTTP {resp.status}"
|
detail=f"从机响应错误,HTTP {resp.status}",
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ResponseBase(code=400, msg=f"连接失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
|
||||||
|
|
||||||
@admin_policy_router.post(
|
@admin_policy_router.post(
|
||||||
path='/',
|
path='/',
|
||||||
summary='创建存储策略',
|
summary='创建存储策略',
|
||||||
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
|
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_policy_add_policy(
|
async def router_policy_add_policy(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
request: PolicyCreateRequest,
|
request: PolicyCreateRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
创建存储策略端点
|
创建存储策略端点
|
||||||
|
|
||||||
@@ -199,14 +277,7 @@ async def router_policy_add_policy(
|
|||||||
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
|
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
policy = await policy.save(session)
|
await policy.save(session)
|
||||||
|
|
||||||
return ResponseBase(data={
|
|
||||||
"id": str(policy.id),
|
|
||||||
"name": policy.name,
|
|
||||||
"type": policy.type.value,
|
|
||||||
"server": policy.server,
|
|
||||||
})
|
|
||||||
|
|
||||||
@admin_policy_router.post(
|
@admin_policy_router.post(
|
||||||
path='/cors',
|
path='/cors',
|
||||||
@@ -274,7 +345,7 @@ async def router_policy_onddrive_oauth(
|
|||||||
async def router_policy_get_policy(
|
async def router_policy_get_policy(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
policy_id: UUID,
|
policy_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> PolicyDetailResponse:
|
||||||
"""
|
"""
|
||||||
获取存储策略详情。
|
获取存储策略详情。
|
||||||
|
|
||||||
@@ -292,35 +363,36 @@ async def router_policy_get_policy(
|
|||||||
# 统计使用此策略的对象数量
|
# 统计使用此策略的对象数量
|
||||||
object_count = await Object.count(session, Object.policy_id == policy_id)
|
object_count = await Object.count(session, Object.policy_id == policy_id)
|
||||||
|
|
||||||
return ResponseBase(data={
|
return PolicyDetailResponse(
|
||||||
"id": str(policy.id),
|
id=str(policy.id),
|
||||||
"name": policy.name,
|
name=policy.name,
|
||||||
"type": policy.type.value,
|
type=policy.type.value,
|
||||||
"server": policy.server,
|
server=policy.server,
|
||||||
"bucket_name": policy.bucket_name,
|
bucket_name=policy.bucket_name,
|
||||||
"is_private": policy.is_private,
|
is_private=policy.is_private,
|
||||||
"base_url": policy.base_url,
|
base_url=policy.base_url,
|
||||||
"max_size": policy.max_size,
|
max_size=policy.max_size,
|
||||||
"auto_rename": policy.auto_rename,
|
auto_rename=policy.auto_rename,
|
||||||
"dir_name_rule": policy.dir_name_rule,
|
dir_name_rule=policy.dir_name_rule,
|
||||||
"file_name_rule": policy.file_name_rule,
|
file_name_rule=policy.file_name_rule,
|
||||||
"is_origin_link_enable": policy.is_origin_link_enable,
|
is_origin_link_enable=policy.is_origin_link_enable,
|
||||||
"options": policy.options.model_dump() if policy.options else None,
|
options=policy.options.model_dump() if policy.options else None,
|
||||||
"groups": [{"id": str(g.id), "name": g.name} for g in groups],
|
groups=[PolicyGroupInfo(id=str(g.id), name=g.name) for g in groups],
|
||||||
"object_count": object_count,
|
object_count=object_count,
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.delete(
|
@admin_policy_router.delete(
|
||||||
path='/{policy_id}',
|
path='/{policy_id}',
|
||||||
summary='删除存储策略',
|
summary='删除存储策略',
|
||||||
description='Delete storage policy by ID',
|
description='Delete storage policy by ID',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_policy_delete_policy(
|
async def router_policy_delete_policy(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
policy_id: UUID,
|
policy_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
删除存储策略。
|
删除存储策略。
|
||||||
|
|
||||||
@@ -346,4 +418,3 @@ async def router_policy_delete_policy(
|
|||||||
await Policy.delete(session, policy)
|
await Policy.delete(session, policy)
|
||||||
|
|
||||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||||
return ResponseBase(data={"deleted": True})
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
@@ -6,8 +7,53 @@ from loguru import logger as l
|
|||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
ResponseBase, ListResponse,
|
ListResponse,
|
||||||
Share, AdminShareListItem, )
|
Share, AdminShareListItem,
|
||||||
|
)
|
||||||
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
|
||||||
|
class ShareDetailResponse(SQLModelBase):
|
||||||
|
"""分享详情响应"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""分享UUID"""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
"""分享码"""
|
||||||
|
|
||||||
|
views: int
|
||||||
|
"""浏览次数"""
|
||||||
|
|
||||||
|
downloads: int
|
||||||
|
"""下载次数"""
|
||||||
|
|
||||||
|
remain_downloads: int | None
|
||||||
|
"""剩余下载次数"""
|
||||||
|
|
||||||
|
expires: datetime | None
|
||||||
|
"""过期时间"""
|
||||||
|
|
||||||
|
preview_enabled: bool
|
||||||
|
"""是否启用预览"""
|
||||||
|
|
||||||
|
score: int
|
||||||
|
"""评分"""
|
||||||
|
|
||||||
|
has_password: bool
|
||||||
|
"""是否有密码"""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
"""用户UUID"""
|
||||||
|
|
||||||
|
username: str | None
|
||||||
|
"""用户名"""
|
||||||
|
|
||||||
|
object: dict | None
|
||||||
|
"""关联对象信息"""
|
||||||
|
|
||||||
|
created_at: str
|
||||||
|
"""创建时间"""
|
||||||
|
|
||||||
admin_share_router = APIRouter(
|
admin_share_router = APIRouter(
|
||||||
prefix='/share',
|
prefix='/share',
|
||||||
@@ -54,7 +100,7 @@ async def router_admin_get_share_list(
|
|||||||
async def router_admin_get_share(
|
async def router_admin_get_share(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
share_id: UUID,
|
share_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> ShareDetailResponse:
|
||||||
"""
|
"""
|
||||||
获取分享详情。
|
获取分享详情。
|
||||||
|
|
||||||
@@ -69,38 +115,39 @@ async def router_admin_get_share(
|
|||||||
obj = await share.awaitable_attrs.object
|
obj = await share.awaitable_attrs.object
|
||||||
user = await share.awaitable_attrs.user
|
user = await share.awaitable_attrs.user
|
||||||
|
|
||||||
return ResponseBase(data={
|
return ShareDetailResponse(
|
||||||
"id": share.id,
|
id=share.id,
|
||||||
"code": share.code,
|
code=share.code,
|
||||||
"views": share.views,
|
views=share.views,
|
||||||
"downloads": share.downloads,
|
downloads=share.downloads,
|
||||||
"remain_downloads": share.remain_downloads,
|
remain_downloads=share.remain_downloads,
|
||||||
"expires": share.expires.isoformat() if share.expires else None,
|
expires=share.expires,
|
||||||
"preview_enabled": share.preview_enabled,
|
preview_enabled=share.preview_enabled,
|
||||||
"score": share.score,
|
score=share.score,
|
||||||
"has_password": bool(share.password),
|
has_password=bool(share.password),
|
||||||
"user_id": str(share.user_id),
|
user_id=str(share.user_id),
|
||||||
"username": user.email if user else None,
|
username=user.email if user else None,
|
||||||
"object": {
|
object={
|
||||||
"id": str(obj.id),
|
"id": str(obj.id),
|
||||||
"name": obj.name,
|
"name": obj.name,
|
||||||
"type": obj.type.value,
|
"type": obj.type.value,
|
||||||
"size": obj.size,
|
"size": obj.size,
|
||||||
} if obj else None,
|
} if obj else None,
|
||||||
"created_at": share.created_at.isoformat(),
|
created_at=share.created_at.isoformat(),
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_share_router.delete(
|
@admin_share_router.delete(
|
||||||
path='/{share_id}',
|
path='/{share_id}',
|
||||||
summary='删除分享',
|
summary='删除分享',
|
||||||
description='Delete share by ID',
|
description='Delete share by ID',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_admin_delete_share(
|
async def router_admin_delete_share(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
share_id: UUID,
|
share_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
删除分享。
|
删除分享。
|
||||||
|
|
||||||
@@ -115,4 +162,3 @@ async def router_admin_delete_share(
|
|||||||
await Share.delete(session, share)
|
await Share.delete(session, share)
|
||||||
|
|
||||||
l.info(f"管理员删除了分享: {share.code}")
|
l.info(f"管理员删除了分享: {share.code}")
|
||||||
return ResponseBase(data={"deleted": True})
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
@@ -6,9 +7,44 @@ from loguru import logger as l
|
|||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
ResponseBase, ListResponse,
|
ListResponse,
|
||||||
Task, TaskSummary,
|
Task, TaskSummary,
|
||||||
)
|
)
|
||||||
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDetailResponse(SQLModelBase):
|
||||||
|
"""任务详情响应"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
"""任务ID"""
|
||||||
|
|
||||||
|
status: int
|
||||||
|
"""任务状态"""
|
||||||
|
|
||||||
|
type: int
|
||||||
|
"""任务类型"""
|
||||||
|
|
||||||
|
progress: int
|
||||||
|
"""任务进度"""
|
||||||
|
|
||||||
|
error: str | None
|
||||||
|
"""错误信息"""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
"""用户UUID"""
|
||||||
|
|
||||||
|
username: str | None
|
||||||
|
"""用户名"""
|
||||||
|
|
||||||
|
props: dict[str, Any] | None
|
||||||
|
"""任务属性"""
|
||||||
|
|
||||||
|
created_at: str
|
||||||
|
"""创建时间"""
|
||||||
|
|
||||||
|
updated_at: str
|
||||||
|
"""更新时间"""
|
||||||
|
|
||||||
admin_task_router = APIRouter(
|
admin_task_router = APIRouter(
|
||||||
prefix='/task',
|
prefix='/task',
|
||||||
@@ -67,7 +103,7 @@ async def router_admin_get_task_list(
|
|||||||
async def router_admin_get_task(
|
async def router_admin_get_task(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
) -> ResponseBase:
|
) -> TaskDetailResponse:
|
||||||
"""
|
"""
|
||||||
获取任务详情。
|
获取任务详情。
|
||||||
|
|
||||||
@@ -82,30 +118,31 @@ async def router_admin_get_task(
|
|||||||
user = await task.awaitable_attrs.user
|
user = await task.awaitable_attrs.user
|
||||||
props = await task.awaitable_attrs.props
|
props = await task.awaitable_attrs.props
|
||||||
|
|
||||||
return ResponseBase(data={
|
return TaskDetailResponse(
|
||||||
"id": task.id,
|
id=task.id,
|
||||||
"status": task.status,
|
status=task.status,
|
||||||
"type": task.type,
|
type=task.type,
|
||||||
"progress": task.progress,
|
progress=task.progress,
|
||||||
"error": task.error,
|
error=task.error,
|
||||||
"user_id": str(task.user_id),
|
user_id=str(task.user_id),
|
||||||
"username": user.email if user else None,
|
username=user.email if user else None,
|
||||||
"props": props.model_dump() if props else None,
|
props=props.model_dump() if props else None,
|
||||||
"created_at": task.created_at.isoformat(),
|
created_at=task.created_at.isoformat(),
|
||||||
"updated_at": task.updated_at.isoformat(),
|
updated_at=task.updated_at.isoformat(),
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_task_router.delete(
|
@admin_task_router.delete(
|
||||||
path='/{task_id}',
|
path='/{task_id}',
|
||||||
summary='删除任务',
|
summary='删除任务',
|
||||||
description='Delete task by ID',
|
description='Delete task by ID',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_admin_delete_task(
|
async def router_admin_delete_task(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
删除任务。
|
删除任务。
|
||||||
|
|
||||||
@@ -120,4 +157,3 @@ async def router_admin_delete_task(
|
|||||||
await Task.delete(session, task)
|
await Task.delete(session, task)
|
||||||
|
|
||||||
l.info(f"管理员删除了任务: {task_id}")
|
l.info(f"管理员删除了任务: {task_id}")
|
||||||
return ResponseBase(data={"deleted": True})
|
|
||||||
@@ -15,7 +15,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
import whatthepatch
|
import whatthepatch
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlmodel_ext import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
from whatthepatch.exceptions import HunkApplyException
|
from whatthepatch.exceptions import HunkApplyException
|
||||||
@@ -37,6 +37,7 @@ from sqlmodels import (
|
|||||||
ResponseBase,
|
ResponseBase,
|
||||||
Setting,
|
Setting,
|
||||||
SettingsType,
|
SettingsType,
|
||||||
|
SourceLink,
|
||||||
UploadChunkResponse,
|
UploadChunkResponse,
|
||||||
UploadSession,
|
UploadSession,
|
||||||
UploadSessionResponse,
|
UploadSessionResponse,
|
||||||
@@ -94,6 +95,41 @@ class PatchContentResponse(ResponseBase):
|
|||||||
new_size: int
|
new_size: int
|
||||||
"""新文件字节大小"""
|
"""新文件字节大小"""
|
||||||
|
|
||||||
|
|
||||||
|
class SourceLinkResponse(ResponseBase):
|
||||||
|
"""外链响应"""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
"""外链地址(永久有效,/source/ 端点自动 302 适配存储策略)"""
|
||||||
|
|
||||||
|
downloads: int
|
||||||
|
"""历史下载次数"""
|
||||||
|
|
||||||
|
|
||||||
|
def _check_policy_size_limit(policy: Policy, file_size: int) -> None:
|
||||||
|
"""
|
||||||
|
检查文件大小是否超过策略限制
|
||||||
|
|
||||||
|
:param policy: 存储策略
|
||||||
|
:param file_size: 文件大小(字节)
|
||||||
|
:raises HTTPException: 413 Payload Too Large
|
||||||
|
"""
|
||||||
|
if policy.max_size > 0 and file_size > policy.max_size:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail=f"文件大小超过限制 ({policy.max_size} bytes)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_site_url(session: SessionDep) -> str:
|
||||||
|
"""获取站点 URL"""
|
||||||
|
site_url_setting = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||||
|
)
|
||||||
|
return site_url_setting.value if site_url_setting else "http://localhost"
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主路由 ====================
|
# ==================== 主路由 ====================
|
||||||
|
|
||||||
router = APIRouter(prefix="/file", tags=["file"])
|
router = APIRouter(prefix="/file", tags=["file"])
|
||||||
@@ -149,11 +185,7 @@ async def create_upload_session(
|
|||||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||||
|
|
||||||
# 验证文件大小限制
|
# 验证文件大小限制
|
||||||
if policy.max_size > 0 and request.file_size > policy.max_size:
|
_check_policy_size_limit(policy, request.file_size)
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail=f"文件大小超过限制 ({policy.max_size} bytes)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查存储配额(auth_required 已预加载 user.group)
|
# 检查存储配额(auth_required 已预加载 user.group)
|
||||||
max_storage = user.group.max_storage
|
max_storage = user.group.max_storage
|
||||||
@@ -344,12 +376,13 @@ async def upload_chunk(
|
|||||||
path='/{session_id}',
|
path='/{session_id}',
|
||||||
summary='删除上传会话',
|
summary='删除上传会话',
|
||||||
description='取消上传并删除会话及已上传的临时文件。',
|
description='取消上传并删除会话及已上传的临时文件。',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def delete_upload_session(
|
async def delete_upload_session(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
session_id: UUID,
|
session_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""删除上传会话端点"""
|
"""删除上传会话端点"""
|
||||||
upload_session = await UploadSession.get(session, UploadSession.id == session_id)
|
upload_session = await UploadSession.get(session, UploadSession.id == session_id)
|
||||||
if not upload_session or upload_session.owner_id != user.id:
|
if not upload_session or upload_session.owner_id != user.id:
|
||||||
@@ -366,18 +399,17 @@ async def delete_upload_session(
|
|||||||
|
|
||||||
l.info(f"删除上传会话: {session_id}")
|
l.info(f"删除上传会话: {session_id}")
|
||||||
|
|
||||||
return ResponseBase(data={"deleted": True})
|
|
||||||
|
|
||||||
|
|
||||||
@_upload_router.delete(
|
@_upload_router.delete(
|
||||||
path='/',
|
path='/',
|
||||||
summary='清除所有上传会话',
|
summary='清除所有上传会话',
|
||||||
description='清除当前用户的所有上传会话。',
|
description='清除当前用户的所有上传会话。',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def clear_upload_sessions(
|
async def clear_upload_sessions(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""清除所有上传会话端点"""
|
"""清除所有上传会话端点"""
|
||||||
# 获取所有会话
|
# 获取所有会话
|
||||||
sessions = await UploadSession.get(
|
sessions = await UploadSession.get(
|
||||||
@@ -399,8 +431,6 @@ async def clear_upload_sessions(
|
|||||||
|
|
||||||
l.info(f"清除用户 {user.id} 的所有上传会话,共 {deleted_count} 个")
|
l.info(f"清除用户 {user.id} 的所有上传会话,共 {deleted_count} 个")
|
||||||
|
|
||||||
return ResponseBase(data={"deleted": deleted_count})
|
|
||||||
|
|
||||||
|
|
||||||
@_upload_router.get(
|
@_upload_router.get(
|
||||||
path='/archive/{session_id}/archive.zip',
|
path='/archive/{session_id}/archive.zip',
|
||||||
@@ -527,12 +557,13 @@ router.include_router(viewers_router)
|
|||||||
path='/create',
|
path='/create',
|
||||||
summary='创建空白文件',
|
summary='创建空白文件',
|
||||||
description='在指定目录下创建空白文件。',
|
description='在指定目录下创建空白文件。',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def create_empty_file(
|
async def create_empty_file(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: CreateFileRequest,
|
request: CreateFileRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""创建空白文件端点"""
|
"""创建空白文件端点"""
|
||||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
@@ -608,12 +639,6 @@ async def create_empty_file(
|
|||||||
|
|
||||||
l.info(f"创建空白文件: {file_object.name}, id={file_object.id}")
|
l.info(f"创建空白文件: {file_object.name}, id={file_object.id}")
|
||||||
|
|
||||||
return ResponseBase(data={
|
|
||||||
"id": str(file_object.id),
|
|
||||||
"name": file_object.name,
|
|
||||||
"size": file_object.size,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== WOPI 会话 ====================
|
# ==================== WOPI 会话 ====================
|
||||||
|
|
||||||
@@ -724,28 +749,145 @@ async def create_wopi_session(
|
|||||||
|
|
||||||
# ==================== 文件外链(保留原有端点结构) ====================
|
# ==================== 文件外链(保留原有端点结构) ====================
|
||||||
|
|
||||||
|
async def _validate_source_link(
|
||||||
|
session: SessionDep,
|
||||||
|
file_id: UUID,
|
||||||
|
) -> tuple[Object, SourceLink, PhysicalFile, Policy]:
|
||||||
|
"""
|
||||||
|
验证外链访问的完整链路
|
||||||
|
|
||||||
|
:returns: (file_obj, link, physical_file, policy)
|
||||||
|
:raises HTTPException: 验证失败
|
||||||
|
"""
|
||||||
|
file_obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == file_id) & (Object.deleted_at == None),
|
||||||
|
)
|
||||||
|
if not file_obj:
|
||||||
|
http_exceptions.raise_not_found("文件不存在")
|
||||||
|
|
||||||
|
if not file_obj.is_file:
|
||||||
|
http_exceptions.raise_bad_request("对象不是文件")
|
||||||
|
|
||||||
|
if file_obj.is_banned:
|
||||||
|
http_exceptions.raise_banned()
|
||||||
|
|
||||||
|
policy = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||||
|
if not policy:
|
||||||
|
http_exceptions.raise_internal_error("存储策略不存在")
|
||||||
|
|
||||||
|
if not policy.is_origin_link_enable:
|
||||||
|
http_exceptions.raise_forbidden("当前存储策略未启用外链功能")
|
||||||
|
|
||||||
|
# SourceLink 必须存在(只有主动创建过外链的文件才能通过外链访问)
|
||||||
|
link: SourceLink | None = await SourceLink.get(
|
||||||
|
session,
|
||||||
|
SourceLink.object_id == file_id,
|
||||||
|
)
|
||||||
|
if not link:
|
||||||
|
http_exceptions.raise_not_found("外链不存在")
|
||||||
|
|
||||||
|
physical_file = await file_obj.awaitable_attrs.physical_file
|
||||||
|
if not physical_file or not physical_file.storage_path:
|
||||||
|
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||||
|
|
||||||
|
return file_obj, link, physical_file, policy
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
path='/get/{id}/{name}',
|
path='/get/{file_id}/{name}',
|
||||||
summary='文件外链(直接输出文件数据)',
|
summary='文件外链(直接输出文件数据)',
|
||||||
description='通过外链直接获取文件内容。',
|
description='通过外链直接获取文件内容,公开访问无需认证。',
|
||||||
)
|
)
|
||||||
async def file_get(
|
async def file_get(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
id: str,
|
file_id: UUID,
|
||||||
name: str,
|
name: str,
|
||||||
) -> FileResponse:
|
) -> FileResponse:
|
||||||
"""文件外链端点(直接输出)"""
|
"""
|
||||||
raise HTTPException(status_code=501, detail="外链功能暂未实现")
|
文件外链端点(直接输出)
|
||||||
|
|
||||||
|
公开访问,无需认证。通过 UUID 定位文件,URL 中的 name 仅用于 Content-Disposition。
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 403: 存储策略未启用外链 / 文件被封禁
|
||||||
|
- 404: 文件不存在 / 外链不存在 / 物理文件不存在
|
||||||
|
"""
|
||||||
|
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||||
|
|
||||||
|
if policy.type != PolicyType.LOCAL:
|
||||||
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
|
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
if not await storage_service.file_exists(physical_file.storage_path):
|
||||||
|
http_exceptions.raise_not_found("物理文件不存在")
|
||||||
|
|
||||||
|
# 缓存物理路径(save 后对象属性会过期)
|
||||||
|
file_path = physical_file.storage_path
|
||||||
|
|
||||||
|
# 递增下载次数
|
||||||
|
link.downloads += 1
|
||||||
|
await link.save(session)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=file_path,
|
||||||
|
filename=name,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
path='/source/{id}/{name}',
|
path='/source/{file_id}/{name}',
|
||||||
summary='文件外链(301跳转)',
|
summary='文件外链(302重定向或直接输出)',
|
||||||
description='通过外链获取文件重定向地址。',
|
description='通过外链获取文件,公有存储 302 重定向,私有存储直接输出。',
|
||||||
|
response_model=None,
|
||||||
)
|
)
|
||||||
async def file_source_redirect(id: str, name: str) -> ResponseBase:
|
async def file_source_redirect(
|
||||||
"""文件外链端点(301跳转)"""
|
session: SessionDep,
|
||||||
raise HTTPException(status_code=501, detail="外链功能暂未实现")
|
file_id: UUID,
|
||||||
|
name: str,
|
||||||
|
) -> FileResponse | RedirectResponse:
|
||||||
|
"""
|
||||||
|
文件外链端点(重定向/直接输出)
|
||||||
|
|
||||||
|
公开访问,无需认证。根据 policy.is_private 决定服务方式:
|
||||||
|
- is_private=False 且 base_url 非空:302 临时重定向
|
||||||
|
- is_private=True 或 base_url 为空:直接返回文件内容
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 403: 存储策略未启用外链 / 文件被封禁
|
||||||
|
- 404: 文件不存在 / 外链不存在 / 物理文件不存在
|
||||||
|
"""
|
||||||
|
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||||
|
|
||||||
|
if policy.type != PolicyType.LOCAL:
|
||||||
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
|
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
if not await storage_service.file_exists(physical_file.storage_path):
|
||||||
|
http_exceptions.raise_not_found("物理文件不存在")
|
||||||
|
|
||||||
|
# 缓存所有需要的值(save 后对象属性会过期)
|
||||||
|
file_path = physical_file.storage_path
|
||||||
|
is_private = policy.is_private
|
||||||
|
base_url = policy.base_url
|
||||||
|
|
||||||
|
# 递增下载次数
|
||||||
|
link.downloads += 1
|
||||||
|
await link.save(session)
|
||||||
|
|
||||||
|
# 公有存储:302 重定向到 base_url
|
||||||
|
if not is_private and base_url:
|
||||||
|
relative_path = storage_service.get_relative_path(file_path)
|
||||||
|
redirect_url = f"{base_url}/{relative_path}"
|
||||||
|
return RedirectResponse(url=redirect_url, status_code=302)
|
||||||
|
|
||||||
|
# 私有存储或 base_url 为空:通过应用代理文件
|
||||||
|
return FileResponse(
|
||||||
|
path=file_path,
|
||||||
|
filename=name,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
@@ -903,6 +1045,9 @@ async def patch_file_content(
|
|||||||
|
|
||||||
new_bytes = new_text.encode('utf-8')
|
new_bytes = new_text.encode('utf-8')
|
||||||
|
|
||||||
|
# 验证文件大小限制
|
||||||
|
_check_policy_size_limit(policy, len(new_bytes))
|
||||||
|
|
||||||
# 写入文件
|
# 写入文件
|
||||||
await storage_service.write_file(storage_path, new_bytes)
|
await storage_service.write_file(storage_path, new_bytes)
|
||||||
|
|
||||||
@@ -939,14 +1084,65 @@ async def file_thumb(id: str) -> ResponseBase:
|
|||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
path='/source/{id}',
|
path='/source/{file_id}',
|
||||||
summary='取得文件外链',
|
summary='创建/获取文件外链',
|
||||||
description='获取文件的外链地址。',
|
description='为指定文件创建或获取已有的外链地址。',
|
||||||
dependencies=[Depends(auth_required)]
|
|
||||||
)
|
)
|
||||||
async def file_source(id: str) -> ResponseBase:
|
async def file_source(
|
||||||
"""获取文件外链"""
|
session: SessionDep,
|
||||||
raise HTTPException(status_code=501, detail="外链功能暂未实现")
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
file_id: UUID,
|
||||||
|
) -> SourceLinkResponse:
|
||||||
|
"""
|
||||||
|
创建/获取文件外链端点
|
||||||
|
|
||||||
|
检查 policy 是否启用外链,查找或创建 SourceLink,返回外链 URL。
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 403: 存储策略未启用外链
|
||||||
|
- 404: 文件不存在
|
||||||
|
"""
|
||||||
|
file_obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == file_id) & (Object.deleted_at == None),
|
||||||
|
)
|
||||||
|
if not file_obj or file_obj.owner_id != user.id:
|
||||||
|
http_exceptions.raise_not_found("文件不存在")
|
||||||
|
|
||||||
|
if not file_obj.is_file:
|
||||||
|
http_exceptions.raise_bad_request("对象不是文件")
|
||||||
|
|
||||||
|
if file_obj.is_banned:
|
||||||
|
http_exceptions.raise_banned()
|
||||||
|
|
||||||
|
policy = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||||
|
if not policy:
|
||||||
|
http_exceptions.raise_internal_error("存储策略不存在")
|
||||||
|
|
||||||
|
if not policy.is_origin_link_enable:
|
||||||
|
http_exceptions.raise_forbidden("当前存储策略未启用外链功能")
|
||||||
|
|
||||||
|
# 缓存文件名(save 后对象属性会过期)
|
||||||
|
file_name = file_obj.name
|
||||||
|
|
||||||
|
# 查找已有 SourceLink
|
||||||
|
link: SourceLink | None = await SourceLink.get(
|
||||||
|
session,
|
||||||
|
(SourceLink.object_id == file_id) & (SourceLink.name == file_name),
|
||||||
|
)
|
||||||
|
if not link:
|
||||||
|
link = SourceLink(
|
||||||
|
name=file_name,
|
||||||
|
object_id=file_id,
|
||||||
|
)
|
||||||
|
link = await link.save(session)
|
||||||
|
|
||||||
|
site_url = await _get_site_url(session)
|
||||||
|
url = f"{site_url}/api/v1/file/source/{file_id}/{file_name}"
|
||||||
|
|
||||||
|
return SourceLinkResponse(url=url, downloads=link.downloads)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from sqlmodels import (
|
|||||||
PhysicalFile,
|
PhysicalFile,
|
||||||
Policy,
|
Policy,
|
||||||
PolicyType,
|
PolicyType,
|
||||||
ResponseBase,
|
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from service.storage import (
|
from service.storage import (
|
||||||
@@ -439,9 +438,9 @@ async def router_object_rename(
|
|||||||
if '/' in new_name or '\\' in new_name:
|
if '/' in new_name or '\\' in new_name:
|
||||||
raise HTTPException(status_code=400, detail="名称不能包含斜杠")
|
raise HTTPException(status_code=400, detail="名称不能包含斜杠")
|
||||||
|
|
||||||
# 如果名称没有变化,直接返回成功
|
# 如果名称没有变化,直接返回
|
||||||
if obj.name == new_name:
|
if obj.name == new_name:
|
||||||
return ResponseBase(data={"success": True})
|
return # noqa: already 204
|
||||||
|
|
||||||
# 检查同目录下是否存在同名对象(仅检查未删除的)
|
# 检查同目录下是否存在同名对象(仅检查未删除的)
|
||||||
existing = await Object.get(
|
existing = await Object.get(
|
||||||
|
|||||||
@@ -110,6 +110,8 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
|||||||
|
|
||||||
return SiteConfigResponse(
|
return SiteConfigResponse(
|
||||||
title=s.get("siteName") or "DiskNext",
|
title=s.get("siteName") or "DiskNext",
|
||||||
|
logo_light=s.get("logo_light") or None,
|
||||||
|
logo_dark=s.get("logo_dark") or None,
|
||||||
register_enabled=s.get("register_enabled") == "1",
|
register_enabled=s.get("register_enabled") == "1",
|
||||||
login_captcha=s.get("login_captcha") == "1",
|
login_captcha=s.get("login_captcha") == "1",
|
||||||
reg_captcha=s.get("reg_captcha") == "1",
|
reg_captcha=s.get("reg_captcha") == "1",
|
||||||
|
|||||||
@@ -20,15 +20,15 @@ slave_aria2_router = APIRouter(
|
|||||||
summary='测试用路由',
|
summary='测试用路由',
|
||||||
description='Test route for checking connectivity.',
|
description='Test route for checking connectivity.',
|
||||||
)
|
)
|
||||||
def router_slave_ping() -> ResponseBase:
|
def router_slave_ping() -> str:
|
||||||
"""
|
"""
|
||||||
Test route for checking connectivity.
|
Test route for checking connectivity.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ResponseBase: A response model indicating success.
|
str: 后端版本号
|
||||||
"""
|
"""
|
||||||
from utils.conf.appmeta import BackendVersion
|
from utils.conf.appmeta import BackendVersion
|
||||||
return ResponseBase(data=BackendVersion)
|
return BackendVersion
|
||||||
|
|
||||||
@slave_router.post(
|
@slave_router.post(
|
||||||
path='/post',
|
path='/post',
|
||||||
|
|||||||
@@ -109,6 +109,8 @@ default_settings: list[Setting] = [
|
|||||||
Setting(name="pwa_display", value="standalone", type=SettingsType.PWA),
|
Setting(name="pwa_display", value="standalone", type=SettingsType.PWA),
|
||||||
Setting(name="pwa_theme_color", value="#000000", type=SettingsType.PWA),
|
Setting(name="pwa_theme_color", value="#000000", type=SettingsType.PWA),
|
||||||
Setting(name="pwa_background_color", value="#ffffff", type=SettingsType.PWA),
|
Setting(name="pwa_background_color", value="#ffffff", type=SettingsType.PWA),
|
||||||
|
Setting(name="logo_light", value="", type=SettingsType.BASIC),
|
||||||
|
Setting(name="logo_dark", value="", type=SettingsType.BASIC),
|
||||||
# ==================== 认证方式配置 ====================
|
# ==================== 认证方式配置 ====================
|
||||||
Setting(name="auth_email_password_enabled", value="1", type=SettingsType.AUTH),
|
Setting(name="auth_email_password_enabled", value="1", type=SettingsType.AUTH),
|
||||||
Setting(name="auth_phone_sms_enabled", value="0", type=SettingsType.AUTH),
|
Setting(name="auth_phone_sms_enabled", value="0", type=SettingsType.AUTH),
|
||||||
|
|||||||
554
tests/integration/api/test_file_source_link.py
Normal file
554
tests/integration/api/test_file_source_link.py
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
"""
|
||||||
|
文件外链与 Policy 规则集成测试
|
||||||
|
|
||||||
|
测试端点:
|
||||||
|
- POST /file/source/{file_id} 创建/获取文件外链
|
||||||
|
- GET /file/get/{file_id}/{name} 外链直接输出
|
||||||
|
- GET /file/source/{file_id}/{name} 外链重定向/输出
|
||||||
|
|
||||||
|
测试 Policy 规则:
|
||||||
|
- max_size 在 PATCH /file/content 中的检查
|
||||||
|
- is_origin_link_enable 控制外链创建与访问
|
||||||
|
- is_private + base_url 控制 302 重定向 vs 应用代理
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from sqlmodels import Object, ObjectType, PhysicalFile, Policy, PolicyType, SourceLink, User
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _register_sqlite_greatest():
|
||||||
|
"""注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法"""
|
||||||
|
|
||||||
|
def _on_connect(dbapi_connection, connection_record):
|
||||||
|
if hasattr(dbapi_connection, 'create_function'):
|
||||||
|
dbapi_connection.create_function("greatest", 2, max)
|
||||||
|
|
||||||
|
event.listen(Engine, "connect", _on_connect)
|
||||||
|
yield
|
||||||
|
event.remove(Engine, "connect", _on_connect)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def source_policy(
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> Policy:
|
||||||
|
"""创建启用外链的本地存储策略"""
|
||||||
|
policy = Policy(
|
||||||
|
id=uuid4(),
|
||||||
|
name="测试外链存储",
|
||||||
|
type=PolicyType.LOCAL,
|
||||||
|
server=str(tmp_path),
|
||||||
|
is_origin_link_enable=True,
|
||||||
|
is_private=True,
|
||||||
|
max_size=0,
|
||||||
|
)
|
||||||
|
initialized_db.add(policy)
|
||||||
|
await initialized_db.commit()
|
||||||
|
await initialized_db.refresh(policy)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def source_file(
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
tmp_path: Path,
|
||||||
|
source_policy: Policy,
|
||||||
|
) -> dict[str, str | int]:
|
||||||
|
"""创建一个文本测试文件,关联到启用外链的存储策略"""
|
||||||
|
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||||
|
root = await Object.get_root(initialized_db, user.id)
|
||||||
|
|
||||||
|
content = "A" * 50
|
||||||
|
content_bytes = content.encode('utf-8')
|
||||||
|
content_hash = hashlib.sha256(content_bytes).hexdigest()
|
||||||
|
|
||||||
|
file_path = tmp_path / "source_test.txt"
|
||||||
|
file_path.write_bytes(content_bytes)
|
||||||
|
|
||||||
|
physical_file = PhysicalFile(
|
||||||
|
id=uuid4(),
|
||||||
|
storage_path=str(file_path),
|
||||||
|
size=len(content_bytes),
|
||||||
|
policy_id=source_policy.id,
|
||||||
|
reference_count=1,
|
||||||
|
)
|
||||||
|
initialized_db.add(physical_file)
|
||||||
|
|
||||||
|
file_obj = Object(
|
||||||
|
id=uuid4(),
|
||||||
|
name="source_test.txt",
|
||||||
|
type=ObjectType.FILE,
|
||||||
|
size=len(content_bytes),
|
||||||
|
physical_file_id=physical_file.id,
|
||||||
|
parent_id=root.id,
|
||||||
|
owner_id=user.id,
|
||||||
|
policy_id=source_policy.id,
|
||||||
|
)
|
||||||
|
initialized_db.add(file_obj)
|
||||||
|
await initialized_db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(file_obj.id),
|
||||||
|
"name": "source_test.txt",
|
||||||
|
"content": content,
|
||||||
|
"hash": content_hash,
|
||||||
|
"size": len(content_bytes),
|
||||||
|
"path": str(file_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def source_file_with_link(
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
source_file: dict[str, str | int],
|
||||||
|
) -> dict[str, str | int]:
|
||||||
|
"""创建已有 SourceLink 的测试文件"""
|
||||||
|
link = SourceLink(
|
||||||
|
name=source_file["name"],
|
||||||
|
object_id=UUID(source_file["id"]),
|
||||||
|
downloads=5,
|
||||||
|
)
|
||||||
|
initialized_db.add(link)
|
||||||
|
await initialized_db.commit()
|
||||||
|
await initialized_db.refresh(link)
|
||||||
|
|
||||||
|
return {**source_file, "link_id": link.id, "link_downloads": 5}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== POST /file/source/{file_id} ====================
|
||||||
|
|
||||||
|
class TestCreateSourceLink:
|
||||||
|
"""POST /file/source/{file_id} 端点测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_source_link_success(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
auth_headers: dict[str, str],
|
||||||
|
source_file: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""成功创建外链"""
|
||||||
|
response = await async_client.post(
|
||||||
|
f"/api/v1/file/source/{source_file['id']}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "/api/v1/file/source/" in data["url"]
|
||||||
|
assert source_file["name"] in data["url"]
|
||||||
|
assert data["downloads"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_source_link_idempotent(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
auth_headers: dict[str, str],
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""已有外链时返回现有外链(幂等)"""
|
||||||
|
response = await async_client.post(
|
||||||
|
f"/api/v1/file/source/{source_file_with_link['id']}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["downloads"] == source_file_with_link["link_downloads"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_source_link_disabled_returns_403(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
auth_headers: dict[str, str],
|
||||||
|
source_file: dict[str, str | int],
|
||||||
|
source_policy: Policy,
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""存储策略未启用外链时返回 403"""
|
||||||
|
source_policy.is_origin_link_enable = False
|
||||||
|
initialized_db.add(source_policy)
|
||||||
|
await initialized_db.commit()
|
||||||
|
|
||||||
|
response = await async_client.post(
|
||||||
|
f"/api/v1/file/source/{source_file['id']}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_source_link_file_not_found(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
auth_headers: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
"""文件不存在返回 404"""
|
||||||
|
response = await async_client.post(
|
||||||
|
f"/api/v1/file/source/{uuid4()}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_source_link_unauthenticated(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""未认证返回 401"""
|
||||||
|
response = await async_client.post(
|
||||||
|
f"/api/v1/file/source/{source_file['id']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== GET /file/get/{file_id}/{name} ====================
|
||||||
|
|
||||||
|
class TestFileGetDirect:
|
||||||
|
"""GET /file/get/{file_id}/{name} 端点测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_direct_success(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""成功通过外链直接获取文件(无需认证)"""
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert source_file_with_link["content"] in response.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_direct_increments_download_count(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
auth_headers: dict[str, str],
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""下载后递增计数"""
|
||||||
|
link_before = await SourceLink.get(
|
||||||
|
initialized_db,
|
||||||
|
SourceLink.object_id == UUID(source_file_with_link["id"]),
|
||||||
|
)
|
||||||
|
downloads_before = link_before.downloads
|
||||||
|
|
||||||
|
await async_client.get(
|
||||||
|
f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
await initialized_db.refresh(link_before)
|
||||||
|
assert link_before.downloads == downloads_before + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_direct_no_link_returns_404(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""未创建外链的文件返回 404"""
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/get/{source_file['id']}/{source_file['name']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_direct_nonexistent_file_returns_404(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
) -> None:
|
||||||
|
"""文件不存在返回 404"""
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/get/{uuid4()}/fake.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_direct_disabled_policy_returns_403(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
source_policy: Policy,
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""存储策略禁用外链时返回 403"""
|
||||||
|
source_policy.is_origin_link_enable = False
|
||||||
|
initialized_db.add(source_policy)
|
||||||
|
await initialized_db.commit()
|
||||||
|
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== GET /file/source/{file_id}/{name} ====================
|
||||||
|
|
||||||
|
class TestFileSourceRedirect:
|
||||||
|
"""GET /file/source/{file_id}/{name} 端点测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_source_private_returns_file_content(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""is_private=True 时直接返回文件内容"""
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert source_file_with_link["content"] in response.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_source_public_redirects_302(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
source_policy: Policy,
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""is_private=False + base_url 时 302 重定向"""
|
||||||
|
source_policy.is_private = False
|
||||||
|
source_policy.base_url = "http://cdn.example.com/storage"
|
||||||
|
initialized_db.add(source_policy)
|
||||||
|
await initialized_db.commit()
|
||||||
|
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 302
|
||||||
|
location = response.headers["location"]
|
||||||
|
assert "cdn.example.com/storage" in location
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_source_public_no_base_url_fallback(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
source_policy: Policy,
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""is_private=False 但 base_url 为空时降级为直接输出"""
|
||||||
|
source_policy.is_private = False
|
||||||
|
source_policy.base_url = None
|
||||||
|
initialized_db.add(source_policy)
|
||||||
|
await initialized_db.commit()
|
||||||
|
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert source_file_with_link["content"] in response.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_source_increments_download_count(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""访问外链递增下载计数"""
|
||||||
|
link_before = await SourceLink.get(
|
||||||
|
initialized_db,
|
||||||
|
SourceLink.object_id == UUID(source_file_with_link["id"]),
|
||||||
|
)
|
||||||
|
downloads_before = link_before.downloads
|
||||||
|
|
||||||
|
await async_client.get(
|
||||||
|
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
await initialized_db.refresh(link_before)
|
||||||
|
assert link_before.downloads == downloads_before + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_source_no_link_returns_404(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""未创建外链的文件返回 404"""
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/source/{source_file['id']}/{source_file['name']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_source_disabled_policy_returns_403(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
source_file_with_link: dict[str, str | int],
|
||||||
|
source_policy: Policy,
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""存储策略禁用外链时返回 403"""
|
||||||
|
source_policy.is_origin_link_enable = False
|
||||||
|
initialized_db.add(source_policy)
|
||||||
|
await initialized_db.commit()
|
||||||
|
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== max_size 在 PATCH 中的检查 ====================
|
||||||
|
|
||||||
|
class TestPatchMaxSizePolicy:
|
||||||
|
"""PATCH /file/content/{file_id} 的 max_size 策略检查"""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def size_limited_policy(
|
||||||
|
self,
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> Policy:
|
||||||
|
"""创建有大小限制的存储策略(100 bytes)"""
|
||||||
|
policy = Policy(
|
||||||
|
id=uuid4(),
|
||||||
|
name="限制大小存储",
|
||||||
|
type=PolicyType.LOCAL,
|
||||||
|
server=str(tmp_path),
|
||||||
|
max_size=100,
|
||||||
|
)
|
||||||
|
initialized_db.add(policy)
|
||||||
|
await initialized_db.commit()
|
||||||
|
await initialized_db.refresh(policy)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def small_file(
|
||||||
|
self,
|
||||||
|
initialized_db: AsyncSession,
|
||||||
|
tmp_path: Path,
|
||||||
|
size_limited_policy: Policy,
|
||||||
|
) -> dict[str, str | int]:
|
||||||
|
"""创建一个 50 字节的文本文件(策略限制 100 字节)"""
|
||||||
|
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||||
|
root = await Object.get_root(initialized_db, user.id)
|
||||||
|
|
||||||
|
content = "A" * 50
|
||||||
|
content_bytes = content.encode('utf-8')
|
||||||
|
content_hash = hashlib.sha256(content_bytes).hexdigest()
|
||||||
|
|
||||||
|
file_path = tmp_path / "small.txt"
|
||||||
|
file_path.write_bytes(content_bytes)
|
||||||
|
|
||||||
|
physical_file = PhysicalFile(
|
||||||
|
id=uuid4(),
|
||||||
|
storage_path=str(file_path),
|
||||||
|
size=len(content_bytes),
|
||||||
|
policy_id=size_limited_policy.id,
|
||||||
|
reference_count=1,
|
||||||
|
)
|
||||||
|
initialized_db.add(physical_file)
|
||||||
|
|
||||||
|
file_obj = Object(
|
||||||
|
id=uuid4(),
|
||||||
|
name="small.txt",
|
||||||
|
type=ObjectType.FILE,
|
||||||
|
size=len(content_bytes),
|
||||||
|
physical_file_id=physical_file.id,
|
||||||
|
parent_id=root.id,
|
||||||
|
owner_id=user.id,
|
||||||
|
policy_id=size_limited_policy.id,
|
||||||
|
)
|
||||||
|
initialized_db.add(file_obj)
|
||||||
|
await initialized_db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(file_obj.id),
|
||||||
|
"content": content,
|
||||||
|
"hash": content_hash,
|
||||||
|
"size": len(content_bytes),
|
||||||
|
"path": str(file_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_patch_exceeds_max_size_returns_413(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
auth_headers: dict[str, str],
|
||||||
|
small_file: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""PATCH 后文件超过 max_size 返回 413"""
|
||||||
|
big_content = "B" * 200
|
||||||
|
patch_text = (
|
||||||
|
"--- a\n"
|
||||||
|
"+++ b\n"
|
||||||
|
"@@ -1 +1 @@\n"
|
||||||
|
f"-{'A' * 50}\n"
|
||||||
|
f"+{big_content}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_client.patch(
|
||||||
|
f"/api/v1/file/content/{small_file['id']}",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"patch": patch_text,
|
||||||
|
"base_hash": small_file["hash"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 413
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_patch_within_max_size_succeeds(
|
||||||
|
self,
|
||||||
|
async_client: AsyncClient,
|
||||||
|
auth_headers: dict[str, str],
|
||||||
|
small_file: dict[str, str | int],
|
||||||
|
) -> None:
|
||||||
|
"""PATCH 后文件未超过 max_size 正常保存"""
|
||||||
|
new_content = "C" * 80 # 80 bytes < 100 bytes limit
|
||||||
|
patch_text = (
|
||||||
|
"--- a\n"
|
||||||
|
"+++ b\n"
|
||||||
|
"@@ -1 +1 @@\n"
|
||||||
|
f"-{'A' * 50}\n"
|
||||||
|
f"+{new_content}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_client.patch(
|
||||||
|
f"/api/v1/file/content/{small_file['id']}",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"patch": patch_text,
|
||||||
|
"base_hash": small_file["hash"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["new_size"] == 80
|
||||||
Reference in New Issue
Block a user