feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s

- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-23 13:38:20 +08:00
parent 7200df6d87
commit 3639a31163
19 changed files with 1728 additions and 124 deletions

View File

@@ -8,6 +8,7 @@ from routers import router
from routers.dav import dav_app from routers.dav import dav_app
from routers.dav.provider import EventLoopRef from routers.dav.provider import EventLoopRef
from service.redis import RedisManager from service.redis import RedisManager
from service.storage import S3StorageService
from sqlmodels.database_connection import DatabaseManager from sqlmodels.database_connection import DatabaseManager
from sqlmodels.migration import migration from sqlmodels.migration import migration
from utils import JWT from utils import JWT
@@ -50,8 +51,10 @@ lifespan.add_startup(_init_db)
lifespan.add_startup(migration) lifespan.add_startup(migration)
lifespan.add_startup(JWT.load_secret_key) lifespan.add_startup(JWT.load_secret_key)
lifespan.add_startup(RedisManager.connect) lifespan.add_startup(RedisManager.connect)
lifespan.add_startup(S3StorageService.initialize_session)
# 添加关闭项 # 添加关闭项
lifespan.add_shutdown(S3StorageService.close_session)
lifespan.add_shutdown(DatabaseManager.close) lifespan.add_shutdown(DatabaseManager.close)
lifespan.add_shutdown(RedisManager.disconnect) lifespan.add_shutdown(RedisManager.disconnect)

View File

