diff --git a/main.py b/main.py index 83a6a58..25b24e0 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ from routers import router from routers.dav import dav_app from routers.dav.provider import EventLoopRef from service.redis import RedisManager +from service.storage import S3StorageService from sqlmodels.database_connection import DatabaseManager from sqlmodels.migration import migration from utils import JWT @@ -50,8 +51,10 @@ lifespan.add_startup(_init_db) lifespan.add_startup(migration) lifespan.add_startup(JWT.load_secret_key) lifespan.add_startup(RedisManager.connect) +lifespan.add_startup(S3StorageService.initialize_session) # 添加关闭项 +lifespan.add_shutdown(S3StorageService.close_session) lifespan.add_shutdown(DatabaseManager.close) lifespan.add_shutdown(RedisManager.disconnect) diff --git a/routers/api/v1/admin/policy/__init__.py b/routers/api/v1/admin/policy/__init__.py index 3eaf0ad..8746d73 100644 --- a/routers/api/v1/admin/policy/__init__.py +++ b/routers/api/v1/admin/policy/__init__.py @@ -8,11 +8,11 @@ from sqlmodel import Field from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep from sqlmodels import ( - Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase, - ListResponse, Object, + Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary, + PolicyUpdateRequest, ResponseBase, ListResponse, Object, ) from sqlmodel_ext import SQLModelBase -from service.storage import DirectoryCreationError, LocalStorageService +from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService admin_policy_router = APIRouter( prefix='/policy', @@ -67,6 +67,12 @@ class PolicyDetailResponse(SQLModelBase): base_url: str | None """基础URL""" + access_key: str | None + """Access Key""" + + secret_key: str | None + """Secret Key""" + max_size: int """最大文件尺寸""" @@ -107,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase): secret: str """从机通信密钥""" -class PolicyCreateRequest(PolicyBase): - """创建存储策略请求 DTO,继承 PolicyBase 中的所有字段""" - pass +class PolicyTestS3Request(SQLModelBase): + """测试 S3 连接请求 DTO""" + + server: str = Field(max_length=255) + """S3 端点地址""" + + bucket_name: str = Field(max_length=255) + """存储桶名称""" + + access_key: str + """Access Key""" + + secret_key: str + """Secret Key""" + + s3_region: str = Field(default='us-east-1', max_length=64) + """S3 区域""" + + s3_path_style: bool = False + """是否使用路径风格""" + + +class PolicyTestS3Response(SQLModelBase): + """S3 连接测试响应""" + + is_connected: bool + """连接是否成功""" + + message: str + """测试结果消息""" + + +# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ==================== + +_OPTIONS_FIELDS: set[str] = { + 'token', 'file_type', 'mimetype', 'od_redirect', + 'chunk_size', 's3_path_style', 's3_region', +} + @admin_policy_router.get( path='/list', @@ -277,7 +319,20 @@ async def router_policy_add_policy( raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}") # 保存到数据库 - await policy.save(session) + policy = await policy.save(session) + + # 创建策略选项 + options = PolicyOptions( + policy_id=policy.id, + token=request.token, + file_type=request.file_type, + mimetype=request.mimetype, + od_redirect=request.od_redirect, + chunk_size=request.chunk_size, + s3_path_style=request.s3_path_style, + s3_region=request.s3_region, + ) + await options.save(session) @admin_policy_router.post( path='/cors', @@ -371,6 +426,8 @@ async def router_policy_get_policy( bucket_name=policy.bucket_name, is_private=policy.is_private, base_url=policy.base_url, + access_key=policy.access_key, + secret_key=policy.secret_key, max_size=policy.max_size, auto_rename=policy.auto_rename, dir_name_rule=policy.dir_name_rule, @@ -417,4 +474,108 @@ async def router_policy_delete_policy( policy_name = policy.name await Policy.delete(session, policy) - l.info(f"管理员删除了存储策略: {policy_name}") \ No newline at end of file + l.info(f"管理员删除了存储策略: {policy_name}") + + +@admin_policy_router.patch( + path='/{policy_id}', + summary='更新存储策略', + description='更新存储策略配置。策略类型创建后不可更改。', + dependencies=[Depends(admin_required)], + status_code=204, +) +async def router_policy_update_policy( + session: SessionDep, + policy_id: UUID, + request: PolicyUpdateRequest, +) -> None: + """ + 更新存储策略端点 + + 功能: + - 更新策略基础字段和扩展选项 + - 策略类型(type)不可更改 + + 认证: + - 需要管理员权限 + + :param session: 数据库会话 + :param policy_id: 存储策略UUID + :param request: 更新请求 + """ + policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options) + if not policy: + raise HTTPException(status_code=404, detail="存储策略不存在") + + # 检查名称唯一性(如果要更新名称) + if request.name and request.name != policy.name: + existing = await Policy.get(session, Policy.name == request.name) + if existing: + raise HTTPException(status_code=409, detail="策略名称已存在") + + # 分离 Policy 字段和 Options 字段 + all_data = request.model_dump(exclude_unset=True) + policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS} + options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS} + + # 更新 Policy 基础字段 + if policy_data: + for key, value in policy_data.items(): + setattr(policy, key, value) + policy = await policy.save(session) + + # 更新或创建 PolicyOptions + if options_data: + if policy.options: + for key, value in options_data.items(): + setattr(policy.options, key, value) + await policy.options.save(session) + else: + options = PolicyOptions(policy_id=policy.id, **options_data) + await options.save(session) + + l.info(f"管理员更新了存储策略: {policy_id}") + + +@admin_policy_router.post( + path='/test/s3', + summary='测试 S3 连接', + description='测试 S3 存储端点的连通性和凭据有效性。', + dependencies=[Depends(admin_required)], +) +async def router_policy_test_s3( + request: PolicyTestS3Request, +) -> PolicyTestS3Response: + """ + 测试 S3 连接端点 + + 通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。 + + :param request: 测试请求 + :return: 测试结果 + """ + from service.storage import S3APIError + + # 构造临时 Policy 对象用于创建 S3StorageService + temp_policy = Policy( + name="__test__", + type=PolicyType.S3, + server=request.server, + bucket_name=request.bucket_name, + access_key=request.access_key, + secret_key=request.secret_key, + ) + s3_service = S3StorageService( + temp_policy, + region=request.s3_region, + is_path_style=request.s3_path_style, + ) + + try: + # 使用 file_exists 发送 HEAD 请求来验证连通性 + await s3_service.file_exists("__connection_test__") + return PolicyTestS3Response(is_connected=True, message="连接成功") + except S3APIError as e: + return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}") + except Exception as e: + return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}") \ No newline at end of file diff --git a/routers/api/v1/directory/__init__.py b/routers/api/v1/directory/__init__.py index c703137..ca7218f 100644 --- a/routers/api/v1/directory/__init__.py +++ b/routers/api/v1/directory/__init__.py @@ -57,7 +57,7 @@ async def _get_directory_response( policy_response = PolicyResponse( id=policy.id, name=policy.name, - type=policy.type.value, + type=policy.type, max_size=policy.max_size, ) @@ -189,6 +189,14 @@ async def router_directory_create( raise HTTPException(status_code=409, detail="同名文件或目录已存在") policy_id = request.policy_id if request.policy_id else parent.policy_id + + # 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时) + if request.policy_id: + group = await user.awaitable_attrs.group + await session.refresh(group, ['policies']) + if request.policy_id not in {p.id for p in group.policies}: + raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略") + parent_id = parent.id # 在 save 前保存 new_folder = Object( diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py index ddb1c05..72e54f2 100644 --- a/routers/api/v1/file/__init__.py +++ b/routers/api/v1/file/__init__.py @@ -13,9 +13,11 @@ from datetime import datetime, timedelta from typing import Annotated from uuid import UUID +import orjson import whatthepatch from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi.responses import FileResponse, RedirectResponse +from starlette.responses import Response from loguru import logger as l from sqlmodel_ext import SQLModelBase from whatthepatch.exceptions import HunkApplyException @@ -44,7 +46,9 @@ from sqlmodels import ( User, WopiSessionResponse, ) -from service.storage import LocalStorageService, adjust_user_storage +import orjson + +from service.storage import LocalStorageService, S3StorageService, adjust_user_storage from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL from utils.JWT.wopi_token import create_wopi_token from utils import http_exceptions @@ -184,6 +188,13 @@ async def create_upload_session( if not policy: raise HTTPException(status_code=404, detail="存储策略不存在") + # 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时) + if request.policy_id: + group = await user.awaitable_attrs.group + await session.refresh(group, ['policies']) + if request.policy_id not in {p.id for p in group.policies}: + raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略") + # 验证文件大小限制 _check_policy_size_limit(policy, request.file_size) @@ -210,6 +221,7 @@ async def create_upload_session( # 生成存储路径 storage_path: str | None = None + s3_upload_id: str | None = None if policy.type == PolicyType.LOCAL: storage_service = LocalStorageService(policy) dir_path, storage_name, full_path = await storage_service.generate_file_path( @@ -217,8 +229,25 @@ async def create_upload_session( original_filename=request.file_name, ) storage_path = full_path - else: - raise HTTPException(status_code=501, detail="S3 存储暂未实现") + elif policy.type == PolicyType.S3: + s3_service = S3StorageService( + policy, + region=options.s3_region if options else 'us-east-1', + is_path_style=options.s3_path_style if options else False, + ) + dir_path, storage_name, storage_path = await s3_service.generate_file_path( + user_id=user.id, + original_filename=request.file_name, + ) + # 多分片时创建 multipart upload + if total_chunks > 1: + s3_upload_id = await s3_service.create_multipart_upload( + storage_path, content_type='application/octet-stream', + ) + + # 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额) + if request.file_size > 0: + await adjust_user_storage(session, user.id, request.file_size, commit=False) # 创建上传会话 upload_session = UploadSession( @@ -227,6 +256,7 @@ async def create_upload_session( chunk_size=chunk_size, total_chunks=total_chunks, storage_path=storage_path, + s3_upload_id=s3_upload_id, expires_at=datetime.now() + timedelta(hours=24), owner_id=user.id, parent_id=request.parent_id, @@ -302,8 +332,38 @@ async def upload_chunk( content, offset, ) - else: - raise HTTPException(status_code=501, detail="S3 存储暂未实现") + elif policy.type == PolicyType.S3: + if not upload_session.storage_path: + raise HTTPException(status_code=500, detail="存储路径丢失") + + s3_service = await S3StorageService.from_policy(policy) + + if upload_session.total_chunks == 1: + # 单分片:直接 PUT 上传 + await s3_service.upload_file(upload_session.storage_path, content) + else: + # 多分片:UploadPart + if not upload_session.s3_upload_id: + raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失") + + etag = await s3_service.upload_part( + upload_session.storage_path, + upload_session.s3_upload_id, + chunk_index + 1, # S3 part number 从 1 开始 + content, + ) + # 追加 ETag 到 s3_part_etags + etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]') + etags.append([chunk_index + 1, etag]) + upload_session.s3_part_etags = orjson.dumps(etags).decode() + + # 在 save(commit)前缓存后续需要的属性(commit 后 ORM 对象会过期) + policy_type = policy.type + s3_upload_id = upload_session.s3_upload_id + s3_part_etags = upload_session.s3_part_etags + s3_service_for_complete: S3StorageService | None = None + if policy_type == PolicyType.S3: + s3_service_for_complete = await S3StorageService.from_policy(policy) # 更新会话进度 upload_session.uploaded_chunks += 1 @@ -319,12 +379,26 @@ async def upload_chunk( if is_complete: # 保存 upload_session 属性(commit 后会过期) file_name = upload_session.file_name + file_size = upload_session.file_size uploaded_size = upload_session.uploaded_size storage_path = upload_session.storage_path upload_session_id = upload_session.id parent_id = upload_session.parent_id policy_id = upload_session.policy_id + # S3 多分片上传完成:合并分片 + if ( + policy_type == PolicyType.S3 + and s3_upload_id + and s3_part_etags + and s3_service_for_complete + ): + parts_data: list[list[int | str]] = orjson.loads(s3_part_etags) + parts = [(int(pn), str(et)) for pn, et in parts_data] + await s3_service_for_complete.complete_multipart_upload( + storage_path, s3_upload_id, parts, + ) + # 创建 PhysicalFile 记录 physical_file = PhysicalFile( storage_path=storage_path, @@ -355,9 +429,10 @@ async def upload_chunk( commit=False ) - # 更新用户存储配额 - if uploaded_size > 0: - await adjust_user_storage(session, user_id, uploaded_size, commit=False) + # 调整存储配额差值(创建会话时已预扣 file_size,这里只补差) + size_diff = uploaded_size - file_size + if size_diff != 0: + await adjust_user_storage(session, user_id, size_diff, commit=False) # 统一提交所有更改 await session.commit() @@ -390,9 +465,25 @@ async def delete_upload_session( # 删除临时文件 policy = await Policy.get(session, Policy.id == upload_session.policy_id) - if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: - storage_service = LocalStorageService(policy) - await storage_service.delete_file(upload_session.storage_path) + if policy and upload_session.storage_path: + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + await storage_service.delete_file(upload_session.storage_path) + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + # 如果有分片上传,先取消 + if upload_session.s3_upload_id: + await s3_service.abort_multipart_upload( + upload_session.storage_path, upload_session.s3_upload_id, + ) + else: + # 单分片上传已完成的话,删除已上传的文件 + if upload_session.uploaded_chunks > 0: + await s3_service.delete_file(upload_session.storage_path) + + # 释放预扣的存储空间 + if upload_session.file_size > 0: + await adjust_user_storage(session, user.id, -upload_session.file_size) # 删除会话记录 await UploadSession.delete(session, upload_session) @@ -422,9 +513,22 @@ async def clear_upload_sessions( for upload_session in sessions: # 删除临时文件 policy = await Policy.get(session, Policy.id == upload_session.policy_id) - if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: - storage_service = LocalStorageService(policy) - await storage_service.delete_file(upload_session.storage_path) + if policy and upload_session.storage_path: + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + await storage_service.delete_file(upload_session.storage_path) + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + if upload_session.s3_upload_id: + await s3_service.abort_multipart_upload( + upload_session.storage_path, upload_session.s3_upload_id, + ) + elif upload_session.uploaded_chunks > 0: + await s3_service.delete_file(upload_session.storage_path) + + # 释放预扣的存储空间 + if upload_session.file_size > 0: + await adjust_user_storage(session, user.id, -upload_session.file_size) await UploadSession.delete(session, upload_session) deleted_count += 1 @@ -486,11 +590,12 @@ async def create_download_token_endpoint( path='/{token}', summary='下载文件', description='使用下载令牌下载文件,令牌在有效期内可重复使用。', + response_model=None, ) async def download_file( session: SessionDep, token: str, -) -> FileResponse: +) -> Response: """ 下载文件端点 @@ -540,8 +645,15 @@ async def download_file( filename=file_obj.name, media_type="application/octet-stream", ) + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + # 302 重定向到预签名 URL + presigned_url = s3_service.generate_presigned_url( + storage_path, method='GET', expires_in=3600, filename=file_obj.name, + ) + return RedirectResponse(url=presigned_url, status_code=302) else: - raise HTTPException(status_code=501, detail="S3 存储暂未实现") + raise HTTPException(status_code=500, detail="不支持的存储类型") # ==================== 包含子路由 ==================== @@ -613,8 +725,13 @@ async def create_empty_file( ) await storage_service.create_empty_file(full_path) storage_path = full_path - else: - raise HTTPException(status_code=501, detail="S3 存储暂未实现") + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + dir_path, storage_name, storage_path = await s3_service.generate_file_path( + user_id=user_id, + original_filename=request.name, + ) + await s3_service.upload_file(storage_path, b'') # 创建 PhysicalFile 记录 physical_file = PhysicalFile( @@ -798,12 +915,13 @@ async def _validate_source_link( path='/get/{file_id}/{name}', summary='文件外链(直接输出文件数据)', description='通过外链直接获取文件内容,公开访问无需认证。', + response_model=None, ) async def file_get( session: SessionDep, file_id: UUID, name: str, -) -> FileResponse: +) -> Response: """ 文件外链端点(直接输出) @@ -815,13 +933,6 @@ async def file_get( """ file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) - if policy.type != PolicyType.LOCAL: - http_exceptions.raise_not_implemented("S3 存储暂未实现") - - storage_service = LocalStorageService(policy) - if not await storage_service.file_exists(physical_file.storage_path): - http_exceptions.raise_not_found("物理文件不存在") - # 缓存物理路径(save 后对象属性会过期) file_path = physical_file.storage_path @@ -829,11 +940,25 @@ async def file_get( link.downloads += 1 await link.save(session) - return FileResponse( - path=file_path, - filename=name, - media_type="application/octet-stream", - ) + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + if not await storage_service.file_exists(file_path): + http_exceptions.raise_not_found("物理文件不存在") + + return FileResponse( + path=file_path, + filename=name, + media_type="application/octet-stream", + ) + elif policy.type == PolicyType.S3: + # S3 外链直接输出:302 重定向到预签名 URL + s3_service = await S3StorageService.from_policy(policy) + presigned_url = s3_service.generate_presigned_url( + file_path, method='GET', expires_in=3600, filename=name, + ) + return RedirectResponse(url=presigned_url, status_code=302) + else: + http_exceptions.raise_internal_error("不支持的存储类型") @router.get( @@ -846,7 +971,7 @@ async def file_source_redirect( session: SessionDep, file_id: UUID, name: str, -) -> FileResponse | RedirectResponse: +) -> Response: """ 文件外链端点(重定向/直接输出) @@ -860,13 +985,6 @@ async def file_source_redirect( """ file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) - if policy.type != PolicyType.LOCAL: - http_exceptions.raise_not_implemented("S3 存储暂未实现") - - storage_service = LocalStorageService(policy) - if not await storage_service.file_exists(physical_file.storage_path): - http_exceptions.raise_not_found("物理文件不存在") - # 缓存所有需要的值(save 后对象属性会过期) file_path = physical_file.storage_path is_private = policy.is_private @@ -876,18 +994,36 @@ async def file_source_redirect( link.downloads += 1 await link.save(session) - # 公有存储:302 重定向到 base_url - if not is_private and base_url: - relative_path = storage_service.get_relative_path(file_path) - redirect_url = f"{base_url}/{relative_path}" - return RedirectResponse(url=redirect_url, status_code=302) + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + if not await storage_service.file_exists(file_path): + http_exceptions.raise_not_found("物理文件不存在") - # 私有存储或 base_url 为空:通过应用代理文件 - return FileResponse( - path=file_path, - filename=name, - media_type="application/octet-stream", - ) + # 公有存储:302 重定向到 base_url + if not is_private and base_url: + relative_path = storage_service.get_relative_path(file_path) + redirect_url = f"{base_url}/{relative_path}" + return RedirectResponse(url=redirect_url, status_code=302) + + # 私有存储或 base_url 为空:通过应用代理文件 + return FileResponse( + path=file_path, + filename=name, + media_type="application/octet-stream", + ) + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + # 公有存储且有 base_url:直接重定向到公开 URL + if not is_private and base_url: + redirect_url = f"{base_url.rstrip('/')}/{file_path}" + return RedirectResponse(url=redirect_url, status_code=302) + # 私有存储:生成预签名 URL 重定向 + presigned_url = s3_service.generate_presigned_url( + file_path, method='GET', expires_in=3600, filename=name, + ) + return RedirectResponse(url=presigned_url, status_code=302) + else: + http_exceptions.raise_internal_error("不支持的存储类型") @router.put( @@ -941,11 +1077,15 @@ async def file_content( if not policy: http_exceptions.raise_internal_error("存储策略不存在") - if policy.type != PolicyType.LOCAL: - http_exceptions.raise_not_implemented("S3 存储暂未实现") - - storage_service = LocalStorageService(policy) - raw_bytes = await storage_service.read_file(physical_file.storage_path) + # 读取文件内容 + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + raw_bytes = await storage_service.read_file(physical_file.storage_path) + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + raw_bytes = await s3_service.download_file(physical_file.storage_path) + else: + http_exceptions.raise_internal_error("不支持的存储类型") try: content = raw_bytes.decode('utf-8') @@ -1011,11 +1151,15 @@ async def patch_file_content( if not policy: http_exceptions.raise_internal_error("存储策略不存在") - if policy.type != PolicyType.LOCAL: - http_exceptions.raise_not_implemented("S3 存储暂未实现") - - storage_service = LocalStorageService(policy) - raw_bytes = await storage_service.read_file(storage_path) + # 读取文件内容 + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + raw_bytes = await storage_service.read_file(storage_path) + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + raw_bytes = await s3_service.download_file(storage_path) + else: + http_exceptions.raise_internal_error("不支持的存储类型") # 解码 + 规范化 original_text = raw_bytes.decode('utf-8') @@ -1049,7 +1193,10 @@ async def patch_file_content( _check_policy_size_limit(policy, len(new_bytes)) # 写入文件 - await storage_service.write_file(storage_path, new_bytes) + if policy.type == PolicyType.LOCAL: + await storage_service.write_file(storage_path, new_bytes) + elif policy.type == PolicyType.S3: + await s3_service.upload_file(storage_path, new_bytes) # 更新数据库 owner_id = file_obj.owner_id diff --git a/routers/api/v1/object/__init__.py b/routers/api/v1/object/__init__.py index 7e7fcc4..59f352a 100644 --- a/routers/api/v1/object/__init__.py +++ b/routers/api/v1/object/__init__.py @@ -8,13 +8,14 @@ from typing import Annotated from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from loguru import logger as l from middleware.auth import auth_required from middleware.dependencies import SessionDep from sqlmodels import ( CreateFileRequest, + Group, Object, ObjectCopyRequest, ObjectDeleteRequest, @@ -22,18 +23,27 @@ from sqlmodels import ( ObjectPropertyDetailResponse, ObjectPropertyResponse, ObjectRenameRequest, + ObjectSwitchPolicyRequest, ObjectType, PhysicalFile, Policy, PolicyType, + Task, + TaskProps, + TaskStatus, + TaskSummaryBase, + TaskType, User, ) from service.storage import ( LocalStorageService, adjust_user_storage, copy_object_recursive, + migrate_file_with_task, + migrate_directory_files, ) from service.storage.object import soft_delete_objects +from sqlmodels.database_connection import DatabaseManager from utils import http_exceptions object_router = APIRouter( @@ -575,3 +585,136 @@ async def router_object_property_detail( response.checksum_md5 = obj.file_metadata.checksum_md5 return response + + +@object_router.patch( + path='/{object_id}/policy', + summary='切换对象存储策略', +) +async def router_object_switch_policy( + session: SessionDep, + background_tasks: BackgroundTasks, + user: Annotated[User, Depends(auth_required)], + object_id: UUID, + request: ObjectSwitchPolicyRequest, +) -> TaskSummaryBase: + """ + 切换对象的存储策略 + + 文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。 + 目录:更新目录 policy_id(新文件使用新策略); + 若 is_migrate_existing=True,额外创建后台任务迁移所有已有文件。 + + 认证:JWT Bearer Token + + 错误处理: + - 404: 对象不存在 + - 403: 无权操作此对象 / 用户组无权使用目标策略 + - 400: 目标策略与当前相同 / 不能对根目录操作 + """ + user_id = user.id + + # 查找对象 + obj = await Object.get( + session, + (Object.id == object_id) & (Object.deleted_at == None) + ) + if not obj: + http_exceptions.raise_not_found("对象不存在") + if obj.owner_id != user_id: + http_exceptions.raise_forbidden("无权操作此对象") + if obj.is_banned: + http_exceptions.raise_banned() + + # 根目录不能直接切换策略(应通过子对象或子目录操作) + if obj.parent_id is None: + raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作") + + # 校验目标策略存在 + dest_policy = await Policy.get(session, Policy.id == request.policy_id) + if not dest_policy: + http_exceptions.raise_not_found("目标存储策略不存在") + + # 校验用户组权限 + group: Group = await user.awaitable_attrs.group + await session.refresh(group, ['policies']) + allowed_ids = {p.id for p in group.policies} + if request.policy_id not in allowed_ids: + http_exceptions.raise_forbidden("当前用户组无权使用该存储策略") + + # 不能切换到相同策略 + if obj.policy_id == request.policy_id: + raise HTTPException(status_code=400, detail="目标策略与当前策略相同") + + # 保存必要的属性,避免 save 后对象过期 + src_policy_id = obj.policy_id + obj_id = obj.id + obj_is_file = obj.type == ObjectType.FILE + dest_policy_id = request.policy_id + dest_policy_name = dest_policy.name + + # 创建任务记录 + task = Task( + type=TaskType.POLICY_MIGRATE, + status=TaskStatus.QUEUED, + user_id=user_id, + ) + task = await task.save(session) + task_id = task.id + + task_props = TaskProps( + task_id=task_id, + source_policy_id=src_policy_id, + dest_policy_id=dest_policy_id, + object_id=obj_id, + ) + await task_props.save(session) + + if obj_is_file: + # 文件:后台迁移 + async def _run_file_migration() -> None: + async with DatabaseManager.session() as bg_session: + bg_obj = await Object.get(bg_session, Object.id == obj_id) + bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id) + bg_task = await Task.get(bg_session, Task.id == task_id) + await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task) + + background_tasks.add_task(_run_file_migration) + else: + # 目录:先更新目录自身的 policy_id + obj = await Object.get(session, Object.id == obj_id) + obj.policy_id = dest_policy_id + await obj.save(session) + + if request.is_migrate_existing: + # 后台迁移所有已有文件 + async def _run_dir_migration() -> None: + async with DatabaseManager.session() as bg_session: + bg_folder = await Object.get(bg_session, Object.id == obj_id) + bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id) + bg_task = await Task.get(bg_session, Task.id == task_id) + await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task) + + background_tasks.add_task(_run_dir_migration) + else: + # 不迁移已有文件,直接完成任务 + task = await Task.get(session, Task.id == task_id) + task.status = TaskStatus.COMPLETED + task.progress = 100 + await task.save(session) + + # 重新获取 task 以读取最新状态 + task = await Task.get(session, Task.id == task_id) + + l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}") + + return TaskSummaryBase( + id=task.id, + type=task.type, + status=task.status, + progress=task.progress, + error=task.error, + user_id=task.user_id, + created_at=task.created_at, + updated_at=task.updated_at, + ) diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py index 2e829c0..9e722f8 100644 --- a/routers/api/v1/user/__init__.py +++ b/routers/api/v1/user/__init__.py @@ -247,11 +247,12 @@ async def router_user_register( ) await identity.save(session) - # 8. 创建用户根目录 - default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储") - if not default_policy: - logger.error("默认存储策略不存在") + # 8. 创建用户根目录(使用用户组关联的第一个存储策略) + await session.refresh(default_group, ['policies']) + if not default_group.policies: + logger.error("默认用户组未关联任何存储策略") http_exceptions.raise_internal_error() + default_policy = default_group.policies[0] await sqlmodels.Object( name="/", diff --git a/routers/api/v1/user/settings/__init__.py b/routers/api/v1/user/settings/__init__.py index 19a39c6..cd7f342 100644 --- a/routers/api/v1/user/settings/__init__.py +++ b/routers/api/v1/user/settings/__init__.py @@ -13,6 +13,7 @@ from sqlmodels import ( AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest, ChangePasswordRequest, AuthnDetailResponse, AuthnRenameRequest, + PolicySummary, ) from sqlmodels.color import ThemeColorsBase from sqlmodels.user_authn import UserAuthn @@ -31,16 +32,25 @@ user_settings_router.include_router(file_viewers_router) @user_settings_router.get( path='/policies', summary='获取用户可选存储策略', - description='Get user selectable storage policies.', ) -def router_user_settings_policies() -> sqlmodels.ResponseBase: +async def router_user_settings_policies( + session: SessionDep, + user: Annotated[sqlmodels.user.User, Depends(auth_required)], +) -> list[PolicySummary]: """ - Get user selectable storage policies. + 获取当前用户所在组可选的存储策略列表 - Returns: - dict: A dictionary containing available storage policies for the user. + 返回用户组关联的所有存储策略的摘要信息。 """ - http_exceptions.raise_not_implemented() + group = await user.awaitable_attrs.group + await session.refresh(group, ['policies']) + return [ + PolicySummary( + id=p.id, name=p.name, type=p.type, + server=p.server, max_size=p.max_size, is_private=p.is_private, + ) + for p in group.policies + ] @user_settings_router.get( diff --git a/service/storage/__init__.py b/service/storage/__init__.py index 0a0c3a5..de96449 100644 --- a/service/storage/__init__.py +++ b/service/storage/__init__.py @@ -3,6 +3,7 @@ 提供文件存储相关的服务,包括: - 本地存储服务 +- S3 存储服务 - 命名规则解析器 - 存储异常定义 """ @@ -11,6 +12,8 @@ from .exceptions import ( FileReadError, FileWriteError, InvalidPathError, + S3APIError, + S3MultipartUploadError, StorageException, StorageFileNotFoundError, UploadSessionExpiredError, @@ -25,4 +28,6 @@ from .object import ( permanently_delete_objects, restore_objects, soft_delete_objects, -) \ No newline at end of file +) +from .migrate import migrate_file_with_task, migrate_directory_files +from .s3_storage import S3StorageService \ No newline at end of file diff --git a/service/storage/exceptions.py b/service/storage/exceptions.py index ae1e4e3..0ba9251 100644 --- a/service/storage/exceptions.py +++ b/service/storage/exceptions.py @@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException): class InvalidPathError(StorageException): """无效的路径""" pass + + +class S3APIError(StorageException): + """S3 API 请求错误""" + pass + + +class S3MultipartUploadError(S3APIError): + """S3 分片上传错误""" + pass diff --git a/service/storage/migrate.py b/service/storage/migrate.py new file mode 100644 index 0000000..0cbab60 --- /dev/null +++ b/service/storage/migrate.py @@ -0,0 +1,291 @@ +""" +存储策略迁移服务 + +提供跨存储策略的文件迁移功能: +- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录 +- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id +""" +from uuid import UUID + +from loguru import logger as l +from sqlmodel.ext.asyncio.session import AsyncSession + +from sqlmodels.object import Object, ObjectType +from sqlmodels.physical_file import PhysicalFile +from sqlmodels.policy import Policy, PolicyType +from sqlmodels.task import Task, TaskStatus + +from .local_storage import LocalStorageService +from .s3_storage import S3StorageService + + +async def _get_storage_service( + policy: Policy, +) -> LocalStorageService | S3StorageService: + """ + 根据策略类型创建对应的存储服务实例 + + :param policy: 存储策略 + :return: 存储服务实例 + """ + if policy.type == PolicyType.LOCAL: + return LocalStorageService(policy) + elif policy.type == PolicyType.S3: + return await S3StorageService.from_policy(policy) + else: + raise ValueError(f"不支持的存储策略类型: {policy.type}") + + +async def _read_file_from_storage( + service: LocalStorageService | S3StorageService, + storage_path: str, +) -> bytes: + """ + 从存储服务读取文件内容 + + :param service: 存储服务实例 + :param storage_path: 文件存储路径 + :return: 文件二进制内容 + """ + if isinstance(service, LocalStorageService): + return await service.read_file(storage_path) + else: + return await service.download_file(storage_path) + + +async def _write_file_to_storage( + service: LocalStorageService | S3StorageService, + storage_path: str, + data: bytes, +) -> None: + """ + 将文件内容写入存储服务 + + :param service: 存储服务实例 + :param storage_path: 文件存储路径 + :param data: 文件二进制内容 + """ + if isinstance(service, LocalStorageService): + await service.write_file(storage_path, data) + else: + await service.upload_file(storage_path, data) + + +async def _delete_file_from_storage( + service: LocalStorageService | S3StorageService, + storage_path: str, +) -> None: + """ + 从存储服务删除文件 + + :param service: 存储服务实例 + :param storage_path: 文件存储路径 + """ + if isinstance(service, LocalStorageService): + await service.delete_file(storage_path) + else: + await service.delete_file(storage_path) + + +async def migrate_single_file( + session: AsyncSession, + obj: Object, + dest_policy: Policy, +) -> None: + """ + 将单个文件对象从当前存储策略迁移到目标策略 + + 流程: + 1. 获取源物理文件和存储服务 + 2. 读取源文件内容 + 3. 在目标存储中生成新路径并写入 + 4. 创建新的 PhysicalFile 记录 + 5. 更新 Object 的 policy_id 和 physical_file_id + 6. 旧 PhysicalFile 引用计数 -1,如为 0 则删除源物理文件 + + :param session: 数据库会话 + :param obj: 待迁移的文件对象(必须为文件类型) + :param dest_policy: 目标存储策略 + """ + if obj.type != ObjectType.FILE: + raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}") + + # 获取源策略和物理文件 + src_policy: Policy = await obj.awaitable_attrs.policy + old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file + + if not old_physical: + l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移") + return + + if src_policy.id == dest_policy.id: + l.debug(f"文件 {obj.id} 已在目标策略中,跳过") + return + + # 1. 从源存储读取文件 + src_service = await _get_storage_service(src_policy) + data = await _read_file_from_storage(src_service, old_physical.storage_path) + + # 2. 在目标存储生成新路径并写入 + dest_service = await _get_storage_service(dest_policy) + _dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path( + user_id=obj.owner_id, + original_filename=obj.name, + ) + await _write_file_to_storage(dest_service, new_storage_path, data) + + # 3. 创建新的 PhysicalFile + new_physical = PhysicalFile( + storage_path=new_storage_path, + size=old_physical.size, + checksum_md5=old_physical.checksum_md5, + policy_id=dest_policy.id, + reference_count=1, + ) + new_physical = await new_physical.save(session) + + # 4. 更新 Object + obj.policy_id = dest_policy.id + obj.physical_file_id = new_physical.id + await obj.save(session) + + # 5. 旧 PhysicalFile 引用计数 -1 + old_physical.decrement_reference() + if old_physical.can_be_deleted: + # 删除源存储中的物理文件 + try: + await _delete_file_from_storage(src_service, old_physical.storage_path) + except Exception as e: + l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}") + await PhysicalFile.delete(session, old_physical) + else: + await old_physical.save(session) + + l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name} → {dest_policy.name}") + + +async def migrate_file_with_task( + session: AsyncSession, + obj: Object, + dest_policy: Policy, + task: Task, +) -> None: + """ + 迁移单个文件并更新任务状态 + + :param session: 数据库会话 + :param obj: 待迁移的文件对象 + :param dest_policy: 目标存储策略 + :param task: 关联的任务记录 + """ + try: + task.status = TaskStatus.RUNNING + task.progress = 0 + task = await task.save(session) + + await migrate_single_file(session, obj, dest_policy) + + task.status = TaskStatus.COMPLETED + task.progress = 100 + await task.save(session) + except Exception as e: + l.error(f"文件迁移任务失败: {obj.id}: {e}") + task.status = TaskStatus.ERROR + task.error = str(e)[:500] + await task.save(session) + + +async def migrate_directory_files( + session: AsyncSession, + folder: Object, + dest_policy: Policy, + task: Task, +) -> None: + """ + 迁移目录下所有文件到目标存储策略 + + 递归遍历目录树,将所有文件迁移到目标策略。 + 子目录的 policy_id 同步更新。 + 任务进度按文件数比例更新。 + + :param session: 数据库会话 + :param folder: 目录对象 + :param dest_policy: 目标存储策略 + :param task: 关联的任务记录 + """ + try: + task.status = TaskStatus.RUNNING + task.progress = 0 + task = await task.save(session) + + # 收集所有需要迁移的文件 + files_to_migrate: list[Object] = [] + folders_to_update: list[Object] = [] + await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update) + + total = len(files_to_migrate) + migrated = 0 + errors: list[str] = [] + + for file_obj in files_to_migrate: + try: + await migrate_single_file(session, file_obj, dest_policy) + migrated += 1 + except Exception as e: + error_msg = f"{file_obj.name}: {e}" + l.error(f"迁移文件失败: {error_msg}") + errors.append(error_msg) + + # 更新进度 + if total > 0: + task.progress = min(99, int(migrated / total * 100)) + task = await task.save(session) + + # 更新所有子目录的 policy_id + for sub_folder in folders_to_update: + sub_folder.policy_id = dest_policy.id + await sub_folder.save(session) + + # 完成任务 + if errors: + task.status = TaskStatus.ERROR + task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5]) + else: + task.status = TaskStatus.COMPLETED + + task.progress = 100 + await task.save(session) + + l.info( + f"目录迁移完成: {folder.name} ({folder.id}), " + f"成功 {migrated}/{total}, 错误 {len(errors)}" + ) + except Exception as e: + l.error(f"目录迁移任务失败: {folder.id}: {e}") + task.status = TaskStatus.ERROR + task.error = str(e)[:500] + await task.save(session) + + +async def _collect_objects_recursive( + session: AsyncSession, + folder: Object, + files: list[Object], + folders: list[Object], +) -> None: + """ + 递归收集目录下所有文件和子目录 + + :param session: 数据库会话 + :param folder: 当前目录 + :param files: 文件列表(输出) + :param folders: 子目录列表(输出) + """ + children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id) + + for child in children: + if child.type == ObjectType.FILE: + files.append(child) + elif child.type == ObjectType.FOLDER: + folders.append(child) + await _collect_objects_recursive(session, child, files, folders) diff --git a/service/storage/object.py b/service/storage/object.py index 057954e..23f3803 100644 --- a/service/storage/object.py +++ b/service/storage/object.py @@ -6,7 +6,8 @@ from sqlalchemy import update as sql_update from sqlalchemy.sql.functions import func from middleware.dependencies import SessionDep -from service.storage import LocalStorageService +from .local_storage import LocalStorageService +from .s3_storage import S3StorageService from sqlmodels import ( Object, PhysicalFile, @@ -271,10 +272,14 @@ async def permanently_delete_objects( if physical_file.can_be_deleted: # 物理删除文件 policy = await Policy.get(session, Policy.id == physical_file.policy_id) - if policy and policy.type == PolicyType.LOCAL: + if policy: try: - storage_service = LocalStorageService(policy) - await storage_service.delete_file(physical_file.storage_path) + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + await storage_service.delete_file(physical_file.storage_path) + elif policy.type == PolicyType.S3: + s3_service = await S3StorageService.from_policy(policy) + await s3_service.delete_file(physical_file.storage_path) l.debug(f"物理文件已删除: {obj_name}") except Exception as e: l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}") @@ -374,10 +379,19 @@ async def delete_object_recursive( if physical_file.can_be_deleted: # 物理删除文件 policy = await Policy.get(session, Policy.id == physical_file.policy_id) - if policy and policy.type == PolicyType.LOCAL: + if policy: try: - storage_service = LocalStorageService(policy) - await storage_service.delete_file(physical_file.storage_path) + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + await storage_service.delete_file(physical_file.storage_path) + elif policy.type == PolicyType.S3: + options = await policy.awaitable_attrs.options + s3_service = S3StorageService( + policy, + region=options.s3_region if options else 'us-east-1', + is_path_style=options.s3_path_style if options else False, + ) + await s3_service.delete_file(physical_file.storage_path) l.debug(f"物理文件已删除: {obj_name}") except Exception as e: l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}") diff --git a/service/storage/s3_storage.py b/service/storage/s3_storage.py new file mode 100644 index 0000000..35b3f8f --- /dev/null +++ b/service/storage/s3_storage.py @@ -0,0 +1,709 @@ +""" +S3 存储服务 + +使用 AWS Signature V4 签名的异步 S3 API 客户端。 +从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。 + +移植自 foxline-pro-backend-server 项目的 S3APIClient, +适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。 +""" +import hashlib +import hmac +import xml.etree.ElementTree as ET +from collections.abc import AsyncIterator +from datetime import datetime, timezone +from typing import ClassVar, Literal +from urllib.parse import quote, urlencode +from uuid import UUID + +import aiohttp +from yarl import URL +from loguru import logger as l + +from sqlmodels.policy import Policy +from .exceptions import S3APIError, S3MultipartUploadError +from .naming_rule import NamingContext, NamingRuleParser + + +def _sign(key: bytes, msg: str) -> bytes: + """HMAC-SHA256 签名""" + return hmac.new(key, msg.encode(), hashlib.sha256).digest() + + +_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/" + + +class S3StorageService: + """ + S3 存储服务 + + 使用 AWS Signature V4 签名的异步 S3 API 客户端。 + 从 Policy 配置中读取 S3 连接信息。 + + 使用示例:: + + service = S3StorageService(policy, region='us-east-1') + await service.upload_file('path/to/file.txt', b'content') + data = await service.download_file('path/to/file.txt') + """ + + _http_session: ClassVar[aiohttp.ClientSession | None] = None + + def __init__( + self, + policy: Policy, + region: str = 'us-east-1', + is_path_style: bool = False, + ): + """ + :param policy: 存储策略(server=endpoint_url, bucket_name, access_key, secret_key) + :param region: S3 区域 + :param is_path_style: 是否使用路径风格 URL + """ + if not policy.server: + raise S3APIError("S3 策略必须指定 server (endpoint URL)") + if not policy.bucket_name: + raise S3APIError("S3 策略必须指定 bucket_name") + if not policy.access_key: + raise S3APIError("S3 策略必须指定 access_key") + if not policy.secret_key: + raise S3APIError("S3 策略必须指定 secret_key") + + self._policy = policy + self._endpoint_url = policy.server.rstrip("/") + self._bucket_name = policy.bucket_name + self._access_key = policy.access_key + self._secret_key = policy.secret_key + self._region = region + self._is_path_style = is_path_style + self._base_url = policy.base_url + + # 从 endpoint_url 提取 host + self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0] + + # ==================== 工厂方法 ==================== + + @classmethod + async def from_policy(cls, policy: Policy) -> 'S3StorageService': + """ + 根据 Policy 异步创建 S3StorageService(自动加载 options) + + :param policy: 存储策略 + :return: S3StorageService 实例 + """ + options = await policy.awaitable_attrs.options + region = options.s3_region if options else 'us-east-1' + is_path_style = options.s3_path_style if options else False + return cls(policy, region=region, is_path_style=is_path_style) + + # ==================== HTTP Session 管理 ==================== + + @classmethod + async def initialize_session(cls) -> None: + """初始化全局 aiohttp ClientSession""" + if cls._http_session is None or cls._http_session.closed: + cls._http_session = aiohttp.ClientSession() + l.info("S3StorageService HTTP session 已初始化") + + @classmethod + async def close_session(cls) -> None: + """关闭全局 aiohttp ClientSession""" + if cls._http_session and not cls._http_session.closed: + await cls._http_session.close() + cls._http_session = None + l.info("S3StorageService HTTP session 已关闭") + + @classmethod + def _get_session(cls) -> aiohttp.ClientSession: + """获取 HTTP session""" + if cls._http_session is None or cls._http_session.closed: + # 懒初始化,以防 initialize_session 未被调用 + cls._http_session = aiohttp.ClientSession() + return cls._http_session + + # ==================== AWS Signature V4 签名 ==================== + + def _get_signature_key(self, date_stamp: str) -> bytes: + """生成 AWS Signature V4 签名密钥""" + k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp) + k_region = _sign(k_date, self._region) + k_service = _sign(k_region, "s3") + return _sign(k_service, "aws4_request") + + def _create_authorization_header( + self, + method: str, + uri: str, + query_string: str, + headers: dict[str, str], + payload_hash: str, + amz_date: str, + date_stamp: str, + ) -> str: + """创建 AWS Signature V4 授权头""" + signed_headers = ";".join(sorted(k.lower() for k in headers.keys())) + canonical_headers = "".join( + f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items()) + ) + canonical_request = ( + f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n" + f"{signed_headers}\n{payload_hash}" + ) + + algorithm = "AWS4-HMAC-SHA256" + credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request" + string_to_sign = ( + f"{algorithm}\n{amz_date}\n{credential_scope}\n" + f"{hashlib.sha256(canonical_request.encode()).hexdigest()}" + ) + + signing_key = self._get_signature_key(date_stamp) + signature = hmac.new( + signing_key, string_to_sign.encode(), hashlib.sha256 + ).hexdigest() + + return ( + f"{algorithm} Credential={self._access_key}/{credential_scope}, " + f"SignedHeaders={signed_headers}, Signature={signature}" + ) + + def _build_headers( + self, + method: str, + uri: str, + query_string: str = "", + payload: bytes = b"", + content_type: str | None = None, + extra_headers: dict[str, str] | None = None, + payload_hash: str | None = None, + host: str | None = None, + ) -> dict[str, str]: + """ + 构建包含 AWS V4 签名的完整请求头 + + :param method: HTTP 方法 + :param uri: 请求 URI + :param query_string: 查询字符串 + :param payload: 请求体字节(用于计算哈希) + :param content_type: Content-Type + :param extra_headers: 额外请求头 + :param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD" + :param host: Host 头(默认使用 self._host) + """ + now_utc = datetime.now(timezone.utc) + amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ") + date_stamp = now_utc.strftime("%Y%m%d") + + if payload_hash is None: + payload_hash = hashlib.sha256(payload).hexdigest() + + effective_host = host or self._host + + headers: dict[str, str] = { + "Host": effective_host, + "X-Amz-Date": amz_date, + "X-Amz-Content-Sha256": payload_hash, + } + if content_type: + headers["Content-Type"] = content_type + if extra_headers: + headers.update(extra_headers) + + authorization = self._create_authorization_header( + method, uri, query_string, headers, payload_hash, amz_date, date_stamp + ) + headers["Authorization"] = authorization + return headers + + # ==================== 内部请求方法 ==================== + + def _build_uri(self, key: str | None = None) -> str: + """ + 构建请求 URI + + 按 AWS S3 Signature V4 规范对路径进行 URI 编码(S3 仅需一次)。 + 斜杠作为路径分隔符保留不编码。 + """ + if self._is_path_style: + if key: + return f"/{self._bucket_name}/{quote(key, safe='/')}" + return f"/{self._bucket_name}" + else: + if key: + return f"/{quote(key, safe='/')}" + return "/" + + def _build_url(self, uri: str, query_string: str = "") -> str: + """构建完整请求 URL""" + if self._is_path_style: + base = self._endpoint_url + else: + # 虚拟主机风格:bucket.endpoint + protocol = "https://" if self._endpoint_url.startswith("https://") else "http://" + base = f"{protocol}{self._bucket_name}.{self._host}" + + url = f"{base}{uri}" + if query_string: + url = f"{url}?{query_string}" + return url + + def _get_effective_host(self) -> str: + """获取实际请求的 Host 头""" + if self._is_path_style: + return self._host + return f"{self._bucket_name}.{self._host}" + + async def _request( + self, + method: str, + key: str | None = None, + query_params: dict[str, str] | None = None, + payload: bytes = b"", + content_type: str | None = None, + extra_headers: dict[str, str] | None = None, + ) -> aiohttp.ClientResponse: + """发送签名请求""" + uri = self._build_uri(key) + query_string = urlencode(sorted(query_params.items())) if query_params else "" + effective_host = self._get_effective_host() + + headers = self._build_headers( + method, uri, query_string, payload, content_type, + extra_headers, host=effective_host, + ) + + url = self._build_url(uri, query_string) + + try: + response = await self._get_session().request( + method, URL(url, encoded=True), + headers=headers, data=payload if payload else None, + ) + return response + except Exception as e: + raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e + + async def _request_streaming( + self, + method: str, + key: str, + data_stream: AsyncIterator[bytes], + content_length: int, + content_type: str | None = None, + ) -> aiohttp.ClientResponse: + """ + 发送流式签名请求(大文件上传) + + 使用 UNSIGNED-PAYLOAD 作为 payload hash。 + """ + uri = self._build_uri(key) + effective_host = self._get_effective_host() + + headers = self._build_headers( + method, + uri, + query_string="", + content_type=content_type, + extra_headers={"Content-Length": str(content_length)}, + payload_hash="UNSIGNED-PAYLOAD", + host=effective_host, + ) + + url = self._build_url(uri) + + try: + response = await self._get_session().request( + method, URL(url, encoded=True), + headers=headers, data=data_stream, + ) + return response + except Exception as e: + raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e + + # ==================== 文件操作 ==================== + + async def upload_file( + self, + key: str, + data: bytes, + content_type: str = 'application/octet-stream', + ) -> None: + """ + 上传文件 + + :param key: S3 对象键 + :param data: 文件内容 + :param content_type: MIME 类型 + """ + async with await self._request( + "PUT", key=key, payload=data, content_type=content_type, + ) as response: + if response.status not in (200, 201): + body = await response.text() + raise S3APIError( + f"上传失败: {self._bucket_name}/{key}, " + f"状态: {response.status}, {body}" + ) + l.debug(f"S3 上传成功: {self._bucket_name}/{key}") + + async def upload_file_streaming( + self, + key: str, + data_stream: AsyncIterator[bytes], + content_length: int, + content_type: str | None = None, + ) -> None: + """ + 流式上传文件(大文件,避免全部加载到内存) + + :param key: S3 对象键 + :param data_stream: 异步字节流迭代器 + :param content_length: 数据总长度(必须准确) + :param content_type: MIME 类型 + """ + async with await self._request_streaming( + "PUT", key=key, data_stream=data_stream, + content_length=content_length, content_type=content_type, + ) as response: + if response.status not in (200, 201): + body = await response.text() + raise S3APIError( + f"流式上传失败: {self._bucket_name}/{key}, " + f"状态: {response.status}, {body}" + ) + l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}") + + async def download_file(self, key: str) -> bytes: + """ + 下载文件 + + :param key: S3 对象键 + :return: 文件内容 + """ + async with await self._request("GET", key=key) as response: + if response.status != 200: + body = await response.text() + raise S3APIError( + f"下载失败: {self._bucket_name}/{key}, " + f"状态: {response.status}, {body}" + ) + data = await response.read() + l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}") + return data + + async def delete_file(self, key: str) -> None: + """ + 删除文件 + + :param key: S3 对象键 + """ + async with await self._request("DELETE", key=key) as response: + if response.status in (200, 204): + l.debug(f"S3 删除成功: {self._bucket_name}/{key}") + else: + body = await response.text() + raise S3APIError( + f"删除失败: {self._bucket_name}/{key}, " + f"状态: {response.status}, {body}" + ) + + async def file_exists(self, key: str) -> bool: + """ + 检查文件是否存在 + + :param key: S3 对象键 + :return: 是否存在 + """ + async with await self._request("HEAD", key=key) as response: + if response.status == 200: + return True + elif response.status == 404: + return False + else: + raise S3APIError( + f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}" + ) + + async def get_file_size(self, key: str) -> int: + """ + 获取文件大小 + + :param key: S3 对象键 + :return: 文件大小(字节) + """ + async with await self._request("HEAD", key=key) as response: + if response.status != 200: + raise S3APIError( + f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}" + ) + return int(response.headers.get("Content-Length", 0)) + + # ==================== Multipart Upload ==================== + + async def create_multipart_upload( + self, + key: str, + content_type: str = 'application/octet-stream', + ) -> str: + """ + 创建分片上传任务 + + :param key: S3 对象键 + :param content_type: MIME 类型 + :return: Upload ID + """ + async with await self._request( + "POST", + key=key, + query_params={"uploads": ""}, + content_type=content_type, + ) as response: + if response.status != 200: + body = await response.text() + raise S3MultipartUploadError( + f"创建分片上传失败: {self._bucket_name}/{key}, " + f"状态: {response.status}, {body}" + ) + + body = await response.text() + root = ET.fromstring(body) + + # 查找 UploadId 元素(支持命名空间) + upload_id_elem = root.find("UploadId") + if upload_id_elem is None: + upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId") + if upload_id_elem is None or not upload_id_elem.text: + raise S3MultipartUploadError( + f"创建分片上传响应中未找到 UploadId: {body}" + ) + + upload_id = upload_id_elem.text + l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}") + return upload_id + + async def upload_part( + self, + key: str, + upload_id: str, + part_number: int, + data: bytes, + ) -> str: + """ + 上传单个分片 + + :param key: S3 对象键 + :param upload_id: 分片上传 ID + :param part_number: 分片编号(从 1 开始) + :param data: 分片数据 + :return: ETag + """ + async with await self._request( + "PUT", + key=key, + query_params={ + "partNumber": str(part_number), + "uploadId": upload_id, + }, + payload=data, + ) as response: + if response.status != 200: + body = await response.text() + raise S3MultipartUploadError( + f"上传分片失败: {self._bucket_name}/{key}, " + f"part={part_number}, 状态: {response.status}, {body}" + ) + + etag = response.headers.get("ETag", "").strip('"') + l.debug( + f"S3 分片上传成功: {self._bucket_name}/{key}, " + f"part={part_number}, etag={etag}" + ) + return etag + + async def complete_multipart_upload( + self, + key: str, + upload_id: str, + parts: list[tuple[int, str]], + ) -> None: + """ + 完成分片上传 + + :param key: S3 对象键 + :param upload_id: 分片上传 ID + :param parts: 分片列表 [(part_number, etag)] + """ + # 按 part_number 排序 + parts_sorted = sorted(parts, key=lambda p: p[0]) + + # 构建 CompleteMultipartUpload XML + xml_parts = ''.join( + f"{pn}{etag}" + for pn, etag in parts_sorted + ) + payload = f'{xml_parts}' + payload_bytes = payload.encode('utf-8') + + async with await self._request( + "POST", + key=key, + query_params={"uploadId": upload_id}, + payload=payload_bytes, + content_type="application/xml", + ) as response: + if response.status != 200: + body = await response.text() + raise S3MultipartUploadError( + f"完成分片上传失败: {self._bucket_name}/{key}, " + f"状态: {response.status}, {body}" + ) + l.info( + f"S3 分片上传已完成: {self._bucket_name}/{key}, " + f"共 {len(parts)} 个分片" + ) + + async def abort_multipart_upload(self, key: str, upload_id: str) -> None: + """ + 取消分片上传 + + :param key: S3 对象键 + :param upload_id: 分片上传 ID + """ + async with await self._request( + "DELETE", + key=key, + query_params={"uploadId": upload_id}, + ) as response: + if response.status in (200, 204): + l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}") + else: + body = await response.text() + l.warning( + f"取消分片上传失败: {self._bucket_name}/{key}, " + f"状态: {response.status}, {body}" + ) + + # ==================== 预签名 URL ==================== + + def generate_presigned_url( + self, + key: str, + method: Literal['GET', 'PUT'] = 'GET', + expires_in: int = 3600, + filename: str | None = None, + ) -> str: + """ + 生成 S3 预签名 URL(AWS Signature V4 Query String) + + :param key: S3 对象键 + :param method: HTTP 方法(GET 下载,PUT 上传) + :param expires_in: URL 有效期(秒) + :param filename: 文件名(GET 请求时设置 Content-Disposition) + :return: 预签名 URL + """ + current_time = datetime.now(timezone.utc) + amz_date = current_time.strftime("%Y%m%dT%H%M%SZ") + date_stamp = current_time.strftime("%Y%m%d") + + credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request" + credential = f"{self._access_key}/{credential_scope}" + + uri = self._build_uri(key) + effective_host = self._get_effective_host() + + query_params: dict[str, str] = { + 'X-Amz-Algorithm': 'AWS4-HMAC-SHA256', + 'X-Amz-Credential': credential, + 'X-Amz-Date': amz_date, + 'X-Amz-Expires': str(expires_in), + 'X-Amz-SignedHeaders': 'host', + } + + # GET 请求时添加 Content-Disposition + if method == "GET" and filename: + encoded_filename = quote(filename, safe='') + query_params['response-content-disposition'] = ( + f"attachment; filename*=UTF-8''{encoded_filename}" + ) + + canonical_query_string = "&".join( + f"{quote(k, safe='')}={quote(v, safe='')}" + for k, v in sorted(query_params.items()) + ) + + canonical_headers = f"host:{effective_host}\n" + signed_headers = "host" + payload_hash = "UNSIGNED-PAYLOAD" + + canonical_request = ( + f"{method}\n" + f"{uri}\n" + f"{canonical_query_string}\n" + f"{canonical_headers}\n" + f"{signed_headers}\n" + f"{payload_hash}" + ) + + algorithm = "AWS4-HMAC-SHA256" + string_to_sign = ( + f"{algorithm}\n" + f"{amz_date}\n" + f"{credential_scope}\n" + f"{hashlib.sha256(canonical_request.encode()).hexdigest()}" + ) + + signing_key = self._get_signature_key(date_stamp) + signature = hmac.new( + signing_key, string_to_sign.encode(), hashlib.sha256 + ).hexdigest() + + base_url = self._build_url(uri) + return ( + f"{base_url}?" + f"{canonical_query_string}&" + f"X-Amz-Signature={signature}" + ) + + # ==================== 路径生成 ==================== + + async def generate_file_path( + self, + user_id: UUID, + original_filename: str, + ) -> tuple[str, str, str]: + """ + 根据命名规则生成 S3 文件存储路径 + + 与 LocalStorageService.generate_file_path 接口一致。 + + :param user_id: 用户UUID + :param original_filename: 原始文件名 + :return: (相对目录路径, 存储文件名, 完整存储路径) + """ + context = NamingContext( + user_id=user_id, + original_filename=original_filename, + ) + + # 解析目录规则 + dir_path = "" + if self._policy.dir_name_rule: + dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context) + + # 解析文件名规则 + if self._policy.auto_rename and self._policy.file_name_rule: + storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context) + # 确保有扩展名 + if '.' in original_filename and '.' not in storage_name: + ext = original_filename.rsplit('.', 1)[1] + storage_name = f"{storage_name}.{ext}" + else: + storage_name = original_filename + + # S3 不需要创建目录,直接拼接路径 + if dir_path: + storage_path = f"{dir_path}/{storage_name}" + else: + storage_path = storage_name + + return dir_path, storage_name, storage_path diff --git a/sqlmodels/README.md b/sqlmodels/README.md index 6855961..5550d73 100644 --- a/sqlmodels/README.md +++ b/sqlmodels/README.md @@ -954,18 +954,11 @@ class PolicyType(StrEnum): S3 = "s3" # S3 兼容存储 ``` -### StorageType +### PolicyType ```python -class StorageType(StrEnum): +class PolicyType(StrEnum): LOCAL = "local" # 本地存储 - QINIU = "qiniu" # 七牛云 - TENCENT = "tencent" # 腾讯云 - ALIYUN = "aliyun" # 阿里云 - ONEDRIVE = "onedrive" # OneDrive - GOOGLE_DRIVE = "google_drive" # Google Drive - DROPBOX = "dropbox" # Dropbox - WEBDAV = "webdav" # WebDAV - REMOTE = "remote" # 远程存储 + S3 = "s3" # S3 兼容存储 ``` ### UserStatus diff --git a/sqlmodels/__init__.py b/sqlmodels/__init__.py index 97898f5..9feec90 100644 --- a/sqlmodels/__init__.py +++ b/sqlmodels/__init__.py @@ -82,6 +82,7 @@ from .object import ( ObjectPropertyResponse, ObjectRenameRequest, ObjectResponse, + ObjectSwitchPolicyRequest, ObjectType, PolicyResponse, UploadChunkResponse, @@ -100,7 +101,10 @@ from .object import ( from .physical_file import PhysicalFile, PhysicalFileBase from .uri import DiskNextURI, FileSystemNamespace from .order import Order, OrderStatus, OrderType -from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary +from .policy import ( + Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase, + PolicyType, PolicySummary, PolicyUpdateRequest, +) from .redeem import Redeem, RedeemType from .report import Report, ReportReason from .setting import ( @@ -116,7 +120,7 @@ from .share import ( from .source_link import SourceLink from .storage_pack import StoragePack from .tag import Tag, TagType -from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary +from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase from .webdav import ( WebDAV, WebDAVBase, WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse, diff --git a/sqlmodels/group.py b/sqlmodels/group.py index 8bea70d..3b7f9a6 100644 --- a/sqlmodels/group.py +++ b/sqlmodels/group.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from uuid import UUID +from sqlalchemy import BigInteger from sqlmodel import Field, Relationship, text from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin @@ -260,7 +261,7 @@ class Group(GroupBase, UUIDTableBaseMixin): name: str = Field(max_length=255, unique=True) """用户组名""" - max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}) """最大存储空间(字节)""" share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) diff --git a/sqlmodels/object.py b/sqlmodels/object.py index e642e7c..130479d 100644 --- a/sqlmodels/object.py +++ b/sqlmodels/object.py @@ -9,6 +9,8 @@ from sqlmodel import Field, Relationship, CheckConstraint, Index, text from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin +from .policy import PolicyType + if TYPE_CHECKING: from .user import User from .policy import Policy @@ -23,18 +25,6 @@ class ObjectType(StrEnum): FILE = "file" FOLDER = "folder" -class StorageType(StrEnum): - """存储类型枚举""" - LOCAL = "local" - QINIU = "qiniu" - TENCENT = "tencent" - ALIYUN = "aliyun" - ONEDRIVE = "onedrive" - GOOGLE_DRIVE = "google_drive" - DROPBOX = "dropbox" - WEBDAV = "webdav" - REMOTE = "remote" - class FileMetadataBase(SQLModelBase): """文件元数据基础模型""" @@ -156,7 +146,7 @@ class PolicyResponse(SQLModelBase): name: str """策略名称""" - type: StorageType + type: PolicyType """存储类型""" max_size: int = Field(ge=0, default=0, sa_type=BigInteger) @@ -624,6 +614,12 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin): storage_path: str | None = Field(default=None, max_length=512) """文件存储路径""" + s3_upload_id: str | None = Field(default=None, max_length=256) + """S3 Multipart Upload ID(仅 S3 策略使用)""" + + s3_part_etags: str | None = None + """S3 已上传分片的 ETag 列表,JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)""" + expires_at: datetime """会话过期时间""" @@ -732,6 +728,16 @@ class CreateFileRequest(SQLModelBase): """存储策略UUID,不指定则使用父目录的策略""" +class ObjectSwitchPolicyRequest(SQLModelBase): + """切换对象存储策略请求""" + + policy_id: UUID + """目标存储策略UUID""" + + is_migrate_existing: bool = False + """(仅目录)是否迁移已有文件,默认 false 只影响新文件""" + + # ==================== 对象操作相关 DTO ==================== class ObjectCopyRequest(SQLModelBase): diff --git a/sqlmodels/policy.py b/sqlmodels/policy.py index c65953f..4d841a3 100644 --- a/sqlmodels/policy.py +++ b/sqlmodels/policy.py @@ -102,6 +102,94 @@ class PolicySummary(SQLModelBase): """是否私有""" +class PolicyCreateRequest(PolicyBase): + """创建存储策略请求 DTO,包含 PolicyOptions 扁平字段""" + + # PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致) + token: str | None = None + """访问令牌""" + + file_type: str | None = None + """允许的文件类型""" + + mimetype: str | None = Field(default=None, max_length=127) + """MIME类型""" + + od_redirect: str | None = Field(default=None, max_length=255) + """OneDrive重定向地址""" + + chunk_size: int = Field(default=52428800, ge=1) + """分片上传大小(字节),默认50MB""" + + s3_path_style: bool = False + """是否使用S3路径风格""" + + s3_region: str = Field(default='us-east-1', max_length=64) + """S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用""" + + +class PolicyUpdateRequest(SQLModelBase): + """更新存储策略请求 DTO(所有字段可选)""" + + name: str | None = Field(default=None, max_length=255) + """策略名称""" + + server: str | None = Field(default=None, max_length=255) + """服务器地址""" + + bucket_name: str | None = Field(default=None, max_length=255) + """存储桶名称""" + + is_private: bool | None = None + """是否为私有空间""" + + base_url: str | None = Field(default=None, max_length=255) + """访问文件的基础URL""" + + access_key: str | None = None + """Access Key""" + + secret_key: str | None = None + """Secret Key""" + + max_size: int | None = Field(default=None, ge=0) + """允许上传的最大文件尺寸(字节)""" + + auto_rename: bool | None = None + """是否自动重命名""" + + dir_name_rule: str | None = Field(default=None, max_length=255) + """目录命名规则""" + + file_name_rule: str | None = Field(default=None, max_length=255) + """文件命名规则""" + + is_origin_link_enable: bool | None = None + """是否开启源链接访问""" + + # PolicyOptions 字段 + token: str | None = None + """访问令牌""" + + file_type: str | None = None + """允许的文件类型""" + + mimetype: str | None = Field(default=None, max_length=127) + """MIME类型""" + + od_redirect: str | None = Field(default=None, max_length=255) + """OneDrive重定向地址""" + + chunk_size: int | None = Field(default=None, ge=1) + """分片上传大小(字节)""" + + s3_path_style: bool | None = None + """是否使用S3路径风格""" + + s3_region: str | None = Field(default=None, max_length=64) + """S3 区域""" + + # ==================== 数据库模型 ==================== @@ -126,6 +214,9 @@ class PolicyOptionsBase(SQLModelBase): s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) """是否使用S3路径风格""" + s3_region: str = Field(default='us-east-1', max_length=64, sa_column_kwargs={"server_default": "'us-east-1'"}) + """S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用""" + class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin): """存储策略选项模型(与Policy一对一关联)""" diff --git a/sqlmodels/task.py b/sqlmodels/task.py index 980c3f8..52f64f6 100644 --- a/sqlmodels/task.py +++ b/sqlmodels/task.py @@ -26,8 +26,8 @@ class TaskStatus(StrEnum): class TaskType(StrEnum): """任务类型枚举""" - # [TODO] 补充具体任务类型 - pass + POLICY_MIGRATE = "policy_migrate" + """存储策略迁移""" # ==================== DTO 模型 ==================== @@ -39,7 +39,7 @@ class TaskSummaryBase(SQLModelBase): id: int """任务ID""" - type: int + type: TaskType """任务类型""" status: TaskStatus @@ -91,7 +91,14 @@ class TaskPropsBase(SQLModelBase): file_ids: str | None = None """文件ID列表(逗号分隔)""" - # [TODO] 根据业务需求补充更多字段 + source_policy_id: UUID | None = None + """源存储策略UUID""" + + dest_policy_id: UUID | None = None + """目标存储策略UUID""" + + object_id: UUID | None = None + """关联的对象UUID""" class TaskProps(TaskPropsBase, TableBaseMixin): @@ -99,7 +106,7 @@ class TaskProps(TaskPropsBase, TableBaseMixin): task_id: int = Field( foreign_key="task.id", - primary_key=True, + unique=True, ondelete="CASCADE" ) """关联的任务ID""" @@ -121,8 +128,8 @@ class Task(SQLModelBase, TableBaseMixin): status: TaskStatus = Field(default=TaskStatus.QUEUED) """任务状态""" - type: int = Field(default=0) - """任务类型 [TODO] 待定义枚举""" + type: TaskType + """任务类型""" progress: int = Field(default=0, ge=0, le=100) """任务进度(0-100)""" diff --git a/sqlmodels/user.py b/sqlmodels/user.py index 198f7f2..9693a93 100644 --- a/sqlmodels/user.py +++ b/sqlmodels/user.py @@ -4,7 +4,7 @@ from typing import Literal, TYPE_CHECKING, TypeVar from uuid import UUID from pydantic import BaseModel -from sqlalchemy import BinaryExpression, ClauseElement, and_ +from sqlalchemy import BigInteger, BinaryExpression, ClauseElement, and_ from sqlmodel import Field, Relationship from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.main import RelationshipInfo @@ -473,7 +473,7 @@ class User(UserBase, UUIDTableBaseMixin): status: UserStatus = UserStatus.ACTIVE """用户状态""" - storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0) + storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}, ge=0) """已用存储空间(字节)""" avatar: str = Field(default="default", max_length=255)