feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s

- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-23 13:38:20 +08:00
parent 7200df6d87
commit 3639a31163
19 changed files with 1728 additions and 124 deletions

View File

@@ -8,11 +8,11 @@ from sqlmodel import Field
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
ListResponse, Object,
Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
PolicyUpdateRequest, ResponseBase, ListResponse, Object,
)
from sqlmodel_ext import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
admin_policy_router = APIRouter(
prefix='/policy',
@@ -67,6 +67,12 @@ class PolicyDetailResponse(SQLModelBase):
base_url: str | None
"""基础URL"""
access_key: str | None
"""Access Key"""
secret_key: str | None
"""Secret Key"""
max_size: int
"""最大文件尺寸"""
@@ -107,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
secret: str
"""从机通信密钥"""
class PolicyCreateRequest(PolicyBase):
"""创建存储策略请求 DTO继承 PolicyBase 中的所有字段"""
pass
class PolicyTestS3Request(SQLModelBase):
"""测试 S3 连接请求 DTO"""
server: str = Field(max_length=255)
"""S3 端点地址"""
bucket_name: str = Field(max_length=255)
"""存储桶名称"""
access_key: str
"""Access Key"""
secret_key: str
"""Secret Key"""
s3_region: str = Field(default='us-east-1', max_length=64)
"""S3 区域"""
s3_path_style: bool = False
"""是否使用路径风格"""
class PolicyTestS3Response(SQLModelBase):
"""S3 连接测试响应"""
is_connected: bool
"""连接是否成功"""
message: str
"""测试结果消息"""
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
_OPTIONS_FIELDS: set[str] = {
'token', 'file_type', 'mimetype', 'od_redirect',
'chunk_size', 's3_path_style', 's3_region',
}
@admin_policy_router.get(
path='/list',
@@ -277,7 +319,20 @@ async def router_policy_add_policy(
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
# 保存到数据库
await policy.save(session)
policy = await policy.save(session)
# 创建策略选项
options = PolicyOptions(
policy_id=policy.id,
token=request.token,
file_type=request.file_type,
mimetype=request.mimetype,
od_redirect=request.od_redirect,
chunk_size=request.chunk_size,
s3_path_style=request.s3_path_style,
s3_region=request.s3_region,
)
await options.save(session)
@admin_policy_router.post(
path='/cors',
@@ -371,6 +426,8 @@ async def router_policy_get_policy(
bucket_name=policy.bucket_name,
is_private=policy.is_private,
base_url=policy.base_url,
access_key=policy.access_key,
secret_key=policy.secret_key,
max_size=policy.max_size,
auto_rename=policy.auto_rename,
dir_name_rule=policy.dir_name_rule,
@@ -417,4 +474,108 @@ async def router_policy_delete_policy(
policy_name = policy.name
await Policy.delete(session, policy)
l.info(f"管理员删除了存储策略: {policy_name}")
l.info(f"管理员删除了存储策略: {policy_name}")
@admin_policy_router.patch(
path='/{policy_id}',
summary='更新存储策略',
description='更新存储策略配置。策略类型创建后不可更改。',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_policy_update_policy(
session: SessionDep,
policy_id: UUID,
request: PolicyUpdateRequest,
) -> None:
"""
更新存储策略端点
功能:
- 更新策略基础字段和扩展选项
- 策略类型type不可更改
认证:
- 需要管理员权限
:param session: 数据库会话
:param policy_id: 存储策略UUID
:param request: 更新请求
"""
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 检查名称唯一性(如果要更新名称)
if request.name and request.name != policy.name:
existing = await Policy.get(session, Policy.name == request.name)
if existing:
raise HTTPException(status_code=409, detail="策略名称已存在")
# 分离 Policy 字段和 Options 字段
all_data = request.model_dump(exclude_unset=True)
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
# 更新 Policy 基础字段
if policy_data:
for key, value in policy_data.items():
setattr(policy, key, value)
policy = await policy.save(session)
# 更新或创建 PolicyOptions
if options_data:
if policy.options:
for key, value in options_data.items():
setattr(policy.options, key, value)
await policy.options.save(session)
else:
options = PolicyOptions(policy_id=policy.id, **options_data)
await options.save(session)
l.info(f"管理员更新了存储策略: {policy_id}")
@admin_policy_router.post(
path='/test/s3',
summary='测试 S3 连接',
description='测试 S3 存储端点的连通性和凭据有效性。',
dependencies=[Depends(admin_required)],
)
async def router_policy_test_s3(
request: PolicyTestS3Request,
) -> PolicyTestS3Response:
"""
测试 S3 连接端点
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
:param request: 测试请求
:return: 测试结果
"""
from service.storage import S3APIError
# 构造临时 Policy 对象用于创建 S3StorageService
temp_policy = Policy(
name="__test__",
type=PolicyType.S3,
server=request.server,
bucket_name=request.bucket_name,
access_key=request.access_key,
secret_key=request.secret_key,
)
s3_service = S3StorageService(
temp_policy,
region=request.s3_region,
is_path_style=request.s3_path_style,
)
try:
# 使用 file_exists 发送 HEAD 请求来验证连通性
await s3_service.file_exists("__connection_test__")
return PolicyTestS3Response(is_connected=True, message="连接成功")
except S3APIError as e:
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
except Exception as e:
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")

View File

@@ -57,7 +57,7 @@ async def _get_directory_response(
policy_response = PolicyResponse(
id=policy.id,
name=policy.name,
type=policy.type.value,
type=policy.type,
max_size=policy.max_size,
)
@@ -189,6 +189,14 @@ async def router_directory_create(
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
policy_id = request.policy_id if request.policy_id else parent.policy_id
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
parent_id = parent.id # 在 save 前保存
new_folder = Object(

View File

@@ -13,9 +13,11 @@ from datetime import datetime, timedelta
from typing import Annotated
from uuid import UUID
import orjson
import whatthepatch
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, RedirectResponse
from starlette.responses import Response
from loguru import logger as l
from sqlmodel_ext import SQLModelBase
from whatthepatch.exceptions import HunkApplyException
@@ -44,7 +46,9 @@ from sqlmodels import (
User,
WopiSessionResponse,
)
from service.storage import LocalStorageService, adjust_user_storage
import orjson
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
from utils.JWT.wopi_token import create_wopi_token
from utils import http_exceptions
@@ -184,6 +188,13 @@ async def create_upload_session(
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
# 验证文件大小限制
_check_policy_size_limit(policy, request.file_size)
@@ -210,6 +221,7 @@ async def create_upload_session(
# 生成存储路径
storage_path: str | None = None
s3_upload_id: str | None = None
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path(
@@ -217,8 +229,25 @@ async def create_upload_session(
original_filename=request.file_name,
)
storage_path = full_path
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
elif policy.type == PolicyType.S3:
s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user.id,
original_filename=request.file_name,
)
# 多分片时创建 multipart upload
if total_chunks > 1:
s3_upload_id = await s3_service.create_multipart_upload(
storage_path, content_type='application/octet-stream',
)
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
if request.file_size > 0:
await adjust_user_storage(session, user.id, request.file_size, commit=False)
# 创建上传会话
upload_session = UploadSession(
@@ -227,6 +256,7 @@ async def create_upload_session(
chunk_size=chunk_size,
total_chunks=total_chunks,
storage_path=storage_path,
s3_upload_id=s3_upload_id,
expires_at=datetime.now() + timedelta(hours=24),
owner_id=user.id,
parent_id=request.parent_id,
@@ -302,8 +332,38 @@ async def upload_chunk(
content,
offset,
)
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
elif policy.type == PolicyType.S3:
if not upload_session.storage_path:
raise HTTPException(status_code=500, detail="存储路径丢失")
s3_service = await S3StorageService.from_policy(policy)
if upload_session.total_chunks == 1:
# 单分片:直接 PUT 上传
await s3_service.upload_file(upload_session.storage_path, content)
else:
# 多分片UploadPart
if not upload_session.s3_upload_id:
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
etag = await s3_service.upload_part(
upload_session.storage_path,
upload_session.s3_upload_id,
chunk_index + 1, # S3 part number 从 1 开始
content,
)
# 追加 ETag 到 s3_part_etags
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
etags.append([chunk_index + 1, etag])
upload_session.s3_part_etags = orjson.dumps(etags).decode()
# 在 savecommit前缓存后续需要的属性commit 后 ORM 对象会过期)
policy_type = policy.type
s3_upload_id = upload_session.s3_upload_id
s3_part_etags = upload_session.s3_part_etags
s3_service_for_complete: S3StorageService | None = None
if policy_type == PolicyType.S3:
s3_service_for_complete = await S3StorageService.from_policy(policy)
# 更新会话进度
upload_session.uploaded_chunks += 1
@@ -319,12 +379,26 @@ async def upload_chunk(
if is_complete:
# 保存 upload_session 属性commit 后会过期)
file_name = upload_session.file_name
file_size = upload_session.file_size
uploaded_size = upload_session.uploaded_size
storage_path = upload_session.storage_path
upload_session_id = upload_session.id
parent_id = upload_session.parent_id
policy_id = upload_session.policy_id
# S3 多分片上传完成:合并分片
if (
policy_type == PolicyType.S3
and s3_upload_id
and s3_part_etags
and s3_service_for_complete
):
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
parts = [(int(pn), str(et)) for pn, et in parts_data]
await s3_service_for_complete.complete_multipart_upload(
storage_path, s3_upload_id, parts,
)
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=storage_path,
@@ -355,9 +429,10 @@ async def upload_chunk(
commit=False
)
# 更新用户存储配额
if uploaded_size > 0:
await adjust_user_storage(session, user_id, uploaded_size, commit=False)
# 调整存储配额差值(创建会话时已预扣 file_size这里只补差
size_diff = uploaded_size - file_size
if size_diff != 0:
await adjust_user_storage(session, user_id, size_diff, commit=False)
# 统一提交所有更改
await session.commit()
@@ -390,9 +465,25 @@ async def delete_upload_session(
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
if policy and upload_session.storage_path:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 如果有分片上传,先取消
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
else:
# 单分片上传已完成的话,删除已上传的文件
if upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
# 删除会话记录
await UploadSession.delete(session, upload_session)
@@ -422,9 +513,22 @@ async def clear_upload_sessions(
for upload_session in sessions:
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
if policy and upload_session.storage_path:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
elif upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
await UploadSession.delete(session, upload_session)
deleted_count += 1
@@ -486,11 +590,12 @@ async def create_download_token_endpoint(
path='/{token}',
summary='下载文件',
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
response_model=None,
)
async def download_file(
session: SessionDep,
token: str,
) -> FileResponse:
) -> Response:
"""
下载文件端点
@@ -540,8 +645,15 @@ async def download_file(
filename=file_obj.name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 302 重定向到预签名 URL
presigned_url = s3_service.generate_presigned_url(
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
raise HTTPException(status_code=500, detail="不支持的存储类型")
# ==================== 包含子路由 ====================
@@ -613,8 +725,13 @@ async def create_empty_file(
)
await storage_service.create_empty_file(full_path)
storage_path = full_path
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user_id,
original_filename=request.name,
)
await s3_service.upload_file(storage_path, b'')
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
@@ -798,12 +915,13 @@ async def _validate_source_link(
path='/get/{file_id}/{name}',
summary='文件外链(直接输出文件数据)',
description='通过外链直接获取文件内容,公开访问无需认证。',
response_model=None,
)
async def file_get(
session: SessionDep,
file_id: UUID,
name: str,
) -> FileResponse:
) -> Response:
"""
文件外链端点(直接输出)
@@ -815,13 +933,6 @@ async def file_get(
"""
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存物理路径save 后对象属性会过期)
file_path = physical_file.storage_path
@@ -829,11 +940,25 @@ async def file_get(
link.downloads += 1
await link.save(session)
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(file_path):
http_exceptions.raise_not_found("物理文件不存在")
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
# S3 外链直接输出302 重定向到预签名 URL
s3_service = await S3StorageService.from_policy(policy)
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
@router.get(
@@ -846,7 +971,7 @@ async def file_source_redirect(
session: SessionDep,
file_id: UUID,
name: str,
) -> FileResponse | RedirectResponse:
) -> Response:
"""
文件外链端点(重定向/直接输出)
@@ -860,13 +985,6 @@ async def file_source_redirect(
"""
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存所有需要的值save 后对象属性会过期)
file_path = physical_file.storage_path
is_private = policy.is_private
@@ -876,18 +994,36 @@ async def file_source_redirect(
link.downloads += 1
await link.save(session)
# 公有存储302 重定向到 base_url
if not is_private and base_url:
relative_path = storage_service.get_relative_path(file_path)
redirect_url = f"{base_url}/{relative_path}"
return RedirectResponse(url=redirect_url, status_code=302)
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(file_path):
http_exceptions.raise_not_found("物理文件不存在")
# 有存储或 base_url 为空:通过应用代理文件
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
# 有存储302 重定向到 base_url
if not is_private and base_url:
relative_path = storage_service.get_relative_path(file_path)
redirect_url = f"{base_url}/{relative_path}"
return RedirectResponse(url=redirect_url, status_code=302)
# 私有存储或 base_url 为空:通过应用代理文件
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 公有存储且有 base_url直接重定向到公开 URL
if not is_private and base_url:
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
return RedirectResponse(url=redirect_url, status_code=302)
# 私有存储:生成预签名 URL 重定向
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
@router.put(
@@ -941,11 +1077,15 @@ async def file_content(
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(physical_file.storage_path)
# 读取文件内容
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(physical_file.storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
try:
content = raw_bytes.decode('utf-8')
@@ -1011,11 +1151,15 @@ async def patch_file_content(
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(storage_path)
# 读取文件内容
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
# 解码 + 规范化
original_text = raw_bytes.decode('utf-8')
@@ -1049,7 +1193,10 @@ async def patch_file_content(
_check_policy_size_limit(policy, len(new_bytes))
# 写入文件
await storage_service.write_file(storage_path, new_bytes)
if policy.type == PolicyType.LOCAL:
await storage_service.write_file(storage_path, new_bytes)
elif policy.type == PolicyType.S3:
await s3_service.upload_file(storage_path, new_bytes)
# 更新数据库
owner_id = file_obj.owner_id

View File

@@ -8,13 +8,14 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
CreateFileRequest,
Group,
Object,
ObjectCopyRequest,
ObjectDeleteRequest,
@@ -22,18 +23,27 @@ from sqlmodels import (
ObjectPropertyDetailResponse,
ObjectPropertyResponse,
ObjectRenameRequest,
ObjectSwitchPolicyRequest,
ObjectType,
PhysicalFile,
Policy,
PolicyType,
Task,
TaskProps,
TaskStatus,
TaskSummaryBase,
TaskType,
User,
)
from service.storage import (
LocalStorageService,
adjust_user_storage,
copy_object_recursive,
migrate_file_with_task,
migrate_directory_files,
)
from service.storage.object import soft_delete_objects
from sqlmodels.database_connection import DatabaseManager
from utils import http_exceptions
object_router = APIRouter(
@@ -575,3 +585,136 @@ async def router_object_property_detail(
response.checksum_md5 = obj.file_metadata.checksum_md5
return response
@object_router.patch(
path='/{object_id}/policy',
summary='切换对象存储策略',
)
async def router_object_switch_policy(
session: SessionDep,
background_tasks: BackgroundTasks,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: ObjectSwitchPolicyRequest,
) -> TaskSummaryBase:
"""
切换对象的存储策略
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
目录:更新目录 policy_id新文件使用新策略
若 is_migrate_existing=True额外创建后台任务迁移所有已有文件。
认证JWT Bearer Token
错误处理:
- 404: 对象不存在
- 403: 无权操作此对象 / 用户组无权使用目标策略
- 400: 目标策略与当前相同 / 不能对根目录操作
"""
user_id = user.id
# 查找对象
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None)
)
if not obj:
http_exceptions.raise_not_found("对象不存在")
if obj.owner_id != user_id:
http_exceptions.raise_forbidden("无权操作此对象")
if obj.is_banned:
http_exceptions.raise_banned()
# 根目录不能直接切换策略(应通过子对象或子目录操作)
if obj.parent_id is None:
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
# 校验目标策略存在
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
if not dest_policy:
http_exceptions.raise_not_found("目标存储策略不存在")
# 校验用户组权限
group: Group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
allowed_ids = {p.id for p in group.policies}
if request.policy_id not in allowed_ids:
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
# 不能切换到相同策略
if obj.policy_id == request.policy_id:
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
# 保存必要的属性,避免 save 后对象过期
src_policy_id = obj.policy_id
obj_id = obj.id
obj_is_file = obj.type == ObjectType.FILE
dest_policy_id = request.policy_id
dest_policy_name = dest_policy.name
# 创建任务记录
task = Task(
type=TaskType.POLICY_MIGRATE,
status=TaskStatus.QUEUED,
user_id=user_id,
)
task = await task.save(session)
task_id = task.id
task_props = TaskProps(
task_id=task_id,
source_policy_id=src_policy_id,
dest_policy_id=dest_policy_id,
object_id=obj_id,
)
await task_props.save(session)
if obj_is_file:
# 文件:后台迁移
async def _run_file_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_obj = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
background_tasks.add_task(_run_file_migration)
else:
# 目录:先更新目录自身的 policy_id
obj = await Object.get(session, Object.id == obj_id)
obj.policy_id = dest_policy_id
await obj.save(session)
if request.is_migrate_existing:
# 后台迁移所有已有文件
async def _run_dir_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_folder = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
background_tasks.add_task(_run_dir_migration)
else:
# 不迁移已有文件,直接完成任务
task = await Task.get(session, Task.id == task_id)
task.status = TaskStatus.COMPLETED
task.progress = 100
await task.save(session)
# 重新获取 task 以读取最新状态
task = await Task.get(session, Task.id == task_id)
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
return TaskSummaryBase(
id=task.id,
type=task.type,
status=task.status,
progress=task.progress,
error=task.error,
user_id=task.user_id,
created_at=task.created_at,
updated_at=task.updated_at,
)

View File

@@ -247,11 +247,12 @@ async def router_user_register(
)
await identity.save(session)
# 8. 创建用户根目录
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
if not default_policy:
logger.error("默认存储策略不存在")
# 8. 创建用户根目录(使用用户组关联的第一个存储策略)
await session.refresh(default_group, ['policies'])
if not default_group.policies:
logger.error("默认用户组未关联任何存储策略")
http_exceptions.raise_internal_error()
default_policy = default_group.policies[0]
await sqlmodels.Object(
name="/",

View File

@@ -13,6 +13,7 @@ from sqlmodels import (
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
ChangePasswordRequest,
AuthnDetailResponse, AuthnRenameRequest,
PolicySummary,
)
from sqlmodels.color import ThemeColorsBase
from sqlmodels.user_authn import UserAuthn
@@ -31,16 +32,25 @@ user_settings_router.include_router(file_viewers_router)
@user_settings_router.get(
path='/policies',
summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
)
def router_user_settings_policies() -> sqlmodels.ResponseBase:
async def router_user_settings_policies(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[PolicySummary]:
"""
Get user selectable storage policies.
获取当前用户所在组可选的存储策略列表
Returns:
dict: A dictionary containing available storage policies for the user.
返回用户组关联的所有存储策略的摘要信息。
"""
http_exceptions.raise_not_implemented()
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
return [
PolicySummary(
id=p.id, name=p.name, type=p.type,
server=p.server, max_size=p.max_size, is_private=p.is_private,
)
for p in group.policies
]
@user_settings_router.get(