feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s

- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-23 13:38:20 +08:00
parent 7200df6d87
commit 3639a31163
19 changed files with 1728 additions and 124 deletions

View File

@@ -3,6 +3,7 @@
提供文件存储相关的服务,包括:
- 本地存储服务
- S3 存储服务
- 命名规则解析器
- 存储异常定义
"""
@@ -11,6 +12,8 @@ from .exceptions import (
FileReadError,
FileWriteError,
InvalidPathError,
S3APIError,
S3MultipartUploadError,
StorageException,
StorageFileNotFoundError,
UploadSessionExpiredError,
@@ -25,4 +28,6 @@ from .object import (
permanently_delete_objects,
restore_objects,
soft_delete_objects,
)
)
from .migrate import migrate_file_with_task, migrate_directory_files
from .s3_storage import S3StorageService

View File

@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
class InvalidPathError(StorageException):
"""无效的路径"""
pass
class S3APIError(StorageException):
"""S3 API 请求错误"""
pass
class S3MultipartUploadError(S3APIError):
"""S3 分片上传错误"""
pass

291
service/storage/migrate.py Normal file
View File

@@ -0,0 +1,291 @@
"""
存储策略迁移服务
提供跨存储策略的文件迁移功能:
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
"""
from uuid import UUID
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.task import Task, TaskStatus
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
async def _get_storage_service(
policy: Policy,
) -> LocalStorageService | S3StorageService:
"""
根据策略类型创建对应的存储服务实例
:param policy: 存储策略
:return: 存储服务实例
"""
if policy.type == PolicyType.LOCAL:
return LocalStorageService(policy)
elif policy.type == PolicyType.S3:
return await S3StorageService.from_policy(policy)
else:
raise ValueError(f"不支持的存储策略类型: {policy.type}")
async def _read_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> bytes:
"""
从存储服务读取文件内容
:param service: 存储服务实例
:param storage_path: 文件存储路径
:return: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
return await service.read_file(storage_path)
else:
return await service.download_file(storage_path)
async def _write_file_to_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
data: bytes,
) -> None:
"""
将文件内容写入存储服务
:param service: 存储服务实例
:param storage_path: 文件存储路径
:param data: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
await service.write_file(storage_path, data)
else:
await service.upload_file(storage_path, data)
async def _delete_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> None:
"""
从存储服务删除文件
:param service: 存储服务实例
:param storage_path: 文件存储路径
"""
if isinstance(service, LocalStorageService):
await service.delete_file(storage_path)
else:
await service.delete_file(storage_path)
async def migrate_single_file(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
) -> None:
"""
将单个文件对象从当前存储策略迁移到目标策略
流程:
1. 获取源物理文件和存储服务
2. 读取源文件内容
3. 在目标存储中生成新路径并写入
4. 创建新的 PhysicalFile 记录
5. 更新 Object 的 policy_id 和 physical_file_id
6. 旧 PhysicalFile 引用计数 -1如为 0 则删除源物理文件
:param session: 数据库会话
:param obj: 待迁移的文件对象(必须为文件类型)
:param dest_policy: 目标存储策略
"""
if obj.type != ObjectType.FILE:
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
# 获取源策略和物理文件
src_policy: Policy = await obj.awaitable_attrs.policy
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
if not old_physical:
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
return
if src_policy.id == dest_policy.id:
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
return
# 1. 从源存储读取文件
src_service = await _get_storage_service(src_policy)
data = await _read_file_from_storage(src_service, old_physical.storage_path)
# 2. 在目标存储生成新路径并写入
dest_service = await _get_storage_service(dest_policy)
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
user_id=obj.owner_id,
original_filename=obj.name,
)
await _write_file_to_storage(dest_service, new_storage_path, data)
# 3. 创建新的 PhysicalFile
new_physical = PhysicalFile(
storage_path=new_storage_path,
size=old_physical.size,
checksum_md5=old_physical.checksum_md5,
policy_id=dest_policy.id,
reference_count=1,
)
new_physical = await new_physical.save(session)
# 4. 更新 Object
obj.policy_id = dest_policy.id
obj.physical_file_id = new_physical.id
await obj.save(session)
# 5. 旧 PhysicalFile 引用计数 -1
old_physical.decrement_reference()
if old_physical.can_be_deleted:
# 删除源存储中的物理文件
try:
await _delete_file_from_storage(src_service, old_physical.storage_path)
except Exception as e:
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
await PhysicalFile.delete(session, old_physical)
else:
await old_physical.save(session)
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name}{dest_policy.name}")
async def migrate_file_with_task(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移单个文件并更新任务状态
:param session: 数据库会话
:param obj: 待迁移的文件对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
await migrate_single_file(session, obj, dest_policy)
task.status = TaskStatus.COMPLETED
task.progress = 100
await task.save(session)
except Exception as e:
l.error(f"文件迁移任务失败: {obj.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
await task.save(session)
async def migrate_directory_files(
session: AsyncSession,
folder: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移目录下所有文件到目标存储策略
递归遍历目录树,将所有文件迁移到目标策略。
子目录的 policy_id 同步更新。
任务进度按文件数比例更新。
:param session: 数据库会话
:param folder: 目录对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
# 收集所有需要迁移的文件
files_to_migrate: list[Object] = []
folders_to_update: list[Object] = []
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
total = len(files_to_migrate)
migrated = 0
errors: list[str] = []
for file_obj in files_to_migrate:
try:
await migrate_single_file(session, file_obj, dest_policy)
migrated += 1
except Exception as e:
error_msg = f"{file_obj.name}: {e}"
l.error(f"迁移文件失败: {error_msg}")
errors.append(error_msg)
# 更新进度
if total > 0:
task.progress = min(99, int(migrated / total * 100))
task = await task.save(session)
# 更新所有子目录的 policy_id
for sub_folder in folders_to_update:
sub_folder.policy_id = dest_policy.id
await sub_folder.save(session)
# 完成任务
if errors:
task.status = TaskStatus.ERROR
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
else:
task.status = TaskStatus.COMPLETED
task.progress = 100
await task.save(session)
l.info(
f"目录迁移完成: {folder.name} ({folder.id}), "
f"成功 {migrated}/{total}, 错误 {len(errors)}"
)
except Exception as e:
l.error(f"目录迁移任务失败: {folder.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
await task.save(session)
async def _collect_objects_recursive(
session: AsyncSession,
folder: Object,
files: list[Object],
folders: list[Object],
) -> None:
"""
递归收集目录下所有文件和子目录
:param session: 数据库会话
:param folder: 当前目录
:param files: 文件列表(输出)
:param folders: 子目录列表(输出)
"""
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
for child in children:
if child.type == ObjectType.FILE:
files.append(child)
elif child.type == ObjectType.FOLDER:
folders.append(child)
await _collect_objects_recursive(session, child, files, folders)

View File

@@ -6,7 +6,8 @@ from sqlalchemy import update as sql_update
from sqlalchemy.sql.functions import func
from middleware.dependencies import SessionDep
from service.storage import LocalStorageService
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
from sqlmodels import (
Object,
PhysicalFile,
@@ -271,10 +272,14 @@ async def permanently_delete_objects(
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL:
if policy:
try:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
@@ -374,10 +379,19 @@ async def delete_object_recursive(
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL:
if policy:
try:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
options = await policy.awaitable_attrs.options
s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")

View File

@@ -0,0 +1,709 @@
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
移植自 foxline-pro-backend-server 项目的 S3APIClient
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
"""
import hashlib
import hmac
import xml.etree.ElementTree as ET
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import ClassVar, Literal
from urllib.parse import quote, urlencode
from uuid import UUID
import aiohttp
from yarl import URL
from loguru import logger as l
from sqlmodels.policy import Policy
from .exceptions import S3APIError, S3MultipartUploadError
from .naming_rule import NamingContext, NamingRuleParser
def _sign(key: bytes, msg: str) -> bytes:
"""HMAC-SHA256 签名"""
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
class S3StorageService:
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息。
使用示例::
service = S3StorageService(policy, region='us-east-1')
await service.upload_file('path/to/file.txt', b'content')
data = await service.download_file('path/to/file.txt')
"""
_http_session: ClassVar[aiohttp.ClientSession | None] = None
def __init__(
self,
policy: Policy,
region: str = 'us-east-1',
is_path_style: bool = False,
):
"""
:param policy: 存储策略server=endpoint_url, bucket_name, access_key, secret_key
:param region: S3 区域
:param is_path_style: 是否使用路径风格 URL
"""
if not policy.server:
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
if not policy.bucket_name:
raise S3APIError("S3 策略必须指定 bucket_name")
if not policy.access_key:
raise S3APIError("S3 策略必须指定 access_key")
if not policy.secret_key:
raise S3APIError("S3 策略必须指定 secret_key")
self._policy = policy
self._endpoint_url = policy.server.rstrip("/")
self._bucket_name = policy.bucket_name
self._access_key = policy.access_key
self._secret_key = policy.secret_key
self._region = region
self._is_path_style = is_path_style
self._base_url = policy.base_url
# 从 endpoint_url 提取 host
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
# ==================== 工厂方法 ====================
@classmethod
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
"""
根据 Policy 异步创建 S3StorageService自动加载 options
:param policy: 存储策略
:return: S3StorageService 实例
"""
options = await policy.awaitable_attrs.options
region = options.s3_region if options else 'us-east-1'
is_path_style = options.s3_path_style if options else False
return cls(policy, region=region, is_path_style=is_path_style)
# ==================== HTTP Session 管理 ====================
@classmethod
async def initialize_session(cls) -> None:
"""初始化全局 aiohttp ClientSession"""
if cls._http_session is None or cls._http_session.closed:
cls._http_session = aiohttp.ClientSession()
l.info("S3StorageService HTTP session 已初始化")
@classmethod
async def close_session(cls) -> None:
"""关闭全局 aiohttp ClientSession"""
if cls._http_session and not cls._http_session.closed:
await cls._http_session.close()
cls._http_session = None
l.info("S3StorageService HTTP session 已关闭")
@classmethod
def _get_session(cls) -> aiohttp.ClientSession:
"""获取 HTTP session"""
if cls._http_session is None or cls._http_session.closed:
# 懒初始化,以防 initialize_session 未被调用
cls._http_session = aiohttp.ClientSession()
return cls._http_session
# ==================== AWS Signature V4 签名 ====================
def _get_signature_key(self, date_stamp: str) -> bytes:
"""生成 AWS Signature V4 签名密钥"""
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
k_region = _sign(k_date, self._region)
k_service = _sign(k_region, "s3")
return _sign(k_service, "aws4_request")
def _create_authorization_header(
self,
method: str,
uri: str,
query_string: str,
headers: dict[str, str],
payload_hash: str,
amz_date: str,
date_stamp: str,
) -> str:
"""创建 AWS Signature V4 授权头"""
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
canonical_headers = "".join(
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
)
canonical_request = (
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
f"{signed_headers}\n{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
string_to_sign = (
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
return (
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
f"SignedHeaders={signed_headers}, Signature={signature}"
)
def _build_headers(
self,
method: str,
uri: str,
query_string: str = "",
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
payload_hash: str | None = None,
host: str | None = None,
) -> dict[str, str]:
"""
构建包含 AWS V4 签名的完整请求头
:param method: HTTP 方法
:param uri: 请求 URI
:param query_string: 查询字符串
:param payload: 请求体字节(用于计算哈希)
:param content_type: Content-Type
:param extra_headers: 额外请求头
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
:param host: Host 头(默认使用 self._host
"""
now_utc = datetime.now(timezone.utc)
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
date_stamp = now_utc.strftime("%Y%m%d")
if payload_hash is None:
payload_hash = hashlib.sha256(payload).hexdigest()
effective_host = host or self._host
headers: dict[str, str] = {
"Host": effective_host,
"X-Amz-Date": amz_date,
"X-Amz-Content-Sha256": payload_hash,
}
if content_type:
headers["Content-Type"] = content_type
if extra_headers:
headers.update(extra_headers)
authorization = self._create_authorization_header(
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
)
headers["Authorization"] = authorization
return headers
# ==================== 内部请求方法 ====================
def _build_uri(self, key: str | None = None) -> str:
"""
构建请求 URI
按 AWS S3 Signature V4 规范对路径进行 URI 编码S3 仅需一次)。
斜杠作为路径分隔符保留不编码。
"""
if self._is_path_style:
if key:
return f"/{self._bucket_name}/{quote(key, safe='/')}"
return f"/{self._bucket_name}"
else:
if key:
return f"/{quote(key, safe='/')}"
return "/"
def _build_url(self, uri: str, query_string: str = "") -> str:
"""构建完整请求 URL"""
if self._is_path_style:
base = self._endpoint_url
else:
# 虚拟主机风格bucket.endpoint
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
base = f"{protocol}{self._bucket_name}.{self._host}"
url = f"{base}{uri}"
if query_string:
url = f"{url}?{query_string}"
return url
def _get_effective_host(self) -> str:
"""获取实际请求的 Host 头"""
if self._is_path_style:
return self._host
return f"{self._bucket_name}.{self._host}"
async def _request(
self,
method: str,
key: str | None = None,
query_params: dict[str, str] | None = None,
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> aiohttp.ClientResponse:
"""发送签名请求"""
uri = self._build_uri(key)
query_string = urlencode(sorted(query_params.items())) if query_params else ""
effective_host = self._get_effective_host()
headers = self._build_headers(
method, uri, query_string, payload, content_type,
extra_headers, host=effective_host,
)
url = self._build_url(uri, query_string)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=payload if payload else None,
)
return response
except Exception as e:
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
async def _request_streaming(
self,
method: str,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> aiohttp.ClientResponse:
"""
发送流式签名请求(大文件上传)
使用 UNSIGNED-PAYLOAD 作为 payload hash。
"""
uri = self._build_uri(key)
effective_host = self._get_effective_host()
headers = self._build_headers(
method,
uri,
query_string="",
content_type=content_type,
extra_headers={"Content-Length": str(content_length)},
payload_hash="UNSIGNED-PAYLOAD",
host=effective_host,
)
url = self._build_url(uri)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=data_stream,
)
return response
except Exception as e:
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
# ==================== 文件操作 ====================
async def upload_file(
self,
key: str,
data: bytes,
content_type: str = 'application/octet-stream',
) -> None:
"""
上传文件
:param key: S3 对象键
:param data: 文件内容
:param content_type: MIME 类型
"""
async with await self._request(
"PUT", key=key, payload=data, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
async def upload_file_streaming(
self,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> None:
"""
流式上传文件(大文件,避免全部加载到内存)
:param key: S3 对象键
:param data_stream: 异步字节流迭代器
:param content_length: 数据总长度(必须准确)
:param content_type: MIME 类型
"""
async with await self._request_streaming(
"PUT", key=key, data_stream=data_stream,
content_length=content_length, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"流式上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
async def download_file(self, key: str) -> bytes:
"""
下载文件
:param key: S3 对象键
:return: 文件内容
"""
async with await self._request("GET", key=key) as response:
if response.status != 200:
body = await response.text()
raise S3APIError(
f"下载失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
data = await response.read()
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
return data
async def delete_file(self, key: str) -> None:
"""
删除文件
:param key: S3 对象键
"""
async with await self._request("DELETE", key=key) as response:
if response.status in (200, 204):
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
else:
body = await response.text()
raise S3APIError(
f"删除失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
async def file_exists(self, key: str) -> bool:
"""
检查文件是否存在
:param key: S3 对象键
:return: 是否存在
"""
async with await self._request("HEAD", key=key) as response:
if response.status == 200:
return True
elif response.status == 404:
return False
else:
raise S3APIError(
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
async def get_file_size(self, key: str) -> int:
"""
获取文件大小
:param key: S3 对象键
:return: 文件大小(字节)
"""
async with await self._request("HEAD", key=key) as response:
if response.status != 200:
raise S3APIError(
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
return int(response.headers.get("Content-Length", 0))
# ==================== Multipart Upload ====================
async def create_multipart_upload(
self,
key: str,
content_type: str = 'application/octet-stream',
) -> str:
"""
创建分片上传任务
:param key: S3 对象键
:param content_type: MIME 类型
:return: Upload ID
"""
async with await self._request(
"POST",
key=key,
query_params={"uploads": ""},
content_type=content_type,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"创建分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
body = await response.text()
root = ET.fromstring(body)
# 查找 UploadId 元素(支持命名空间)
upload_id_elem = root.find("UploadId")
if upload_id_elem is None:
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
if upload_id_elem is None or not upload_id_elem.text:
raise S3MultipartUploadError(
f"创建分片上传响应中未找到 UploadId: {body}"
)
upload_id = upload_id_elem.text
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
return upload_id
async def upload_part(
self,
key: str,
upload_id: str,
part_number: int,
data: bytes,
) -> str:
"""
上传单个分片
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param part_number: 分片编号(从 1 开始)
:param data: 分片数据
:return: ETag
"""
async with await self._request(
"PUT",
key=key,
query_params={
"partNumber": str(part_number),
"uploadId": upload_id,
},
payload=data,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"上传分片失败: {self._bucket_name}/{key}, "
f"part={part_number}, 状态: {response.status}, {body}"
)
etag = response.headers.get("ETag", "").strip('"')
l.debug(
f"S3 分片上传成功: {self._bucket_name}/{key}, "
f"part={part_number}, etag={etag}"
)
return etag
async def complete_multipart_upload(
self,
key: str,
upload_id: str,
parts: list[tuple[int, str]],
) -> None:
"""
完成分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param parts: 分片列表 [(part_number, etag)]
"""
# 按 part_number 排序
parts_sorted = sorted(parts, key=lambda p: p[0])
# 构建 CompleteMultipartUpload XML
xml_parts = ''.join(
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
for pn, etag in parts_sorted
)
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
payload_bytes = payload.encode('utf-8')
async with await self._request(
"POST",
key=key,
query_params={"uploadId": upload_id},
payload=payload_bytes,
content_type="application/xml",
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"完成分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.info(
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
f"{len(parts)} 个分片"
)
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
"""
取消分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
"""
async with await self._request(
"DELETE",
key=key,
query_params={"uploadId": upload_id},
) as response:
if response.status in (200, 204):
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
else:
body = await response.text()
l.warning(
f"取消分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
# ==================== 预签名 URL ====================
def generate_presigned_url(
self,
key: str,
method: Literal['GET', 'PUT'] = 'GET',
expires_in: int = 3600,
filename: str | None = None,
) -> str:
"""
生成 S3 预签名 URLAWS Signature V4 Query String
:param key: S3 对象键
:param method: HTTP 方法GET 下载PUT 上传)
:param expires_in: URL 有效期(秒)
:param filename: 文件名GET 请求时设置 Content-Disposition
:return: 预签名 URL
"""
current_time = datetime.now(timezone.utc)
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
date_stamp = current_time.strftime("%Y%m%d")
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
credential = f"{self._access_key}/{credential_scope}"
uri = self._build_uri(key)
effective_host = self._get_effective_host()
query_params: dict[str, str] = {
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
'X-Amz-Credential': credential,
'X-Amz-Date': amz_date,
'X-Amz-Expires': str(expires_in),
'X-Amz-SignedHeaders': 'host',
}
# GET 请求时添加 Content-Disposition
if method == "GET" and filename:
encoded_filename = quote(filename, safe='')
query_params['response-content-disposition'] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
canonical_query_string = "&".join(
f"{quote(k, safe='')}={quote(v, safe='')}"
for k, v in sorted(query_params.items())
)
canonical_headers = f"host:{effective_host}\n"
signed_headers = "host"
payload_hash = "UNSIGNED-PAYLOAD"
canonical_request = (
f"{method}\n"
f"{uri}\n"
f"{canonical_query_string}\n"
f"{canonical_headers}\n"
f"{signed_headers}\n"
f"{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
string_to_sign = (
f"{algorithm}\n"
f"{amz_date}\n"
f"{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
base_url = self._build_url(uri)
return (
f"{base_url}?"
f"{canonical_query_string}&"
f"X-Amz-Signature={signature}"
)
# ==================== 路径生成 ====================
async def generate_file_path(
self,
user_id: UUID,
original_filename: str,
) -> tuple[str, str, str]:
"""
根据命名规则生成 S3 文件存储路径
与 LocalStorageService.generate_file_path 接口一致。
:param user_id: 用户UUID
:param original_filename: 原始文件名
:return: (相对目录路径, 存储文件名, 完整存储路径)
"""
context = NamingContext(
user_id=user_id,
original_filename=original_filename,
)
# 解析目录规则
dir_path = ""
if self._policy.dir_name_rule:
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
# 解析文件名规则
if self._policy.auto_rename and self._policy.file_name_rule:
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
# 确保有扩展名
if '.' in original_filename and '.' not in storage_name:
ext = original_filename.rsplit('.', 1)[1]
storage_name = f"{storage_name}.{ext}"
else:
storage_name = original_filename
# S3 不需要创建目录,直接拼接路径
if dir_path:
storage_path = f"{dir_path}/{storage_name}"
else:
storage_path = storage_name
return dir_path, storage_name, storage_path