feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s
Some checks failed
Test / test (push) Failing after 2m21s
- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
提供文件存储相关的服务,包括:
|
||||
- 本地存储服务
|
||||
- S3 存储服务
|
||||
- 命名规则解析器
|
||||
- 存储异常定义
|
||||
"""
|
||||
@@ -11,6 +12,8 @@ from .exceptions import (
|
||||
FileReadError,
|
||||
FileWriteError,
|
||||
InvalidPathError,
|
||||
S3APIError,
|
||||
S3MultipartUploadError,
|
||||
StorageException,
|
||||
StorageFileNotFoundError,
|
||||
UploadSessionExpiredError,
|
||||
@@ -25,4 +28,6 @@ from .object import (
|
||||
permanently_delete_objects,
|
||||
restore_objects,
|
||||
soft_delete_objects,
|
||||
)
|
||||
)
|
||||
from .migrate import migrate_file_with_task, migrate_directory_files
|
||||
from .s3_storage import S3StorageService
|
||||
@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
|
||||
class InvalidPathError(StorageException):
|
||||
"""无效的路径"""
|
||||
pass
|
||||
|
||||
|
||||
class S3APIError(StorageException):
|
||||
"""S3 API 请求错误"""
|
||||
pass
|
||||
|
||||
|
||||
class S3MultipartUploadError(S3APIError):
|
||||
"""S3 分片上传错误"""
|
||||
pass
|
||||
|
||||
291
service/storage/migrate.py
Normal file
291
service/storage/migrate.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
存储策略迁移服务
|
||||
|
||||
提供跨存储策略的文件迁移功能:
|
||||
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
|
||||
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.physical_file import PhysicalFile
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.task import Task, TaskStatus
|
||||
|
||||
from .local_storage import LocalStorageService
|
||||
from .s3_storage import S3StorageService
|
||||
|
||||
|
||||
async def _get_storage_service(
|
||||
policy: Policy,
|
||||
) -> LocalStorageService | S3StorageService:
|
||||
"""
|
||||
根据策略类型创建对应的存储服务实例
|
||||
|
||||
:param policy: 存储策略
|
||||
:return: 存储服务实例
|
||||
"""
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
return LocalStorageService(policy)
|
||||
elif policy.type == PolicyType.S3:
|
||||
return await S3StorageService.from_policy(policy)
|
||||
else:
|
||||
raise ValueError(f"不支持的存储策略类型: {policy.type}")
|
||||
|
||||
|
||||
async def _read_file_from_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
) -> bytes:
|
||||
"""
|
||||
从存储服务读取文件内容
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
:return: 文件二进制内容
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
return await service.read_file(storage_path)
|
||||
else:
|
||||
return await service.download_file(storage_path)
|
||||
|
||||
|
||||
async def _write_file_to_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
data: bytes,
|
||||
) -> None:
|
||||
"""
|
||||
将文件内容写入存储服务
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
:param data: 文件二进制内容
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
await service.write_file(storage_path, data)
|
||||
else:
|
||||
await service.upload_file(storage_path, data)
|
||||
|
||||
|
||||
async def _delete_file_from_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
从存储服务删除文件
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
await service.delete_file(storage_path)
|
||||
else:
|
||||
await service.delete_file(storage_path)
|
||||
|
||||
|
||||
async def migrate_single_file(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
dest_policy: Policy,
|
||||
) -> None:
|
||||
"""
|
||||
将单个文件对象从当前存储策略迁移到目标策略
|
||||
|
||||
流程:
|
||||
1. 获取源物理文件和存储服务
|
||||
2. 读取源文件内容
|
||||
3. 在目标存储中生成新路径并写入
|
||||
4. 创建新的 PhysicalFile 记录
|
||||
5. 更新 Object 的 policy_id 和 physical_file_id
|
||||
6. 旧 PhysicalFile 引用计数 -1,如为 0 则删除源物理文件
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 待迁移的文件对象(必须为文件类型)
|
||||
:param dest_policy: 目标存储策略
|
||||
"""
|
||||
if obj.type != ObjectType.FILE:
|
||||
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
|
||||
|
||||
# 获取源策略和物理文件
|
||||
src_policy: Policy = await obj.awaitable_attrs.policy
|
||||
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
|
||||
|
||||
if not old_physical:
|
||||
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
|
||||
return
|
||||
|
||||
if src_policy.id == dest_policy.id:
|
||||
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
|
||||
return
|
||||
|
||||
# 1. 从源存储读取文件
|
||||
src_service = await _get_storage_service(src_policy)
|
||||
data = await _read_file_from_storage(src_service, old_physical.storage_path)
|
||||
|
||||
# 2. 在目标存储生成新路径并写入
|
||||
dest_service = await _get_storage_service(dest_policy)
|
||||
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
|
||||
user_id=obj.owner_id,
|
||||
original_filename=obj.name,
|
||||
)
|
||||
await _write_file_to_storage(dest_service, new_storage_path, data)
|
||||
|
||||
# 3. 创建新的 PhysicalFile
|
||||
new_physical = PhysicalFile(
|
||||
storage_path=new_storage_path,
|
||||
size=old_physical.size,
|
||||
checksum_md5=old_physical.checksum_md5,
|
||||
policy_id=dest_policy.id,
|
||||
reference_count=1,
|
||||
)
|
||||
new_physical = await new_physical.save(session)
|
||||
|
||||
# 4. 更新 Object
|
||||
obj.policy_id = dest_policy.id
|
||||
obj.physical_file_id = new_physical.id
|
||||
await obj.save(session)
|
||||
|
||||
# 5. 旧 PhysicalFile 引用计数 -1
|
||||
old_physical.decrement_reference()
|
||||
if old_physical.can_be_deleted:
|
||||
# 删除源存储中的物理文件
|
||||
try:
|
||||
await _delete_file_from_storage(src_service, old_physical.storage_path)
|
||||
except Exception as e:
|
||||
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
|
||||
await PhysicalFile.delete(session, old_physical)
|
||||
else:
|
||||
await old_physical.save(session)
|
||||
|
||||
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name} → {dest_policy.name}")
|
||||
|
||||
|
||||
async def migrate_file_with_task(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
dest_policy: Policy,
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""
|
||||
迁移单个文件并更新任务状态
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 待迁移的文件对象
|
||||
:param dest_policy: 目标存储策略
|
||||
:param task: 关联的任务记录
|
||||
"""
|
||||
try:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = 0
|
||||
task = await task.save(session)
|
||||
|
||||
await migrate_single_file(session, obj, dest_policy)
|
||||
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.progress = 100
|
||||
await task.save(session)
|
||||
except Exception as e:
|
||||
l.error(f"文件迁移任务失败: {obj.id}: {e}")
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = str(e)[:500]
|
||||
await task.save(session)
|
||||
|
||||
|
||||
async def migrate_directory_files(
|
||||
session: AsyncSession,
|
||||
folder: Object,
|
||||
dest_policy: Policy,
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""
|
||||
迁移目录下所有文件到目标存储策略
|
||||
|
||||
递归遍历目录树,将所有文件迁移到目标策略。
|
||||
子目录的 policy_id 同步更新。
|
||||
任务进度按文件数比例更新。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param folder: 目录对象
|
||||
:param dest_policy: 目标存储策略
|
||||
:param task: 关联的任务记录
|
||||
"""
|
||||
try:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = 0
|
||||
task = await task.save(session)
|
||||
|
||||
# 收集所有需要迁移的文件
|
||||
files_to_migrate: list[Object] = []
|
||||
folders_to_update: list[Object] = []
|
||||
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
|
||||
|
||||
total = len(files_to_migrate)
|
||||
migrated = 0
|
||||
errors: list[str] = []
|
||||
|
||||
for file_obj in files_to_migrate:
|
||||
try:
|
||||
await migrate_single_file(session, file_obj, dest_policy)
|
||||
migrated += 1
|
||||
except Exception as e:
|
||||
error_msg = f"{file_obj.name}: {e}"
|
||||
l.error(f"迁移文件失败: {error_msg}")
|
||||
errors.append(error_msg)
|
||||
|
||||
# 更新进度
|
||||
if total > 0:
|
||||
task.progress = min(99, int(migrated / total * 100))
|
||||
task = await task.save(session)
|
||||
|
||||
# 更新所有子目录的 policy_id
|
||||
for sub_folder in folders_to_update:
|
||||
sub_folder.policy_id = dest_policy.id
|
||||
await sub_folder.save(session)
|
||||
|
||||
# 完成任务
|
||||
if errors:
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
|
||||
else:
|
||||
task.status = TaskStatus.COMPLETED
|
||||
|
||||
task.progress = 100
|
||||
await task.save(session)
|
||||
|
||||
l.info(
|
||||
f"目录迁移完成: {folder.name} ({folder.id}), "
|
||||
f"成功 {migrated}/{total}, 错误 {len(errors)}"
|
||||
)
|
||||
except Exception as e:
|
||||
l.error(f"目录迁移任务失败: {folder.id}: {e}")
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = str(e)[:500]
|
||||
await task.save(session)
|
||||
|
||||
|
||||
async def _collect_objects_recursive(
|
||||
session: AsyncSession,
|
||||
folder: Object,
|
||||
files: list[Object],
|
||||
folders: list[Object],
|
||||
) -> None:
|
||||
"""
|
||||
递归收集目录下所有文件和子目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param folder: 当前目录
|
||||
:param files: 文件列表(输出)
|
||||
:param folders: 子目录列表(输出)
|
||||
"""
|
||||
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
|
||||
|
||||
for child in children:
|
||||
if child.type == ObjectType.FILE:
|
||||
files.append(child)
|
||||
elif child.type == ObjectType.FOLDER:
|
||||
folders.append(child)
|
||||
await _collect_objects_recursive(session, child, files, folders)
|
||||
@@ -6,7 +6,8 @@ from sqlalchemy import update as sql_update
|
||||
from sqlalchemy.sql.functions import func
|
||||
from middleware.dependencies import SessionDep
|
||||
|
||||
from service.storage import LocalStorageService
|
||||
from .local_storage import LocalStorageService
|
||||
from .s3_storage import S3StorageService
|
||||
from sqlmodels import (
|
||||
Object,
|
||||
PhysicalFile,
|
||||
@@ -271,10 +272,14 @@ async def permanently_delete_objects(
|
||||
if physical_file.can_be_deleted:
|
||||
# 物理删除文件
|
||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL:
|
||||
if policy:
|
||||
try:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
await s3_service.delete_file(physical_file.storage_path)
|
||||
l.debug(f"物理文件已删除: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||
@@ -374,10 +379,19 @@ async def delete_object_recursive(
|
||||
if physical_file.can_be_deleted:
|
||||
# 物理删除文件
|
||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL:
|
||||
if policy:
|
||||
try:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
options = await policy.awaitable_attrs.options
|
||||
s3_service = S3StorageService(
|
||||
policy,
|
||||
region=options.s3_region if options else 'us-east-1',
|
||||
is_path_style=options.s3_path_style if options else False,
|
||||
)
|
||||
await s3_service.delete_file(physical_file.storage_path)
|
||||
l.debug(f"物理文件已删除: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||
|
||||
709
service/storage/s3_storage.py
Normal file
709
service/storage/s3_storage.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""
|
||||
S3 存储服务
|
||||
|
||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
|
||||
|
||||
移植自 foxline-pro-backend-server 项目的 S3APIClient,
|
||||
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import xml.etree.ElementTree as ET
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import ClassVar, Literal
|
||||
from urllib.parse import quote, urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
from loguru import logger as l
|
||||
|
||||
from sqlmodels.policy import Policy
|
||||
from .exceptions import S3APIError, S3MultipartUploadError
|
||||
from .naming_rule import NamingContext, NamingRuleParser
|
||||
|
||||
|
||||
def _sign(key: bytes, msg: str) -> bytes:
|
||||
"""HMAC-SHA256 签名"""
|
||||
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
|
||||
|
||||
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
||||
|
||||
|
||||
class S3StorageService:
|
||||
"""
|
||||
S3 存储服务
|
||||
|
||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||
从 Policy 配置中读取 S3 连接信息。
|
||||
|
||||
使用示例::
|
||||
|
||||
service = S3StorageService(policy, region='us-east-1')
|
||||
await service.upload_file('path/to/file.txt', b'content')
|
||||
data = await service.download_file('path/to/file.txt')
|
||||
"""
|
||||
|
||||
_http_session: ClassVar[aiohttp.ClientSession | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Policy,
|
||||
region: str = 'us-east-1',
|
||||
is_path_style: bool = False,
|
||||
):
|
||||
"""
|
||||
:param policy: 存储策略(server=endpoint_url, bucket_name, access_key, secret_key)
|
||||
:param region: S3 区域
|
||||
:param is_path_style: 是否使用路径风格 URL
|
||||
"""
|
||||
if not policy.server:
|
||||
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
|
||||
if not policy.bucket_name:
|
||||
raise S3APIError("S3 策略必须指定 bucket_name")
|
||||
if not policy.access_key:
|
||||
raise S3APIError("S3 策略必须指定 access_key")
|
||||
if not policy.secret_key:
|
||||
raise S3APIError("S3 策略必须指定 secret_key")
|
||||
|
||||
self._policy = policy
|
||||
self._endpoint_url = policy.server.rstrip("/")
|
||||
self._bucket_name = policy.bucket_name
|
||||
self._access_key = policy.access_key
|
||||
self._secret_key = policy.secret_key
|
||||
self._region = region
|
||||
self._is_path_style = is_path_style
|
||||
self._base_url = policy.base_url
|
||||
|
||||
# 从 endpoint_url 提取 host
|
||||
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
|
||||
# ==================== 工厂方法 ====================
|
||||
|
||||
@classmethod
|
||||
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
|
||||
"""
|
||||
根据 Policy 异步创建 S3StorageService(自动加载 options)
|
||||
|
||||
:param policy: 存储策略
|
||||
:return: S3StorageService 实例
|
||||
"""
|
||||
options = await policy.awaitable_attrs.options
|
||||
region = options.s3_region if options else 'us-east-1'
|
||||
is_path_style = options.s3_path_style if options else False
|
||||
return cls(policy, region=region, is_path_style=is_path_style)
|
||||
|
||||
# ==================== HTTP Session 管理 ====================
|
||||
|
||||
@classmethod
|
||||
async def initialize_session(cls) -> None:
|
||||
"""初始化全局 aiohttp ClientSession"""
|
||||
if cls._http_session is None or cls._http_session.closed:
|
||||
cls._http_session = aiohttp.ClientSession()
|
||||
l.info("S3StorageService HTTP session 已初始化")
|
||||
|
||||
@classmethod
|
||||
async def close_session(cls) -> None:
|
||||
"""关闭全局 aiohttp ClientSession"""
|
||||
if cls._http_session and not cls._http_session.closed:
|
||||
await cls._http_session.close()
|
||||
cls._http_session = None
|
||||
l.info("S3StorageService HTTP session 已关闭")
|
||||
|
||||
@classmethod
|
||||
def _get_session(cls) -> aiohttp.ClientSession:
|
||||
"""获取 HTTP session"""
|
||||
if cls._http_session is None or cls._http_session.closed:
|
||||
# 懒初始化,以防 initialize_session 未被调用
|
||||
cls._http_session = aiohttp.ClientSession()
|
||||
return cls._http_session
|
||||
|
||||
# ==================== AWS Signature V4 签名 ====================
|
||||
|
||||
def _get_signature_key(self, date_stamp: str) -> bytes:
|
||||
"""生成 AWS Signature V4 签名密钥"""
|
||||
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
|
||||
k_region = _sign(k_date, self._region)
|
||||
k_service = _sign(k_region, "s3")
|
||||
return _sign(k_service, "aws4_request")
|
||||
|
||||
def _create_authorization_header(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
query_string: str,
|
||||
headers: dict[str, str],
|
||||
payload_hash: str,
|
||||
amz_date: str,
|
||||
date_stamp: str,
|
||||
) -> str:
|
||||
"""创建 AWS Signature V4 授权头"""
|
||||
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
|
||||
canonical_headers = "".join(
|
||||
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
|
||||
)
|
||||
canonical_request = (
|
||||
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
|
||||
f"{signed_headers}\n{payload_hash}"
|
||||
)
|
||||
|
||||
algorithm = "AWS4-HMAC-SHA256"
|
||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||
string_to_sign = (
|
||||
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
|
||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||
)
|
||||
|
||||
signing_key = self._get_signature_key(date_stamp)
|
||||
signature = hmac.new(
|
||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return (
|
||||
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
|
||||
f"SignedHeaders={signed_headers}, Signature={signature}"
|
||||
)
|
||||
|
||||
def _build_headers(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
query_string: str = "",
|
||||
payload: bytes = b"",
|
||||
content_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
payload_hash: str | None = None,
|
||||
host: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
构建包含 AWS V4 签名的完整请求头
|
||||
|
||||
:param method: HTTP 方法
|
||||
:param uri: 请求 URI
|
||||
:param query_string: 查询字符串
|
||||
:param payload: 请求体字节(用于计算哈希)
|
||||
:param content_type: Content-Type
|
||||
:param extra_headers: 额外请求头
|
||||
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
|
||||
:param host: Host 头(默认使用 self._host)
|
||||
"""
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
|
||||
date_stamp = now_utc.strftime("%Y%m%d")
|
||||
|
||||
if payload_hash is None:
|
||||
payload_hash = hashlib.sha256(payload).hexdigest()
|
||||
|
||||
effective_host = host or self._host
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Host": effective_host,
|
||||
"X-Amz-Date": amz_date,
|
||||
"X-Amz-Content-Sha256": payload_hash,
|
||||
}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
authorization = self._create_authorization_header(
|
||||
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
|
||||
)
|
||||
headers["Authorization"] = authorization
|
||||
return headers
|
||||
|
||||
# ==================== 内部请求方法 ====================
|
||||
|
||||
def _build_uri(self, key: str | None = None) -> str:
|
||||
"""
|
||||
构建请求 URI
|
||||
|
||||
按 AWS S3 Signature V4 规范对路径进行 URI 编码(S3 仅需一次)。
|
||||
斜杠作为路径分隔符保留不编码。
|
||||
"""
|
||||
if self._is_path_style:
|
||||
if key:
|
||||
return f"/{self._bucket_name}/{quote(key, safe='/')}"
|
||||
return f"/{self._bucket_name}"
|
||||
else:
|
||||
if key:
|
||||
return f"/{quote(key, safe='/')}"
|
||||
return "/"
|
||||
|
||||
def _build_url(self, uri: str, query_string: str = "") -> str:
|
||||
"""构建完整请求 URL"""
|
||||
if self._is_path_style:
|
||||
base = self._endpoint_url
|
||||
else:
|
||||
# 虚拟主机风格:bucket.endpoint
|
||||
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
|
||||
base = f"{protocol}{self._bucket_name}.{self._host}"
|
||||
|
||||
url = f"{base}{uri}"
|
||||
if query_string:
|
||||
url = f"{url}?{query_string}"
|
||||
return url
|
||||
|
||||
def _get_effective_host(self) -> str:
|
||||
"""获取实际请求的 Host 头"""
|
||||
if self._is_path_style:
|
||||
return self._host
|
||||
return f"{self._bucket_name}.{self._host}"
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
key: str | None = None,
|
||||
query_params: dict[str, str] | None = None,
|
||||
payload: bytes = b"",
|
||||
content_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""发送签名请求"""
|
||||
uri = self._build_uri(key)
|
||||
query_string = urlencode(sorted(query_params.items())) if query_params else ""
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
headers = self._build_headers(
|
||||
method, uri, query_string, payload, content_type,
|
||||
extra_headers, host=effective_host,
|
||||
)
|
||||
|
||||
url = self._build_url(uri, query_string)
|
||||
|
||||
try:
|
||||
response = await self._get_session().request(
|
||||
method, URL(url, encoded=True),
|
||||
headers=headers, data=payload if payload else None,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
|
||||
|
||||
async def _request_streaming(
|
||||
self,
|
||||
method: str,
|
||||
key: str,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
content_length: int,
|
||||
content_type: str | None = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""
|
||||
发送流式签名请求(大文件上传)
|
||||
|
||||
使用 UNSIGNED-PAYLOAD 作为 payload hash。
|
||||
"""
|
||||
uri = self._build_uri(key)
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
headers = self._build_headers(
|
||||
method,
|
||||
uri,
|
||||
query_string="",
|
||||
content_type=content_type,
|
||||
extra_headers={"Content-Length": str(content_length)},
|
||||
payload_hash="UNSIGNED-PAYLOAD",
|
||||
host=effective_host,
|
||||
)
|
||||
|
||||
url = self._build_url(uri)
|
||||
|
||||
try:
|
||||
response = await self._get_session().request(
|
||||
method, URL(url, encoded=True),
|
||||
headers=headers, data=data_stream,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
|
||||
|
||||
# ==================== 文件操作 ====================
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
key: str,
|
||||
data: bytes,
|
||||
content_type: str = 'application/octet-stream',
|
||||
) -> None:
|
||||
"""
|
||||
上传文件
|
||||
|
||||
:param key: S3 对象键
|
||||
:param data: 文件内容
|
||||
:param content_type: MIME 类型
|
||||
"""
|
||||
async with await self._request(
|
||||
"PUT", key=key, payload=data, content_type=content_type,
|
||||
) as response:
|
||||
if response.status not in (200, 201):
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
|
||||
|
||||
async def upload_file_streaming(
|
||||
self,
|
||||
key: str,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
content_length: int,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
流式上传文件(大文件,避免全部加载到内存)
|
||||
|
||||
:param key: S3 对象键
|
||||
:param data_stream: 异步字节流迭代器
|
||||
:param content_length: 数据总长度(必须准确)
|
||||
:param content_type: MIME 类型
|
||||
"""
|
||||
async with await self._request_streaming(
|
||||
"PUT", key=key, data_stream=data_stream,
|
||||
content_length=content_length, content_type=content_type,
|
||||
) as response:
|
||||
if response.status not in (200, 201):
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"流式上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
|
||||
|
||||
async def download_file(self, key: str) -> bytes:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 文件内容
|
||||
"""
|
||||
async with await self._request("GET", key=key) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"下载失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
data = await response.read()
|
||||
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
|
||||
return data
|
||||
|
||||
async def delete_file(self, key: str) -> None:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
:param key: S3 对象键
|
||||
"""
|
||||
async with await self._request("DELETE", key=key) as response:
|
||||
if response.status in (200, 204):
|
||||
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
|
||||
else:
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"删除失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
async def file_exists(self, key: str) -> bool:
|
||||
"""
|
||||
检查文件是否存在
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 是否存在
|
||||
"""
|
||||
async with await self._request("HEAD", key=key) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
elif response.status == 404:
|
||||
return False
|
||||
else:
|
||||
raise S3APIError(
|
||||
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||
)
|
||||
|
||||
async def get_file_size(self, key: str) -> int:
|
||||
"""
|
||||
获取文件大小
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 文件大小(字节)
|
||||
"""
|
||||
async with await self._request("HEAD", key=key) as response:
|
||||
if response.status != 200:
|
||||
raise S3APIError(
|
||||
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||
)
|
||||
return int(response.headers.get("Content-Length", 0))
|
||||
|
||||
# ==================== Multipart Upload ====================
|
||||
|
||||
async def create_multipart_upload(
|
||||
self,
|
||||
key: str,
|
||||
content_type: str = 'application/octet-stream',
|
||||
) -> str:
|
||||
"""
|
||||
创建分片上传任务
|
||||
|
||||
:param key: S3 对象键
|
||||
:param content_type: MIME 类型
|
||||
:return: Upload ID
|
||||
"""
|
||||
async with await self._request(
|
||||
"POST",
|
||||
key=key,
|
||||
query_params={"uploads": ""},
|
||||
content_type=content_type,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"创建分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
body = await response.text()
|
||||
root = ET.fromstring(body)
|
||||
|
||||
# 查找 UploadId 元素(支持命名空间)
|
||||
upload_id_elem = root.find("UploadId")
|
||||
if upload_id_elem is None:
|
||||
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
|
||||
if upload_id_elem is None or not upload_id_elem.text:
|
||||
raise S3MultipartUploadError(
|
||||
f"创建分片上传响应中未找到 UploadId: {body}"
|
||||
)
|
||||
|
||||
upload_id = upload_id_elem.text
|
||||
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
|
||||
return upload_id
|
||||
|
||||
async def upload_part(
|
||||
self,
|
||||
key: str,
|
||||
upload_id: str,
|
||||
part_number: int,
|
||||
data: bytes,
|
||||
) -> str:
|
||||
"""
|
||||
上传单个分片
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
:param part_number: 分片编号(从 1 开始)
|
||||
:param data: 分片数据
|
||||
:return: ETag
|
||||
"""
|
||||
async with await self._request(
|
||||
"PUT",
|
||||
key=key,
|
||||
query_params={
|
||||
"partNumber": str(part_number),
|
||||
"uploadId": upload_id,
|
||||
},
|
||||
payload=data,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"上传分片失败: {self._bucket_name}/{key}, "
|
||||
f"part={part_number}, 状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
etag = response.headers.get("ETag", "").strip('"')
|
||||
l.debug(
|
||||
f"S3 分片上传成功: {self._bucket_name}/{key}, "
|
||||
f"part={part_number}, etag={etag}"
|
||||
)
|
||||
return etag
|
||||
|
||||
async def complete_multipart_upload(
|
||||
self,
|
||||
key: str,
|
||||
upload_id: str,
|
||||
parts: list[tuple[int, str]],
|
||||
) -> None:
|
||||
"""
|
||||
完成分片上传
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
:param parts: 分片列表 [(part_number, etag)]
|
||||
"""
|
||||
# 按 part_number 排序
|
||||
parts_sorted = sorted(parts, key=lambda p: p[0])
|
||||
|
||||
# 构建 CompleteMultipartUpload XML
|
||||
xml_parts = ''.join(
|
||||
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
|
||||
for pn, etag in parts_sorted
|
||||
)
|
||||
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
|
||||
payload_bytes = payload.encode('utf-8')
|
||||
|
||||
async with await self._request(
|
||||
"POST",
|
||||
key=key,
|
||||
query_params={"uploadId": upload_id},
|
||||
payload=payload_bytes,
|
||||
content_type="application/xml",
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"完成分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.info(
|
||||
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
|
||||
f"共 {len(parts)} 个分片"
|
||||
)
|
||||
|
||||
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
|
||||
"""
|
||||
取消分片上传
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
"""
|
||||
async with await self._request(
|
||||
"DELETE",
|
||||
key=key,
|
||||
query_params={"uploadId": upload_id},
|
||||
) as response:
|
||||
if response.status in (200, 204):
|
||||
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
|
||||
else:
|
||||
body = await response.text()
|
||||
l.warning(
|
||||
f"取消分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
# ==================== 预签名 URL ====================
|
||||
|
||||
def generate_presigned_url(
|
||||
self,
|
||||
key: str,
|
||||
method: Literal['GET', 'PUT'] = 'GET',
|
||||
expires_in: int = 3600,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
生成 S3 预签名 URL(AWS Signature V4 Query String)
|
||||
|
||||
:param key: S3 对象键
|
||||
:param method: HTTP 方法(GET 下载,PUT 上传)
|
||||
:param expires_in: URL 有效期(秒)
|
||||
:param filename: 文件名(GET 请求时设置 Content-Disposition)
|
||||
:return: 预签名 URL
|
||||
"""
|
||||
current_time = datetime.now(timezone.utc)
|
||||
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
|
||||
date_stamp = current_time.strftime("%Y%m%d")
|
||||
|
||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||
credential = f"{self._access_key}/{credential_scope}"
|
||||
|
||||
uri = self._build_uri(key)
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
query_params: dict[str, str] = {
|
||||
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
|
||||
'X-Amz-Credential': credential,
|
||||
'X-Amz-Date': amz_date,
|
||||
'X-Amz-Expires': str(expires_in),
|
||||
'X-Amz-SignedHeaders': 'host',
|
||||
}
|
||||
|
||||
# GET 请求时添加 Content-Disposition
|
||||
if method == "GET" and filename:
|
||||
encoded_filename = quote(filename, safe='')
|
||||
query_params['response-content-disposition'] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
canonical_query_string = "&".join(
|
||||
f"{quote(k, safe='')}={quote(v, safe='')}"
|
||||
for k, v in sorted(query_params.items())
|
||||
)
|
||||
|
||||
canonical_headers = f"host:{effective_host}\n"
|
||||
signed_headers = "host"
|
||||
payload_hash = "UNSIGNED-PAYLOAD"
|
||||
|
||||
canonical_request = (
|
||||
f"{method}\n"
|
||||
f"{uri}\n"
|
||||
f"{canonical_query_string}\n"
|
||||
f"{canonical_headers}\n"
|
||||
f"{signed_headers}\n"
|
||||
f"{payload_hash}"
|
||||
)
|
||||
|
||||
algorithm = "AWS4-HMAC-SHA256"
|
||||
string_to_sign = (
|
||||
f"{algorithm}\n"
|
||||
f"{amz_date}\n"
|
||||
f"{credential_scope}\n"
|
||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||
)
|
||||
|
||||
signing_key = self._get_signature_key(date_stamp)
|
||||
signature = hmac.new(
|
||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
base_url = self._build_url(uri)
|
||||
return (
|
||||
f"{base_url}?"
|
||||
f"{canonical_query_string}&"
|
||||
f"X-Amz-Signature={signature}"
|
||||
)
|
||||
|
||||
# ==================== 路径生成 ====================
|
||||
|
||||
async def generate_file_path(
|
||||
self,
|
||||
user_id: UUID,
|
||||
original_filename: str,
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
根据命名规则生成 S3 文件存储路径
|
||||
|
||||
与 LocalStorageService.generate_file_path 接口一致。
|
||||
|
||||
:param user_id: 用户UUID
|
||||
:param original_filename: 原始文件名
|
||||
:return: (相对目录路径, 存储文件名, 完整存储路径)
|
||||
"""
|
||||
context = NamingContext(
|
||||
user_id=user_id,
|
||||
original_filename=original_filename,
|
||||
)
|
||||
|
||||
# 解析目录规则
|
||||
dir_path = ""
|
||||
if self._policy.dir_name_rule:
|
||||
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
|
||||
|
||||
# 解析文件名规则
|
||||
if self._policy.auto_rename and self._policy.file_name_rule:
|
||||
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
|
||||
# 确保有扩展名
|
||||
if '.' in original_filename and '.' not in storage_name:
|
||||
ext = original_filename.rsplit('.', 1)[1]
|
||||
storage_name = f"{storage_name}.{ext}"
|
||||
else:
|
||||
storage_name = original_filename
|
||||
|
||||
# S3 不需要创建目录,直接拼接路径
|
||||
if dir_path:
|
||||
storage_path = f"{dir_path}/{storage_name}"
|
||||
else:
|
||||
storage_path = storage_name
|
||||
|
||||
return dir_path, storage_name, storage_path
|
||||
Reference in New Issue
Block a user