@@ -8,11 +8,11 @@ from sqlmodel import Field
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import ( from sqlmodels import (
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase, Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
ListResponse, Object, PolicyUpdateRequest, ResponseBase, ListResponse, Object,
) )
from sqlmodel_ext import SQLModelBase from sqlmodel_ext import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
admin_policy_router = APIRouter( admin_policy_router = APIRouter(
prefix='/policy', prefix='/policy',
@@ -67,6 +67,12 @@ class PolicyDetailResponse(SQLModelBase):
base_url: str | None base_url: str | None
"""基础URL""" """基础URL"""
access_key: str | None
"""Access Key"""
secret_key: str | None
"""Secret Key"""
max_size: int max_size: int
"""最大文件尺寸""" """最大文件尺寸"""
@@ -107,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
secret: str secret: str
"""从机通信密钥""" """从机通信密钥"""
class PolicyCreateRequest(PolicyBase): class PolicyTestS3Request(SQLModelBase):
"""创建存储策略请求 DTO继承 PolicyBase 中的所有字段""" """测试 S3 连接请求 DTO"""
pass
server: str = Field(max_length=255)
"""S3 端点地址"""
bucket_name: str = Field(max_length=255)
"""存储桶名称"""
access_key: str
"""Access Key"""
secret_key: str
"""Secret Key"""
s3_region: str = Field(default='us-east-1', max_length=64)
"""S3 区域"""
s3_path_style: bool = False
"""是否使用路径风格"""
class PolicyTestS3Response(SQLModelBase):
"""S3 连接测试响应"""
is_connected: bool
"""连接是否成功"""
message: str
"""测试结果消息"""
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
_OPTIONS_FIELDS: set[str] = {
'token', 'file_type', 'mimetype', 'od_redirect',
'chunk_size', 's3_path_style', 's3_region',
}
@admin_policy_router.get( @admin_policy_router.get(
path='/list', path='/list',
@@ -277,7 +319,20 @@ async def router_policy_add_policy(
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}") raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
# 保存到数据库 # 保存到数据库
await policy.save(session) policy = await policy.save(session)
# 创建策略选项
options = PolicyOptions(
policy_id=policy.id,
token=request.token,
file_type=request.file_type,
mimetype=request.mimetype,
od_redirect=request.od_redirect,
chunk_size=request.chunk_size,
s3_path_style=request.s3_path_style,
s3_region=request.s3_region,
)
await options.save(session)
@admin_policy_router.post( @admin_policy_router.post(
path='/cors', path='/cors',
@@ -371,6 +426,8 @@ async def router_policy_get_policy(
bucket_name=policy.bucket_name, bucket_name=policy.bucket_name,
is_private=policy.is_private, is_private=policy.is_private,
base_url=policy.base_url, base_url=policy.base_url,
access_key=policy.access_key,
secret_key=policy.secret_key,
max_size=policy.max_size, max_size=policy.max_size,
auto_rename=policy.auto_rename, auto_rename=policy.auto_rename,
dir_name_rule=policy.dir_name_rule, dir_name_rule=policy.dir_name_rule,
@@ -417,4 +474,108 @@ async def router_policy_delete_policy(
policy_name = policy.name policy_name = policy.name
await Policy.delete(session, policy) await Policy.delete(session, policy)
l.info(f"管理员删除了存储策略: {policy_name}") l.info(f"管理员删除了存储策略: {policy_name}")
@admin_policy_router.patch(
path='/{policy_id}',
summary='更新存储策略',
description='更新存储策略配置。策略类型创建后不可更改。',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_policy_update_policy(
session: SessionDep,
policy_id: UUID,
request: PolicyUpdateRequest,
) -> None:
"""
更新存储策略端点
功能:
- 更新策略基础字段和扩展选项
- 策略类型type不可更改
认证:
- 需要管理员权限
:param session: 数据库会话
:param policy_id: 存储策略UUID
:param request: 更新请求
"""
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 检查名称唯一性(如果要更新名称)
if request.name and request.name != policy.name:
existing = await Policy.get(session, Policy.name == request.name)
if existing:
raise HTTPException(status_code=409, detail="策略名称已存在")
# 分离 Policy 字段和 Options 字段
all_data = request.model_dump(exclude_unset=True)
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
# 更新 Policy 基础字段
if policy_data:
for key, value in policy_data.items():
setattr(policy, key, value)
policy = await policy.save(session)
# 更新或创建 PolicyOptions
if options_data:
if policy.options:
for key, value in options_data.items():
setattr(policy.options, key, value)
await policy.options.save(session)
else:
options = PolicyOptions(policy_id=policy.id, **options_data)
await options.save(session)
l.info(f"管理员更新了存储策略: {policy_id}")
@admin_policy_router.post(
path='/test/s3',
summary='测试 S3 连接',
description='测试 S3 存储端点的连通性和凭据有效性。',
dependencies=[Depends(admin_required)],
)
async def router_policy_test_s3(
request: PolicyTestS3Request,
) -> PolicyTestS3Response:
"""
测试 S3 连接端点
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
:param request: 测试请求
:return: 测试结果
"""
from service.storage import S3APIError
# 构造临时 Policy 对象用于创建 S3StorageService
temp_policy = Policy(
name="__test__",
type=PolicyType.S3,
server=request.server,
bucket_name=request.bucket_name,
access_key=request.access_key,
secret_key=request.secret_key,
)
s3_service = S3StorageService(
temp_policy,
region=request.s3_region,
is_path_style=request.s3_path_style,
)
try:
# 使用 file_exists 发送 HEAD 请求来验证连通性
await s3_service.file_exists("__connection_test__")
return PolicyTestS3Response(is_connected=True, message="连接成功")
except S3APIError as e:
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
except Exception as e:
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")

View File

@@ -57,7 +57,7 @@ async def _get_directory_response(
policy_response = PolicyResponse( policy_response = PolicyResponse(
id=policy.id, id=policy.id,
name=policy.name, name=policy.name,
type=policy.type.value, type=policy.type,
max_size=policy.max_size, max_size=policy.max_size,
) )
@@ -189,6 +189,14 @@ async def router_directory_create(
raise HTTPException(status_code=409, detail="同名文件或目录已存在") raise HTTPException(status_code=409, detail="同名文件或目录已存在")
policy_id = request.policy_id if request.policy_id else parent.policy_id policy_id = request.policy_id if request.policy_id else parent.policy_id
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
parent_id = parent.id # 在 save 前保存 parent_id = parent.id # 在 save 前保存
new_folder = Object( new_folder = Object(

View File

@@ -13,9 +13,11 @@ from datetime import datetime, timedelta
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
import orjson
import whatthepatch import whatthepatch
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, RedirectResponse from fastapi.responses import FileResponse, RedirectResponse
from starlette.responses import Response
from loguru import logger as l from loguru import logger as l
from sqlmodel_ext import SQLModelBase from sqlmodel_ext import SQLModelBase
from whatthepatch.exceptions import HunkApplyException from whatthepatch.exceptions import HunkApplyException
@@ -44,7 +46,9 @@ from sqlmodels import (
User, User,
WopiSessionResponse, WopiSessionResponse,
) )
from service.storage import LocalStorageService, adjust_user_storage import orjson
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
from utils.JWT.wopi_token import create_wopi_token from utils.JWT.wopi_token import create_wopi_token
from utils import http_exceptions from utils import http_exceptions
@@ -184,6 +188,13 @@ async def create_upload_session(
if not policy: if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在") raise HTTPException(status_code=404, detail="存储策略不存在")
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
# 验证文件大小限制 # 验证文件大小限制
_check_policy_size_limit(policy, request.file_size) _check_policy_size_limit(policy, request.file_size)
@@ -210,6 +221,7 @@ async def create_upload_session(
# 生成存储路径 # 生成存储路径
storage_path: str | None = None storage_path: str | None = None
s3_upload_id: str | None = None
if policy.type == PolicyType.LOCAL: if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy) storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path( dir_path, storage_name, full_path = await storage_service.generate_file_path(
@@ -217,8 +229,25 @@ async def create_upload_session(
original_filename=request.file_name, original_filename=request.file_name,
) )
storage_path = full_path storage_path = full_path
else: elif policy.type == PolicyType.S3:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user.id,
original_filename=request.file_name,
)
# 多分片时创建 multipart upload
if total_chunks > 1:
s3_upload_id = await s3_service.create_multipart_upload(
storage_path, content_type='application/octet-stream',
)
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
if request.file_size > 0:
await adjust_user_storage(session, user.id, request.file_size, commit=False)
# 创建上传会话 # 创建上传会话
upload_session = UploadSession( upload_session = UploadSession(
@@ -227,6 +256,7 @@ async def create_upload_session(
chunk_size=chunk_size, chunk_size=chunk_size,
total_chunks=total_chunks, total_chunks=total_chunks,
storage_path=storage_path, storage_path=storage_path,
s3_upload_id=s3_upload_id,
expires_at=datetime.now() + timedelta(hours=24), expires_at=datetime.now() + timedelta(hours=24),
owner_id=user.id, owner_id=user.id,
parent_id=request.parent_id, parent_id=request.parent_id,
@@ -302,8 +332,38 @@ async def upload_chunk(
content, content,
offset, offset,
) )
else: elif policy.type == PolicyType.S3:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") if not upload_session.storage_path:
raise HTTPException(status_code=500, detail="存储路径丢失")
s3_service = await S3StorageService.from_policy(policy)
if upload_session.total_chunks == 1:
# 单分片:直接 PUT 上传
await s3_service.upload_file(upload_session.storage_path, content)
else:
# 多分片UploadPart
if not upload_session.s3_upload_id:
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
etag = await s3_service.upload_part(
upload_session.storage_path,
upload_session.s3_upload_id,
chunk_index + 1, # S3 part number 从 1 开始
content,
)
# 追加 ETag 到 s3_part_etags
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
etags.append([chunk_index + 1, etag])
upload_session.s3_part_etags = orjson.dumps(etags).decode()
# 在 savecommit前缓存后续需要的属性commit 后 ORM 对象会过期)
policy_type = policy.type
s3_upload_id = upload_session.s3_upload_id
s3_part_etags = upload_session.s3_part_etags
s3_service_for_complete: S3StorageService | None = None
if policy_type == PolicyType.S3:
s3_service_for_complete = await S3StorageService.from_policy(policy)
# 更新会话进度 # 更新会话进度
upload_session.uploaded_chunks += 1 upload_session.uploaded_chunks += 1
@@ -319,12 +379,26 @@ async def upload_chunk(
if is_complete: if is_complete:
# 保存 upload_session 属性commit 后会过期) # 保存 upload_session 属性commit 后会过期)
file_name = upload_session.file_name file_name = upload_session.file_name
file_size = upload_session.file_size
uploaded_size = upload_session.uploaded_size uploaded_size = upload_session.uploaded_size
storage_path = upload_session.storage_path storage_path = upload_session.storage_path
upload_session_id = upload_session.id upload_session_id = upload_session.id
parent_id = upload_session.parent_id parent_id = upload_session.parent_id
policy_id = upload_session.policy_id policy_id = upload_session.policy_id
# S3 多分片上传完成:合并分片
if (
policy_type == PolicyType.S3
and s3_upload_id
and s3_part_etags
and s3_service_for_complete
):
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
parts = [(int(pn), str(et)) for pn, et in parts_data]
await s3_service_for_complete.complete_multipart_upload(
storage_path, s3_upload_id, parts,
)
# 创建 PhysicalFile 记录 # 创建 PhysicalFile 记录
physical_file = PhysicalFile( physical_file = PhysicalFile(
storage_path=storage_path, storage_path=storage_path,
@@ -355,9 +429,10 @@ async def upload_chunk(
commit=False commit=False
) )
# 更新用户存储配额 # 调整存储配额差值(创建会话时已预扣 file_size这里只补差
if uploaded_size > 0: size_diff = uploaded_size - file_size
await adjust_user_storage(session, user_id, uploaded_size, commit=False) if size_diff != 0:
await adjust_user_storage(session, user_id, size_diff, commit=False)
# 统一提交所有更改 # 统一提交所有更改
await session.commit() await session.commit()
@@ -390,9 +465,25 @@ async def delete_upload_session(
# 删除临时文件 # 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id) policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: if policy and upload_session.storage_path:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(upload_session.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 如果有分片上传,先取消
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
else:
# 单分片上传已完成的话,删除已上传的文件
if upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
# 删除会话记录 # 删除会话记录
await UploadSession.delete(session, upload_session) await UploadSession.delete(session, upload_session)
@@ -422,9 +513,22 @@ async def clear_upload_sessions(
for upload_session in sessions: for upload_session in sessions:
# 删除临时文件 # 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id) policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: if policy and upload_session.storage_path:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(upload_session.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
elif upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
await UploadSession.delete(session, upload_session) await UploadSession.delete(session, upload_session)
deleted_count += 1 deleted_count += 1
@@ -486,11 +590,12 @@ async def create_download_token_endpoint(
path='/{token}', path='/{token}',
summary='下载文件', summary='下载文件',
description='使用下载令牌下载文件,令牌在有效期内可重复使用。', description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
response_model=None,
) )
async def download_file( async def download_file(
session: SessionDep, session: SessionDep,
token: str, token: str,
) -> FileResponse: ) -> Response:
""" """
下载文件端点 下载文件端点
@@ -540,8 +645,15 @@ async def download_file(
filename=file_obj.name, filename=file_obj.name,
media_type="application/octet-stream", media_type="application/octet-stream",
) )
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 302 重定向到预签名 URL
presigned_url = s3_service.generate_presigned_url(
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else: else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") raise HTTPException(status_code=500, detail="不支持的存储类型")
# ==================== 包含子路由 ==================== # ==================== 包含子路由 ====================
@@ -613,8 +725,13 @@ async def create_empty_file(
) )
await storage_service.create_empty_file(full_path) await storage_service.create_empty_file(full_path)
storage_path = full_path storage_path = full_path
else: elif policy.type == PolicyType.S3:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") s3_service = await S3StorageService.from_policy(policy)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user_id,
original_filename=request.name,
)
await s3_service.upload_file(storage_path, b'')
# 创建 PhysicalFile 记录 # 创建 PhysicalFile 记录
physical_file = PhysicalFile( physical_file = PhysicalFile(
@@ -798,12 +915,13 @@ async def _validate_source_link(
path='/get/{file_id}/{name}', path='/get/{file_id}/{name}',
summary='文件外链(直接输出文件数据)', summary='文件外链(直接输出文件数据)',
description='通过外链直接获取文件内容,公开访问无需认证。', description='通过外链直接获取文件内容,公开访问无需认证。',
response_model=None,
) )
async def file_get( async def file_get(
session: SessionDep, session: SessionDep,
file_id: UUID, file_id: UUID,
name: str, name: str,
) -> FileResponse: ) -> Response:
""" """
文件外链端点(直接输出) 文件外链端点(直接输出)
@@ -815,13 +933,6 @@ async def file_get(
""" """
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存物理路径save 后对象属性会过期) # 缓存物理路径save 后对象属性会过期)
file_path = physical_file.storage_path file_path = physical_file.storage_path
@@ -829,11 +940,25 @@ async def file_get(
link.downloads += 1 link.downloads += 1
await link.save(session) await link.save(session)
return FileResponse( if policy.type == PolicyType.LOCAL:
path=file_path, storage_service = LocalStorageService(policy)
filename=name, if not await storage_service.file_exists(file_path):
media_type="application/octet-stream", http_exceptions.raise_not_found("物理文件不存在")
)
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
# S3 外链直接输出302 重定向到预签名 URL
s3_service = await S3StorageService.from_policy(policy)
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
@router.get( @router.get(
@@ -846,7 +971,7 @@ async def file_source_redirect(
session: SessionDep, session: SessionDep,
file_id: UUID, file_id: UUID,
name: str, name: str,
) -> FileResponse | RedirectResponse: ) -> Response:
""" """
文件外链端点(重定向/直接输出) 文件外链端点(重定向/直接输出)
@@ -860,13 +985,6 @@ async def file_source_redirect(
""" """
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存所有需要的值save 后对象属性会过期) # 缓存所有需要的值save 后对象属性会过期)
file_path = physical_file.storage_path file_path = physical_file.storage_path
is_private = policy.is_private is_private = policy.is_private
@@ -876,18 +994,36 @@ async def file_source_redirect(
link.downloads += 1 link.downloads += 1
await link.save(session) await link.save(session)
# 公有存储302 重定向到 base_url if policy.type == PolicyType.LOCAL:
if not is_private and base_url: storage_service = LocalStorageService(policy)
relative_path = storage_service.get_relative_path(file_path) if not await storage_service.file_exists(file_path):
redirect_url = f"{base_url}/{relative_path}" http_exceptions.raise_not_found("物理文件不存在")
return RedirectResponse(url=redirect_url, status_code=302)
# 有存储或 base_url 为空:通过应用代理文件 # 有存储302 重定向到 base_url
return FileResponse( if not is_private and base_url:
path=file_path, relative_path = storage_service.get_relative_path(file_path)
filename=name, redirect_url = f"{base_url}/{relative_path}"
media_type="application/octet-stream", return RedirectResponse(url=redirect_url, status_code=302)
)
# 私有存储或 base_url 为空:通过应用代理文件
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 公有存储且有 base_url直接重定向到公开 URL
if not is_private and base_url:
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
return RedirectResponse(url=redirect_url, status_code=302)
# 私有存储:生成预签名 URL 重定向
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
@router.put( @router.put(
@@ -941,11 +1077,15 @@ async def file_content(
if not policy: if not policy:
http_exceptions.raise_internal_error("存储策略不存在") http_exceptions.raise_internal_error("存储策略不存在")
if policy.type != PolicyType.LOCAL: # 读取文件内容
http_exceptions.raise_not_implemented("S3 存储暂未实现") if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
storage_service = LocalStorageService(policy) raw_bytes = await storage_service.read_file(physical_file.storage_path)
raw_bytes = await storage_service.read_file(physical_file.storage_path) elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(physical_file.storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
try: try:
content = raw_bytes.decode('utf-8') content = raw_bytes.decode('utf-8')
@@ -1011,11 +1151,15 @@ async def patch_file_content(
if not policy: if not policy:
http_exceptions.raise_internal_error("存储策略不存在") http_exceptions.raise_internal_error("存储策略不存在")
if policy.type != PolicyType.LOCAL: # 读取文件内容
http_exceptions.raise_not_implemented("S3 存储暂未实现") if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
storage_service = LocalStorageService(policy) raw_bytes = await storage_service.read_file(storage_path)
raw_bytes = await storage_service.read_file(storage_path) elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
# 解码 + 规范化 # 解码 + 规范化
original_text = raw_bytes.decode('utf-8') original_text = raw_bytes.decode('utf-8')
@@ -1049,7 +1193,10 @@ async def patch_file_content(
_check_policy_size_limit(policy, len(new_bytes)) _check_policy_size_limit(policy, len(new_bytes))
# 写入文件 # 写入文件
await storage_service.write_file(storage_path, new_bytes) if policy.type == PolicyType.LOCAL:
await storage_service.write_file(storage_path, new_bytes)
elif policy.type == PolicyType.S3:
await s3_service.upload_file(storage_path, new_bytes)
# 更新数据库 # 更新数据库
owner_id = file_obj.owner_id owner_id = file_obj.owner_id

View File

@@ -8,13 +8,14 @@
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from loguru import logger as l from loguru import logger as l
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from sqlmodels import ( from sqlmodels import (
CreateFileRequest, CreateFileRequest,
Group,
Object, Object,
ObjectCopyRequest, ObjectCopyRequest,
ObjectDeleteRequest, ObjectDeleteRequest,
@@ -22,18 +23,27 @@ from sqlmodels import (
ObjectPropertyDetailResponse, ObjectPropertyDetailResponse,
ObjectPropertyResponse, ObjectPropertyResponse,
ObjectRenameRequest, ObjectRenameRequest,
ObjectSwitchPolicyRequest,
ObjectType, ObjectType,
PhysicalFile, PhysicalFile,
Policy, Policy,
PolicyType, PolicyType,
Task,
TaskProps,
TaskStatus,
TaskSummaryBase,
TaskType,
User, User,
) )
from service.storage import ( from service.storage import (
LocalStorageService, LocalStorageService,
adjust_user_storage, adjust_user_storage,
copy_object_recursive, copy_object_recursive,
migrate_file_with_task,
migrate_directory_files,
) )
from service.storage.object import soft_delete_objects from service.storage.object import soft_delete_objects
from sqlmodels.database_connection import DatabaseManager
from utils import http_exceptions from utils import http_exceptions
object_router = APIRouter( object_router = APIRouter(
@@ -575,3 +585,136 @@ async def router_object_property_detail(
response.checksum_md5 = obj.file_metadata.checksum_md5 response.checksum_md5 = obj.file_metadata.checksum_md5
return response return response
@object_router.patch(
path='/{object_id}/policy',
summary='切换对象存储策略',
)
async def router_object_switch_policy(
session: SessionDep,
background_tasks: BackgroundTasks,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: ObjectSwitchPolicyRequest,
) -> TaskSummaryBase:
"""
切换对象的存储策略
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
目录:更新目录 policy_id新文件使用新策略
若 is_migrate_existing=True额外创建后台任务迁移所有已有文件。
认证JWT Bearer Token
错误处理:
- 404: 对象不存在
- 403: 无权操作此对象 / 用户组无权使用目标策略
- 400: 目标策略与当前相同 / 不能对根目录操作
"""
user_id = user.id
# 查找对象
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None)
)
if not obj:
http_exceptions.raise_not_found("对象不存在")
if obj.owner_id != user_id:
http_exceptions.raise_forbidden("无权操作此对象")
if obj.is_banned:
http_exceptions.raise_banned()
# 根目录不能直接切换策略(应通过子对象或子目录操作)
if obj.parent_id is None:
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
# 校验目标策略存在
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
if not dest_policy:
http_exceptions.raise_not_found("目标存储策略不存在")
# 校验用户组权限
group: Group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
allowed_ids = {p.id for p in group.policies}
if request.policy_id not in allowed_ids:
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
# 不能切换到相同策略
if obj.policy_id == request.policy_id:
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
# 保存必要的属性,避免 save 后对象过期
src_policy_id = obj.policy_id
obj_id = obj.id
obj_is_file = obj.type == ObjectType.FILE
dest_policy_id = request.policy_id
dest_policy_name = dest_policy.name
# 创建任务记录
task = Task(
type=TaskType.POLICY_MIGRATE,
status=TaskStatus.QUEUED,
user_id=user_id,
)
task = await task.save(session)
task_id = task.id
task_props = TaskProps(
task_id=task_id,
source_policy_id=src_policy_id,
dest_policy_id=dest_policy_id,
object_id=obj_id,
)
await task_props.save(session)
if obj_is_file:
# 文件:后台迁移
async def _run_file_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_obj = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
background_tasks.add_task(_run_file_migration)
else:
# 目录:先更新目录自身的 policy_id
obj = await Object.get(session, Object.id == obj_id)
obj.policy_id = dest_policy_id
await obj.save(session)
if request.is_migrate_existing:
# 后台迁移所有已有文件
async def _run_dir_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_folder = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
background_tasks.add_task(_run_dir_migration)
else:
# 不迁移已有文件,直接完成任务
task = await Task.get(session, Task.id == task_id)
task.status = TaskStatus.COMPLETED
task.progress = 100
await task.save(session)
# 重新获取 task 以读取最新状态
task = await Task.get(session, Task.id == task_id)
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
return TaskSummaryBase(
id=task.id,
type=task.type,
status=task.status,
progress=task.progress,
error=task.error,
user_id=task.user_id,
created_at=task.created_at,
updated_at=task.updated_at,
)

View File

@@ -247,11 +247,12 @@ async def router_user_register(
) )
await identity.save(session) await identity.save(session)
# 8. 创建用户根目录 # 8. 创建用户根目录(使用用户组关联的第一个存储策略)
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储") await session.refresh(default_group, ['policies'])
if not default_policy: if not default_group.policies:
logger.error("默认存储策略不存在") logger.error("默认用户组未关联任何存储策略")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
default_policy = default_group.policies[0]
await sqlmodels.Object( await sqlmodels.Object(
name="/", name="/",

View File

@@ -13,6 +13,7 @@ from sqlmodels import (
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest, AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
ChangePasswordRequest, ChangePasswordRequest,
AuthnDetailResponse, AuthnRenameRequest, AuthnDetailResponse, AuthnRenameRequest,
PolicySummary,
) )
from sqlmodels.color import ThemeColorsBase from sqlmodels.color import ThemeColorsBase
from sqlmodels.user_authn import UserAuthn from sqlmodels.user_authn import UserAuthn
@@ -31,16 +32,25 @@ user_settings_router.include_router(file_viewers_router)
@user_settings_router.get( @user_settings_router.get(
path='/policies', path='/policies',
summary='获取用户可选存储策略', summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
) )
def router_user_settings_policies() -> sqlmodels.ResponseBase: async def router_user_settings_policies(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[PolicySummary]:
""" """
Get user selectable storage policies. 获取当前用户所在组可选的存储策略列表
Returns: 返回用户组关联的所有存储策略的摘要信息。
dict: A dictionary containing available storage policies for the user.
""" """
http_exceptions.raise_not_implemented() group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
return [
PolicySummary(
id=p.id, name=p.name, type=p.type,
server=p.server, max_size=p.max_size, is_private=p.is_private,
)
for p in group.policies
]
@user_settings_router.get( @user_settings_router.get(

View File

@@ -3,6 +3,7 @@
提供文件存储相关的服务,包括: 提供文件存储相关的服务,包括:
- 本地存储服务 - 本地存储服务
- S3 存储服务
- 命名规则解析器 - 命名规则解析器
- 存储异常定义 - 存储异常定义
""" """
@@ -11,6 +12,8 @@ from .exceptions import (
FileReadError, FileReadError,
FileWriteError, FileWriteError,
InvalidPathError, InvalidPathError,
S3APIError,
S3MultipartUploadError,
StorageException, StorageException,
StorageFileNotFoundError, StorageFileNotFoundError,
UploadSessionExpiredError, UploadSessionExpiredError,
@@ -25,4 +28,6 @@ from .object import (
permanently_delete_objects, permanently_delete_objects,
restore_objects, restore_objects,
soft_delete_objects, soft_delete_objects,
) )
from .migrate import migrate_file_with_task, migrate_directory_files
from .s3_storage import S3StorageService

View File

@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
class InvalidPathError(StorageException): class InvalidPathError(StorageException):
"""无效的路径""" """无效的路径"""
pass pass
class S3APIError(StorageException):
"""S3 API 请求错误"""
pass
class S3MultipartUploadError(S3APIError):
"""S3 分片上传错误"""
pass

291
service/storage/migrate.py Normal file
View File

@@ -0,0 +1,291 @@
"""
存储策略迁移服务
提供跨存储策略的文件迁移功能:
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
"""
from uuid import UUID
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.task import Task, TaskStatus
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
async def _get_storage_service(
policy: Policy,
) -> LocalStorageService | S3StorageService:
"""
根据策略类型创建对应的存储服务实例
:param policy: 存储策略
:return: 存储服务实例
"""
if policy.type == PolicyType.LOCAL:
return LocalStorageService(policy)
elif policy.type == PolicyType.S3:
return await S3StorageService.from_policy(policy)
else:
raise ValueError(f"不支持的存储策略类型: {policy.type}")
async def _read_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> bytes:
"""
从存储服务读取文件内容
:param service: 存储服务实例
:param storage_path: 文件存储路径
:return: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
return await service.read_file(storage_path)
else:
return await service.download_file(storage_path)
async def _write_file_to_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
data: bytes,
) -> None:
"""
将文件内容写入存储服务
:param service: 存储服务实例
:param storage_path: 文件存储路径
:param data: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
await service.write_file(storage_path, data)
else:
await service.upload_file(storage_path, data)
async def _delete_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> None:
"""
从存储服务删除文件
:param service: 存储服务实例
:param storage_path: 文件存储路径
"""
if isinstance(service, LocalStorageService):
await service.delete_file(storage_path)
else:
await service.delete_file(storage_path)
async def migrate_single_file(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
) -> None:
"""
将单个文件对象从当前存储策略迁移到目标策略
流程:
1. 获取源物理文件和存储服务
2. 读取源文件内容
3. 在目标存储中生成新路径并写入
4. 创建新的 PhysicalFile 记录
5. 更新 Object 的 policy_id 和 physical_file_id
6. 旧 PhysicalFile 引用计数 -1如为 0 则删除源物理文件
:param session: 数据库会话
:param obj: 待迁移的文件对象(必须为文件类型)
:param dest_policy: 目标存储策略
"""
if obj.type != ObjectType.FILE:
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
# 获取源策略和物理文件
src_policy: Policy = await obj.awaitable_attrs.policy
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
if not old_physical:
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
return
if src_policy.id == dest_policy.id:
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
return
# 1. 从源存储读取文件
src_service = await _get_storage_service(src_policy)
data = await _read_file_from_storage(src_service, old_physical.storage_path)
# 2. 在目标存储生成新路径并写入
dest_service = await _get_storage_service(dest_policy)
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
user_id=obj.owner_id,
original_filename=obj.name,
)
await _write_file_to_storage(dest_service, new_storage_path, data)
# 3. 创建新的 PhysicalFile
new_physical = PhysicalFile(
storage_path=new_storage_path,
size=old_physical.size,
checksum_md5=old_physical.checksum_md5,
policy_id=dest_policy.id,
reference_count=1,
)
new_physical = await new_physical.save(session)
# 4. 更新 Object
obj.policy_id = dest_policy.id
obj.physical_file_id = new_physical.id
await obj.save(session)
# 5. 旧 PhysicalFile 引用计数 -1
old_physical.decrement_reference()
if old_physical.can_be_deleted:
# 删除源存储中的物理文件
try:
await _delete_file_from_storage(src_service, old_physical.storage_path)
except Exception as e:
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
await PhysicalFile.delete(session, old_physical)
else:
await old_physical.save(session)
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name}{dest_policy.name}")
async def migrate_file_with_task(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移单个文件并更新任务状态
:param session: 数据库会话
:param obj: 待迁移的文件对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
await migrate_single_file(session, obj, dest_policy)
task.status = TaskStatus.COMPLETED
task.progress = 100
await task.save(session)
except Exception as e:
l.error(f"文件迁移任务失败: {obj.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
await task.save(session)
async def migrate_directory_files(
session: AsyncSession,
folder: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移目录下所有文件到目标存储策略
递归遍历目录树,将所有文件迁移到目标策略。
子目录的 policy_id 同步更新。
任务进度按文件数比例更新。
:param session: 数据库会话
:param folder: 目录对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
# 收集所有需要迁移的文件
files_to_migrate: list[Object] = []
folders_to_update: list[Object] = []
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
total = len(files_to_migrate)
migrated = 0
errors: list[str] = []
for file_obj in files_to_migrate:
try:
await migrate_single_file(session, file_obj, dest_policy)
migrated += 1
except Exception as e:
error_msg = f"{file_obj.name}: {e}"
l.error(f"迁移文件失败: {error_msg}")
errors.append(error_msg)
# 更新进度
if total > 0:
task.progress = min(99, int(migrated / total * 100))
task = await task.save(session)
# 更新所有子目录的 policy_id
for sub_folder in folders_to_update:
sub_folder.policy_id = dest_policy.id
await sub_folder.save(session)
# 完成任务
if errors:
task.status = TaskStatus.ERROR
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
else:
task.status = TaskStatus.COMPLETED
task.progress = 100
await task.save(session)
l.info(
f"目录迁移完成: {folder.name} ({folder.id}), "
f"成功 {migrated}/{total}, 错误 {len(errors)}"
)
except Exception as e:
l.error(f"目录迁移任务失败: {folder.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
await task.save(session)
async def _collect_objects_recursive(
session: AsyncSession,
folder: Object,
files: list[Object],
folders: list[Object],
) -> None:
"""
递归收集目录下所有文件和子目录
:param session: 数据库会话
:param folder: 当前目录
:param files: 文件列表(输出)
:param folders: 子目录列表(输出)
"""
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
for child in children:
if child.type == ObjectType.FILE:
files.append(child)
elif child.type == ObjectType.FOLDER:
folders.append(child)
await _collect_objects_recursive(session, child, files, folders)

View File

@@ -6,7 +6,8 @@ from sqlalchemy import update as sql_update
from sqlalchemy.sql.functions import func from sqlalchemy.sql.functions import func
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from service.storage import LocalStorageService from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
from sqlmodels import ( from sqlmodels import (
Object, Object,
PhysicalFile, PhysicalFile,
@@ -271,10 +272,14 @@ async def permanently_delete_objects(
if physical_file.can_be_deleted: if physical_file.can_be_deleted:
# 物理删除文件 # 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id) policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL: if policy:
try: try:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(physical_file.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}") l.debug(f"物理文件已删除: {obj_name}")
except Exception as e: except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}") l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
@@ -374,10 +379,19 @@ async def delete_object_recursive(
if physical_file.can_be_deleted: if physical_file.can_be_deleted:
# 物理删除文件 # 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id) policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL: if policy:
try: try:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(physical_file.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
options = await policy.awaitable_attrs.options
s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}") l.debug(f"物理文件已删除: {obj_name}")
except Exception as e: except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}") l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")

View File

@@ -0,0 +1,709 @@
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
移植自 foxline-pro-backend-server 项目的 S3APIClient
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
"""
import hashlib
import hmac
import xml.etree.ElementTree as ET
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import ClassVar, Literal
from urllib.parse import quote, urlencode
from uuid import UUID
import aiohttp
from yarl import URL
from loguru import logger as l
from sqlmodels.policy import Policy
from .exceptions import S3APIError, S3MultipartUploadError
from .naming_rule import NamingContext, NamingRuleParser
def _sign(key: bytes, msg: str) -> bytes:
"""HMAC-SHA256 签名"""
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
class S3StorageService:
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息。
使用示例::
service = S3StorageService(policy, region='us-east-1')
await service.upload_file('path/to/file.txt', b'content')
data = await service.download_file('path/to/file.txt')
"""
_http_session: ClassVar[aiohttp.ClientSession | None] = None
def __init__(
self,
policy: Policy,
region: str = 'us-east-1',
is_path_style: bool = False,
):
"""
:param policy: 存储策略server=endpoint_url, bucket_name, access_key, secret_key
:param region: S3 区域
:param is_path_style: 是否使用路径风格 URL
"""
if not policy.server:
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
if not policy.bucket_name:
raise S3APIError("S3 策略必须指定 bucket_name")
if not policy.access_key:
raise S3APIError("S3 策略必须指定 access_key")
if not policy.secret_key:
raise S3APIError("S3 策略必须指定 secret_key")
self._policy = policy
self._endpoint_url = policy.server.rstrip("/")
self._bucket_name = policy.bucket_name
self._access_key = policy.access_key
self._secret_key = policy.secret_key
self._region = region
self._is_path_style = is_path_style
self._base_url = policy.base_url
# 从 endpoint_url 提取 host
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
# ==================== 工厂方法 ====================
@classmethod
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
"""
根据 Policy 异步创建 S3StorageService自动加载 options
:param policy: 存储策略
:return: S3StorageService 实例
"""
options = await policy.awaitable_attrs.options
region = options.s3_region if options else 'us-east-1'
is_path_style = options.s3_path_style if options else False
return cls(policy, region=region, is_path_style=is_path_style)
# ==================== HTTP Session 管理 ====================
@classmethod
async def initialize_session(cls) -> None:
"""初始化全局 aiohttp ClientSession"""
if cls._http_session is None or cls._http_session.closed:
cls._http_session = aiohttp.ClientSession()
l.info("S3StorageService HTTP session 已初始化")
@classmethod
async def close_session(cls) -> None:
"""关闭全局 aiohttp ClientSession"""
if cls._http_session and not cls._http_session.closed:
await cls._http_session.close()
cls._http_session = None
l.info("S3StorageService HTTP session 已关闭")
@classmethod
def _get_session(cls) -> aiohttp.ClientSession:
"""获取 HTTP session"""
if cls._http_session is None or cls._http_session.closed:
# 懒初始化,以防 initialize_session 未被调用
cls._http_session = aiohttp.ClientSession()
return cls._http_session
# ==================== AWS Signature V4 签名 ====================
def _get_signature_key(self, date_stamp: str) -> bytes:
"""生成 AWS Signature V4 签名密钥"""
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
k_region = _sign(k_date, self._region)
k_service = _sign(k_region, "s3")
return _sign(k_service, "aws4_request")
def _create_authorization_header(
self,
method: str,
uri: str,
query_string: str,
headers: dict[str, str],
payload_hash: str,
amz_date: str,
date_stamp: str,
) -> str:
"""创建 AWS Signature V4 授权头"""
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
canonical_headers = "".join(
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
)
canonical_request = (
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
f"{signed_headers}\n{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
string_to_sign = (
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
return (
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
f"SignedHeaders={signed_headers}, Signature={signature}"
)
def _build_headers(
self,
method: str,
uri: str,
query_string: str = "",
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
payload_hash: str | None = None,
host: str | None = None,
) -> dict[str, str]:
"""
构建包含 AWS V4 签名的完整请求头
:param method: HTTP 方法
:param uri: 请求 URI
:param query_string: 查询字符串
:param payload: 请求体字节(用于计算哈希)
:param content_type: Content-Type
:param extra_headers: 额外请求头
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
:param host: Host 头(默认使用 self._host
"""
now_utc = datetime.now(timezone.utc)
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
date_stamp = now_utc.strftime("%Y%m%d")
if payload_hash is None:
payload_hash = hashlib.sha256(payload).hexdigest()
effective_host = host or self._host
headers: dict[str, str] = {
"Host": effective_host,
"X-Amz-Date": amz_date,
"X-Amz-Content-Sha256": payload_hash,
}
if content_type:
headers["Content-Type"] = content_type
if extra_headers:
headers.update(extra_headers)
authorization = self._create_authorization_header(
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
)
headers["Authorization"] = authorization
return headers
# ==================== 内部请求方法 ====================
def _build_uri(self, key: str | None = None) -> str:
"""
构建请求 URI
按 AWS S3 Signature V4 规范对路径进行 URI 编码S3 仅需一次)。
斜杠作为路径分隔符保留不编码。
"""
if self._is_path_style:
if key:
return f"/{self._bucket_name}/{quote(key, safe='/')}"
return f"/{self._bucket_name}"
else:
if key:
return f"/{quote(key, safe='/')}"
return "/"
def _build_url(self, uri: str, query_string: str = "") -> str:
"""构建完整请求 URL"""
if self._is_path_style:
base = self._endpoint_url
else:
# 虚拟主机风格bucket.endpoint
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
base = f"{protocol}{self._bucket_name}.{self._host}"
url = f"{base}{uri}"
if query_string:
url = f"{url}?{query_string}"
return url
def _get_effective_host(self) -> str:
"""获取实际请求的 Host 头"""
if self._is_path_style:
return self._host
return f"{self._bucket_name}.{self._host}"
async def _request(
self,
method: str,
key: str | None = None,
query_params: dict[str, str] | None = None,
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> aiohttp.ClientResponse:
"""发送签名请求"""
uri = self._build_uri(key)
query_string = urlencode(sorted(query_params.items())) if query_params else ""
effective_host = self._get_effective_host()
headers = self._build_headers(
method, uri, query_string, payload, content_type,
extra_headers, host=effective_host,
)
url = self._build_url(uri, query_string)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=payload if payload else None,
)
return response
except Exception as e:
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
async def _request_streaming(
self,
method: str,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> aiohttp.ClientResponse:
"""
发送流式签名请求(大文件上传)
使用 UNSIGNED-PAYLOAD 作为 payload hash。
"""
uri = self._build_uri(key)
effective_host = self._get_effective_host()
headers = self._build_headers(
method,
uri,
query_string="",
content_type=content_type,
extra_headers={"Content-Length": str(content_length)},
payload_hash="UNSIGNED-PAYLOAD",
host=effective_host,
)
url = self._build_url(uri)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=data_stream,
)
return response
except Exception as e:
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
# ==================== 文件操作 ====================
async def upload_file(
self,
key: str,
data: bytes,
content_type: str = 'application/octet-stream',
) -> None:
"""
上传文件
:param key: S3 对象键
:param data: 文件内容
:param content_type: MIME 类型
"""
async with await self._request(
"PUT", key=key, payload=data, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
async def upload_file_streaming(
self,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> None:
"""
流式上传文件(大文件,避免全部加载到内存)
:param key: S3 对象键
:param data_stream: 异步字节流迭代器
:param content_length: 数据总长度(必须准确)
:param content_type: MIME 类型
"""
async with await self._request_streaming(
"PUT", key=key, data_stream=data_stream,
content_length=content_length, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"流式上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
async def download_file(self, key: str) -> bytes:
"""
下载文件
:param key: S3 对象键
:return: 文件内容
"""
async with await self._request("GET", key=key) as response:
if response.status != 200:
body = await response.text()
raise S3APIError(
f"下载失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
data = await response.read()
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
return data
async def delete_file(self, key: str) -> None:
"""
删除文件
:param key: S3 对象键
"""
async with await self._request("DELETE", key=key) as response:
if response.status in (200, 204):
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
else:
body = await response.text()
raise S3APIError(
f"删除失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
async def file_exists(self, key: str) -> bool:
"""
检查文件是否存在
:param key: S3 对象键
:return: 是否存在
"""
async with await self._request("HEAD", key=key) as response:
if response.status == 200:
return True
elif response.status == 404:
return False
else:
raise S3APIError(
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
async def get_file_size(self, key: str) -> int:
"""
获取文件大小
:param key: S3 对象键
:return: 文件大小(字节)
"""
async with await self._request("HEAD", key=key) as response:
if response.status != 200:
raise S3APIError(
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
return int(response.headers.get("Content-Length", 0))
# ==================== Multipart Upload ====================
async def create_multipart_upload(
self,
key: str,
content_type: str = 'application/octet-stream',
) -> str:
"""
创建分片上传任务
:param key: S3 对象键
:param content_type: MIME 类型
:return: Upload ID
"""
async with await self._request(
"POST",
key=key,
query_params={"uploads": ""},
content_type=content_type,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"创建分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
body = await response.text()
root = ET.fromstring(body)
# 查找 UploadId 元素(支持命名空间)
upload_id_elem = root.find("UploadId")
if upload_id_elem is None:
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
if upload_id_elem is None or not upload_id_elem.text:
raise S3MultipartUploadError(
f"创建分片上传响应中未找到 UploadId: {body}"
)
upload_id = upload_id_elem.text
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
return upload_id
async def upload_part(
self,
key: str,
upload_id: str,
part_number: int,
data: bytes,
) -> str:
"""
上传单个分片
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param part_number: 分片编号(从 1 开始)
:param data: 分片数据
:return: ETag
"""
async with await self._request(
"PUT",
key=key,
query_params={
"partNumber": str(part_number),
"uploadId": upload_id,
},
payload=data,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"上传分片失败: {self._bucket_name}/{key}, "
f"part={part_number}, 状态: {response.status}, {body}"
)
etag = response.headers.get("ETag", "").strip('"')
l.debug(
f"S3 分片上传成功: {self._bucket_name}/{key}, "
f"part={part_number}, etag={etag}"
)
return etag
async def complete_multipart_upload(
self,
key: str,
upload_id: str,
parts: list[tuple[int, str]],
) -> None:
"""
完成分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param parts: 分片列表 [(part_number, etag)]
"""
# 按 part_number 排序
parts_sorted = sorted(parts, key=lambda p: p[0])
# 构建 CompleteMultipartUpload XML
xml_parts = ''.join(
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
for pn, etag in parts_sorted
)
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
payload_bytes = payload.encode('utf-8')
async with await self._request(
"POST",
key=key,
query_params={"uploadId": upload_id},
payload=payload_bytes,
content_type="application/xml",
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"完成分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.info(
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
f"{len(parts)} 个分片"
)
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
"""
取消分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
"""
async with await self._request(
"DELETE",
key=key,
query_params={"uploadId": upload_id},
) as response:
if response.status in (200, 204):
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
else:
body = await response.text()
l.warning(
f"取消分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
# ==================== 预签名 URL ====================
def generate_presigned_url(
self,
key: str,
method: Literal['GET', 'PUT'] = 'GET',
expires_in: int = 3600,
filename: str | None = None,
) -> str:
"""
生成 S3 预签名 URLAWS Signature V4 Query String
:param key: S3 对象键
:param method: HTTP 方法GET 下载PUT 上传)
:param expires_in: URL 有效期(秒)
:param filename: 文件名GET 请求时设置 Content-Disposition
:return: 预签名 URL
"""
current_time = datetime.now(timezone.utc)
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
date_stamp = current_time.strftime("%Y%m%d")
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
credential = f"{self._access_key}/{credential_scope}"
uri = self._build_uri(key)
effective_host = self._get_effective_host()
query_params: dict[str, str] = {
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
'X-Amz-Credential': credential,
'X-Amz-Date': amz_date,
'X-Amz-Expires': str(expires_in),
'X-Amz-SignedHeaders': 'host',
}
# GET 请求时添加 Content-Disposition
if method == "GET" and filename:
encoded_filename = quote(filename, safe='')
query_params['response-content-disposition'] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
canonical_query_string = "&".join(
f"{quote(k, safe='')}={quote(v, safe='')}"
for k, v in sorted(query_params.items())
)
canonical_headers = f"host:{effective_host}\n"
signed_headers = "host"
payload_hash = "UNSIGNED-PAYLOAD"
canonical_request = (
f"{method}\n"
f"{uri}\n"
f"{canonical_query_string}\n"
f"{canonical_headers}\n"
f"{signed_headers}\n"
f"{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
string_to_sign = (
f"{algorithm}\n"
f"{amz_date}\n"
f"{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
base_url = self._build_url(uri)
return (
f"{base_url}?"
f"{canonical_query_string}&"
f"X-Amz-Signature={signature}"
)
# ==================== 路径生成 ====================
async def generate_file_path(
self,
user_id: UUID,
original_filename: str,
) -> tuple[str, str, str]:
"""
根据命名规则生成 S3 文件存储路径
与 LocalStorageService.generate_file_path 接口一致。
:param user_id: 用户UUID
:param original_filename: 原始文件名
:return: (相对目录路径, 存储文件名, 完整存储路径)
"""
context = NamingContext(
user_id=user_id,
original_filename=original_filename,
)
# 解析目录规则
dir_path = ""
if self._policy.dir_name_rule:
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
# 解析文件名规则
if self._policy.auto_rename and self._policy.file_name_rule:
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
# 确保有扩展名
if '.' in original_filename and '.' not in storage_name:
ext = original_filename.rsplit('.', 1)[1]
storage_name = f"{storage_name}.{ext}"
else:
storage_name = original_filename
# S3 不需要创建目录,直接拼接路径
if dir_path:
storage_path = f"{dir_path}/{storage_name}"
else:
storage_path = storage_name
return dir_path, storage_name, storage_path

View File

@@ -954,18 +954,11 @@ class PolicyType(StrEnum):
S3 = "s3" # S3 兼容存储 S3 = "s3" # S3 兼容存储
``` ```
### StorageType ### PolicyType
```python ```python
class StorageType(StrEnum): class PolicyType(StrEnum):
LOCAL = "local" # 本地存储 LOCAL = "local" # 本地存储
QINIU = "qiniu" # 七牛云 S3 = "s3" # S3 兼容存储
TENCENT = "tencent" # 腾讯云
ALIYUN = "aliyun" # 阿里云
ONEDRIVE = "onedrive" # OneDrive
GOOGLE_DRIVE = "google_drive" # Google Drive
DROPBOX = "dropbox" # Dropbox
WEBDAV = "webdav" # WebDAV
REMOTE = "remote" # 远程存储
``` ```
### UserStatus ### UserStatus

View File

@@ -82,6 +82,7 @@ from .object import (
ObjectPropertyResponse, ObjectPropertyResponse,
ObjectRenameRequest, ObjectRenameRequest,
ObjectResponse, ObjectResponse,
ObjectSwitchPolicyRequest,
ObjectType, ObjectType,
PolicyResponse, PolicyResponse,
UploadChunkResponse, UploadChunkResponse,
@@ -100,7 +101,10 @@ from .object import (
from .physical_file import PhysicalFile, PhysicalFileBase from .physical_file import PhysicalFile, PhysicalFileBase
from .uri import DiskNextURI, FileSystemNamespace from .uri import DiskNextURI, FileSystemNamespace
from .order import Order, OrderStatus, OrderType from .order import Order, OrderStatus, OrderType
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary from .policy import (
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
PolicyType, PolicySummary, PolicyUpdateRequest,
)
from .redeem import Redeem, RedeemType from .redeem import Redeem, RedeemType
from .report import Report, ReportReason from .report import Report, ReportReason
from .setting import ( from .setting import (
@@ -116,7 +120,7 @@ from .share import (
from .source_link import SourceLink from .source_link import SourceLink
from .storage_pack import StoragePack from .storage_pack import StoragePack
from .tag import Tag, TagType from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
from .webdav import ( from .webdav import (
WebDAV, WebDAVBase, WebDAV, WebDAVBase,
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse, WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,

View File

@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, text from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
@@ -260,7 +261,7 @@ class Group(GroupBase, UUIDTableBaseMixin):
name: str = Field(max_length=255, unique=True) name: str = Field(max_length=255, unique=True)
"""用户组名""" """用户组名"""
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""最大存储空间(字节)""" """最大存储空间(字节)"""
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})

View File

@@ -9,6 +9,8 @@ from sqlmodel import Field, Relationship, CheckConstraint, Index, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
from .policy import PolicyType
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
from .policy import Policy from .policy import Policy
@@ -23,18 +25,6 @@ class ObjectType(StrEnum):
FILE = "file" FILE = "file"
FOLDER = "folder" FOLDER = "folder"
class StorageType(StrEnum):
"""存储类型枚举"""
LOCAL = "local"
QINIU = "qiniu"
TENCENT = "tencent"
ALIYUN = "aliyun"
ONEDRIVE = "onedrive"
GOOGLE_DRIVE = "google_drive"
DROPBOX = "dropbox"
WEBDAV = "webdav"
REMOTE = "remote"
class FileMetadataBase(SQLModelBase): class FileMetadataBase(SQLModelBase):
"""文件元数据基础模型""" """文件元数据基础模型"""
@@ -156,7 +146,7 @@ class PolicyResponse(SQLModelBase):
name: str name: str
"""策略名称""" """策略名称"""
type: StorageType type: PolicyType
"""存储类型""" """存储类型"""
max_size: int = Field(ge=0, default=0, sa_type=BigInteger) max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
@@ -624,6 +614,12 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
storage_path: str | None = Field(default=None, max_length=512) storage_path: str | None = Field(default=None, max_length=512)
"""文件存储路径""" """文件存储路径"""
s3_upload_id: str | None = Field(default=None, max_length=256)
"""S3 Multipart Upload ID仅 S3 策略使用)"""
s3_part_etags: str | None = None
"""S3 已上传分片的 ETag 列表JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
expires_at: datetime expires_at: datetime
"""会话过期时间""" """会话过期时间"""
@@ -732,6 +728,16 @@ class CreateFileRequest(SQLModelBase):
"""存储策略UUID不指定则使用父目录的策略""" """存储策略UUID不指定则使用父目录的策略"""
class ObjectSwitchPolicyRequest(SQLModelBase):
"""切换对象存储策略请求"""
policy_id: UUID
"""目标存储策略UUID"""
is_migrate_existing: bool = False
"""(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
# ==================== 对象操作相关 DTO ==================== # ==================== 对象操作相关 DTO ====================
class ObjectCopyRequest(SQLModelBase): class ObjectCopyRequest(SQLModelBase):

View File

@@ -102,6 +102,94 @@ class PolicySummary(SQLModelBase):
"""是否私有""" """是否私有"""
class PolicyCreateRequest(PolicyBase):
"""创建存储策略请求 DTO包含 PolicyOptions 扁平字段"""
# PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: str | None = Field(default=None, max_length=255)
"""OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, ge=1)
"""分片上传大小字节默认50MB"""
s3_path_style: bool = False
"""是否使用S3路径风格"""
s3_region: str = Field(default='us-east-1', max_length=64)
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyUpdateRequest(SQLModelBase):
"""更新存储策略请求 DTO所有字段可选"""
name: str | None = Field(default=None, max_length=255)
"""策略名称"""
server: str | None = Field(default=None, max_length=255)
"""服务器地址"""
bucket_name: str | None = Field(default=None, max_length=255)
"""存储桶名称"""
is_private: bool | None = None
"""是否为私有空间"""
base_url: str | None = Field(default=None, max_length=255)
"""访问文件的基础URL"""
access_key: str | None = None
"""Access Key"""
secret_key: str | None = None
"""Secret Key"""
max_size: int | None = Field(default=None, ge=0)
"""允许上传的最大文件尺寸(字节)"""
auto_rename: bool | None = None
"""是否自动重命名"""
dir_name_rule: str | None = Field(default=None, max_length=255)
"""目录命名规则"""
file_name_rule: str | None = Field(default=None, max_length=255)
"""文件命名规则"""
is_origin_link_enable: bool | None = None
"""是否开启源链接访问"""
# PolicyOptions 字段
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: str | None = Field(default=None, max_length=255)
"""OneDrive重定向地址"""
chunk_size: int | None = Field(default=None, ge=1)
"""分片上传大小(字节)"""
s3_path_style: bool | None = None
"""是否使用S3路径风格"""
s3_region: str | None = Field(default=None, max_length=64)
"""S3 区域"""
# ==================== 数据库模型 ==================== # ==================== 数据库模型 ====================
@@ -126,6 +214,9 @@ class PolicyOptionsBase(SQLModelBase):
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否使用S3路径风格""" """是否使用S3路径风格"""
s3_region: str = Field(default='us-east-1', max_length=64, sa_column_kwargs={"server_default": "'us-east-1'"})
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin): class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
"""存储策略选项模型与Policy一对一关联""" """存储策略选项模型与Policy一对一关联"""

View File

@@ -26,8 +26,8 @@ class TaskStatus(StrEnum):
class TaskType(StrEnum): class TaskType(StrEnum):
"""任务类型枚举""" """任务类型枚举"""
# [TODO] 补充具体任务类型 POLICY_MIGRATE = "policy_migrate"
pass """存储策略迁移"""
# ==================== DTO 模型 ==================== # ==================== DTO 模型 ====================
@@ -39,7 +39,7 @@ class TaskSummaryBase(SQLModelBase):
id: int id: int
"""任务ID""" """任务ID"""
type: int type: TaskType
"""任务类型""" """任务类型"""
status: TaskStatus status: TaskStatus
@@ -91,7 +91,14 @@ class TaskPropsBase(SQLModelBase):
file_ids: str | None = None file_ids: str | None = None
"""文件ID列表逗号分隔""" """文件ID列表逗号分隔"""
# [TODO] 根据业务需求补充更多字段 source_policy_id: UUID | None = None
"""源存储策略UUID"""
dest_policy_id: UUID | None = None
"""目标存储策略UUID"""
object_id: UUID | None = None
"""关联的对象UUID"""
class TaskProps(TaskPropsBase, TableBaseMixin): class TaskProps(TaskPropsBase, TableBaseMixin):
@@ -99,7 +106,7 @@ class TaskProps(TaskPropsBase, TableBaseMixin):
task_id: int = Field( task_id: int = Field(
foreign_key="task.id", foreign_key="task.id",
primary_key=True, unique=True,
ondelete="CASCADE" ondelete="CASCADE"
) )
"""关联的任务ID""" """关联的任务ID"""
@@ -121,8 +128,8 @@ class Task(SQLModelBase, TableBaseMixin):
status: TaskStatus = Field(default=TaskStatus.QUEUED) status: TaskStatus = Field(default=TaskStatus.QUEUED)
"""任务状态""" """任务状态"""
type: int = Field(default=0) type: TaskType
"""任务类型 [TODO] 待定义枚举""" """任务类型"""
progress: int = Field(default=0, ge=0, le=100) progress: int = Field(default=0, ge=0, le=100)
"""任务进度0-100""" """任务进度0-100"""

View File

@@ -4,7 +4,7 @@ from typing import Literal, TYPE_CHECKING, TypeVar
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import BinaryExpression, ClauseElement, and_ from sqlalchemy import BigInteger, BinaryExpression, ClauseElement, and_
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo from sqlmodel.main import RelationshipInfo
@@ -473,7 +473,7 @@ class User(UserBase, UUIDTableBaseMixin):
status: UserStatus = UserStatus.ACTIVE status: UserStatus = UserStatus.ACTIVE
"""用户状态""" """用户状态"""
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0) storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}, ge=0)
"""已用存储空间(字节)""" """已用存储空间(字节)"""
avatar: str = Field(default="default", max_length=255) avatar: str = Field(default="default", max_length=255)