diff --git a/main.py b/main.py
index 83a6a58..25b24e0 100644
--- a/main.py
+++ b/main.py
@@ -8,6 +8,7 @@ from routers import router
from routers.dav import dav_app
from routers.dav.provider import EventLoopRef
from service.redis import RedisManager
+from service.storage import S3StorageService
from sqlmodels.database_connection import DatabaseManager
from sqlmodels.migration import migration
from utils import JWT
@@ -50,8 +51,10 @@ lifespan.add_startup(_init_db)
lifespan.add_startup(migration)
lifespan.add_startup(JWT.load_secret_key)
lifespan.add_startup(RedisManager.connect)
+lifespan.add_startup(S3StorageService.initialize_session)
# 添加关闭项
+lifespan.add_shutdown(S3StorageService.close_session)
lifespan.add_shutdown(DatabaseManager.close)
lifespan.add_shutdown(RedisManager.disconnect)
diff --git a/routers/api/v1/admin/policy/__init__.py b/routers/api/v1/admin/policy/__init__.py
index 3eaf0ad..8746d73 100644
--- a/routers/api/v1/admin/policy/__init__.py
+++ b/routers/api/v1/admin/policy/__init__.py
@@ -8,11 +8,11 @@ from sqlmodel import Field
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
- Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
- ListResponse, Object,
+ Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
+ PolicyUpdateRequest, ResponseBase, ListResponse, Object,
)
from sqlmodel_ext import SQLModelBase
-from service.storage import DirectoryCreationError, LocalStorageService
+from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
admin_policy_router = APIRouter(
prefix='/policy',
@@ -67,6 +67,12 @@ class PolicyDetailResponse(SQLModelBase):
base_url: str | None
"""基础URL"""
+ access_key: str | None
+ """Access Key"""
+
+ secret_key: str | None
+ """Secret Key"""
+
max_size: int
"""最大文件尺寸"""
@@ -107,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
secret: str
"""从机通信密钥"""
-class PolicyCreateRequest(PolicyBase):
- """创建存储策略请求 DTO,继承 PolicyBase 中的所有字段"""
- pass
+class PolicyTestS3Request(SQLModelBase):
+ """测试 S3 连接请求 DTO"""
+
+ server: str = Field(max_length=255)
+ """S3 端点地址"""
+
+ bucket_name: str = Field(max_length=255)
+ """存储桶名称"""
+
+ access_key: str
+ """Access Key"""
+
+ secret_key: str
+ """Secret Key"""
+
+ s3_region: str = Field(default='us-east-1', max_length=64)
+ """S3 区域"""
+
+ s3_path_style: bool = False
+ """是否使用路径风格"""
+
+
+class PolicyTestS3Response(SQLModelBase):
+ """S3 连接测试响应"""
+
+ is_connected: bool
+ """连接是否成功"""
+
+ message: str
+ """测试结果消息"""
+
+
+# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
+
+_OPTIONS_FIELDS: set[str] = {
+ 'token', 'file_type', 'mimetype', 'od_redirect',
+ 'chunk_size', 's3_path_style', 's3_region',
+}
+
@admin_policy_router.get(
path='/list',
@@ -277,7 +319,20 @@ async def router_policy_add_policy(
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
# 保存到数据库
- await policy.save(session)
+ policy = await policy.save(session)
+
+ # 创建策略选项
+ options = PolicyOptions(
+ policy_id=policy.id,
+ token=request.token,
+ file_type=request.file_type,
+ mimetype=request.mimetype,
+ od_redirect=request.od_redirect,
+ chunk_size=request.chunk_size,
+ s3_path_style=request.s3_path_style,
+ s3_region=request.s3_region,
+ )
+ await options.save(session)
@admin_policy_router.post(
path='/cors',
@@ -371,6 +426,8 @@ async def router_policy_get_policy(
bucket_name=policy.bucket_name,
is_private=policy.is_private,
base_url=policy.base_url,
+ access_key=policy.access_key,
+ secret_key=policy.secret_key,
max_size=policy.max_size,
auto_rename=policy.auto_rename,
dir_name_rule=policy.dir_name_rule,
@@ -417,4 +474,108 @@ async def router_policy_delete_policy(
policy_name = policy.name
await Policy.delete(session, policy)
- l.info(f"管理员删除了存储策略: {policy_name}")
\ No newline at end of file
+ l.info(f"管理员删除了存储策略: {policy_name}")
+
+
+@admin_policy_router.patch(
+ path='/{policy_id}',
+ summary='更新存储策略',
+ description='更新存储策略配置。策略类型创建后不可更改。',
+ dependencies=[Depends(admin_required)],
+ status_code=204,
+)
+async def router_policy_update_policy(
+ session: SessionDep,
+ policy_id: UUID,
+ request: PolicyUpdateRequest,
+) -> None:
+ """
+ 更新存储策略端点
+
+ 功能:
+ - 更新策略基础字段和扩展选项
+ - 策略类型(type)不可更改
+
+ 认证:
+ - 需要管理员权限
+
+ :param session: 数据库会话
+ :param policy_id: 存储策略UUID
+ :param request: 更新请求
+ """
+ policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
+ if not policy:
+ raise HTTPException(status_code=404, detail="存储策略不存在")
+
+ # 检查名称唯一性(如果要更新名称)
+ if request.name and request.name != policy.name:
+ existing = await Policy.get(session, Policy.name == request.name)
+ if existing:
+ raise HTTPException(status_code=409, detail="策略名称已存在")
+
+ # 分离 Policy 字段和 Options 字段
+ all_data = request.model_dump(exclude_unset=True)
+ policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
+ options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
+
+ # 更新 Policy 基础字段
+ if policy_data:
+ for key, value in policy_data.items():
+ setattr(policy, key, value)
+ policy = await policy.save(session)
+
+ # 更新或创建 PolicyOptions
+ if options_data:
+ if policy.options:
+ for key, value in options_data.items():
+ setattr(policy.options, key, value)
+ await policy.options.save(session)
+ else:
+ options = PolicyOptions(policy_id=policy.id, **options_data)
+ await options.save(session)
+
+ l.info(f"管理员更新了存储策略: {policy_id}")
+
+
+@admin_policy_router.post(
+ path='/test/s3',
+ summary='测试 S3 连接',
+ description='测试 S3 存储端点的连通性和凭据有效性。',
+ dependencies=[Depends(admin_required)],
+)
+async def router_policy_test_s3(
+ request: PolicyTestS3Request,
+) -> PolicyTestS3Response:
+ """
+ 测试 S3 连接端点
+
+ 通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
+
+ :param request: 测试请求
+ :return: 测试结果
+ """
+ from service.storage import S3APIError
+
+ # 构造临时 Policy 对象用于创建 S3StorageService
+ temp_policy = Policy(
+ name="__test__",
+ type=PolicyType.S3,
+ server=request.server,
+ bucket_name=request.bucket_name,
+ access_key=request.access_key,
+ secret_key=request.secret_key,
+ )
+ s3_service = S3StorageService(
+ temp_policy,
+ region=request.s3_region,
+ is_path_style=request.s3_path_style,
+ )
+
+ try:
+ # 使用 file_exists 发送 HEAD 请求来验证连通性
+ await s3_service.file_exists("__connection_test__")
+ return PolicyTestS3Response(is_connected=True, message="连接成功")
+ except S3APIError as e:
+ return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
+ except Exception as e:
+ return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")
\ No newline at end of file
diff --git a/routers/api/v1/directory/__init__.py b/routers/api/v1/directory/__init__.py
index c703137..ca7218f 100644
--- a/routers/api/v1/directory/__init__.py
+++ b/routers/api/v1/directory/__init__.py
@@ -57,7 +57,7 @@ async def _get_directory_response(
policy_response = PolicyResponse(
id=policy.id,
name=policy.name,
- type=policy.type.value,
+ type=policy.type,
max_size=policy.max_size,
)
@@ -189,6 +189,14 @@ async def router_directory_create(
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
policy_id = request.policy_id if request.policy_id else parent.policy_id
+
+ # 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
+ if request.policy_id:
+ group = await user.awaitable_attrs.group
+ await session.refresh(group, ['policies'])
+ if request.policy_id not in {p.id for p in group.policies}:
+ raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
+
parent_id = parent.id # 在 save 前保存
new_folder = Object(
diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py
index ddb1c05..72e54f2 100644
--- a/routers/api/v1/file/__init__.py
+++ b/routers/api/v1/file/__init__.py
@@ -13,9 +13,11 @@ from datetime import datetime, timedelta
from typing import Annotated
from uuid import UUID
+import orjson
import whatthepatch
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, RedirectResponse
+from starlette.responses import Response
from loguru import logger as l
from sqlmodel_ext import SQLModelBase
from whatthepatch.exceptions import HunkApplyException
@@ -44,7 +46,9 @@ from sqlmodels import (
User,
WopiSessionResponse,
)
-from service.storage import LocalStorageService, adjust_user_storage
+import orjson
+
+from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
from utils.JWT.wopi_token import create_wopi_token
from utils import http_exceptions
@@ -184,6 +188,13 @@ async def create_upload_session(
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
+ # 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
+ if request.policy_id:
+ group = await user.awaitable_attrs.group
+ await session.refresh(group, ['policies'])
+ if request.policy_id not in {p.id for p in group.policies}:
+ raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
+
# 验证文件大小限制
_check_policy_size_limit(policy, request.file_size)
@@ -210,6 +221,7 @@ async def create_upload_session(
# 生成存储路径
storage_path: str | None = None
+ s3_upload_id: str | None = None
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path(
@@ -217,8 +229,25 @@ async def create_upload_session(
original_filename=request.file_name,
)
storage_path = full_path
- else:
- raise HTTPException(status_code=501, detail="S3 存储暂未实现")
+ elif policy.type == PolicyType.S3:
+ s3_service = S3StorageService(
+ policy,
+ region=options.s3_region if options else 'us-east-1',
+ is_path_style=options.s3_path_style if options else False,
+ )
+ dir_path, storage_name, storage_path = await s3_service.generate_file_path(
+ user_id=user.id,
+ original_filename=request.file_name,
+ )
+ # 多分片时创建 multipart upload
+ if total_chunks > 1:
+ s3_upload_id = await s3_service.create_multipart_upload(
+ storage_path, content_type='application/octet-stream',
+ )
+
+ # 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
+ if request.file_size > 0:
+ await adjust_user_storage(session, user.id, request.file_size, commit=False)
# 创建上传会话
upload_session = UploadSession(
@@ -227,6 +256,7 @@ async def create_upload_session(
chunk_size=chunk_size,
total_chunks=total_chunks,
storage_path=storage_path,
+ s3_upload_id=s3_upload_id,
expires_at=datetime.now() + timedelta(hours=24),
owner_id=user.id,
parent_id=request.parent_id,
@@ -302,8 +332,38 @@ async def upload_chunk(
content,
offset,
)
- else:
- raise HTTPException(status_code=501, detail="S3 存储暂未实现")
+ elif policy.type == PolicyType.S3:
+ if not upload_session.storage_path:
+ raise HTTPException(status_code=500, detail="存储路径丢失")
+
+ s3_service = await S3StorageService.from_policy(policy)
+
+ if upload_session.total_chunks == 1:
+ # 单分片:直接 PUT 上传
+ await s3_service.upload_file(upload_session.storage_path, content)
+ else:
+ # 多分片:UploadPart
+ if not upload_session.s3_upload_id:
+ raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
+
+ etag = await s3_service.upload_part(
+ upload_session.storage_path,
+ upload_session.s3_upload_id,
+ chunk_index + 1, # S3 part number 从 1 开始
+ content,
+ )
+ # 追加 ETag 到 s3_part_etags
+ etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
+ etags.append([chunk_index + 1, etag])
+ upload_session.s3_part_etags = orjson.dumps(etags).decode()
+
+ # 在 save(commit)前缓存后续需要的属性(commit 后 ORM 对象会过期)
+ policy_type = policy.type
+ s3_upload_id = upload_session.s3_upload_id
+ s3_part_etags = upload_session.s3_part_etags
+ s3_service_for_complete: S3StorageService | None = None
+ if policy_type == PolicyType.S3:
+ s3_service_for_complete = await S3StorageService.from_policy(policy)
# 更新会话进度
upload_session.uploaded_chunks += 1
@@ -319,12 +379,26 @@ async def upload_chunk(
if is_complete:
# 保存 upload_session 属性(commit 后会过期)
file_name = upload_session.file_name
+ file_size = upload_session.file_size
uploaded_size = upload_session.uploaded_size
storage_path = upload_session.storage_path
upload_session_id = upload_session.id
parent_id = upload_session.parent_id
policy_id = upload_session.policy_id
+ # S3 多分片上传完成:合并分片
+ if (
+ policy_type == PolicyType.S3
+ and s3_upload_id
+ and s3_part_etags
+ and s3_service_for_complete
+ ):
+ parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
+ parts = [(int(pn), str(et)) for pn, et in parts_data]
+ await s3_service_for_complete.complete_multipart_upload(
+ storage_path, s3_upload_id, parts,
+ )
+
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=storage_path,
@@ -355,9 +429,10 @@ async def upload_chunk(
commit=False
)
- # 更新用户存储配额
- if uploaded_size > 0:
- await adjust_user_storage(session, user_id, uploaded_size, commit=False)
+ # 调整存储配额差值(创建会话时已预扣 file_size,这里只补差)
+ size_diff = uploaded_size - file_size
+ if size_diff != 0:
+ await adjust_user_storage(session, user_id, size_diff, commit=False)
# 统一提交所有更改
await session.commit()
@@ -390,9 +465,25 @@ async def delete_upload_session(
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
- if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
- storage_service = LocalStorageService(policy)
- await storage_service.delete_file(upload_session.storage_path)
+ if policy and upload_session.storage_path:
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ await storage_service.delete_file(upload_session.storage_path)
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ # 如果有分片上传,先取消
+ if upload_session.s3_upload_id:
+ await s3_service.abort_multipart_upload(
+ upload_session.storage_path, upload_session.s3_upload_id,
+ )
+ else:
+ # 单分片上传已完成的话,删除已上传的文件
+ if upload_session.uploaded_chunks > 0:
+ await s3_service.delete_file(upload_session.storage_path)
+
+ # 释放预扣的存储空间
+ if upload_session.file_size > 0:
+ await adjust_user_storage(session, user.id, -upload_session.file_size)
# 删除会话记录
await UploadSession.delete(session, upload_session)
@@ -422,9 +513,22 @@ async def clear_upload_sessions(
for upload_session in sessions:
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
- if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
- storage_service = LocalStorageService(policy)
- await storage_service.delete_file(upload_session.storage_path)
+ if policy and upload_session.storage_path:
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ await storage_service.delete_file(upload_session.storage_path)
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ if upload_session.s3_upload_id:
+ await s3_service.abort_multipart_upload(
+ upload_session.storage_path, upload_session.s3_upload_id,
+ )
+ elif upload_session.uploaded_chunks > 0:
+ await s3_service.delete_file(upload_session.storage_path)
+
+ # 释放预扣的存储空间
+ if upload_session.file_size > 0:
+ await adjust_user_storage(session, user.id, -upload_session.file_size)
await UploadSession.delete(session, upload_session)
deleted_count += 1
@@ -486,11 +590,12 @@ async def create_download_token_endpoint(
path='/{token}',
summary='下载文件',
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
+ response_model=None,
)
async def download_file(
session: SessionDep,
token: str,
-) -> FileResponse:
+) -> Response:
"""
下载文件端点
@@ -540,8 +645,15 @@ async def download_file(
filename=file_obj.name,
media_type="application/octet-stream",
)
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ # 302 重定向到预签名 URL
+ presigned_url = s3_service.generate_presigned_url(
+ storage_path, method='GET', expires_in=3600, filename=file_obj.name,
+ )
+ return RedirectResponse(url=presigned_url, status_code=302)
else:
- raise HTTPException(status_code=501, detail="S3 存储暂未实现")
+ raise HTTPException(status_code=500, detail="不支持的存储类型")
# ==================== 包含子路由 ====================
@@ -613,8 +725,13 @@ async def create_empty_file(
)
await storage_service.create_empty_file(full_path)
storage_path = full_path
- else:
- raise HTTPException(status_code=501, detail="S3 存储暂未实现")
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ dir_path, storage_name, storage_path = await s3_service.generate_file_path(
+ user_id=user_id,
+ original_filename=request.name,
+ )
+ await s3_service.upload_file(storage_path, b'')
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
@@ -798,12 +915,13 @@ async def _validate_source_link(
path='/get/{file_id}/{name}',
summary='文件外链(直接输出文件数据)',
description='通过外链直接获取文件内容,公开访问无需认证。',
+ response_model=None,
)
async def file_get(
session: SessionDep,
file_id: UUID,
name: str,
-) -> FileResponse:
+) -> Response:
"""
文件外链端点(直接输出)
@@ -815,13 +933,6 @@ async def file_get(
"""
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
- if policy.type != PolicyType.LOCAL:
- http_exceptions.raise_not_implemented("S3 存储暂未实现")
-
- storage_service = LocalStorageService(policy)
- if not await storage_service.file_exists(physical_file.storage_path):
- http_exceptions.raise_not_found("物理文件不存在")
-
# 缓存物理路径(save 后对象属性会过期)
file_path = physical_file.storage_path
@@ -829,11 +940,25 @@ async def file_get(
link.downloads += 1
await link.save(session)
- return FileResponse(
- path=file_path,
- filename=name,
- media_type="application/octet-stream",
- )
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ if not await storage_service.file_exists(file_path):
+ http_exceptions.raise_not_found("物理文件不存在")
+
+ return FileResponse(
+ path=file_path,
+ filename=name,
+ media_type="application/octet-stream",
+ )
+ elif policy.type == PolicyType.S3:
+ # S3 外链直接输出:302 重定向到预签名 URL
+ s3_service = await S3StorageService.from_policy(policy)
+ presigned_url = s3_service.generate_presigned_url(
+ file_path, method='GET', expires_in=3600, filename=name,
+ )
+ return RedirectResponse(url=presigned_url, status_code=302)
+ else:
+ http_exceptions.raise_internal_error("不支持的存储类型")
@router.get(
@@ -846,7 +971,7 @@ async def file_source_redirect(
session: SessionDep,
file_id: UUID,
name: str,
-) -> FileResponse | RedirectResponse:
+) -> Response:
"""
文件外链端点(重定向/直接输出)
@@ -860,13 +985,6 @@ async def file_source_redirect(
"""
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
- if policy.type != PolicyType.LOCAL:
- http_exceptions.raise_not_implemented("S3 存储暂未实现")
-
- storage_service = LocalStorageService(policy)
- if not await storage_service.file_exists(physical_file.storage_path):
- http_exceptions.raise_not_found("物理文件不存在")
-
# 缓存所有需要的值(save 后对象属性会过期)
file_path = physical_file.storage_path
is_private = policy.is_private
@@ -876,18 +994,36 @@ async def file_source_redirect(
link.downloads += 1
await link.save(session)
- # 公有存储:302 重定向到 base_url
- if not is_private and base_url:
- relative_path = storage_service.get_relative_path(file_path)
- redirect_url = f"{base_url}/{relative_path}"
- return RedirectResponse(url=redirect_url, status_code=302)
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ if not await storage_service.file_exists(file_path):
+ http_exceptions.raise_not_found("物理文件不存在")
- # 私有存储或 base_url 为空:通过应用代理文件
- return FileResponse(
- path=file_path,
- filename=name,
- media_type="application/octet-stream",
- )
+ # 公有存储:302 重定向到 base_url
+ if not is_private and base_url:
+ relative_path = storage_service.get_relative_path(file_path)
+ redirect_url = f"{base_url}/{relative_path}"
+ return RedirectResponse(url=redirect_url, status_code=302)
+
+ # 私有存储或 base_url 为空:通过应用代理文件
+ return FileResponse(
+ path=file_path,
+ filename=name,
+ media_type="application/octet-stream",
+ )
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ # 公有存储且有 base_url:直接重定向到公开 URL
+ if not is_private and base_url:
+ redirect_url = f"{base_url.rstrip('/')}/{file_path}"
+ return RedirectResponse(url=redirect_url, status_code=302)
+ # 私有存储:生成预签名 URL 重定向
+ presigned_url = s3_service.generate_presigned_url(
+ file_path, method='GET', expires_in=3600, filename=name,
+ )
+ return RedirectResponse(url=presigned_url, status_code=302)
+ else:
+ http_exceptions.raise_internal_error("不支持的存储类型")
@router.put(
@@ -941,11 +1077,15 @@ async def file_content(
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
- if policy.type != PolicyType.LOCAL:
- http_exceptions.raise_not_implemented("S3 存储暂未实现")
-
- storage_service = LocalStorageService(policy)
- raw_bytes = await storage_service.read_file(physical_file.storage_path)
+ # 读取文件内容
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ raw_bytes = await storage_service.read_file(physical_file.storage_path)
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ raw_bytes = await s3_service.download_file(physical_file.storage_path)
+ else:
+ http_exceptions.raise_internal_error("不支持的存储类型")
try:
content = raw_bytes.decode('utf-8')
@@ -1011,11 +1151,15 @@ async def patch_file_content(
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
- if policy.type != PolicyType.LOCAL:
- http_exceptions.raise_not_implemented("S3 存储暂未实现")
-
- storage_service = LocalStorageService(policy)
- raw_bytes = await storage_service.read_file(storage_path)
+ # 读取文件内容
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ raw_bytes = await storage_service.read_file(storage_path)
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ raw_bytes = await s3_service.download_file(storage_path)
+ else:
+ http_exceptions.raise_internal_error("不支持的存储类型")
# 解码 + 规范化
original_text = raw_bytes.decode('utf-8')
@@ -1049,7 +1193,10 @@ async def patch_file_content(
_check_policy_size_limit(policy, len(new_bytes))
# 写入文件
- await storage_service.write_file(storage_path, new_bytes)
+ if policy.type == PolicyType.LOCAL:
+ await storage_service.write_file(storage_path, new_bytes)
+ elif policy.type == PolicyType.S3:
+ await s3_service.upload_file(storage_path, new_bytes)
# 更新数据库
owner_id = file_obj.owner_id
diff --git a/routers/api/v1/object/__init__.py b/routers/api/v1/object/__init__.py
index 7e7fcc4..59f352a 100644
--- a/routers/api/v1/object/__init__.py
+++ b/routers/api/v1/object/__init__.py
@@ -8,13 +8,14 @@
from typing import Annotated
from uuid import UUID
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
CreateFileRequest,
+ Group,
Object,
ObjectCopyRequest,
ObjectDeleteRequest,
@@ -22,18 +23,27 @@ from sqlmodels import (
ObjectPropertyDetailResponse,
ObjectPropertyResponse,
ObjectRenameRequest,
+ ObjectSwitchPolicyRequest,
ObjectType,
PhysicalFile,
Policy,
PolicyType,
+ Task,
+ TaskProps,
+ TaskStatus,
+ TaskSummaryBase,
+ TaskType,
User,
)
from service.storage import (
LocalStorageService,
adjust_user_storage,
copy_object_recursive,
+ migrate_file_with_task,
+ migrate_directory_files,
)
from service.storage.object import soft_delete_objects
+from sqlmodels.database_connection import DatabaseManager
from utils import http_exceptions
object_router = APIRouter(
@@ -575,3 +585,136 @@ async def router_object_property_detail(
response.checksum_md5 = obj.file_metadata.checksum_md5
return response
+
+
+@object_router.patch(
+ path='/{object_id}/policy',
+ summary='切换对象存储策略',
+)
+async def router_object_switch_policy(
+ session: SessionDep,
+ background_tasks: BackgroundTasks,
+ user: Annotated[User, Depends(auth_required)],
+ object_id: UUID,
+ request: ObjectSwitchPolicyRequest,
+) -> TaskSummaryBase:
+ """
+ 切换对象的存储策略
+
+ 文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
+ 目录:更新目录 policy_id(新文件使用新策略);
+ 若 is_migrate_existing=True,额外创建后台任务迁移所有已有文件。
+
+ 认证:JWT Bearer Token
+
+ 错误处理:
+ - 404: 对象不存在
+ - 403: 无权操作此对象 / 用户组无权使用目标策略
+ - 400: 目标策略与当前相同 / 不能对根目录操作
+ """
+ user_id = user.id
+
+ # 查找对象
+ obj = await Object.get(
+ session,
+ (Object.id == object_id) & (Object.deleted_at == None)
+ )
+ if not obj:
+ http_exceptions.raise_not_found("对象不存在")
+ if obj.owner_id != user_id:
+ http_exceptions.raise_forbidden("无权操作此对象")
+ if obj.is_banned:
+ http_exceptions.raise_banned()
+
+ # 根目录不能直接切换策略(应通过子对象或子目录操作)
+ if obj.parent_id is None:
+ raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
+
+ # 校验目标策略存在
+ dest_policy = await Policy.get(session, Policy.id == request.policy_id)
+ if not dest_policy:
+ http_exceptions.raise_not_found("目标存储策略不存在")
+
+ # 校验用户组权限
+ group: Group = await user.awaitable_attrs.group
+ await session.refresh(group, ['policies'])
+ allowed_ids = {p.id for p in group.policies}
+ if request.policy_id not in allowed_ids:
+ http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
+
+ # 不能切换到相同策略
+ if obj.policy_id == request.policy_id:
+ raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
+
+ # 保存必要的属性,避免 save 后对象过期
+ src_policy_id = obj.policy_id
+ obj_id = obj.id
+ obj_is_file = obj.type == ObjectType.FILE
+ dest_policy_id = request.policy_id
+ dest_policy_name = dest_policy.name
+
+ # 创建任务记录
+ task = Task(
+ type=TaskType.POLICY_MIGRATE,
+ status=TaskStatus.QUEUED,
+ user_id=user_id,
+ )
+ task = await task.save(session)
+ task_id = task.id
+
+ task_props = TaskProps(
+ task_id=task_id,
+ source_policy_id=src_policy_id,
+ dest_policy_id=dest_policy_id,
+ object_id=obj_id,
+ )
+ await task_props.save(session)
+
+ if obj_is_file:
+ # 文件:后台迁移
+ async def _run_file_migration() -> None:
+ async with DatabaseManager.session() as bg_session:
+ bg_obj = await Object.get(bg_session, Object.id == obj_id)
+ bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
+ bg_task = await Task.get(bg_session, Task.id == task_id)
+ await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
+
+ background_tasks.add_task(_run_file_migration)
+ else:
+ # 目录:先更新目录自身的 policy_id
+ obj = await Object.get(session, Object.id == obj_id)
+ obj.policy_id = dest_policy_id
+ await obj.save(session)
+
+ if request.is_migrate_existing:
+ # 后台迁移所有已有文件
+ async def _run_dir_migration() -> None:
+ async with DatabaseManager.session() as bg_session:
+ bg_folder = await Object.get(bg_session, Object.id == obj_id)
+ bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
+ bg_task = await Task.get(bg_session, Task.id == task_id)
+ await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
+
+ background_tasks.add_task(_run_dir_migration)
+ else:
+ # 不迁移已有文件,直接完成任务
+ task = await Task.get(session, Task.id == task_id)
+ task.status = TaskStatus.COMPLETED
+ task.progress = 100
+ await task.save(session)
+
+ # 重新获取 task 以读取最新状态
+ task = await Task.get(session, Task.id == task_id)
+
+ l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
+
+ return TaskSummaryBase(
+ id=task.id,
+ type=task.type,
+ status=task.status,
+ progress=task.progress,
+ error=task.error,
+ user_id=task.user_id,
+ created_at=task.created_at,
+ updated_at=task.updated_at,
+ )
diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py
index 2e829c0..9e722f8 100644
--- a/routers/api/v1/user/__init__.py
+++ b/routers/api/v1/user/__init__.py
@@ -247,11 +247,12 @@ async def router_user_register(
)
await identity.save(session)
- # 8. 创建用户根目录
- default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
- if not default_policy:
- logger.error("默认存储策略不存在")
+ # 8. 创建用户根目录(使用用户组关联的第一个存储策略)
+ await session.refresh(default_group, ['policies'])
+ if not default_group.policies:
+ logger.error("默认用户组未关联任何存储策略")
http_exceptions.raise_internal_error()
+ default_policy = default_group.policies[0]
await sqlmodels.Object(
name="/",
diff --git a/routers/api/v1/user/settings/__init__.py b/routers/api/v1/user/settings/__init__.py
index 19a39c6..cd7f342 100644
--- a/routers/api/v1/user/settings/__init__.py
+++ b/routers/api/v1/user/settings/__init__.py
@@ -13,6 +13,7 @@ from sqlmodels import (
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
ChangePasswordRequest,
AuthnDetailResponse, AuthnRenameRequest,
+ PolicySummary,
)
from sqlmodels.color import ThemeColorsBase
from sqlmodels.user_authn import UserAuthn
@@ -31,16 +32,25 @@ user_settings_router.include_router(file_viewers_router)
@user_settings_router.get(
path='/policies',
summary='获取用户可选存储策略',
- description='Get user selectable storage policies.',
)
-def router_user_settings_policies() -> sqlmodels.ResponseBase:
+async def router_user_settings_policies(
+ session: SessionDep,
+ user: Annotated[sqlmodels.user.User, Depends(auth_required)],
+) -> list[PolicySummary]:
"""
- Get user selectable storage policies.
+ 获取当前用户所在组可选的存储策略列表
- Returns:
- dict: A dictionary containing available storage policies for the user.
+ 返回用户组关联的所有存储策略的摘要信息。
"""
- http_exceptions.raise_not_implemented()
+ group = await user.awaitable_attrs.group
+ await session.refresh(group, ['policies'])
+ return [
+ PolicySummary(
+ id=p.id, name=p.name, type=p.type,
+ server=p.server, max_size=p.max_size, is_private=p.is_private,
+ )
+ for p in group.policies
+ ]
@user_settings_router.get(
diff --git a/service/storage/__init__.py b/service/storage/__init__.py
index 0a0c3a5..de96449 100644
--- a/service/storage/__init__.py
+++ b/service/storage/__init__.py
@@ -3,6 +3,7 @@
提供文件存储相关的服务,包括:
- 本地存储服务
+- S3 存储服务
- 命名规则解析器
- 存储异常定义
"""
@@ -11,6 +12,8 @@ from .exceptions import (
FileReadError,
FileWriteError,
InvalidPathError,
+ S3APIError,
+ S3MultipartUploadError,
StorageException,
StorageFileNotFoundError,
UploadSessionExpiredError,
@@ -25,4 +28,6 @@ from .object import (
permanently_delete_objects,
restore_objects,
soft_delete_objects,
-)
\ No newline at end of file
+)
+from .migrate import migrate_file_with_task, migrate_directory_files
+from .s3_storage import S3StorageService
\ No newline at end of file
diff --git a/service/storage/exceptions.py b/service/storage/exceptions.py
index ae1e4e3..0ba9251 100644
--- a/service/storage/exceptions.py
+++ b/service/storage/exceptions.py
@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
class InvalidPathError(StorageException):
"""无效的路径"""
pass
+
+
+class S3APIError(StorageException):
+ """S3 API 请求错误"""
+ pass
+
+
+class S3MultipartUploadError(S3APIError):
+ """S3 分片上传错误"""
+ pass
diff --git a/service/storage/migrate.py b/service/storage/migrate.py
new file mode 100644
index 0000000..0cbab60
--- /dev/null
+++ b/service/storage/migrate.py
@@ -0,0 +1,291 @@
+"""
+存储策略迁移服务
+
+提供跨存储策略的文件迁移功能:
+- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
+- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
+"""
+from uuid import UUID
+
+from loguru import logger as l
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from sqlmodels.object import Object, ObjectType
+from sqlmodels.physical_file import PhysicalFile
+from sqlmodels.policy import Policy, PolicyType
+from sqlmodels.task import Task, TaskStatus
+
+from .local_storage import LocalStorageService
+from .s3_storage import S3StorageService
+
+
+async def _get_storage_service(
+ policy: Policy,
+) -> LocalStorageService | S3StorageService:
+ """
+ 根据策略类型创建对应的存储服务实例
+
+ :param policy: 存储策略
+ :return: 存储服务实例
+ """
+ if policy.type == PolicyType.LOCAL:
+ return LocalStorageService(policy)
+ elif policy.type == PolicyType.S3:
+ return await S3StorageService.from_policy(policy)
+ else:
+ raise ValueError(f"不支持的存储策略类型: {policy.type}")
+
+
+async def _read_file_from_storage(
+ service: LocalStorageService | S3StorageService,
+ storage_path: str,
+) -> bytes:
+ """
+ 从存储服务读取文件内容
+
+ :param service: 存储服务实例
+ :param storage_path: 文件存储路径
+ :return: 文件二进制内容
+ """
+ if isinstance(service, LocalStorageService):
+ return await service.read_file(storage_path)
+ else:
+ return await service.download_file(storage_path)
+
+
+async def _write_file_to_storage(
+ service: LocalStorageService | S3StorageService,
+ storage_path: str,
+ data: bytes,
+) -> None:
+ """
+ 将文件内容写入存储服务
+
+ :param service: 存储服务实例
+ :param storage_path: 文件存储路径
+ :param data: 文件二进制内容
+ """
+ if isinstance(service, LocalStorageService):
+ await service.write_file(storage_path, data)
+ else:
+ await service.upload_file(storage_path, data)
+
+
+async def _delete_file_from_storage(
+ service: LocalStorageService | S3StorageService,
+ storage_path: str,
+) -> None:
+ """
+ 从存储服务删除文件
+
+ :param service: 存储服务实例
+ :param storage_path: 文件存储路径
+ """
+ if isinstance(service, LocalStorageService):
+ await service.delete_file(storage_path)
+ else:
+ await service.delete_file(storage_path)
+
+
+async def migrate_single_file(
+ session: AsyncSession,
+ obj: Object,
+ dest_policy: Policy,
+) -> None:
+ """
+ 将单个文件对象从当前存储策略迁移到目标策略
+
+ 流程:
+ 1. 获取源物理文件和存储服务
+ 2. 读取源文件内容
+ 3. 在目标存储中生成新路径并写入
+ 4. 创建新的 PhysicalFile 记录
+ 5. 更新 Object 的 policy_id 和 physical_file_id
+ 6. 旧 PhysicalFile 引用计数 -1,如为 0 则删除源物理文件
+
+ :param session: 数据库会话
+ :param obj: 待迁移的文件对象(必须为文件类型)
+ :param dest_policy: 目标存储策略
+ """
+ if obj.type != ObjectType.FILE:
+ raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
+
+ # 获取源策略和物理文件
+ src_policy: Policy = await obj.awaitable_attrs.policy
+ old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
+
+ if not old_physical:
+ l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
+ return
+
+ if src_policy.id == dest_policy.id:
+ l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
+ return
+
+ # 1. 从源存储读取文件
+ src_service = await _get_storage_service(src_policy)
+ data = await _read_file_from_storage(src_service, old_physical.storage_path)
+
+ # 2. 在目标存储生成新路径并写入
+ dest_service = await _get_storage_service(dest_policy)
+ _dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
+ user_id=obj.owner_id,
+ original_filename=obj.name,
+ )
+ await _write_file_to_storage(dest_service, new_storage_path, data)
+
+ # 3. 创建新的 PhysicalFile
+ new_physical = PhysicalFile(
+ storage_path=new_storage_path,
+ size=old_physical.size,
+ checksum_md5=old_physical.checksum_md5,
+ policy_id=dest_policy.id,
+ reference_count=1,
+ )
+ new_physical = await new_physical.save(session)
+
+ # 4. 更新 Object
+ obj.policy_id = dest_policy.id
+ obj.physical_file_id = new_physical.id
+ await obj.save(session)
+
+ # 5. 旧 PhysicalFile 引用计数 -1
+ old_physical.decrement_reference()
+ if old_physical.can_be_deleted:
+ # 删除源存储中的物理文件
+ try:
+ await _delete_file_from_storage(src_service, old_physical.storage_path)
+ except Exception as e:
+ l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
+ await PhysicalFile.delete(session, old_physical)
+ else:
+ await old_physical.save(session)
+
+ l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name} → {dest_policy.name}")
+
+
+async def migrate_file_with_task(
+ session: AsyncSession,
+ obj: Object,
+ dest_policy: Policy,
+ task: Task,
+) -> None:
+ """
+ 迁移单个文件并更新任务状态
+
+ :param session: 数据库会话
+ :param obj: 待迁移的文件对象
+ :param dest_policy: 目标存储策略
+ :param task: 关联的任务记录
+ """
+ try:
+ task.status = TaskStatus.RUNNING
+ task.progress = 0
+ task = await task.save(session)
+
+ await migrate_single_file(session, obj, dest_policy)
+
+ task.status = TaskStatus.COMPLETED
+ task.progress = 100
+ await task.save(session)
+ except Exception as e:
+ l.error(f"文件迁移任务失败: {obj.id}: {e}")
+ task.status = TaskStatus.ERROR
+ task.error = str(e)[:500]
+ await task.save(session)
+
+
+async def migrate_directory_files(
+ session: AsyncSession,
+ folder: Object,
+ dest_policy: Policy,
+ task: Task,
+) -> None:
+ """
+ 迁移目录下所有文件到目标存储策略
+
+ 递归遍历目录树,将所有文件迁移到目标策略。
+ 子目录的 policy_id 同步更新。
+ 任务进度按文件数比例更新。
+
+ :param session: 数据库会话
+ :param folder: 目录对象
+ :param dest_policy: 目标存储策略
+ :param task: 关联的任务记录
+ """
+ try:
+ task.status = TaskStatus.RUNNING
+ task.progress = 0
+ task = await task.save(session)
+
+ # 收集所有需要迁移的文件
+ files_to_migrate: list[Object] = []
+ folders_to_update: list[Object] = []
+ await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
+
+ total = len(files_to_migrate)
+ migrated = 0
+ errors: list[str] = []
+
+ for file_obj in files_to_migrate:
+ try:
+ await migrate_single_file(session, file_obj, dest_policy)
+ migrated += 1
+ except Exception as e:
+ error_msg = f"{file_obj.name}: {e}"
+ l.error(f"迁移文件失败: {error_msg}")
+ errors.append(error_msg)
+
+ # 更新进度
+ if total > 0:
+ task.progress = min(99, int(migrated / total * 100))
+ task = await task.save(session)
+
+ # 更新所有子目录的 policy_id
+ for sub_folder in folders_to_update:
+ sub_folder.policy_id = dest_policy.id
+ await sub_folder.save(session)
+
+ # 完成任务
+ if errors:
+ task.status = TaskStatus.ERROR
+ task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
+ else:
+ task.status = TaskStatus.COMPLETED
+
+ task.progress = 100
+ await task.save(session)
+
+ l.info(
+ f"目录迁移完成: {folder.name} ({folder.id}), "
+ f"成功 {migrated}/{total}, 错误 {len(errors)}"
+ )
+ except Exception as e:
+ l.error(f"目录迁移任务失败: {folder.id}: {e}")
+ task.status = TaskStatus.ERROR
+ task.error = str(e)[:500]
+ await task.save(session)
+
+
+async def _collect_objects_recursive(
+ session: AsyncSession,
+ folder: Object,
+ files: list[Object],
+ folders: list[Object],
+) -> None:
+ """
+ 递归收集目录下所有文件和子目录
+
+ :param session: 数据库会话
+ :param folder: 当前目录
+ :param files: 文件列表(输出)
+ :param folders: 子目录列表(输出)
+ """
+ children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
+
+ for child in children:
+ if child.type == ObjectType.FILE:
+ files.append(child)
+ elif child.type == ObjectType.FOLDER:
+ folders.append(child)
+ await _collect_objects_recursive(session, child, files, folders)
diff --git a/service/storage/object.py b/service/storage/object.py
index 057954e..23f3803 100644
--- a/service/storage/object.py
+++ b/service/storage/object.py
@@ -6,7 +6,8 @@ from sqlalchemy import update as sql_update
from sqlalchemy.sql.functions import func
from middleware.dependencies import SessionDep
-from service.storage import LocalStorageService
+from .local_storage import LocalStorageService
+from .s3_storage import S3StorageService
from sqlmodels import (
Object,
PhysicalFile,
@@ -271,10 +272,14 @@ async def permanently_delete_objects(
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
- if policy and policy.type == PolicyType.LOCAL:
+ if policy:
try:
- storage_service = LocalStorageService(policy)
- await storage_service.delete_file(physical_file.storage_path)
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ await storage_service.delete_file(physical_file.storage_path)
+ elif policy.type == PolicyType.S3:
+ s3_service = await S3StorageService.from_policy(policy)
+ await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
@@ -374,10 +379,19 @@ async def delete_object_recursive(
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
- if policy and policy.type == PolicyType.LOCAL:
+ if policy:
try:
- storage_service = LocalStorageService(policy)
- await storage_service.delete_file(physical_file.storage_path)
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ await storage_service.delete_file(physical_file.storage_path)
+ elif policy.type == PolicyType.S3:
+ options = await policy.awaitable_attrs.options
+ s3_service = S3StorageService(
+ policy,
+ region=options.s3_region if options else 'us-east-1',
+ is_path_style=options.s3_path_style if options else False,
+ )
+ await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
diff --git a/service/storage/s3_storage.py b/service/storage/s3_storage.py
new file mode 100644
index 0000000..35b3f8f
--- /dev/null
+++ b/service/storage/s3_storage.py
@@ -0,0 +1,709 @@
+"""
+S3 存储服务
+
+使用 AWS Signature V4 签名的异步 S3 API 客户端。
+从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
+
+移植自 foxline-pro-backend-server 项目的 S3APIClient,
+适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
+"""
+import hashlib
+import hmac
+import xml.etree.ElementTree as ET
+from collections.abc import AsyncIterator
+from datetime import datetime, timezone
+from typing import ClassVar, Literal
+from urllib.parse import quote, urlencode
+from uuid import UUID
+
+import aiohttp
+from yarl import URL
+from loguru import logger as l
+
+from sqlmodels.policy import Policy
+from .exceptions import S3APIError, S3MultipartUploadError
+from .naming_rule import NamingContext, NamingRuleParser
+
+
+def _sign(key: bytes, msg: str) -> bytes:
+ """HMAC-SHA256 签名"""
+ return hmac.new(key, msg.encode(), hashlib.sha256).digest()
+
+
+_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
+
+
+class S3StorageService:
+ """
+ S3 存储服务
+
+ 使用 AWS Signature V4 签名的异步 S3 API 客户端。
+ 从 Policy 配置中读取 S3 连接信息。
+
+ 使用示例::
+
+ service = S3StorageService(policy, region='us-east-1')
+ await service.upload_file('path/to/file.txt', b'content')
+ data = await service.download_file('path/to/file.txt')
+ """
+
+ _http_session: ClassVar[aiohttp.ClientSession | None] = None
+
+ def __init__(
+ self,
+ policy: Policy,
+ region: str = 'us-east-1',
+ is_path_style: bool = False,
+ ):
+ """
+ :param policy: 存储策略(server=endpoint_url, bucket_name, access_key, secret_key)
+ :param region: S3 区域
+ :param is_path_style: 是否使用路径风格 URL
+ """
+ if not policy.server:
+ raise S3APIError("S3 策略必须指定 server (endpoint URL)")
+ if not policy.bucket_name:
+ raise S3APIError("S3 策略必须指定 bucket_name")
+ if not policy.access_key:
+ raise S3APIError("S3 策略必须指定 access_key")
+ if not policy.secret_key:
+ raise S3APIError("S3 策略必须指定 secret_key")
+
+ self._policy = policy
+ self._endpoint_url = policy.server.rstrip("/")
+ self._bucket_name = policy.bucket_name
+ self._access_key = policy.access_key
+ self._secret_key = policy.secret_key
+ self._region = region
+ self._is_path_style = is_path_style
+ self._base_url = policy.base_url
+
+ # 从 endpoint_url 提取 host
+ self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
+
+ # ==================== 工厂方法 ====================
+
+ @classmethod
+ async def from_policy(cls, policy: Policy) -> 'S3StorageService':
+ """
+ 根据 Policy 异步创建 S3StorageService(自动加载 options)
+
+ :param policy: 存储策略
+ :return: S3StorageService 实例
+ """
+ options = await policy.awaitable_attrs.options
+ region = options.s3_region if options else 'us-east-1'
+ is_path_style = options.s3_path_style if options else False
+ return cls(policy, region=region, is_path_style=is_path_style)
+
+ # ==================== HTTP Session 管理 ====================
+
+ @classmethod
+ async def initialize_session(cls) -> None:
+ """初始化全局 aiohttp ClientSession"""
+ if cls._http_session is None or cls._http_session.closed:
+ cls._http_session = aiohttp.ClientSession()
+ l.info("S3StorageService HTTP session 已初始化")
+
+ @classmethod
+ async def close_session(cls) -> None:
+ """关闭全局 aiohttp ClientSession"""
+ if cls._http_session and not cls._http_session.closed:
+ await cls._http_session.close()
+ cls._http_session = None
+ l.info("S3StorageService HTTP session 已关闭")
+
+ @classmethod
+ def _get_session(cls) -> aiohttp.ClientSession:
+ """获取 HTTP session"""
+ if cls._http_session is None or cls._http_session.closed:
+ # 懒初始化,以防 initialize_session 未被调用
+ cls._http_session = aiohttp.ClientSession()
+ return cls._http_session
+
+ # ==================== AWS Signature V4 签名 ====================
+
+ def _get_signature_key(self, date_stamp: str) -> bytes:
+ """生成 AWS Signature V4 签名密钥"""
+ k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
+ k_region = _sign(k_date, self._region)
+ k_service = _sign(k_region, "s3")
+ return _sign(k_service, "aws4_request")
+
+ def _create_authorization_header(
+ self,
+ method: str,
+ uri: str,
+ query_string: str,
+ headers: dict[str, str],
+ payload_hash: str,
+ amz_date: str,
+ date_stamp: str,
+ ) -> str:
+ """创建 AWS Signature V4 授权头"""
+ signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
+ canonical_headers = "".join(
+ f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
+ )
+ canonical_request = (
+ f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
+ f"{signed_headers}\n{payload_hash}"
+ )
+
+ algorithm = "AWS4-HMAC-SHA256"
+ credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
+ string_to_sign = (
+ f"{algorithm}\n{amz_date}\n{credential_scope}\n"
+ f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
+ )
+
+ signing_key = self._get_signature_key(date_stamp)
+ signature = hmac.new(
+ signing_key, string_to_sign.encode(), hashlib.sha256
+ ).hexdigest()
+
+ return (
+ f"{algorithm} Credential={self._access_key}/{credential_scope}, "
+ f"SignedHeaders={signed_headers}, Signature={signature}"
+ )
+
+ def _build_headers(
+ self,
+ method: str,
+ uri: str,
+ query_string: str = "",
+ payload: bytes = b"",
+ content_type: str | None = None,
+ extra_headers: dict[str, str] | None = None,
+ payload_hash: str | None = None,
+ host: str | None = None,
+ ) -> dict[str, str]:
+ """
+ 构建包含 AWS V4 签名的完整请求头
+
+ :param method: HTTP 方法
+ :param uri: 请求 URI
+ :param query_string: 查询字符串
+ :param payload: 请求体字节(用于计算哈希)
+ :param content_type: Content-Type
+ :param extra_headers: 额外请求头
+ :param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
+ :param host: Host 头(默认使用 self._host)
+ """
+ now_utc = datetime.now(timezone.utc)
+ amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
+ date_stamp = now_utc.strftime("%Y%m%d")
+
+ if payload_hash is None:
+ payload_hash = hashlib.sha256(payload).hexdigest()
+
+ effective_host = host or self._host
+
+ headers: dict[str, str] = {
+ "Host": effective_host,
+ "X-Amz-Date": amz_date,
+ "X-Amz-Content-Sha256": payload_hash,
+ }
+ if content_type:
+ headers["Content-Type"] = content_type
+ if extra_headers:
+ headers.update(extra_headers)
+
+ authorization = self._create_authorization_header(
+ method, uri, query_string, headers, payload_hash, amz_date, date_stamp
+ )
+ headers["Authorization"] = authorization
+ return headers
+
+ # ==================== 内部请求方法 ====================
+
+ def _build_uri(self, key: str | None = None) -> str:
+ """
+ 构建请求 URI
+
+ 按 AWS S3 Signature V4 规范对路径进行 URI 编码(S3 仅需一次)。
+ 斜杠作为路径分隔符保留不编码。
+ """
+ if self._is_path_style:
+ if key:
+ return f"/{self._bucket_name}/{quote(key, safe='/')}"
+ return f"/{self._bucket_name}"
+ else:
+ if key:
+ return f"/{quote(key, safe='/')}"
+ return "/"
+
+ def _build_url(self, uri: str, query_string: str = "") -> str:
+ """构建完整请求 URL"""
+ if self._is_path_style:
+ base = self._endpoint_url
+ else:
+ # 虚拟主机风格:bucket.endpoint
+ protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
+ base = f"{protocol}{self._bucket_name}.{self._host}"
+
+ url = f"{base}{uri}"
+ if query_string:
+ url = f"{url}?{query_string}"
+ return url
+
+ def _get_effective_host(self) -> str:
+ """获取实际请求的 Host 头"""
+ if self._is_path_style:
+ return self._host
+ return f"{self._bucket_name}.{self._host}"
+
+ async def _request(
+ self,
+ method: str,
+ key: str | None = None,
+ query_params: dict[str, str] | None = None,
+ payload: bytes = b"",
+ content_type: str | None = None,
+ extra_headers: dict[str, str] | None = None,
+ ) -> aiohttp.ClientResponse:
+ """发送签名请求"""
+ uri = self._build_uri(key)
+ query_string = urlencode(sorted(query_params.items())) if query_params else ""
+ effective_host = self._get_effective_host()
+
+ headers = self._build_headers(
+ method, uri, query_string, payload, content_type,
+ extra_headers, host=effective_host,
+ )
+
+ url = self._build_url(uri, query_string)
+
+ try:
+ response = await self._get_session().request(
+ method, URL(url, encoded=True),
+ headers=headers, data=payload if payload else None,
+ )
+ return response
+ except Exception as e:
+ raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
+
+ async def _request_streaming(
+ self,
+ method: str,
+ key: str,
+ data_stream: AsyncIterator[bytes],
+ content_length: int,
+ content_type: str | None = None,
+ ) -> aiohttp.ClientResponse:
+ """
+ 发送流式签名请求(大文件上传)
+
+ 使用 UNSIGNED-PAYLOAD 作为 payload hash。
+ """
+ uri = self._build_uri(key)
+ effective_host = self._get_effective_host()
+
+ headers = self._build_headers(
+ method,
+ uri,
+ query_string="",
+ content_type=content_type,
+ extra_headers={"Content-Length": str(content_length)},
+ payload_hash="UNSIGNED-PAYLOAD",
+ host=effective_host,
+ )
+
+ url = self._build_url(uri)
+
+ try:
+ response = await self._get_session().request(
+ method, URL(url, encoded=True),
+ headers=headers, data=data_stream,
+ )
+ return response
+ except Exception as e:
+ raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
+
+ # ==================== 文件操作 ====================
+
+ async def upload_file(
+ self,
+ key: str,
+ data: bytes,
+ content_type: str = 'application/octet-stream',
+ ) -> None:
+ """
+ 上传文件
+
+ :param key: S3 对象键
+ :param data: 文件内容
+ :param content_type: MIME 类型
+ """
+ async with await self._request(
+ "PUT", key=key, payload=data, content_type=content_type,
+ ) as response:
+ if response.status not in (200, 201):
+ body = await response.text()
+ raise S3APIError(
+ f"上传失败: {self._bucket_name}/{key}, "
+ f"状态: {response.status}, {body}"
+ )
+ l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
+
+ async def upload_file_streaming(
+ self,
+ key: str,
+ data_stream: AsyncIterator[bytes],
+ content_length: int,
+ content_type: str | None = None,
+ ) -> None:
+ """
+ 流式上传文件(大文件,避免全部加载到内存)
+
+ :param key: S3 对象键
+ :param data_stream: 异步字节流迭代器
+ :param content_length: 数据总长度(必须准确)
+ :param content_type: MIME 类型
+ """
+ async with await self._request_streaming(
+ "PUT", key=key, data_stream=data_stream,
+ content_length=content_length, content_type=content_type,
+ ) as response:
+ if response.status not in (200, 201):
+ body = await response.text()
+ raise S3APIError(
+ f"流式上传失败: {self._bucket_name}/{key}, "
+ f"状态: {response.status}, {body}"
+ )
+ l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
+
+ async def download_file(self, key: str) -> bytes:
+ """
+ 下载文件
+
+ :param key: S3 对象键
+ :return: 文件内容
+ """
+ async with await self._request("GET", key=key) as response:
+ if response.status != 200:
+ body = await response.text()
+ raise S3APIError(
+ f"下载失败: {self._bucket_name}/{key}, "
+ f"状态: {response.status}, {body}"
+ )
+ data = await response.read()
+ l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
+ return data
+
+ async def delete_file(self, key: str) -> None:
+ """
+ 删除文件
+
+ :param key: S3 对象键
+ """
+ async with await self._request("DELETE", key=key) as response:
+ if response.status in (200, 204):
+ l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
+ else:
+ body = await response.text()
+ raise S3APIError(
+ f"删除失败: {self._bucket_name}/{key}, "
+ f"状态: {response.status}, {body}"
+ )
+
+ async def file_exists(self, key: str) -> bool:
+ """
+ 检查文件是否存在
+
+ :param key: S3 对象键
+ :return: 是否存在
+ """
+ async with await self._request("HEAD", key=key) as response:
+ if response.status == 200:
+ return True
+ elif response.status == 404:
+ return False
+ else:
+ raise S3APIError(
+ f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
+ )
+
+ async def get_file_size(self, key: str) -> int:
+ """
+ 获取文件大小
+
+ :param key: S3 对象键
+ :return: 文件大小(字节)
+ """
+ async with await self._request("HEAD", key=key) as response:
+ if response.status != 200:
+ raise S3APIError(
+ f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
+ )
+ return int(response.headers.get("Content-Length", 0))
+
+ # ==================== Multipart Upload ====================
+
+ async def create_multipart_upload(
+ self,
+ key: str,
+ content_type: str = 'application/octet-stream',
+ ) -> str:
+ """
+ 创建分片上传任务
+
+ :param key: S3 对象键
+ :param content_type: MIME 类型
+ :return: Upload ID
+ """
+ async with await self._request(
+ "POST",
+ key=key,
+ query_params={"uploads": ""},
+ content_type=content_type,
+ ) as response:
+ if response.status != 200:
+ body = await response.text()
+ raise S3MultipartUploadError(
+ f"创建分片上传失败: {self._bucket_name}/{key}, "
+ f"状态: {response.status}, {body}"
+ )
+
+ body = await response.text()
+ root = ET.fromstring(body)
+
+ # 查找 UploadId 元素(支持命名空间)
+ upload_id_elem = root.find("UploadId")
+ if upload_id_elem is None:
+ upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
+ if upload_id_elem is None or not upload_id_elem.text:
+ raise S3MultipartUploadError(
+ f"创建分片上传响应中未找到 UploadId: {body}"
+ )
+
+ upload_id = upload_id_elem.text
+ l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
+ return upload_id
+
+ async def upload_part(
+ self,
+ key: str,
+ upload_id: str,
+ part_number: int,
+ data: bytes,
+ ) -> str:
+ """
+ 上传单个分片
+
+ :param key: S3 对象键
+ :param upload_id: 分片上传 ID
+ :param part_number: 分片编号(从 1 开始)
+ :param data: 分片数据
+ :return: ETag
+ """
+ async with await self._request(
+ "PUT",
+ key=key,
+ query_params={
+ "partNumber": str(part_number),
+ "uploadId": upload_id,
+ },
+ payload=data,
+ ) as response:
+ if response.status != 200:
+ body = await response.text()
+ raise S3MultipartUploadError(
+ f"上传分片失败: {self._bucket_name}/{key}, "
+ f"part={part_number}, 状态: {response.status}, {body}"
+ )
+
+ etag = response.headers.get("ETag", "").strip('"')
+ l.debug(
+ f"S3 分片上传成功: {self._bucket_name}/{key}, "
+ f"part={part_number}, etag={etag}"
+ )
+ return etag
+
+ async def complete_multipart_upload(
+ self,
+ key: str,
+ upload_id: str,
+ parts: list[tuple[int, str]],
+ ) -> None:
+ """
+ 完成分片上传
+
+ :param key: S3 对象键
+ :param upload_id: 分片上传 ID
+ :param parts: 分片列表 [(part_number, etag)]
+ """
+ # 按 part_number 排序
+ parts_sorted = sorted(parts, key=lambda p: p[0])
+
+ # 构建 CompleteMultipartUpload XML
+ xml_parts = ''.join(
+ f"{pn}{etag}"
+ for pn, etag in parts_sorted
+ )
+ payload = f'{xml_parts}'
+ payload_bytes = payload.encode('utf-8')
+
+ async with await self._request(
+ "POST",
+ key=key,
+ query_params={"uploadId": upload_id},
+ payload=payload_bytes,
+ content_type="application/xml",
+ ) as response:
+ if response.status != 200:
+ body = await response.text()
+ raise S3MultipartUploadError(
+ f"完成分片上传失败: {self._bucket_name}/{key}, "
+ f"状态: {response.status}, {body}"
+ )
+ l.info(
+ f"S3 分片上传已完成: {self._bucket_name}/{key}, "
+ f"共 {len(parts)} 个分片"
+ )
+
+ async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
+ """
+ 取消分片上传
+
+ :param key: S3 对象键
+ :param upload_id: 分片上传 ID
+ """
+ async with await self._request(
+ "DELETE",
+ key=key,
+ query_params={"uploadId": upload_id},
+ ) as response:
+ if response.status in (200, 204):
+ l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
+ else:
+ body = await response.text()
+ l.warning(
+ f"取消分片上传失败: {self._bucket_name}/{key}, "
+ f"状态: {response.status}, {body}"
+ )
+
+ # ==================== 预签名 URL ====================
+
+ def generate_presigned_url(
+ self,
+ key: str,
+ method: Literal['GET', 'PUT'] = 'GET',
+ expires_in: int = 3600,
+ filename: str | None = None,
+ ) -> str:
+ """
+ 生成 S3 预签名 URL(AWS Signature V4 Query String)
+
+ :param key: S3 对象键
+ :param method: HTTP 方法(GET 下载,PUT 上传)
+ :param expires_in: URL 有效期(秒)
+ :param filename: 文件名(GET 请求时设置 Content-Disposition)
+ :return: 预签名 URL
+ """
+ current_time = datetime.now(timezone.utc)
+ amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
+ date_stamp = current_time.strftime("%Y%m%d")
+
+ credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
+ credential = f"{self._access_key}/{credential_scope}"
+
+ uri = self._build_uri(key)
+ effective_host = self._get_effective_host()
+
+ query_params: dict[str, str] = {
+ 'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
+ 'X-Amz-Credential': credential,
+ 'X-Amz-Date': amz_date,
+ 'X-Amz-Expires': str(expires_in),
+ 'X-Amz-SignedHeaders': 'host',
+ }
+
+ # GET 请求时添加 Content-Disposition
+ if method == "GET" and filename:
+ encoded_filename = quote(filename, safe='')
+ query_params['response-content-disposition'] = (
+ f"attachment; filename*=UTF-8''{encoded_filename}"
+ )
+
+ canonical_query_string = "&".join(
+ f"{quote(k, safe='')}={quote(v, safe='')}"
+ for k, v in sorted(query_params.items())
+ )
+
+ canonical_headers = f"host:{effective_host}\n"
+ signed_headers = "host"
+ payload_hash = "UNSIGNED-PAYLOAD"
+
+ canonical_request = (
+ f"{method}\n"
+ f"{uri}\n"
+ f"{canonical_query_string}\n"
+ f"{canonical_headers}\n"
+ f"{signed_headers}\n"
+ f"{payload_hash}"
+ )
+
+ algorithm = "AWS4-HMAC-SHA256"
+ string_to_sign = (
+ f"{algorithm}\n"
+ f"{amz_date}\n"
+ f"{credential_scope}\n"
+ f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
+ )
+
+ signing_key = self._get_signature_key(date_stamp)
+ signature = hmac.new(
+ signing_key, string_to_sign.encode(), hashlib.sha256
+ ).hexdigest()
+
+ base_url = self._build_url(uri)
+ return (
+ f"{base_url}?"
+ f"{canonical_query_string}&"
+ f"X-Amz-Signature={signature}"
+ )
+
+ # ==================== 路径生成 ====================
+
+ async def generate_file_path(
+ self,
+ user_id: UUID,
+ original_filename: str,
+ ) -> tuple[str, str, str]:
+ """
+ 根据命名规则生成 S3 文件存储路径
+
+ 与 LocalStorageService.generate_file_path 接口一致。
+
+ :param user_id: 用户UUID
+ :param original_filename: 原始文件名
+ :return: (相对目录路径, 存储文件名, 完整存储路径)
+ """
+ context = NamingContext(
+ user_id=user_id,
+ original_filename=original_filename,
+ )
+
+ # 解析目录规则
+ dir_path = ""
+ if self._policy.dir_name_rule:
+ dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
+
+ # 解析文件名规则
+ if self._policy.auto_rename and self._policy.file_name_rule:
+ storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
+ # 确保有扩展名
+ if '.' in original_filename and '.' not in storage_name:
+ ext = original_filename.rsplit('.', 1)[1]
+ storage_name = f"{storage_name}.{ext}"
+ else:
+ storage_name = original_filename
+
+ # S3 不需要创建目录,直接拼接路径
+ if dir_path:
+ storage_path = f"{dir_path}/{storage_name}"
+ else:
+ storage_path = storage_name
+
+ return dir_path, storage_name, storage_path
diff --git a/sqlmodels/README.md b/sqlmodels/README.md
index 6855961..5550d73 100644
--- a/sqlmodels/README.md
+++ b/sqlmodels/README.md
@@ -954,18 +954,11 @@ class PolicyType(StrEnum):
S3 = "s3" # S3 兼容存储
```
-### StorageType
+### PolicyType
```python
-class StorageType(StrEnum):
+class PolicyType(StrEnum):
LOCAL = "local" # 本地存储
- QINIU = "qiniu" # 七牛云
- TENCENT = "tencent" # 腾讯云
- ALIYUN = "aliyun" # 阿里云
- ONEDRIVE = "onedrive" # OneDrive
- GOOGLE_DRIVE = "google_drive" # Google Drive
- DROPBOX = "dropbox" # Dropbox
- WEBDAV = "webdav" # WebDAV
- REMOTE = "remote" # 远程存储
+ S3 = "s3" # S3 兼容存储
```
### UserStatus
diff --git a/sqlmodels/__init__.py b/sqlmodels/__init__.py
index 97898f5..9feec90 100644
--- a/sqlmodels/__init__.py
+++ b/sqlmodels/__init__.py
@@ -82,6 +82,7 @@ from .object import (
ObjectPropertyResponse,
ObjectRenameRequest,
ObjectResponse,
+ ObjectSwitchPolicyRequest,
ObjectType,
PolicyResponse,
UploadChunkResponse,
@@ -100,7 +101,10 @@ from .object import (
from .physical_file import PhysicalFile, PhysicalFileBase
from .uri import DiskNextURI, FileSystemNamespace
from .order import Order, OrderStatus, OrderType
-from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
+from .policy import (
+ Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
+ PolicyType, PolicySummary, PolicyUpdateRequest,
+)
from .redeem import Redeem, RedeemType
from .report import Report, ReportReason
from .setting import (
@@ -116,7 +120,7 @@ from .share import (
from .source_link import SourceLink
from .storage_pack import StoragePack
from .tag import Tag, TagType
-from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
+from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
from .webdav import (
WebDAV, WebDAVBase,
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
diff --git a/sqlmodels/group.py b/sqlmodels/group.py
index 8bea70d..3b7f9a6 100644
--- a/sqlmodels/group.py
+++ b/sqlmodels/group.py
@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
from uuid import UUID
+from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
@@ -260,7 +261,7 @@ class Group(GroupBase, UUIDTableBaseMixin):
name: str = Field(max_length=255, unique=True)
"""用户组名"""
- max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
+ max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""最大存储空间(字节)"""
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
diff --git a/sqlmodels/object.py b/sqlmodels/object.py
index e642e7c..130479d 100644
--- a/sqlmodels/object.py
+++ b/sqlmodels/object.py
@@ -9,6 +9,8 @@ from sqlmodel import Field, Relationship, CheckConstraint, Index, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
+from .policy import PolicyType
+
if TYPE_CHECKING:
from .user import User
from .policy import Policy
@@ -23,18 +25,6 @@ class ObjectType(StrEnum):
FILE = "file"
FOLDER = "folder"
-class StorageType(StrEnum):
- """存储类型枚举"""
- LOCAL = "local"
- QINIU = "qiniu"
- TENCENT = "tencent"
- ALIYUN = "aliyun"
- ONEDRIVE = "onedrive"
- GOOGLE_DRIVE = "google_drive"
- DROPBOX = "dropbox"
- WEBDAV = "webdav"
- REMOTE = "remote"
-
class FileMetadataBase(SQLModelBase):
"""文件元数据基础模型"""
@@ -156,7 +146,7 @@ class PolicyResponse(SQLModelBase):
name: str
"""策略名称"""
- type: StorageType
+ type: PolicyType
"""存储类型"""
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
@@ -624,6 +614,12 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
storage_path: str | None = Field(default=None, max_length=512)
"""文件存储路径"""
+ s3_upload_id: str | None = Field(default=None, max_length=256)
+ """S3 Multipart Upload ID(仅 S3 策略使用)"""
+
+ s3_part_etags: str | None = None
+ """S3 已上传分片的 ETag 列表,JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
+
expires_at: datetime
"""会话过期时间"""
@@ -732,6 +728,16 @@ class CreateFileRequest(SQLModelBase):
"""存储策略UUID,不指定则使用父目录的策略"""
+class ObjectSwitchPolicyRequest(SQLModelBase):
+ """切换对象存储策略请求"""
+
+ policy_id: UUID
+ """目标存储策略UUID"""
+
+ is_migrate_existing: bool = False
+ """(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
+
+
# ==================== 对象操作相关 DTO ====================
class ObjectCopyRequest(SQLModelBase):
diff --git a/sqlmodels/policy.py b/sqlmodels/policy.py
index c65953f..4d841a3 100644
--- a/sqlmodels/policy.py
+++ b/sqlmodels/policy.py
@@ -102,6 +102,94 @@ class PolicySummary(SQLModelBase):
"""是否私有"""
+class PolicyCreateRequest(PolicyBase):
+ """创建存储策略请求 DTO,包含 PolicyOptions 扁平字段"""
+
+ # PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
+ token: str | None = None
+ """访问令牌"""
+
+ file_type: str | None = None
+ """允许的文件类型"""
+
+ mimetype: str | None = Field(default=None, max_length=127)
+ """MIME类型"""
+
+ od_redirect: str | None = Field(default=None, max_length=255)
+ """OneDrive重定向地址"""
+
+ chunk_size: int = Field(default=52428800, ge=1)
+ """分片上传大小(字节),默认50MB"""
+
+ s3_path_style: bool = False
+ """是否使用S3路径风格"""
+
+ s3_region: str = Field(default='us-east-1', max_length=64)
+ """S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
+
+
+class PolicyUpdateRequest(SQLModelBase):
+ """更新存储策略请求 DTO(所有字段可选)"""
+
+ name: str | None = Field(default=None, max_length=255)
+ """策略名称"""
+
+ server: str | None = Field(default=None, max_length=255)
+ """服务器地址"""
+
+ bucket_name: str | None = Field(default=None, max_length=255)
+ """存储桶名称"""
+
+ is_private: bool | None = None
+ """是否为私有空间"""
+
+ base_url: str | None = Field(default=None, max_length=255)
+ """访问文件的基础URL"""
+
+ access_key: str | None = None
+ """Access Key"""
+
+ secret_key: str | None = None
+ """Secret Key"""
+
+ max_size: int | None = Field(default=None, ge=0)
+ """允许上传的最大文件尺寸(字节)"""
+
+ auto_rename: bool | None = None
+ """是否自动重命名"""
+
+ dir_name_rule: str | None = Field(default=None, max_length=255)
+ """目录命名规则"""
+
+ file_name_rule: str | None = Field(default=None, max_length=255)
+ """文件命名规则"""
+
+ is_origin_link_enable: bool | None = None
+ """是否开启源链接访问"""
+
+ # PolicyOptions 字段
+ token: str | None = None
+ """访问令牌"""
+
+ file_type: str | None = None
+ """允许的文件类型"""
+
+ mimetype: str | None = Field(default=None, max_length=127)
+ """MIME类型"""
+
+ od_redirect: str | None = Field(default=None, max_length=255)
+ """OneDrive重定向地址"""
+
+ chunk_size: int | None = Field(default=None, ge=1)
+ """分片上传大小(字节)"""
+
+ s3_path_style: bool | None = None
+ """是否使用S3路径风格"""
+
+ s3_region: str | None = Field(default=None, max_length=64)
+ """S3 区域"""
+
+
# ==================== 数据库模型 ====================
@@ -126,6 +214,9 @@ class PolicyOptionsBase(SQLModelBase):
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否使用S3路径风格"""
+ s3_region: str = Field(default='us-east-1', max_length=64, sa_column_kwargs={"server_default": "'us-east-1'"})
+ """S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
+
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
"""存储策略选项模型(与Policy一对一关联)"""
diff --git a/sqlmodels/task.py b/sqlmodels/task.py
index 980c3f8..52f64f6 100644
--- a/sqlmodels/task.py
+++ b/sqlmodels/task.py
@@ -26,8 +26,8 @@ class TaskStatus(StrEnum):
class TaskType(StrEnum):
"""任务类型枚举"""
- # [TODO] 补充具体任务类型
- pass
+ POLICY_MIGRATE = "policy_migrate"
+ """存储策略迁移"""
# ==================== DTO 模型 ====================
@@ -39,7 +39,7 @@ class TaskSummaryBase(SQLModelBase):
id: int
"""任务ID"""
- type: int
+ type: TaskType
"""任务类型"""
status: TaskStatus
@@ -91,7 +91,14 @@ class TaskPropsBase(SQLModelBase):
file_ids: str | None = None
"""文件ID列表(逗号分隔)"""
- # [TODO] 根据业务需求补充更多字段
+ source_policy_id: UUID | None = None
+ """源存储策略UUID"""
+
+ dest_policy_id: UUID | None = None
+ """目标存储策略UUID"""
+
+ object_id: UUID | None = None
+ """关联的对象UUID"""
class TaskProps(TaskPropsBase, TableBaseMixin):
@@ -99,7 +106,7 @@ class TaskProps(TaskPropsBase, TableBaseMixin):
task_id: int = Field(
foreign_key="task.id",
- primary_key=True,
+ unique=True,
ondelete="CASCADE"
)
"""关联的任务ID"""
@@ -121,8 +128,8 @@ class Task(SQLModelBase, TableBaseMixin):
status: TaskStatus = Field(default=TaskStatus.QUEUED)
"""任务状态"""
- type: int = Field(default=0)
- """任务类型 [TODO] 待定义枚举"""
+ type: TaskType
+ """任务类型"""
progress: int = Field(default=0, ge=0, le=100)
"""任务进度(0-100)"""
diff --git a/sqlmodels/user.py b/sqlmodels/user.py
index 198f7f2..9693a93 100644
--- a/sqlmodels/user.py
+++ b/sqlmodels/user.py
@@ -4,7 +4,7 @@ from typing import Literal, TYPE_CHECKING, TypeVar
from uuid import UUID
from pydantic import BaseModel
-from sqlalchemy import BinaryExpression, ClauseElement, and_
+from sqlalchemy import BigInteger, BinaryExpression, ClauseElement, and_
from sqlmodel import Field, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
@@ -473,7 +473,7 @@ class User(UserBase, UUIDTableBaseMixin):
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
- storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
+ storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}, ge=0)
"""已用存储空间(字节)"""
avatar: str = Field(default="default", max_length=255)