feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s
Some checks failed
Test / test (push) Failing after 2m21s
- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -8,11 +8,11 @@ from sqlmodel import Field
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||
ListResponse, Object,
|
||||
Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
|
||||
PolicyUpdateRequest, ResponseBase, ListResponse, Object,
|
||||
)
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService
|
||||
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
|
||||
|
||||
admin_policy_router = APIRouter(
|
||||
prefix='/policy',
|
||||
@@ -67,6 +67,12 @@ class PolicyDetailResponse(SQLModelBase):
|
||||
base_url: str | None
|
||||
"""基础URL"""
|
||||
|
||||
access_key: str | None
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str | None
|
||||
"""Secret Key"""
|
||||
|
||||
max_size: int
|
||||
"""最大文件尺寸"""
|
||||
|
||||
@@ -107,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
|
||||
secret: str
|
||||
"""从机通信密钥"""
|
||||
|
||||
class PolicyCreateRequest(PolicyBase):
|
||||
"""创建存储策略请求 DTO,继承 PolicyBase 中的所有字段"""
|
||||
pass
|
||||
class PolicyTestS3Request(SQLModelBase):
|
||||
"""测试 S3 连接请求 DTO"""
|
||||
|
||||
server: str = Field(max_length=255)
|
||||
"""S3 端点地址"""
|
||||
|
||||
bucket_name: str = Field(max_length=255)
|
||||
"""存储桶名称"""
|
||||
|
||||
access_key: str
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str
|
||||
"""Secret Key"""
|
||||
|
||||
s3_region: str = Field(default='us-east-1', max_length=64)
|
||||
"""S3 区域"""
|
||||
|
||||
s3_path_style: bool = False
|
||||
"""是否使用路径风格"""
|
||||
|
||||
|
||||
class PolicyTestS3Response(SQLModelBase):
|
||||
"""S3 连接测试响应"""
|
||||
|
||||
is_connected: bool
|
||||
"""连接是否成功"""
|
||||
|
||||
message: str
|
||||
"""测试结果消息"""
|
||||
|
||||
|
||||
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
|
||||
|
||||
_OPTIONS_FIELDS: set[str] = {
|
||||
'token', 'file_type', 'mimetype', 'od_redirect',
|
||||
'chunk_size', 's3_path_style', 's3_region',
|
||||
}
|
||||
|
||||
|
||||
@admin_policy_router.get(
|
||||
path='/list',
|
||||
@@ -277,7 +319,20 @@ async def router_policy_add_policy(
|
||||
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
|
||||
|
||||
# 保存到数据库
|
||||
await policy.save(session)
|
||||
policy = await policy.save(session)
|
||||
|
||||
# 创建策略选项
|
||||
options = PolicyOptions(
|
||||
policy_id=policy.id,
|
||||
token=request.token,
|
||||
file_type=request.file_type,
|
||||
mimetype=request.mimetype,
|
||||
od_redirect=request.od_redirect,
|
||||
chunk_size=request.chunk_size,
|
||||
s3_path_style=request.s3_path_style,
|
||||
s3_region=request.s3_region,
|
||||
)
|
||||
await options.save(session)
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/cors',
|
||||
@@ -371,6 +426,8 @@ async def router_policy_get_policy(
|
||||
bucket_name=policy.bucket_name,
|
||||
is_private=policy.is_private,
|
||||
base_url=policy.base_url,
|
||||
access_key=policy.access_key,
|
||||
secret_key=policy.secret_key,
|
||||
max_size=policy.max_size,
|
||||
auto_rename=policy.auto_rename,
|
||||
dir_name_rule=policy.dir_name_rule,
|
||||
@@ -417,4 +474,108 @@ async def router_policy_delete_policy(
|
||||
policy_name = policy.name
|
||||
await Policy.delete(session, policy)
|
||||
|
||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||
|
||||
|
||||
@admin_policy_router.patch(
|
||||
path='/{policy_id}',
|
||||
summary='更新存储策略',
|
||||
description='更新存储策略配置。策略类型创建后不可更改。',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_policy_update_policy(
|
||||
session: SessionDep,
|
||||
policy_id: UUID,
|
||||
request: PolicyUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新存储策略端点
|
||||
|
||||
功能:
|
||||
- 更新策略基础字段和扩展选项
|
||||
- 策略类型(type)不可更改
|
||||
|
||||
认证:
|
||||
- 需要管理员权限
|
||||
|
||||
:param session: 数据库会话
|
||||
:param policy_id: 存储策略UUID
|
||||
:param request: 更新请求
|
||||
"""
|
||||
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
|
||||
# 检查名称唯一性(如果要更新名称)
|
||||
if request.name and request.name != policy.name:
|
||||
existing = await Policy.get(session, Policy.name == request.name)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="策略名称已存在")
|
||||
|
||||
# 分离 Policy 字段和 Options 字段
|
||||
all_data = request.model_dump(exclude_unset=True)
|
||||
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
|
||||
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
|
||||
|
||||
# 更新 Policy 基础字段
|
||||
if policy_data:
|
||||
for key, value in policy_data.items():
|
||||
setattr(policy, key, value)
|
||||
policy = await policy.save(session)
|
||||
|
||||
# 更新或创建 PolicyOptions
|
||||
if options_data:
|
||||
if policy.options:
|
||||
for key, value in options_data.items():
|
||||
setattr(policy.options, key, value)
|
||||
await policy.options.save(session)
|
||||
else:
|
||||
options = PolicyOptions(policy_id=policy.id, **options_data)
|
||||
await options.save(session)
|
||||
|
||||
l.info(f"管理员更新了存储策略: {policy_id}")
|
||||
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/test/s3',
|
||||
summary='测试 S3 连接',
|
||||
description='测试 S3 存储端点的连通性和凭据有效性。',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_policy_test_s3(
|
||||
request: PolicyTestS3Request,
|
||||
) -> PolicyTestS3Response:
|
||||
"""
|
||||
测试 S3 连接端点
|
||||
|
||||
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
|
||||
|
||||
:param request: 测试请求
|
||||
:return: 测试结果
|
||||
"""
|
||||
from service.storage import S3APIError
|
||||
|
||||
# 构造临时 Policy 对象用于创建 S3StorageService
|
||||
temp_policy = Policy(
|
||||
name="__test__",
|
||||
type=PolicyType.S3,
|
||||
server=request.server,
|
||||
bucket_name=request.bucket_name,
|
||||
access_key=request.access_key,
|
||||
secret_key=request.secret_key,
|
||||
)
|
||||
s3_service = S3StorageService(
|
||||
temp_policy,
|
||||
region=request.s3_region,
|
||||
is_path_style=request.s3_path_style,
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 file_exists 发送 HEAD 请求来验证连通性
|
||||
await s3_service.file_exists("__connection_test__")
|
||||
return PolicyTestS3Response(is_connected=True, message="连接成功")
|
||||
except S3APIError as e:
|
||||
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
|
||||
except Exception as e:
|
||||
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")
|
||||
@@ -57,7 +57,7 @@ async def _get_directory_response(
|
||||
policy_response = PolicyResponse(
|
||||
id=policy.id,
|
||||
name=policy.name,
|
||||
type=policy.type.value,
|
||||
type=policy.type,
|
||||
max_size=policy.max_size,
|
||||
)
|
||||
|
||||
@@ -189,6 +189,14 @@ async def router_directory_create(
|
||||
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
||||
|
||||
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
||||
|
||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
||||
if request.policy_id:
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
if request.policy_id not in {p.id for p in group.policies}:
|
||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
||||
|
||||
parent_id = parent.id # 在 save 前保存
|
||||
|
||||
new_folder = Object(
|
||||
|
||||
@@ -13,9 +13,11 @@ from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
import orjson
|
||||
import whatthepatch
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from starlette.responses import Response
|
||||
from loguru import logger as l
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from whatthepatch.exceptions import HunkApplyException
|
||||
@@ -44,7 +46,9 @@ from sqlmodels import (
|
||||
User,
|
||||
WopiSessionResponse,
|
||||
)
|
||||
from service.storage import LocalStorageService, adjust_user_storage
|
||||
import orjson
|
||||
|
||||
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
|
||||
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
|
||||
from utils.JWT.wopi_token import create_wopi_token
|
||||
from utils import http_exceptions
|
||||
@@ -184,6 +188,13 @@ async def create_upload_session(
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
|
||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
||||
if request.policy_id:
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
if request.policy_id not in {p.id for p in group.policies}:
|
||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
||||
|
||||
# 验证文件大小限制
|
||||
_check_policy_size_limit(policy, request.file_size)
|
||||
|
||||
@@ -210,6 +221,7 @@ async def create_upload_session(
|
||||
|
||||
# 生成存储路径
|
||||
storage_path: str | None = None
|
||||
s3_upload_id: str | None = None
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||
@@ -217,8 +229,25 @@ async def create_upload_session(
|
||||
original_filename=request.file_name,
|
||||
)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = S3StorageService(
|
||||
policy,
|
||||
region=options.s3_region if options else 'us-east-1',
|
||||
is_path_style=options.s3_path_style if options else False,
|
||||
)
|
||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
||||
user_id=user.id,
|
||||
original_filename=request.file_name,
|
||||
)
|
||||
# 多分片时创建 multipart upload
|
||||
if total_chunks > 1:
|
||||
s3_upload_id = await s3_service.create_multipart_upload(
|
||||
storage_path, content_type='application/octet-stream',
|
||||
)
|
||||
|
||||
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
|
||||
if request.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, request.file_size, commit=False)
|
||||
|
||||
# 创建上传会话
|
||||
upload_session = UploadSession(
|
||||
@@ -227,6 +256,7 @@ async def create_upload_session(
|
||||
chunk_size=chunk_size,
|
||||
total_chunks=total_chunks,
|
||||
storage_path=storage_path,
|
||||
s3_upload_id=s3_upload_id,
|
||||
expires_at=datetime.now() + timedelta(hours=24),
|
||||
owner_id=user.id,
|
||||
parent_id=request.parent_id,
|
||||
@@ -302,8 +332,38 @@ async def upload_chunk(
|
||||
content,
|
||||
offset,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
if not upload_session.storage_path:
|
||||
raise HTTPException(status_code=500, detail="存储路径丢失")
|
||||
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
|
||||
if upload_session.total_chunks == 1:
|
||||
# 单分片:直接 PUT 上传
|
||||
await s3_service.upload_file(upload_session.storage_path, content)
|
||||
else:
|
||||
# 多分片:UploadPart
|
||||
if not upload_session.s3_upload_id:
|
||||
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
|
||||
|
||||
etag = await s3_service.upload_part(
|
||||
upload_session.storage_path,
|
||||
upload_session.s3_upload_id,
|
||||
chunk_index + 1, # S3 part number 从 1 开始
|
||||
content,
|
||||
)
|
||||
# 追加 ETag 到 s3_part_etags
|
||||
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
|
||||
etags.append([chunk_index + 1, etag])
|
||||
upload_session.s3_part_etags = orjson.dumps(etags).decode()
|
||||
|
||||
# 在 save(commit)前缓存后续需要的属性(commit 后 ORM 对象会过期)
|
||||
policy_type = policy.type
|
||||
s3_upload_id = upload_session.s3_upload_id
|
||||
s3_part_etags = upload_session.s3_part_etags
|
||||
s3_service_for_complete: S3StorageService | None = None
|
||||
if policy_type == PolicyType.S3:
|
||||
s3_service_for_complete = await S3StorageService.from_policy(policy)
|
||||
|
||||
# 更新会话进度
|
||||
upload_session.uploaded_chunks += 1
|
||||
@@ -319,12 +379,26 @@ async def upload_chunk(
|
||||
if is_complete:
|
||||
# 保存 upload_session 属性(commit 后会过期)
|
||||
file_name = upload_session.file_name
|
||||
file_size = upload_session.file_size
|
||||
uploaded_size = upload_session.uploaded_size
|
||||
storage_path = upload_session.storage_path
|
||||
upload_session_id = upload_session.id
|
||||
parent_id = upload_session.parent_id
|
||||
policy_id = upload_session.policy_id
|
||||
|
||||
# S3 多分片上传完成:合并分片
|
||||
if (
|
||||
policy_type == PolicyType.S3
|
||||
and s3_upload_id
|
||||
and s3_part_etags
|
||||
and s3_service_for_complete
|
||||
):
|
||||
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
|
||||
parts = [(int(pn), str(et)) for pn, et in parts_data]
|
||||
await s3_service_for_complete.complete_multipart_upload(
|
||||
storage_path, s3_upload_id, parts,
|
||||
)
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
@@ -355,9 +429,10 @@ async def upload_chunk(
|
||||
commit=False
|
||||
)
|
||||
|
||||
# 更新用户存储配额
|
||||
if uploaded_size > 0:
|
||||
await adjust_user_storage(session, user_id, uploaded_size, commit=False)
|
||||
# 调整存储配额差值(创建会话时已预扣 file_size,这里只补差)
|
||||
size_diff = uploaded_size - file_size
|
||||
if size_diff != 0:
|
||||
await adjust_user_storage(session, user_id, size_diff, commit=False)
|
||||
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
@@ -390,9 +465,25 @@ async def delete_upload_session(
|
||||
|
||||
# 删除临时文件
|
||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
if policy and upload_session.storage_path:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 如果有分片上传,先取消
|
||||
if upload_session.s3_upload_id:
|
||||
await s3_service.abort_multipart_upload(
|
||||
upload_session.storage_path, upload_session.s3_upload_id,
|
||||
)
|
||||
else:
|
||||
# 单分片上传已完成的话,删除已上传的文件
|
||||
if upload_session.uploaded_chunks > 0:
|
||||
await s3_service.delete_file(upload_session.storage_path)
|
||||
|
||||
# 释放预扣的存储空间
|
||||
if upload_session.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
||||
|
||||
# 删除会话记录
|
||||
await UploadSession.delete(session, upload_session)
|
||||
@@ -422,9 +513,22 @@ async def clear_upload_sessions(
|
||||
for upload_session in sessions:
|
||||
# 删除临时文件
|
||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
if policy and upload_session.storage_path:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
if upload_session.s3_upload_id:
|
||||
await s3_service.abort_multipart_upload(
|
||||
upload_session.storage_path, upload_session.s3_upload_id,
|
||||
)
|
||||
elif upload_session.uploaded_chunks > 0:
|
||||
await s3_service.delete_file(upload_session.storage_path)
|
||||
|
||||
# 释放预扣的存储空间
|
||||
if upload_session.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
||||
|
||||
await UploadSession.delete(session, upload_session)
|
||||
deleted_count += 1
|
||||
@@ -486,11 +590,12 @@ async def create_download_token_endpoint(
|
||||
path='/{token}',
|
||||
summary='下载文件',
|
||||
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
|
||||
response_model=None,
|
||||
)
|
||||
async def download_file(
|
||||
session: SessionDep,
|
||||
token: str,
|
||||
) -> FileResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
下载文件端点
|
||||
|
||||
@@ -540,8 +645,15 @@ async def download_file(
|
||||
filename=file_obj.name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 302 重定向到预签名 URL
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
raise HTTPException(status_code=500, detail="不支持的存储类型")
|
||||
|
||||
|
||||
# ==================== 包含子路由 ====================
|
||||
@@ -613,8 +725,13 @@ async def create_empty_file(
|
||||
)
|
||||
await storage_service.create_empty_file(full_path)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
||||
user_id=user_id,
|
||||
original_filename=request.name,
|
||||
)
|
||||
await s3_service.upload_file(storage_path, b'')
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
@@ -798,12 +915,13 @@ async def _validate_source_link(
|
||||
path='/get/{file_id}/{name}',
|
||||
summary='文件外链(直接输出文件数据)',
|
||||
description='通过外链直接获取文件内容,公开访问无需认证。',
|
||||
response_model=None,
|
||||
)
|
||||
async def file_get(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
name: str,
|
||||
) -> FileResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
文件外链端点(直接输出)
|
||||
|
||||
@@ -815,13 +933,6 @@ async def file_get(
|
||||
"""
|
||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 缓存物理路径(save 后对象属性会过期)
|
||||
file_path = physical_file.storage_path
|
||||
|
||||
@@ -829,11 +940,25 @@ async def file_get(
|
||||
link.downloads += 1
|
||||
await link.save(session)
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(file_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
# S3 外链直接输出:302 重定向到预签名 URL
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
file_path, method='GET', expires_in=3600, filename=name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -846,7 +971,7 @@ async def file_source_redirect(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
name: str,
|
||||
) -> FileResponse | RedirectResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
文件外链端点(重定向/直接输出)
|
||||
|
||||
@@ -860,13 +985,6 @@ async def file_source_redirect(
|
||||
"""
|
||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 缓存所有需要的值(save 后对象属性会过期)
|
||||
file_path = physical_file.storage_path
|
||||
is_private = policy.is_private
|
||||
@@ -876,18 +994,36 @@ async def file_source_redirect(
|
||||
link.downloads += 1
|
||||
await link.save(session)
|
||||
|
||||
# 公有存储:302 重定向到 base_url
|
||||
if not is_private and base_url:
|
||||
relative_path = storage_service.get_relative_path(file_path)
|
||||
redirect_url = f"{base_url}/{relative_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(file_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 私有存储或 base_url 为空:通过应用代理文件
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
# 公有存储:302 重定向到 base_url
|
||||
if not is_private and base_url:
|
||||
relative_path = storage_service.get_relative_path(file_path)
|
||||
redirect_url = f"{base_url}/{relative_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
# 私有存储或 base_url 为空:通过应用代理文件
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 公有存储且有 base_url:直接重定向到公开 URL
|
||||
if not is_private and base_url:
|
||||
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
# 私有存储:生成预签名 URL 重定向
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
file_path, method='GET', expires_in=3600, filename=name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -941,11 +1077,15 @@ async def file_content(
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||
# 读取文件内容
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
raw_bytes = await s3_service.download_file(physical_file.storage_path)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
try:
|
||||
content = raw_bytes.decode('utf-8')
|
||||
@@ -1011,11 +1151,15 @@ async def patch_file_content(
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(storage_path)
|
||||
# 读取文件内容
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
raw_bytes = await s3_service.download_file(storage_path)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
# 解码 + 规范化
|
||||
original_text = raw_bytes.decode('utf-8')
|
||||
@@ -1049,7 +1193,10 @@ async def patch_file_content(
|
||||
_check_policy_size_limit(policy, len(new_bytes))
|
||||
|
||||
# 写入文件
|
||||
await storage_service.write_file(storage_path, new_bytes)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
await storage_service.write_file(storage_path, new_bytes)
|
||||
elif policy.type == PolicyType.S3:
|
||||
await s3_service.upload_file(storage_path, new_bytes)
|
||||
|
||||
# 更新数据库
|
||||
owner_id = file_obj.owner_id
|
||||
|
||||
@@ -8,13 +8,14 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
Group,
|
||||
Object,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
@@ -22,18 +23,27 @@ from sqlmodels import (
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
ObjectSwitchPolicyRequest,
|
||||
ObjectType,
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
Task,
|
||||
TaskProps,
|
||||
TaskStatus,
|
||||
TaskSummaryBase,
|
||||
TaskType,
|
||||
User,
|
||||
)
|
||||
from service.storage import (
|
||||
LocalStorageService,
|
||||
adjust_user_storage,
|
||||
copy_object_recursive,
|
||||
migrate_file_with_task,
|
||||
migrate_directory_files,
|
||||
)
|
||||
from service.storage.object import soft_delete_objects
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from utils import http_exceptions
|
||||
|
||||
object_router = APIRouter(
|
||||
@@ -575,3 +585,136 @@ async def router_object_property_detail(
|
||||
response.checksum_md5 = obj.file_metadata.checksum_md5
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@object_router.patch(
|
||||
path='/{object_id}/policy',
|
||||
summary='切换对象存储策略',
|
||||
)
|
||||
async def router_object_switch_policy(
|
||||
session: SessionDep,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
object_id: UUID,
|
||||
request: ObjectSwitchPolicyRequest,
|
||||
) -> TaskSummaryBase:
|
||||
"""
|
||||
切换对象的存储策略
|
||||
|
||||
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
|
||||
目录:更新目录 policy_id(新文件使用新策略);
|
||||
若 is_migrate_existing=True,额外创建后台任务迁移所有已有文件。
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 404: 对象不存在
|
||||
- 403: 无权操作此对象 / 用户组无权使用目标策略
|
||||
- 400: 目标策略与当前相同 / 不能对根目录操作
|
||||
"""
|
||||
user_id = user.id
|
||||
|
||||
# 查找对象
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == object_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not obj:
|
||||
http_exceptions.raise_not_found("对象不存在")
|
||||
if obj.owner_id != user_id:
|
||||
http_exceptions.raise_forbidden("无权操作此对象")
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 根目录不能直接切换策略(应通过子对象或子目录操作)
|
||||
if obj.parent_id is None:
|
||||
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
|
||||
|
||||
# 校验目标策略存在
|
||||
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
|
||||
if not dest_policy:
|
||||
http_exceptions.raise_not_found("目标存储策略不存在")
|
||||
|
||||
# 校验用户组权限
|
||||
group: Group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
allowed_ids = {p.id for p in group.policies}
|
||||
if request.policy_id not in allowed_ids:
|
||||
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
|
||||
|
||||
# 不能切换到相同策略
|
||||
if obj.policy_id == request.policy_id:
|
||||
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
|
||||
|
||||
# 保存必要的属性,避免 save 后对象过期
|
||||
src_policy_id = obj.policy_id
|
||||
obj_id = obj.id
|
||||
obj_is_file = obj.type == ObjectType.FILE
|
||||
dest_policy_id = request.policy_id
|
||||
dest_policy_name = dest_policy.name
|
||||
|
||||
# 创建任务记录
|
||||
task = Task(
|
||||
type=TaskType.POLICY_MIGRATE,
|
||||
status=TaskStatus.QUEUED,
|
||||
user_id=user_id,
|
||||
)
|
||||
task = await task.save(session)
|
||||
task_id = task.id
|
||||
|
||||
task_props = TaskProps(
|
||||
task_id=task_id,
|
||||
source_policy_id=src_policy_id,
|
||||
dest_policy_id=dest_policy_id,
|
||||
object_id=obj_id,
|
||||
)
|
||||
await task_props.save(session)
|
||||
|
||||
if obj_is_file:
|
||||
# 文件:后台迁移
|
||||
async def _run_file_migration() -> None:
|
||||
async with DatabaseManager.session() as bg_session:
|
||||
bg_obj = await Object.get(bg_session, Object.id == obj_id)
|
||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
|
||||
|
||||
background_tasks.add_task(_run_file_migration)
|
||||
else:
|
||||
# 目录:先更新目录自身的 policy_id
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
obj.policy_id = dest_policy_id
|
||||
await obj.save(session)
|
||||
|
||||
if request.is_migrate_existing:
|
||||
# 后台迁移所有已有文件
|
||||
async def _run_dir_migration() -> None:
|
||||
async with DatabaseManager.session() as bg_session:
|
||||
bg_folder = await Object.get(bg_session, Object.id == obj_id)
|
||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
|
||||
|
||||
background_tasks.add_task(_run_dir_migration)
|
||||
else:
|
||||
# 不迁移已有文件,直接完成任务
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.progress = 100
|
||||
await task.save(session)
|
||||
|
||||
# 重新获取 task 以读取最新状态
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
|
||||
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
|
||||
|
||||
return TaskSummaryBase(
|
||||
id=task.id,
|
||||
type=task.type,
|
||||
status=task.status,
|
||||
progress=task.progress,
|
||||
error=task.error,
|
||||
user_id=task.user_id,
|
||||
created_at=task.created_at,
|
||||
updated_at=task.updated_at,
|
||||
)
|
||||
|
||||
@@ -247,11 +247,12 @@ async def router_user_register(
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
# 8. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
# 8. 创建用户根目录(使用用户组关联的第一个存储策略)
|
||||
await session.refresh(default_group, ['policies'])
|
||||
if not default_group.policies:
|
||||
logger.error("默认用户组未关联任何存储策略")
|
||||
http_exceptions.raise_internal_error()
|
||||
default_policy = default_group.policies[0]
|
||||
|
||||
await sqlmodels.Object(
|
||||
name="/",
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlmodels import (
|
||||
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
||||
ChangePasswordRequest,
|
||||
AuthnDetailResponse, AuthnRenameRequest,
|
||||
PolicySummary,
|
||||
)
|
||||
from sqlmodels.color import ThemeColorsBase
|
||||
from sqlmodels.user_authn import UserAuthn
|
||||
@@ -31,16 +32,25 @@ user_settings_router.include_router(file_viewers_router)
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> sqlmodels.ResponseBase:
|
||||
async def router_user_settings_policies(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> list[PolicySummary]:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
获取当前用户所在组可选的存储策略列表
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
返回用户组关联的所有存储策略的摘要信息。
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
return [
|
||||
PolicySummary(
|
||||
id=p.id, name=p.name, type=p.type,
|
||||
server=p.server, max_size=p.max_size, is_private=p.is_private,
|
||||
)
|
||||
for p in group.policies
|
||||
]
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
|
||||
Reference in New Issue
Block a user