feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s
Some checks failed
Test / test (push) Failing after 2m21s
- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,9 +13,11 @@ from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
import orjson
|
||||
import whatthepatch
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from starlette.responses import Response
|
||||
from loguru import logger as l
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from whatthepatch.exceptions import HunkApplyException
|
||||
@@ -44,7 +46,9 @@ from sqlmodels import (
|
||||
User,
|
||||
WopiSessionResponse,
|
||||
)
|
||||
from service.storage import LocalStorageService, adjust_user_storage
|
||||
import orjson
|
||||
|
||||
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
|
||||
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
|
||||
from utils.JWT.wopi_token import create_wopi_token
|
||||
from utils import http_exceptions
|
||||
@@ -184,6 +188,13 @@ async def create_upload_session(
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
|
||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
||||
if request.policy_id:
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
if request.policy_id not in {p.id for p in group.policies}:
|
||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
||||
|
||||
# 验证文件大小限制
|
||||
_check_policy_size_limit(policy, request.file_size)
|
||||
|
||||
@@ -210,6 +221,7 @@ async def create_upload_session(
|
||||
|
||||
# 生成存储路径
|
||||
storage_path: str | None = None
|
||||
s3_upload_id: str | None = None
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||
@@ -217,8 +229,25 @@ async def create_upload_session(
|
||||
original_filename=request.file_name,
|
||||
)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = S3StorageService(
|
||||
policy,
|
||||
region=options.s3_region if options else 'us-east-1',
|
||||
is_path_style=options.s3_path_style if options else False,
|
||||
)
|
||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
||||
user_id=user.id,
|
||||
original_filename=request.file_name,
|
||||
)
|
||||
# 多分片时创建 multipart upload
|
||||
if total_chunks > 1:
|
||||
s3_upload_id = await s3_service.create_multipart_upload(
|
||||
storage_path, content_type='application/octet-stream',
|
||||
)
|
||||
|
||||
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
|
||||
if request.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, request.file_size, commit=False)
|
||||
|
||||
# 创建上传会话
|
||||
upload_session = UploadSession(
|
||||
@@ -227,6 +256,7 @@ async def create_upload_session(
|
||||
chunk_size=chunk_size,
|
||||
total_chunks=total_chunks,
|
||||
storage_path=storage_path,
|
||||
s3_upload_id=s3_upload_id,
|
||||
expires_at=datetime.now() + timedelta(hours=24),
|
||||
owner_id=user.id,
|
||||
parent_id=request.parent_id,
|
||||
@@ -302,8 +332,38 @@ async def upload_chunk(
|
||||
content,
|
||||
offset,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
if not upload_session.storage_path:
|
||||
raise HTTPException(status_code=500, detail="存储路径丢失")
|
||||
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
|
||||
if upload_session.total_chunks == 1:
|
||||
# 单分片:直接 PUT 上传
|
||||
await s3_service.upload_file(upload_session.storage_path, content)
|
||||
else:
|
||||
# 多分片:UploadPart
|
||||
if not upload_session.s3_upload_id:
|
||||
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
|
||||
|
||||
etag = await s3_service.upload_part(
|
||||
upload_session.storage_path,
|
||||
upload_session.s3_upload_id,
|
||||
chunk_index + 1, # S3 part number 从 1 开始
|
||||
content,
|
||||
)
|
||||
# 追加 ETag 到 s3_part_etags
|
||||
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
|
||||
etags.append([chunk_index + 1, etag])
|
||||
upload_session.s3_part_etags = orjson.dumps(etags).decode()
|
||||
|
||||
# 在 save(commit)前缓存后续需要的属性(commit 后 ORM 对象会过期)
|
||||
policy_type = policy.type
|
||||
s3_upload_id = upload_session.s3_upload_id
|
||||
s3_part_etags = upload_session.s3_part_etags
|
||||
s3_service_for_complete: S3StorageService | None = None
|
||||
if policy_type == PolicyType.S3:
|
||||
s3_service_for_complete = await S3StorageService.from_policy(policy)
|
||||
|
||||
# 更新会话进度
|
||||
upload_session.uploaded_chunks += 1
|
||||
@@ -319,12 +379,26 @@ async def upload_chunk(
|
||||
if is_complete:
|
||||
# 保存 upload_session 属性(commit 后会过期)
|
||||
file_name = upload_session.file_name
|
||||
file_size = upload_session.file_size
|
||||
uploaded_size = upload_session.uploaded_size
|
||||
storage_path = upload_session.storage_path
|
||||
upload_session_id = upload_session.id
|
||||
parent_id = upload_session.parent_id
|
||||
policy_id = upload_session.policy_id
|
||||
|
||||
# S3 多分片上传完成:合并分片
|
||||
if (
|
||||
policy_type == PolicyType.S3
|
||||
and s3_upload_id
|
||||
and s3_part_etags
|
||||
and s3_service_for_complete
|
||||
):
|
||||
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
|
||||
parts = [(int(pn), str(et)) for pn, et in parts_data]
|
||||
await s3_service_for_complete.complete_multipart_upload(
|
||||
storage_path, s3_upload_id, parts,
|
||||
)
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
@@ -355,9 +429,10 @@ async def upload_chunk(
|
||||
commit=False
|
||||
)
|
||||
|
||||
# 更新用户存储配额
|
||||
if uploaded_size > 0:
|
||||
await adjust_user_storage(session, user_id, uploaded_size, commit=False)
|
||||
# 调整存储配额差值(创建会话时已预扣 file_size,这里只补差)
|
||||
size_diff = uploaded_size - file_size
|
||||
if size_diff != 0:
|
||||
await adjust_user_storage(session, user_id, size_diff, commit=False)
|
||||
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
@@ -390,9 +465,25 @@ async def delete_upload_session(
|
||||
|
||||
# 删除临时文件
|
||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
if policy and upload_session.storage_path:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 如果有分片上传,先取消
|
||||
if upload_session.s3_upload_id:
|
||||
await s3_service.abort_multipart_upload(
|
||||
upload_session.storage_path, upload_session.s3_upload_id,
|
||||
)
|
||||
else:
|
||||
# 单分片上传已完成的话,删除已上传的文件
|
||||
if upload_session.uploaded_chunks > 0:
|
||||
await s3_service.delete_file(upload_session.storage_path)
|
||||
|
||||
# 释放预扣的存储空间
|
||||
if upload_session.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
||||
|
||||
# 删除会话记录
|
||||
await UploadSession.delete(session, upload_session)
|
||||
@@ -422,9 +513,22 @@ async def clear_upload_sessions(
|
||||
for upload_session in sessions:
|
||||
# 删除临时文件
|
||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
if policy and upload_session.storage_path:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
if upload_session.s3_upload_id:
|
||||
await s3_service.abort_multipart_upload(
|
||||
upload_session.storage_path, upload_session.s3_upload_id,
|
||||
)
|
||||
elif upload_session.uploaded_chunks > 0:
|
||||
await s3_service.delete_file(upload_session.storage_path)
|
||||
|
||||
# 释放预扣的存储空间
|
||||
if upload_session.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
||||
|
||||
await UploadSession.delete(session, upload_session)
|
||||
deleted_count += 1
|
||||
@@ -486,11 +590,12 @@ async def create_download_token_endpoint(
|
||||
path='/{token}',
|
||||
summary='下载文件',
|
||||
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
|
||||
response_model=None,
|
||||
)
|
||||
async def download_file(
|
||||
session: SessionDep,
|
||||
token: str,
|
||||
) -> FileResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
下载文件端点
|
||||
|
||||
@@ -540,8 +645,15 @@ async def download_file(
|
||||
filename=file_obj.name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 302 重定向到预签名 URL
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
raise HTTPException(status_code=500, detail="不支持的存储类型")
|
||||
|
||||
|
||||
# ==================== 包含子路由 ====================
|
||||
@@ -613,8 +725,13 @@ async def create_empty_file(
|
||||
)
|
||||
await storage_service.create_empty_file(full_path)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
||||
user_id=user_id,
|
||||
original_filename=request.name,
|
||||
)
|
||||
await s3_service.upload_file(storage_path, b'')
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
@@ -798,12 +915,13 @@ async def _validate_source_link(
|
||||
path='/get/{file_id}/{name}',
|
||||
summary='文件外链(直接输出文件数据)',
|
||||
description='通过外链直接获取文件内容,公开访问无需认证。',
|
||||
response_model=None,
|
||||
)
|
||||
async def file_get(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
name: str,
|
||||
) -> FileResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
文件外链端点(直接输出)
|
||||
|
||||
@@ -815,13 +933,6 @@ async def file_get(
|
||||
"""
|
||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 缓存物理路径(save 后对象属性会过期)
|
||||
file_path = physical_file.storage_path
|
||||
|
||||
@@ -829,11 +940,25 @@ async def file_get(
|
||||
link.downloads += 1
|
||||
await link.save(session)
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(file_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
# S3 外链直接输出:302 重定向到预签名 URL
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
file_path, method='GET', expires_in=3600, filename=name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -846,7 +971,7 @@ async def file_source_redirect(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
name: str,
|
||||
) -> FileResponse | RedirectResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
文件外链端点(重定向/直接输出)
|
||||
|
||||
@@ -860,13 +985,6 @@ async def file_source_redirect(
|
||||
"""
|
||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 缓存所有需要的值(save 后对象属性会过期)
|
||||
file_path = physical_file.storage_path
|
||||
is_private = policy.is_private
|
||||
@@ -876,18 +994,36 @@ async def file_source_redirect(
|
||||
link.downloads += 1
|
||||
await link.save(session)
|
||||
|
||||
# 公有存储:302 重定向到 base_url
|
||||
if not is_private and base_url:
|
||||
relative_path = storage_service.get_relative_path(file_path)
|
||||
redirect_url = f"{base_url}/{relative_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(file_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 私有存储或 base_url 为空:通过应用代理文件
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
# 公有存储:302 重定向到 base_url
|
||||
if not is_private and base_url:
|
||||
relative_path = storage_service.get_relative_path(file_path)
|
||||
redirect_url = f"{base_url}/{relative_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
# 私有存储或 base_url 为空:通过应用代理文件
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 公有存储且有 base_url:直接重定向到公开 URL
|
||||
if not is_private and base_url:
|
||||
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
# 私有存储:生成预签名 URL 重定向
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
file_path, method='GET', expires_in=3600, filename=name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -941,11 +1077,15 @@ async def file_content(
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||
# 读取文件内容
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
raw_bytes = await s3_service.download_file(physical_file.storage_path)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
try:
|
||||
content = raw_bytes.decode('utf-8')
|
||||
@@ -1011,11 +1151,15 @@ async def patch_file_content(
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(storage_path)
|
||||
# 读取文件内容
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
raw_bytes = await s3_service.download_file(storage_path)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
# 解码 + 规范化
|
||||
original_text = raw_bytes.decode('utf-8')
|
||||
@@ -1049,7 +1193,10 @@ async def patch_file_content(
|
||||
_check_policy_size_limit(policy, len(new_bytes))
|
||||
|
||||
# 写入文件
|
||||
await storage_service.write_file(storage_path, new_bytes)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
await storage_service.write_file(storage_path, new_bytes)
|
||||
elif policy.type == PolicyType.S3:
|
||||
await s3_service.upload_file(storage_path, new_bytes)
|
||||
|
||||
# 更新数据库
|
||||
owner_id = file_obj.owner_id
|
||||
|
||||
Reference in New Issue
Block a user