From 446d219aca446b45bccfbf0d048519aa775f2762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Tue, 23 Dec 2025 12:20:06 +0800 Subject: [PATCH] Refactor import statements for ResponseBase in API routers - Updated import statements in the following files to import ResponseBase directly from models instead of models.response: - routers/api/v1/share/__init__.py - routers/api/v1/site/__init__.py - routers/api/v1/slave/__init__.py - routers/api/v1/tag/__init__.py - routers/api/v1/user/__init__.py - routers/api/v1/vas/__init__.py - routers/api/v1/webdav/__init__.py Enhance user registration and related endpoints in user router - Changed return type annotations from models.response.ResponseBase to models.ResponseBase in multiple functions. - Updated return statements to reflect the new import structure. - Improved documentation for clarity. Add PhysicalFile model and storage service implementation - Introduced PhysicalFile model to represent actual files on disk with reference counting logic. - Created storage service module with local storage implementation, including file operations and error handling. - Defined exceptions for storage operations to improve error handling. - Implemented naming rule parser for generating file and directory names based on templates. Update dependency management in uv.lock - Added aiofiles version 25.1.0 to the project dependencies. --- .gitignore | 3 + models/__init__.py | 24 +- models/migration.py | 7 +- models/object.py | 257 +++++++- models/physical_file.py | 90 +++ models/response.py | 14 - pyproject.toml | 1 + routers/api/v1/__init__.py | 3 +- routers/api/v1/admin/__init__.py | 121 +++- routers/api/v1/callback/__init__.py | 2 +- routers/api/v1/directory/__init__.py | 8 +- routers/api/v1/download/__init__.py | 2 +- routers/api/v1/file/__init__.py | 846 ++++++++++++++++++--------- routers/api/v1/object/__init__.py | 496 ++++++++++++++-- routers/api/v1/share/__init__.py | 2 +- routers/api/v1/site/__init__.py | 2 +- routers/api/v1/slave/__init__.py | 2 +- routers/api/v1/tag/__init__.py | 2 +- routers/api/v1/user/__init__.py | 60 +- routers/api/v1/vas/__init__.py | 2 +- routers/api/v1/webdav/__init__.py | 2 +- service/storage/__init__.py | 20 + service/storage/exceptions.py | 45 ++ service/storage/local_storage.py | 388 ++++++++++++ service/storage/naming_rule.py | 144 +++++ uv.lock | 11 + 26 files changed, 2155 insertions(+), 399 deletions(-) create mode 100644 models/physical_file.py delete mode 100644 models/response.py create mode 100644 service/storage/__init__.py create mode 100644 service/storage/exceptions.py create mode 100644 service/storage/local_storage.py create mode 100644 service/storage/naming_rule.py diff --git a/.gitignore b/.gitignore index da7e787..a9c401d 100644 --- a/.gitignore +++ b/.gitignore @@ -62,3 +62,6 @@ node_modules/ *.bak *.tmp *.temp + +# 文件 +data/ \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index 3579b57..18ee124 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,5 +1,3 @@ -from . import response - from .user import ( LoginRequest, RegisterRequest, @@ -31,18 +29,29 @@ from .node import ( ) from .group import Group, GroupBase, GroupOptions, GroupOptionsBase, GroupResponse from .object import ( + CreateFileRequest, + CreateUploadSessionRequest, DirectoryCreateRequest, DirectoryResponse, FileMetadata, FileMetadataBase, Object, ObjectBase, + ObjectCopyRequest, ObjectDeleteRequest, ObjectMoveRequest, + ObjectPropertyDetailResponse, + ObjectPropertyResponse, + ObjectRenameRequest, ObjectResponse, ObjectType, PolicyResponse, + UploadChunkResponse, + UploadSession, + UploadSessionBase, + UploadSessionResponse, ) +from .physical_file import PhysicalFile, PhysicalFileBase from .order import Order, OrderStatus, OrderType from .policy import Policy, PolicyOptions, PolicyOptionsBase, PolicyType from .redeem import Redeem, RedeemType @@ -56,3 +65,14 @@ from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType from .webdav import WebDAV from .database import engine, get_session + + +import uuid +from sqlmodel import Field +from .base import SQLModelBase + +class ResponseBase(SQLModelBase): + """通用响应模型""" + + instance_id: uuid.UUID = Field(default_factory=uuid.uuid4) + """实例ID,用于标识请求的唯一性""" \ No newline at end of file diff --git a/models/migration.py b/models/migration.py index bd08020..8852b29 100644 --- a/models/migration.py +++ b/models/migration.py @@ -283,6 +283,7 @@ async def init_default_user() -> None: async def init_default_policy() -> None: from .policy import Policy, PolicyType from .database import get_session + from service.storage import LocalStorageService log.info('初始化默认存储策略...') @@ -302,6 +303,10 @@ async def init_default_policy() -> None: file_name_rule="{randomkey16}_{originname}", ) - await local_policy.save(session) + local_policy = await local_policy.save(session) + + # 创建物理存储目录 + storage_service = LocalStorageService(local_policy) + await storage_service.ensure_base_directory() log.info('已创建默认本地存储策略,存储目录:./data') \ No newline at end of file diff --git a/models/object.py b/models/object.py index 0e0c2c5..d7ce648 100644 --- a/models/object.py +++ b/models/object.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from .policy import Policy from .source_link import SourceLink from .share import Share + from .physical_file import PhysicalFile class ObjectType(StrEnum): @@ -112,9 +113,6 @@ class ObjectResponse(ObjectBase): id: UUID """对象UUID""" - path: str - """对象路径""" - thumb: bool = False """是否有缩略图""" @@ -222,15 +220,20 @@ class Object(ObjectBase, UUIDTableBaseMixin): # ==================== 文件专属字段 ==================== - source_name: str | None = Field(default=None, max_length=255) - """源文件名(仅文件有效)""" - size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) """文件大小(字节),目录为 0""" upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True) """分块上传会话ID(仅文件有效)""" + physical_file_id: UUID | None = Field( + default=None, + foreign_key="physicalfile.id", + index=True, + ondelete="SET NULL" + ) + """关联的物理文件UUID(仅文件有效,目录为NULL)""" + # ==================== 外键 ==================== parent_id: UUID | None = Field( @@ -295,8 +298,22 @@ class Object(ObjectBase, UUIDTableBaseMixin): ) """分享列表""" + physical_file: "PhysicalFile" = Relationship(back_populates="objects") + """关联的物理文件(仅文件有效)""" + # ==================== 业务属性 ==================== + @property + def source_name(self) -> str | None: + """ + 源文件存储路径(向后兼容属性) + + :return: 物理文件存储路径,如果没有关联物理文件则返回 None + """ + if self.physical_file: + return self.physical_file.storage_path + return None + @property def is_file(self) -> bool: """是否为文件""" @@ -397,3 +414,231 @@ class Object(ObjectBase, UUIDTableBaseMixin): (cls.owner_id == user_id) & (cls.parent_id == parent_id), fetch_mode="all" ) + + +# ==================== 上传会话模型 ==================== + +class UploadSessionBase(SQLModelBase): + """上传会话基础字段""" + + file_name: str = Field(max_length=255) + """原始文件名""" + + file_size: int = Field(ge=0) + """文件总大小(字节)""" + + chunk_size: int = Field(ge=1) + """分片大小(字节)""" + + total_chunks: int = Field(ge=1) + """总分片数""" + + +class UploadSession(UploadSessionBase, UUIDTableBaseMixin): + """ + 上传会话模型 + + 用于管理分片上传的会话状态。 + 会话有效期为24小时,过期后自动失效。 + """ + + # 会话状态 + uploaded_chunks: int = 0 + """已上传分片数""" + + uploaded_size: int = 0 + """已上传大小(字节)""" + + storage_path: str | None = Field(default=None, max_length=512) + """文件存储路径""" + + expires_at: datetime + """会话过期时间""" + + # 外键 + owner_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE") + """上传者用户UUID""" + + parent_id: UUID = Field(foreign_key="object.id", index=True, ondelete="CASCADE") + """目标父目录UUID""" + + policy_id: UUID = Field(foreign_key="policy.id", index=True, ondelete="RESTRICT") + """存储策略UUID""" + + # 关系 + owner: "User" = Relationship() + """上传者""" + + parent: "Object" = Relationship( + sa_relationship_kwargs={"foreign_keys": "[UploadSession.parent_id]"} + ) + """目标父目录""" + + policy: "Policy" = Relationship() + """存储策略""" + + @property + def is_expired(self) -> bool: + """会话是否已过期""" + return datetime.now() > self.expires_at + + @property + def is_complete(self) -> bool: + """上传是否完成""" + return self.uploaded_chunks >= self.total_chunks + + +# ==================== 上传会话相关 DTO ==================== + +class CreateUploadSessionRequest(SQLModelBase): + """创建上传会话请求 DTO""" + + file_name: str = Field(max_length=255) + """文件名""" + + file_size: int = Field(ge=0) + """文件大小(字节)""" + + parent_id: UUID + """父目录UUID""" + + policy_id: UUID | None = None + """存储策略UUID,不指定则使用父目录的策略""" + + +class UploadSessionResponse(SQLModelBase): + """上传会话响应 DTO""" + + id: UUID + """会话UUID""" + + file_name: str + """原始文件名""" + + file_size: int + """文件总大小(字节)""" + + chunk_size: int + """分片大小(字节)""" + + total_chunks: int + """总分片数""" + + uploaded_chunks: int + """已上传分片数""" + + expires_at: datetime + """过期时间""" + + +class UploadChunkResponse(SQLModelBase): + """上传分片响应 DTO""" + + uploaded_chunks: int + """已上传分片数""" + + total_chunks: int + """总分片数""" + + is_complete: bool + """是否上传完成""" + + object_id: UUID | None = None + """完成后的文件对象UUID,未完成时为None""" + + +class CreateFileRequest(SQLModelBase): + """创建空白文件请求 DTO""" + + name: str = Field(max_length=255) + """文件名""" + + parent_id: UUID + """父目录UUID""" + + policy_id: UUID | None = None + """存储策略UUID,不指定则使用父目录的策略""" + + +# ==================== 对象操作相关 DTO ==================== + +class ObjectCopyRequest(SQLModelBase): + """复制对象请求 DTO""" + + src_ids: list[UUID] + """源对象UUID列表""" + + dst_id: UUID + """目标文件夹UUID""" + + +class ObjectRenameRequest(SQLModelBase): + """重命名对象请求 DTO""" + + id: UUID + """对象UUID""" + + new_name: str = Field(max_length=255) + """新名称""" + + +class ObjectPropertyResponse(SQLModelBase): + """对象基本属性响应 DTO""" + + id: UUID + """对象UUID""" + + name: str + """对象名称""" + + type: ObjectType + """对象类型""" + + size: int + """文件大小(字节)""" + + created_at: datetime + """创建时间""" + + updated_at: datetime + """修改时间""" + + parent_id: UUID | None + """父目录UUID""" + + +class ObjectPropertyDetailResponse(ObjectPropertyResponse): + """对象详细属性响应 DTO(继承基本属性)""" + + # 元数据信息 + mime_type: str | None = None + """MIME类型""" + + width: int | None = None + """图片宽度(像素)""" + + height: int | None = None + """图片高度(像素)""" + + duration: float | None = None + """音视频时长(秒)""" + + checksum_md5: str | None = None + """MD5校验和""" + + # 分享统计 + share_count: int = 0 + """分享次数""" + + total_views: int = 0 + """总浏览次数""" + + total_downloads: int = 0 + """总下载次数""" + + # 存储信息 + policy_name: str | None = None + """存储策略名称""" + + reference_count: int = 1 + """物理文件引用计数(仅文件有效)""" diff --git a/models/physical_file.py b/models/physical_file.py new file mode 100644 index 0000000..49fd5e5 --- /dev/null +++ b/models/physical_file.py @@ -0,0 +1,90 @@ +""" +物理文件模型 + +表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile, +实现文件共享而不复制物理文件。 + +引用计数逻辑: +- 每个引用此文件的 Object 都会增加引用计数 +- 当 Object 被删除时,减少引用计数 +- 只有当引用计数为 0 时,才物理删除文件 +""" +from typing import TYPE_CHECKING +from uuid import UUID + +from sqlmodel import Field, Relationship, Index + +from .base import SQLModelBase +from .mixin import UUIDTableBaseMixin + +if TYPE_CHECKING: + from .object import Object + from .policy import Policy + + +class PhysicalFileBase(SQLModelBase): + """物理文件基础模型""" + + storage_path: str = Field(max_length=512) + """物理存储路径(相对于存储策略根目录)""" + + size: int = 0 + """文件大小(字节)""" + + checksum_md5: str | None = Field(default=None, max_length=32) + """MD5校验和(用于文件去重和完整性校验)""" + + +class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin): + """ + 物理文件模型 + + 表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile, + 实现文件共享而不复制物理文件。 + """ + + __table_args__ = ( + Index("ix_physical_file_policy_path", "policy_id", "storage_path"), + Index("ix_physical_file_checksum", "checksum_md5"), + ) + + policy_id: UUID = Field( + foreign_key="policy.id", + index=True, + ondelete="RESTRICT", + ) + """存储策略UUID""" + + reference_count: int = Field(default=1, ge=0) + """引用计数(有多少个 Object 引用此物理文件)""" + + # 关系 + policy: "Policy" = Relationship() + """存储策略""" + + objects: list["Object"] = Relationship(back_populates="physical_file") + """引用此物理文件的所有逻辑对象""" + + def increment_reference(self) -> int: + """ + 增加引用计数 + + :return: 更新后的引用计数 + """ + self.reference_count += 1 + return self.reference_count + + def decrement_reference(self) -> int: + """ + 减少引用计数 + + :return: 更新后的引用计数 + """ + if self.reference_count > 0: + self.reference_count -= 1 + return self.reference_count + + @property + def can_be_deleted(self) -> bool: + """是否可以物理删除(引用计数为0)""" + return self.reference_count == 0 diff --git a/models/response.py b/models/response.py deleted file mode 100644 index 11bc235..0000000 --- a/models/response.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -通用响应模型定义 -""" -import uuid - -from sqlmodel import Field - -from .base import SQLModelBase - -class ResponseBase(SQLModelBase): - """通用响应模型""" - - instance_id: uuid.UUID = Field(default_factory=uuid.uuid4) - """实例ID,用于标识请求的唯一性""" diff --git a/pyproject.toml b/pyproject.toml index 19d246f..f2625da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.13" dependencies = [ + "aiofiles>=25.1.0", "aiohttp>=3.13.2", "aiosqlite>=0.21.0", "argon2-cffi>=25.1.0", diff --git a/routers/api/v1/__init__.py b/routers/api/v1/__init__.py index 38ef1b6..9e04141 100644 --- a/routers/api/v1/__init__.py +++ b/routers/api/v1/__init__.py @@ -12,7 +12,7 @@ from .admin import admin_vas_router from .callback import callback_router from .directory import directory_router from .download import download_router -from .file import file_router +from .file import file_router, file_upload_router from .object import object_router from .share import share_router from .site import site_router @@ -36,6 +36,7 @@ router.include_router(callback_router) router.include_router(directory_router) router.include_router(download_router) router.include_router(file_router) +router.include_router(file_upload_router) router.include_router(object_router) router.include_router(share_router) router.include_router(site_router) diff --git a/routers/api/v1/admin/__init__.py b/routers/api/v1/admin/__init__.py index d942b91..0ee727e 100644 --- a/routers/api/v1/admin/__init__.py +++ b/routers/api/v1/admin/__init__.py @@ -1,11 +1,57 @@ -from fastapi import APIRouter, Depends -from loguru import logger +from fastapi import APIRouter, Depends, HTTPException +from loguru import logger as l +from sqlmodel import Field from middleware.auth import AdminRequired from middleware.dependencies import SessionDep -from models import User +from models import Policy, PolicyOptions, PolicyType, User +from models.base import SQLModelBase +from models import ResponseBase from models.user import UserPublic -from models.response import ResponseBase +from service.storage import DirectoryCreationError, LocalStorageService + + +class PolicyCreateRequest(SQLModelBase): + """创建存储策略请求 DTO""" + + name: str = Field(max_length=255) + """策略名称""" + + type: PolicyType + """策略类型""" + + server: str | None = Field(default=None, max_length=255) + """服务器地址/本地路径(本地存储必填)""" + + bucket_name: str | None = Field(default=None, max_length=255) + """存储桶名称(S3必填)""" + + is_private: bool = True + """是否为私有空间""" + + base_url: str | None = Field(default=None, max_length=255) + """访问文件的基础URL""" + + access_key: str | None = None + """Access Key""" + + secret_key: str | None = None + """Secret Key""" + + max_size: int = Field(default=0, ge=0) + """允许上传的最大文件尺寸(字节),0表示不限制""" + + auto_rename: bool = False + """是否自动重命名""" + + dir_name_rule: str | None = Field(default=None, max_length=255) + """目录命名规则""" + + file_name_rule: str | None = Field(default=None, max_length=255) + """文件命名规则""" + + is_origin_link_enable: bool = False + """是否开启源链接访问""" # 管理员根目录 /api/admin admin_router = APIRouter( @@ -464,11 +510,72 @@ def router_policy_test_slave() -> ResponseBase: @admin_policy_router.post( path='/', summary='创建存储策略', - description='', + description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。', dependencies=[Depends(AdminRequired)] ) -def router_policy_add_policy() -> ResponseBase: - pass +async def router_policy_add_policy( + session: SessionDep, + request: PolicyCreateRequest, +) -> ResponseBase: + """ + 创建存储策略端点 + + 功能: + - 创建新的存储策略配置 + - 对于 LOCAL 类型,自动创建物理目录 + + 认证: + - 需要管理员权限 + + :param session: 数据库会话 + :param request: 创建请求 + :return: 创建结果 + """ + # 验证本地存储策略必须指定 server 路径 + if request.type == PolicyType.LOCAL: + if not request.server: + raise HTTPException(status_code=400, detail="本地存储策略必须指定 server 路径") + + # 检查策略名称是否已存在 + existing = await Policy.get(session, Policy.name == request.name) + if existing: + raise HTTPException(status_code=409, detail="策略名称已存在") + + # 创建策略对象 + policy = Policy( + name=request.name, + type=request.type, + server=request.server, + bucket_name=request.bucket_name, + is_private=request.is_private, + base_url=request.base_url, + access_key=request.access_key, + secret_key=request.secret_key, + max_size=request.max_size, + auto_rename=request.auto_rename, + dir_name_rule=request.dir_name_rule, + file_name_rule=request.file_name_rule, + is_origin_link_enable=request.is_origin_link_enable, + ) + + # 对于本地存储策略,创建物理目录 + if policy.type == PolicyType.LOCAL: + try: + storage_service = LocalStorageService(policy) + await storage_service.ensure_base_directory() + l.info(f"已为本地存储策略 '{policy.name}' 创建目录: {policy.server}") + except DirectoryCreationError as e: + raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}") + + # 保存到数据库 + policy = await policy.save(session) + + return ResponseBase(data={ + "id": str(policy.id), + "name": policy.name, + "type": policy.type.value, + "server": policy.server, + }) @admin_policy_router.post( path='/cors', diff --git a/routers/api/v1/callback/__init__.py b/routers/api/v1/callback/__init__.py index 6f15df5..004a161 100644 --- a/routers/api/v1/callback/__init__.py +++ b/routers/api/v1/callback/__init__.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, Query from fastapi.responses import PlainTextResponse, RedirectResponse from middleware.auth import SignRequired -from models.response import ResponseBase +from models import ResponseBase import service.oauth callback_router = APIRouter( diff --git a/routers/api/v1/directory/__init__.py b/routers/api/v1/directory/__init__.py index 8da3894..db34710 100644 --- a/routers/api/v1/directory/__init__.py +++ b/routers/api/v1/directory/__init__.py @@ -12,7 +12,7 @@ from models import ( ObjectType, PolicyResponse, User, - response, + ResponseBase, ) directory_router = APIRouter( @@ -63,7 +63,6 @@ async def router_directory_get( ObjectResponse( id=child.id, name=child.name, - path=f"/{child.name}", # TODO: 完整路径 thumb=False, size=child.size, type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE, @@ -97,7 +96,7 @@ async def router_directory_create( session: SessionDep, user: Annotated[User, Depends(AuthRequired)], request: DirectoryCreateRequest -) -> response.ResponseBase: +) -> ResponseBase: """ 创建目录 @@ -111,6 +110,7 @@ async def router_directory_create( if not name: raise HTTPException(status_code=400, detail="目录名称不能为空") + # [TODO] 进一步验证名称合法性 if "/" in name or "\\" in name: raise HTTPException(status_code=400, detail="目录名称不能包含斜杠") @@ -146,7 +146,7 @@ async def router_directory_create( new_folder_name = new_folder.name await new_folder.save(session) - return response.ResponseBase( + return ResponseBase( data={ "id": new_folder_id, "name": new_folder_name, diff --git a/routers/api/v1/download/__init__.py b/routers/api/v1/download/__init__.py index 5c39621..e70778d 100644 --- a/routers/api/v1/download/__init__.py +++ b/routers/api/v1/download/__init__.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends from middleware.auth import SignRequired -from models.response import ResponseBase +from models import ResponseBase download_router = APIRouter( prefix="/download", diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py index df16db3..1ba8705 100644 --- a/routers/api/v1/file/__init__.py +++ b/routers/api/v1/file/__init__.py @@ -1,7 +1,37 @@ -from fastapi import APIRouter, Depends, UploadFile +""" +文件操作路由 + +提供文件上传、下载、创建等核心功能。 + +路由前缀: +- /file - 文件操作 +- /file/upload - 上传相关操作 +""" +from datetime import datetime, timedelta +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi.responses import FileResponse -from middleware.auth import SignRequired -from models.response import ResponseBase +from loguru import logger as l + +from middleware.auth import AuthRequired, SignRequired +from middleware.dependencies import SessionDep +from models import ( + CreateFileRequest, + CreateUploadSessionRequest, + Object, + ObjectType, + PhysicalFile, + Policy, + PolicyType, + UploadChunkResponse, + UploadSession, + UploadSessionResponse, + User, +) +from models import ResponseBase +from service.storage import LocalStorageService, StorageFileNotFoundError file_router = APIRouter( prefix="/file", @@ -13,370 +43,614 @@ file_upload_router = APIRouter( tags=["file"] ) -@file_router.get( - path='/get/{id}/{name}', - summary='文件外链(直接输出文件数据)', - description='Get file external link endpoint.', -) -def router_file_get(id: str, name: str) -> FileResponse: - """ - Get file external link endpoint. - - Args: - id (str): The ID of the file. - name (str): The name of the file. - - Returns: - FileResponse: A response containing the file data. - """ - pass -@file_router.get( - path='/source/{id}/{name}', - summary='文件外链(301跳转)', - description='Get file external link with 301 redirect endpoint.', -) -def router_file_source(id: str, name: str) -> ResponseBase: - """ - Get file external link with 301 redirect endpoint. - - Args: - id (str): The ID of the file. - name (str): The name of the file. - - Returns: - ResponseBase: A model containing the response data for the file with a redirect. - """ - pass - -@file_upload_router.get( - path='/download/{id}', - summary='下载文件', - description='Download file endpoint.', -) -def router_file_download(id: str) -> ResponseBase: - """ - Download file endpoint. - - Args: - id (str): The ID of the file to download. - - Returns: - ResponseBase: A model containing the response data for the file download. - """ - pass - -@file_upload_router.get( - path='/archive/{sessionID}/archive.zip', - summary='打包并下载文件', - description='Archive and download files endpoint.', -) -def router_file_archive_download(sessionID: str) -> ResponseBase: - """ - Archive and download files endpoint. - - Args: - sessionID (str): The session ID for the archive. - - Returns: - ResponseBase: A model containing the response data for the archived files download. - """ - pass - -@file_upload_router.post( - path='/{sessionID}/{index}', - summary='文件上传', - description='File upload endpoint.', -) -def router_file_upload(sessionID: str, index: int, file: UploadFile) -> ResponseBase: - """ - File upload endpoint. - - Args: - sessionID (str): The session ID for the upload. - index (int): The index of the file being uploaded. - - Returns: - ResponseBase: A model containing the response data. - """ - pass +# ==================== 上传会话管理 ==================== @file_upload_router.put( path='/', summary='创建上传会话', - description='Create an upload session endpoint.', - dependencies=[Depends(SignRequired)], + description='创建文件上传会话,返回会话ID用于后续分片上传。', ) -def router_file_upload_session() -> ResponseBase: +async def create_upload_session( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + request: CreateUploadSessionRequest, +) -> UploadSessionResponse: """ - Create an upload session endpoint. - - Returns: - ResponseBase: A model containing the response data for the upload session. + 创建上传会话端点 + + 流程: + 1. 验证父目录存在且属于当前用户 + 2. 确定存储策略(使用指定的或继承父目录的) + 3. 验证文件大小限制 + 4. 创建上传会话并生成存储路径 + 5. 返回会话信息 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param request: 创建请求 + :return: 上传会话信息 """ - pass + # 验证文件名 + if not request.file_name or '/' in request.file_name or '\\' in request.file_name: + raise HTTPException(status_code=400, detail="无效的文件名") + + # 验证父目录 + parent = await Object.get(session, Object.id == request.parent_id) + if not parent or parent.owner_id != user.id: + raise HTTPException(status_code=404, detail="父目录不存在") + + if not parent.is_folder: + raise HTTPException(status_code=400, detail="父对象不是目录") + + # 确定存储策略 + policy_id = request.policy_id or parent.policy_id + policy = await Policy.get(session, Policy.id == policy_id) + if not policy: + raise HTTPException(status_code=404, detail="存储策略不存在") + + # 验证文件大小限制 + if policy.max_size > 0 and request.file_size > policy.max_size: + raise HTTPException( + status_code=400, + detail=f"文件大小超过限制 ({policy.max_size} bytes)" + ) + + # 检查是否已存在同名文件 + existing = await Object.get( + session, + (Object.owner_id == user.id) & + (Object.parent_id == parent.id) & + (Object.name == request.file_name) + ) + if existing: + raise HTTPException(status_code=409, detail="同名文件已存在") + + # 计算分片信息 + options = await policy.awaitable_attrs.options + chunk_size = options.chunk_size if options else 52428800 # 默认 50MB + total_chunks = max(1, (request.file_size + chunk_size - 1) // chunk_size) if request.file_size > 0 else 1 + + # 生成存储路径 + storage_path: str | None = None + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + dir_path, storage_name, full_path = await storage_service.generate_file_path( + user_id=user.id, + original_filename=request.file_name, + ) + storage_path = full_path + else: + # S3 后续实现 + raise HTTPException(status_code=501, detail="S3 存储暂未实现") + + # 创建上传会话 + upload_session = UploadSession( + file_name=request.file_name, + file_size=request.file_size, + chunk_size=chunk_size, + total_chunks=total_chunks, + storage_path=storage_path, + expires_at=datetime.now() + timedelta(hours=24), # 24小时过期 + owner_id=user.id, + parent_id=request.parent_id, + policy_id=policy_id, + ) + upload_session = await upload_session.save(session) + + l.info(f"创建上传会话: {upload_session.id}, 文件: {request.file_name}, 大小: {request.file_size}") + + return UploadSessionResponse( + id=upload_session.id, + file_name=upload_session.file_name, + file_size=upload_session.file_size, + chunk_size=upload_session.chunk_size, + total_chunks=upload_session.total_chunks, + uploaded_chunks=0, + expires_at=upload_session.expires_at, + ) + + +@file_upload_router.post( + path='/{session_id}/{chunk_index}', + summary='上传文件分片', + description='上传指定分片,分片索引从0开始。', +) +async def upload_chunk( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + session_id: UUID, + chunk_index: int, + file: UploadFile = File(...), +) -> UploadChunkResponse: + """ + 上传文件分片端点 + + 流程: + 1. 验证上传会话 + 2. 写入分片数据 + 3. 更新会话进度 + 4. 如果所有分片上传完成,创建 Object 记录 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param session_id: 上传会话UUID + :param chunk_index: 分片索引(从0开始) + :param file: 上传的文件分片 + :return: 上传进度信息 + """ + # 获取上传会话 + upload_session = await UploadSession.get(session, UploadSession.id == session_id) + if not upload_session or upload_session.owner_id != user.id: + raise HTTPException(status_code=404, detail="上传会话不存在") + + if upload_session.is_expired: + raise HTTPException(status_code=400, detail="上传会话已过期") + + # 存储 user.id,避免后续 save() 导致 user 过期后无法访问 + user_id = user.id + + if chunk_index < 0 or chunk_index >= upload_session.total_chunks: + raise HTTPException(status_code=400, detail="无效的分片索引") + + # 获取策略 + policy = await Policy.get(session, Policy.id == upload_session.policy_id) + if not policy: + raise HTTPException(status_code=500, detail="存储策略不存在") + + # 读取分片内容 + content = await file.read() + + # 写入分片 + if policy.type == PolicyType.LOCAL: + if not upload_session.storage_path: + raise HTTPException(status_code=500, detail="存储路径丢失") + + storage_service = LocalStorageService(policy) + offset = chunk_index * upload_session.chunk_size + await storage_service.write_file_chunk( + upload_session.storage_path, + content, + offset, + ) + else: + raise HTTPException(status_code=501, detail="S3 存储暂未实现") + + # 更新会话进度 + upload_session.uploaded_chunks += 1 + upload_session.uploaded_size += len(content) + upload_session = await upload_session.save(session) + + # 检查是否完成 + is_complete = upload_session.is_complete + file_object_id: UUID | None = None + + if is_complete: + # 创建 PhysicalFile 记录 + physical_file = PhysicalFile( + storage_path=upload_session.storage_path, + size=upload_session.uploaded_size, + policy_id=upload_session.policy_id, + reference_count=1, + ) + physical_file = await physical_file.save(session) + + # 创建 Object 记录 + file_object = Object( + name=upload_session.file_name, + type=ObjectType.FILE, + size=upload_session.uploaded_size, + physical_file_id=physical_file.id, + upload_session_id=str(upload_session.id), + parent_id=upload_session.parent_id, + owner_id=user_id, + policy_id=upload_session.policy_id, + ) + file_object = await file_object.save(session) + file_object_id = file_object.id + + # 删除上传会话 + await UploadSession.delete(session, upload_session) + + l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}") + + return UploadChunkResponse( + uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks, + total_chunks=upload_session.total_chunks, + is_complete=is_complete, + object_id=file_object_id, + ) + @file_upload_router.delete( - path='/{sessionID}', + path='/{session_id}', summary='删除上传会话', - description='Delete an upload session endpoint.', - dependencies=[Depends(SignRequired)] + description='取消上传并删除会话及已上传的临时文件。', ) -def router_file_upload_session_delete(sessionID: str) -> ResponseBase: +async def delete_upload_session( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + session_id: UUID, +) -> ResponseBase: """ - Delete an upload session endpoint. - - Args: - sessionID (str): The session ID to delete. - - Returns: - ResponseBase: A model containing the response data for the deletion. + 删除上传会话端点 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param session_id: 上传会话UUID + :return: 删除结果 """ - pass + upload_session = await UploadSession.get(session, UploadSession.id == session_id) + if not upload_session or upload_session.owner_id != user.id: + raise HTTPException(status_code=404, detail="上传会话不存在") + + # 删除临时文件 + policy = await Policy.get(session, Policy.id == upload_session.policy_id) + if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: + storage_service = LocalStorageService(policy) + await storage_service.delete_file(upload_session.storage_path) + + # 删除会话记录 + await UploadSession.delete(session, upload_session) + + l.info(f"删除上传会话: {session_id}") + + return ResponseBase(data={"deleted": True}) + @file_upload_router.delete( path='/', summary='清除所有上传会话', - description='Clear all upload sessions endpoint.', - dependencies=[Depends(SignRequired)] + description='清除当前用户的所有上传会话。', ) -def router_file_upload_session_clear() -> ResponseBase: +async def clear_upload_sessions( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], +) -> ResponseBase: """ - Clear all upload sessions endpoint. - - Returns: - ResponseBase: A model containing the response data for clearing all sessions. - """ - pass + 清除所有上传会话端点 -@file_router.put( - path='/update/{id}', - summary='更新文件', - description='Update file information endpoint.', - dependencies=[Depends(SignRequired)] + :param session: 数据库会话 + :param user: 当前登录用户 + :return: 清除结果 + """ + # 获取所有会话 + sessions = await UploadSession.get( + session, + UploadSession.owner_id == user.id, + fetch_mode="all" + ) + + deleted_count = 0 + for upload_session in sessions: + # 删除临时文件 + policy = await Policy.get(session, Policy.id == upload_session.policy_id) + if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: + storage_service = LocalStorageService(policy) + await storage_service.delete_file(upload_session.storage_path) + + await UploadSession.delete(session, upload_session) + deleted_count += 1 + + l.info(f"清除用户 {user.id} 的所有上传会话,共 {deleted_count} 个") + + return ResponseBase(data={"deleted": deleted_count}) + + +# ==================== 文件下载 ==================== + +@file_upload_router.get( + path='/download/{file_id}', + summary='下载文件', + description='下载指定文件。', ) -def router_file_update(id: str) -> ResponseBase: +async def download_file( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + file_id: UUID, +) -> FileResponse: """ - Update file information endpoint. - - Args: - id (str): The ID of the file to update. - - Returns: - ResponseBase: A model containing the response data for the file update. + 下载文件端点 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param file_id: 文件UUID + :return: 文件响应 """ - pass + file_obj = await Object.get(session, Object.id == file_id) + if not file_obj or file_obj.owner_id != user.id: + raise HTTPException(status_code=404, detail="文件不存在") + + if not file_obj.is_file: + raise HTTPException(status_code=400, detail="对象不是文件") + + if not file_obj.source_name: + raise HTTPException(status_code=500, detail="文件存储路径丢失") + + # 获取策略 + policy = await Policy.get(session, Policy.id == file_obj.policy_id) + if not policy: + raise HTTPException(status_code=500, detail="存储策略不存在") + + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + if not await storage_service.file_exists(file_obj.source_name): + raise HTTPException(status_code=404, detail="物理文件不存在") + + return FileResponse( + path=file_obj.source_name, + filename=file_obj.name, + media_type="application/octet-stream", + ) + else: + raise HTTPException(status_code=501, detail="S3 存储暂未实现") + + +# ==================== 创建空白文件 ==================== @file_router.post( path='/create', summary='创建空白文件', - description='Create a blank file endpoint.', - dependencies=[Depends(SignRequired)] + description='在指定目录下创建空白文件。', ) -def router_file_create() -> ResponseBase: +async def create_empty_file( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + request: CreateFileRequest, +) -> ResponseBase: """ - Create a blank file endpoint. - - Returns: - ResponseBase: A model containing the response data for the file creation. + 创建空白文件端点 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param request: 创建请求 + :return: 创建结果 """ - pass + # 存储 user.id,避免后续 save() 导致 user 过期后无法访问 + user_id = user.id + + # 验证文件名 + if not request.name or '/' in request.name or '\\' in request.name: + raise HTTPException(status_code=400, detail="无效的文件名") + + # 验证父目录 + parent = await Object.get(session, Object.id == request.parent_id) + if not parent or parent.owner_id != user_id: + raise HTTPException(status_code=404, detail="父目录不存在") + + if not parent.is_folder: + raise HTTPException(status_code=400, detail="父对象不是目录") + + # 检查是否已存在同名文件 + existing = await Object.get( + session, + (Object.owner_id == user_id) & + (Object.parent_id == parent.id) & + (Object.name == request.name) + ) + if existing: + raise HTTPException(status_code=409, detail="同名文件已存在") + + # 确定存储策略 + policy_id = request.policy_id or parent.policy_id + policy = await Policy.get(session, Policy.id == policy_id) + if not policy: + raise HTTPException(status_code=404, detail="存储策略不存在") + + # 生成存储路径并创建空文件 + storage_path: str | None = None + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + dir_path, storage_name, full_path = await storage_service.generate_file_path( + user_id=user_id, + original_filename=request.name, + ) + await storage_service.create_empty_file(full_path) + storage_path = full_path + else: + raise HTTPException(status_code=501, detail="S3 存储暂未实现") + + # 创建 PhysicalFile 记录 + physical_file = PhysicalFile( + storage_path=storage_path, + size=0, + policy_id=policy_id, + reference_count=1, + ) + physical_file = await physical_file.save(session) + + # 创建 Object 记录 + file_object = Object( + name=request.name, + type=ObjectType.FILE, + size=0, + physical_file_id=physical_file.id, + parent_id=request.parent_id, + owner_id=user_id, + policy_id=policy_id, + ) + file_object = await file_object.save(session) + + l.info(f"创建空白文件: {file_object.name}, id={file_object.id}") + + return ResponseBase(data={ + "id": str(file_object.id), + "name": file_object.name, + "size": file_object.size, + }) + + +# ==================== 文件外链(保留原有端点结构) ==================== + +@file_router.get( + path='/get/{id}/{name}', + summary='文件外链(直接输出文件数据)', + description='通过外链直接获取文件内容。', +) +async def router_file_get( + session: SessionDep, + id: str, + name: str, +) -> FileResponse: + """ + 文件外链端点(直接输出) + + TODO: 实现签名验证和权限控制 + """ + raise HTTPException(status_code=501, detail="外链功能暂未实现") + + +@file_router.get( + path='/source/{id}/{name}', + summary='文件外链(301跳转)', + description='通过外链获取文件重定向地址。', +) +async def router_file_source_redirect(id: str, name: str) -> ResponseBase: + """ + 文件外链端点(301跳转) + + TODO: 实现签名验证和重定向 + """ + raise HTTPException(status_code=501, detail="外链功能暂未实现") + @file_router.put( - path='/download/{id}', - summary='创建文件下载会话', - description='Create a file download session endpoint.', - dependencies=[Depends(SignRequired)] + path='/update/{id}', + summary='更新文件', + description='更新文件内容。', + dependencies=[Depends(AuthRequired)] ) -def router_file_download(id: str) -> ResponseBase: - """ - Create a file download session endpoint. - - Args: - id (str): The ID of the file to download. - - Returns: - ResponseBase: A model containing the response data for the file download session. - """ - pass +async def router_file_update(id: str) -> ResponseBase: + """更新文件内容""" + raise HTTPException(status_code=501, detail="更新文件功能暂未实现") + @file_router.get( path='/preview/{id}', summary='预览文件', - description='Preview file endpoint.', - dependencies=[Depends(SignRequired)] + description='获取文件预览。', + dependencies=[Depends(AuthRequired)] ) -def router_file_preview(id: str) -> ResponseBase: - """ - Preview file endpoint. - - Args: - id (str): The ID of the file to preview. - - Returns: - ResponseBase: A model containing the response data for the file preview. - """ - pass +async def router_file_preview(id: str) -> ResponseBase: + """预览文件""" + raise HTTPException(status_code=501, detail="预览功能暂未实现") + @file_router.get( path='/content/{id}', summary='获取文本文件内容', - description='Get text file content endpoint.', - dependencies=[Depends(SignRequired)] + description='获取文本文件内容。', + dependencies=[Depends(AuthRequired)] ) -def router_file_content(id: str) -> ResponseBase: - """ - Get text file content endpoint. - - Args: - id (str): The ID of the text file. - - Returns: - ResponseBase: A model containing the response data for the text file content. - """ - pass +async def router_file_content(id: str) -> ResponseBase: + """获取文本文件内容""" + raise HTTPException(status_code=501, detail="文本内容功能暂未实现") + @file_router.get( path='/doc/{id}', summary='获取Office文档预览地址', - description='Get Office document preview URL endpoint.', - dependencies=[Depends(SignRequired)] + description='获取Office文档在线预览地址。', + dependencies=[Depends(AuthRequired)] ) -def router_file_doc(id: str) -> ResponseBase: - """ - Get Office document preview URL endpoint. - - Args: - id (str): The ID of the Office document. - - Returns: - ResponseBase: A model containing the response data for the Office document preview URL. - """ - pass +async def router_file_doc(id: str) -> ResponseBase: + """获取Office文档预览地址""" + raise HTTPException(status_code=501, detail="Office预览功能暂未实现") + @file_router.get( path='/thumb/{id}', summary='获取文件缩略图', - description='Get file thumbnail endpoint.', - dependencies=[Depends(SignRequired)] + description='获取文件缩略图。', + dependencies=[Depends(AuthRequired)] ) -def router_file_thumb(id: str) -> ResponseBase: - """ - Get file thumbnail endpoint. - - Args: - id (str): The ID of the file to get the thumbnail for. - - Returns: - ResponseBase: A model containing the response data for the file thumbnail. - """ - pass +async def router_file_thumb(id: str) -> ResponseBase: + """获取文件缩略图""" + raise HTTPException(status_code=501, detail="缩略图功能暂未实现") + @file_router.post( path='/source/{id}', summary='取得文件外链', - description='Get file external link endpoint.', - dependencies=[Depends(SignRequired)] + description='获取文件的外链地址。', + dependencies=[Depends(AuthRequired)] ) -def router_file_source(id: str) -> ResponseBase: - """ - Get file external link endpoint. - - Args: - id (str): The ID of the file to get the external link for. - - Returns: - ResponseBase: A model containing the response data for the file external link. - """ - pass +async def router_file_source(id: str) -> ResponseBase: + """获取文件外链""" + raise HTTPException(status_code=501, detail="外链功能暂未实现") + @file_router.post( path='/archive', summary='打包要下载的文件', - description='Archive files for download endpoint.', - dependencies=[Depends(SignRequired)] + description='将多个文件打包下载。', + dependencies=[Depends(AuthRequired)] ) -def router_file_archive(id: str) -> ResponseBase: - """ - Archive files for download endpoint. - - Args: - id (str): The ID of the file to archive. - - Returns: - ResponseBase: A model containing the response data for the archived files. - """ - pass +async def router_file_archive() -> ResponseBase: + """打包文件""" + raise HTTPException(status_code=501, detail="打包功能暂未实现") + @file_router.post( path='/compress', summary='创建文件压缩任务', - description='Create file compression task endpoint.', - dependencies=[Depends(SignRequired)] + description='创建文件压缩任务。', + dependencies=[Depends(AuthRequired)] ) -def router_file_compress(id: str) -> ResponseBase: - """ - Create file compression task endpoint. - - Args: - id (str): The ID of the file to compress. - - Returns: - ResponseBase: A model containing the response data for the file compression task. - """ - pass +async def router_file_compress() -> ResponseBase: + """创建压缩任务""" + raise HTTPException(status_code=501, detail="压缩功能暂未实现") + @file_router.post( path='/decompress', summary='创建文件解压任务', - description='Create file extraction task endpoint.', - dependencies=[Depends(SignRequired)] + description='创建文件解压任务。', + dependencies=[Depends(AuthRequired)] ) -def router_file_decompress(id: str) -> ResponseBase: - """ - Create file extraction task endpoint. - - Args: - id (str): The ID of the file to decompress. - - Returns: - ResponseBase: A model containing the response data for the file extraction task. - """ - pass +async def router_file_decompress() -> ResponseBase: + """创建解压任务""" + raise HTTPException(status_code=501, detail="解压功能暂未实现") + @file_router.post( path='/relocate', summary='创建文件转移任务', - description='Create file relocation task endpoint.', - dependencies=[Depends(SignRequired)] + description='创建文件转移任务。', + dependencies=[Depends(AuthRequired)] ) -def router_file_relocate(id: str) -> ResponseBase: - """ - Create file relocation task endpoint. - - Args: - id (str): The ID of the file to relocate. - - Returns: - ResponseBase: A model containing the response data for the file relocation task. - """ - pass +async def router_file_relocate() -> ResponseBase: + """创建转移任务""" + raise HTTPException(status_code=501, detail="转移功能暂未实现") + @file_router.get( path='/search/{type}/{keyword}', summary='搜索文件', - description='Search files by keyword endpoint.', - dependencies=[Depends(SignRequired)] + description='按关键字搜索文件。', + dependencies=[Depends(AuthRequired)] ) -def router_file_search(type: str, keyword: str) -> ResponseBase: - """ - Search files by keyword endpoint. - - Args: - type (str): The type of search (e.g., 'name', 'content'). - keyword (str): The keyword to search for. - - Returns: - ResponseBase: A model containing the response data for the file search. - """ - pass \ No newline at end of file +async def router_file_search(type: str, keyword: str) -> ResponseBase: + """搜索文件""" + raise HTTPException(status_code=501, detail="搜索功能暂未实现") + + +@file_upload_router.get( + path='/archive/{sessionID}/archive.zip', + summary='打包并下载文件', + description='获取打包后的文件。', +) +async def router_file_archive_download(sessionID: str) -> ResponseBase: + """打包下载""" + raise HTTPException(status_code=501, detail="打包下载功能暂未实现") + + +@file_router.put( + path='/download/{id}', + summary='创建文件下载会话', + description='创建文件下载会话。', + dependencies=[Depends(AuthRequired)] +) +async def router_file_download_session(id: str) -> ResponseBase: + """创建下载会话""" + raise HTTPException(status_code=501, detail="下载会话功能暂未实现") diff --git a/routers/api/v1/object/__init__.py b/routers/api/v1/object/__init__.py index 151fd90..567b441 100644 --- a/routers/api/v1/object/__init__.py +++ b/routers/api/v1/object/__init__.py @@ -1,11 +1,35 @@ +""" +对象操作路由 + +提供文件和目录对象的管理功能:删除、移动、复制、重命名等。 + +路由前缀:/object +""" from typing import Annotated +from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from loguru import logger as l +from sqlmodel.ext.asyncio.session import AsyncSession from middleware.auth import AuthRequired from middleware.dependencies import SessionDep -from models import Object, ObjectDeleteRequest, ObjectMoveRequest, User -from models.response import ResponseBase +from models import ( + Object, + ObjectCopyRequest, + ObjectDeleteRequest, + ObjectMoveRequest, + ObjectPropertyDetailResponse, + ObjectPropertyResponse, + ObjectRenameRequest, + ObjectType, + PhysicalFile, + Policy, + PolicyType, + User, +) +from models import ResponseBase +from service.storage import LocalStorageService object_router = APIRouter( prefix="/object", @@ -13,10 +37,137 @@ object_router = APIRouter( ) +async def _delete_object_recursive( + session: AsyncSession, + obj: Object, + user_id: UUID, +) -> int: + """ + 递归删除对象(软删除) + + 对于文件: + - 减少 PhysicalFile 引用计数 + - 只有引用计数为0时才移动物理文件到回收站 + + 对于目录: + - 递归处理所有子对象 + + :param session: 数据库会话 + :param obj: 要删除的对象 + :param user_id: 用户UUID + :return: 删除的对象数量 + """ + deleted_count = 0 + + if obj.is_folder: + # 递归删除子对象 + children = await Object.get_children(session, user_id, obj.id) + for child in children: + deleted_count += await _delete_object_recursive(session, child, user_id) + + # 如果是文件,处理物理文件引用 + if obj.is_file and obj.physical_file_id: + physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id) + if physical_file: + # 减少引用计数 + new_count = physical_file.decrement_reference() + + if physical_file.can_be_deleted: + # 引用计数为0,移动物理文件到回收站 + policy = await Policy.get(session, Policy.id == physical_file.policy_id) + if policy and policy.type == PolicyType.LOCAL: + try: + storage_service = LocalStorageService(policy) + await storage_service.move_to_trash( + source_path=physical_file.storage_path, + user_id=user_id, + object_id=obj.id, + ) + l.debug(f"物理文件已移动到回收站: {obj.name}") + except Exception as e: + l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}") + + # 删除 PhysicalFile 记录 + await PhysicalFile.delete(session, physical_file) + l.debug(f"物理文件记录已删除: {physical_file.storage_path}") + else: + # 还有其他引用,只更新引用计数 + await physical_file.save(session) + l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}") + + # 删除数据库记录 + await Object.delete(session, obj) + deleted_count += 1 + + return deleted_count + + +async def _copy_object_recursive( + session: AsyncSession, + src: Object, + dst_parent_id: UUID, + user_id: UUID, +) -> tuple[int, list[UUID]]: + """ + 递归复制对象 + + 对于文件: + - 增加 PhysicalFile 引用计数 + - 创建新的 Object 记录指向同一 PhysicalFile + + 对于目录: + - 创建新目录 + - 递归复制所有子对象 + + :param session: 数据库会话 + :param src: 源对象 + :param dst_parent_id: 目标父目录UUID + :param user_id: 用户UUID + :return: (复制数量, 新对象UUID列表) + """ + copied_count = 0 + new_ids: list[UUID] = [] + + # 创建新的 Object 记录 + new_obj = Object( + name=src.name, + type=src.type, + size=src.size, + password=src.password, + parent_id=dst_parent_id, + owner_id=user_id, + policy_id=src.policy_id, + physical_file_id=src.physical_file_id, + ) + + # 如果是文件,增加物理文件引用计数 + if src.is_file and src.physical_file_id: + physical_file = await PhysicalFile.get(session, PhysicalFile.id == src.physical_file_id) + if physical_file: + physical_file.increment_reference() + await physical_file.save(session) + + new_obj = await new_obj.save(session) + copied_count += 1 + new_ids.append(new_obj.id) + + # 如果是目录,递归复制子对象 + if src.is_folder: + children = await Object.get_children(session, user_id, src.id) + for child in children: + child_count, child_ids = await _copy_object_recursive( + session, child, new_obj.id, user_id + ) + copied_count += child_count + new_ids.extend(child_ids) + + return copied_count, new_ids + + @object_router.delete( path='/', summary='删除对象', - description='删除一个或多个对象(文件或目录)', + description='删除一个或多个对象(文件或目录),文件会移动到用户回收站。', ) async def router_object_delete( session: SessionDep, @@ -24,22 +175,39 @@ async def router_object_delete( request: ObjectDeleteRequest, ) -> ResponseBase: """ - 删除对象端点 + 删除对象端点(软删除) + + 流程: + 1. 验证对象存在且属于当前用户 + 2. 对于文件,减少物理文件引用计数 + 3. 如果引用计数为0,移动物理文件到 .trash 目录 + 4. 对于目录,递归处理子对象 + 5. 从数据库中删除记录 :param session: 数据库会话 :param user: 当前登录用户 :param request: 删除请求(包含待删除对象的UUID列表) :return: 删除结果 """ + # 存储 user.id,避免后续 save() 导致 user 过期后无法访问 + user_id = user.id deleted_count = 0 for obj_id in request.ids: obj = await Object.get(session, Object.id == obj_id) - if obj and obj.owner_id == user.id: - # TODO: 递归删除子对象(如果是目录) - # TODO: 更新用户存储空间 - await obj.delete(session) - deleted_count += 1 + if not obj or obj.owner_id != user_id: + continue + + # 不能删除根目录 + if obj.parent_id is None: + l.warning(f"尝试删除根目录被阻止: {obj.name}") + continue + + # 递归删除(包含引用计数逻辑) + count = await _delete_object_recursive(session, obj, user_id) + deleted_count += count + + l.info(f"用户 {user_id} 删除了 {deleted_count} 个对象") return ResponseBase( data={ @@ -67,9 +235,12 @@ async def router_object_move( :param request: 移动请求(包含源对象UUID列表和目标目录UUID) :return: 移动结果 """ + # 存储 user.id,避免后续 save() 导致 user 过期后无法访问 + user_id = user.id + # 验证目标目录 dst = await Object.get(session, Object.id == request.dst_id) - if not dst or dst.owner_id != user.id: + if not dst or dst.owner_id != user_id: raise HTTPException(status_code=404, detail="目标目录不存在") if not dst.is_folder: @@ -79,17 +250,33 @@ async def router_object_move( for src_id in request.src_ids: src = await Object.get(session, Object.id == src_id) - if not src or src.owner_id != user.id: + if not src or src.owner_id != user_id: + continue + + # 不能移动根目录 + if src.parent_id is None: continue # 检查是否移动到自身或子目录(防止循环引用) if src.id == dst.id: continue + # 检查是否将目录移动到其子目录中(循环检测) + if src.is_folder: + current = dst + is_cycle = False + while current and current.parent_id: + if current.parent_id == src.id: + is_cycle = True + break + current = await Object.get(session, Object.id == current.parent_id) + if is_cycle: + continue + # 检查目标目录下是否存在同名对象 existing = await Object.get( session, - (Object.owner_id == user.id) & + (Object.owner_id == user_id) & (Object.parent_id == dst.id) & (Object.name == src.name) ) @@ -107,50 +294,279 @@ async def router_object_move( } ) + @object_router.post( path='/copy', summary='复制对象', - description='Copy an object endpoint.', - dependencies=[Depends(AuthRequired)] + description='复制一个或多个对象到目标目录。文件复制仅增加物理文件引用计数,不复制物理文件。', ) -def router_object_copy() -> ResponseBase: +async def router_object_copy( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + request: ObjectCopyRequest, +) -> ResponseBase: """ - Copy an object endpoint. - - Returns: - ResponseBase: A model containing the response data for the object copy. + 复制对象端点 + + 流程: + 1. 验证目标目录存在且属于当前用户 + 2. 对于每个源对象: + - 验证源对象存在且属于当前用户 + - 检查目标目录下是否存在同名对象 + - 如果是文件:增加 PhysicalFile 引用计数,创建新 Object + - 如果是目录:递归复制所有子对象 + 3. 返回复制结果 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param request: 复制请求 + :return: 复制结果 """ - pass + # 存储 user.id,避免后续 save() 导致 user 过期后无法访问 + user_id = user.id + + # 验证目标目录 + dst = await Object.get(session, Object.id == request.dst_id) + if not dst or dst.owner_id != user_id: + raise HTTPException(status_code=404, detail="目标目录不存在") + + if not dst.is_folder: + raise HTTPException(status_code=400, detail="目标不是有效文件夹") + + copied_count = 0 + new_ids: list[UUID] = [] + + for src_id in request.src_ids: + src = await Object.get(session, Object.id == src_id) + if not src or src.owner_id != user_id: + continue + + # 不能复制根目录 + if src.parent_id is None: + continue + + # 不能复制到自身 + if src.id == dst.id: + continue + + # 不能将目录复制到其子目录中 + if src.is_folder: + current = dst + is_cycle = False + while current and current.parent_id: + if current.parent_id == src.id: + is_cycle = True + break + current = await Object.get(session, Object.id == current.parent_id) + if is_cycle: + continue + + # 检查目标目录下是否存在同名对象 + existing = await Object.get( + session, + (Object.owner_id == user_id) & + (Object.parent_id == dst.id) & + (Object.name == src.name) + ) + if existing: + continue # 跳过重名对象 + + # 递归复制 + count, ids = await _copy_object_recursive(session, src, dst.id, user_id) + copied_count += count + new_ids.extend(ids) + + l.info(f"用户 {user_id} 复制了 {copied_count} 个对象") + + return ResponseBase( + data={ + "copied": copied_count, + "total": len(request.src_ids), + "new_ids": new_ids, + } + ) + @object_router.post( path='/rename', summary='重命名对象', - description='Rename an object endpoint.', - dependencies=[Depends(AuthRequired)] + description='重命名对象(文件或目录)。', ) -def router_object_rename() -> ResponseBase: +async def router_object_rename( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + request: ObjectRenameRequest, +) -> ResponseBase: """ - Rename an object endpoint. - - Returns: - ResponseBase: A model containing the response data for the object rename. + 重命名对象端点 + + 流程: + 1. 验证对象存在且属于当前用户 + 2. 验证新名称格式(不含非法字符) + 3. 检查同目录下是否存在同名对象 + 4. 更新 name 字段 + 5. 返回更新结果 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param request: 重命名请求 + :return: 重命名结果 """ - pass + # 存储 user.id,避免后续 save() 导致 user 过期后无法访问 + user_id = user.id + + # 验证对象存在 + obj = await Object.get(session, Object.id == request.id) + if not obj: + raise HTTPException(status_code=404, detail="对象不存在") + + if obj.owner_id != user_id: + raise HTTPException(status_code=403, detail="无权操作此对象") + + # 不能重命名根目录 + if obj.parent_id is None: + raise HTTPException(status_code=400, detail="无法重命名根目录") + + # 验证新名称格式 + new_name = request.new_name.strip() + if not new_name: + raise HTTPException(status_code=400, detail="名称不能为空") + + if '/' in new_name or '\\' in new_name: + raise HTTPException(status_code=400, detail="名称不能包含斜杠") + + # 如果名称没有变化,直接返回成功 + if obj.name == new_name: + return ResponseBase(data={"success": True}) + + # 检查同目录下是否存在同名对象 + existing = await Object.get( + session, + (Object.owner_id == user_id) & + (Object.parent_id == obj.parent_id) & + (Object.name == new_name) + ) + if existing: + raise HTTPException(status_code=409, detail="同名对象已存在") + + # 更新名称 + obj.name = new_name + await obj.save(session) + + l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}") + + return ResponseBase(data={"success": True}) + @object_router.get( path='/property/{id}', - summary='获取对象属性', - description='Get object properties endpoint.', - dependencies=[Depends(AuthRequired)] + summary='获取对象基本属性', + description='获取对象的基本属性信息(名称、类型、大小、创建/修改时间等)。', ) -def router_object_property(id: str) -> ResponseBase: +async def router_object_property( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + id: UUID, +) -> ObjectPropertyResponse: """ - Get object properties endpoint. - - Args: - id (str): The ID of the object to retrieve properties for. - - Returns: - ResponseBase: A model containing the response data for the object properties. + 获取对象基本属性端点 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param id: 对象UUID + :return: 对象基本属性 """ - pass \ No newline at end of file + obj = await Object.get(session, Object.id == id) + if not obj: + raise HTTPException(status_code=404, detail="对象不存在") + + if obj.owner_id != user.id: + raise HTTPException(status_code=403, detail="无权查看此对象") + + return ObjectPropertyResponse( + id=obj.id, + name=obj.name, + type=obj.type, + size=obj.size, + created_at=obj.created_at, + updated_at=obj.updated_at, + parent_id=obj.parent_id, + ) + + +@object_router.get( + path='/property/{id}/detail', + summary='获取对象详细属性', + description='获取对象的详细属性信息,包括元数据、分享统计、存储信息等。', +) +async def router_object_property_detail( + session: SessionDep, + user: Annotated[User, Depends(AuthRequired)], + id: UUID, +) -> ObjectPropertyDetailResponse: + """ + 获取对象详细属性端点 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param id: 对象UUID + :return: 对象详细属性 + """ + obj = await Object.get( + session, + Object.id == id, + load=Object.file_metadata, + ) + if not obj: + raise HTTPException(status_code=404, detail="对象不存在") + + if obj.owner_id != user.id: + raise HTTPException(status_code=403, detail="无权查看此对象") + + # 获取策略名称 + policy = await Policy.get(session, Policy.id == obj.policy_id) + policy_name = policy.name if policy else None + + # 获取分享统计 + from models import Share + shares = await Share.get( + session, + Share.object_id == obj.id, + fetch_mode="all" + ) + share_count = len(shares) + total_views = sum(s.views for s in shares) + total_downloads = sum(s.downloads for s in shares) + + # 获取物理文件引用计数 + reference_count = 1 + if obj.physical_file_id: + physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id) + if physical_file: + reference_count = physical_file.reference_count + + # 构建响应 + response = ObjectPropertyDetailResponse( + id=obj.id, + name=obj.name, + type=obj.type, + size=obj.size, + created_at=obj.created_at, + updated_at=obj.updated_at, + parent_id=obj.parent_id, + policy_name=policy_name, + share_count=share_count, + total_views=total_views, + total_downloads=total_downloads, + reference_count=reference_count, + ) + + # 添加文件元数据 + if obj.file_metadata: + response.mime_type = obj.file_metadata.mime_type + response.width = obj.file_metadata.width + response.height = obj.file_metadata.height + response.duration = obj.file_metadata.duration + response.checksum_md5 = obj.file_metadata.checksum_md5 + + return response diff --git a/routers/api/v1/share/__init__.py b/routers/api/v1/share/__init__.py index 66a0569..910b285 100644 --- a/routers/api/v1/share/__init__.py +++ b/routers/api/v1/share/__init__.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends from middleware.auth import SignRequired -from models.response import ResponseBase +from models import ResponseBase share_router = APIRouter( prefix='/share', diff --git a/routers/api/v1/site/__init__.py b/routers/api/v1/site/__init__.py index 6254d97..ace0b1a 100644 --- a/routers/api/v1/site/__init__.py +++ b/routers/api/v1/site/__init__.py @@ -3,7 +3,7 @@ from sqlalchemy import and_ import json from middleware.dependencies import SessionDep -from models.response import ResponseBase +from models import ResponseBase from models.setting import Setting site_router = APIRouter( diff --git a/routers/api/v1/slave/__init__.py b/routers/api/v1/slave/__init__.py index bfbf6c0..4addaa0 100644 --- a/routers/api/v1/slave/__init__.py +++ b/routers/api/v1/slave/__init__.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from fastapi.responses import FileResponse from middleware.auth import SignRequired -from models.response import ResponseBase +from models import ResponseBase slave_router = APIRouter( prefix="/slave", diff --git a/routers/api/v1/tag/__init__.py b/routers/api/v1/tag/__init__.py index 47ea411..c0eb6dd 100644 --- a/routers/api/v1/tag/__init__.py +++ b/routers/api/v1/tag/__init__.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends from middleware.auth import SignRequired -from models.response import ResponseBase +from models import ResponseBase tag_router = APIRouter( prefix='/tag', diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py index ff60433..30c68d9 100644 --- a/routers/api/v1/user/__init__.py +++ b/routers/api/v1/user/__init__.py @@ -96,7 +96,7 @@ async def router_user_session( async def router_user_register( session: SessionDep, request: models.RegisterRequest, -) -> models.response.ResponseBase: +) -> models.ResponseBase: """ 用户注册端点 @@ -157,7 +157,7 @@ async def router_user_register( policy_id=default_policy.id, ).save(session) - return models.response.ResponseBase( + return models.ResponseBase( data={ "user_id": new_user_id, "username": new_user_username, @@ -172,7 +172,7 @@ async def router_user_register( ) def router_user_email_code( reason: Literal['register', 'reset'] = 'register', -) -> models.response.ResponseBase: +) -> models.ResponseBase: """ Send a verification code email. @@ -186,7 +186,7 @@ def router_user_email_code( summary='初始化QQ登录', description='Initialize QQ login for a user.', ) -def router_user_qq() -> models.response.ResponseBase: +def router_user_qq() -> models.ResponseBase: """ Initialize QQ login for a user. @@ -200,7 +200,7 @@ def router_user_qq() -> models.response.ResponseBase: summary='WebAuthn登录初始化', description='Initialize WebAuthn login for a user.', ) -async def router_user_authn(username: str) -> models.response.ResponseBase: +async def router_user_authn(username: str) -> models.ResponseBase: pass @@ -209,7 +209,7 @@ async def router_user_authn(username: str) -> models.response.ResponseBase: summary='WebAuthn登录', description='Finish WebAuthn login for a user.', ) -def router_user_authn_finish(username: str) -> models.response.ResponseBase: +def router_user_authn_finish(username: str) -> models.ResponseBase: """ Finish WebAuthn login for a user. @@ -226,7 +226,7 @@ def router_user_authn_finish(username: str) -> models.response.ResponseBase: summary='获取用户主页展示用分享', description='Get user profile for display.', ) -def router_user_profile(id: str) -> models.response.ResponseBase: +def router_user_profile(id: str) -> models.ResponseBase: """ Get user profile for display. @@ -243,7 +243,7 @@ def router_user_profile(id: str) -> models.response.ResponseBase: summary='获取用户头像', description='Get user avatar by ID and size.', ) -def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseBase: +def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase: """ Get user avatar by ID and size. @@ -265,17 +265,17 @@ def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseBase summary='获取用户信息', description='Get user information.', dependencies=[Depends(dependency=AuthRequired)], - response_model=models.response.ResponseBase, + response_model=models.ResponseBase, ) async def router_user_me( session: SessionDep, user: Annotated[models.User, Depends(AuthRequired)], -) -> models.response.ResponseBase: +) -> models.ResponseBase: """ 获取用户信息. - :return: response.ResponseBase containing user information. - :rtype: response.ResponseBase + :return: ResponseBase containing user information. + :rtype: ResponseBase """ # 加载 group 及其 options 关系 group = await models.Group.get( @@ -302,7 +302,7 @@ async def router_user_me( tags=[tag.name for tag in user_tags] if user_tags else [], ) - return models.response.ResponseBase(data=user_response.model_dump()) + return models.ResponseBase(data=user_response.model_dump()) @user_router.get( path='/storage', @@ -313,7 +313,7 @@ async def router_user_me( async def router_user_storage( session: SessionDep, user: Annotated[models.user.User, Depends(AuthRequired)], -) -> models.response.ResponseBase: +) -> models.ResponseBase: """ 获取用户存储空间信息。 @@ -330,7 +330,7 @@ async def router_user_storage( used: int = user.storage free: int = max(0, total - used) - return models.response.ResponseBase( + return models.ResponseBase( data={ "used": used, "free": free, @@ -347,7 +347,7 @@ async def router_user_storage( async def router_user_authn_start( session: SessionDep, user: Annotated[models.user.User, Depends(AuthRequired)], -) -> models.response.ResponseBase: +) -> models.ResponseBase: """ Initialize WebAuthn login for a user. @@ -378,7 +378,7 @@ async def router_user_authn_start( user_display_name=user.nick or user.username, ) - return models.response.ResponseBase(data=options_to_json_dict(options)) + return models.ResponseBase(data=options_to_json_dict(options)) @user_router.put( path='/authn/finish', @@ -386,7 +386,7 @@ async def router_user_authn_start( description='Finish WebAuthn login for a user.', dependencies=[Depends(AuthRequired)], ) -def router_user_authn_finish() -> models.response.ResponseBase: +def router_user_authn_finish() -> models.ResponseBase: """ Finish WebAuthn login for a user. @@ -400,7 +400,7 @@ def router_user_authn_finish() -> models.response.ResponseBase: summary='获取用户可选存储策略', description='Get user selectable storage policies.', ) -def router_user_settings_policies() -> models.response.ResponseBase: +def router_user_settings_policies() -> models.ResponseBase: """ Get user selectable storage policies. @@ -415,7 +415,7 @@ def router_user_settings_policies() -> models.response.ResponseBase: description='Get user selectable nodes.', dependencies=[Depends(AuthRequired)], ) -def router_user_settings_nodes() -> models.response.ResponseBase: +def router_user_settings_nodes() -> models.ResponseBase: """ Get user selectable nodes. @@ -430,7 +430,7 @@ def router_user_settings_nodes() -> models.response.ResponseBase: description='Get user task queue.', dependencies=[Depends(AuthRequired)], ) -def router_user_settings_tasks() -> models.response.ResponseBase: +def router_user_settings_tasks() -> models.ResponseBase: """ Get user task queue. @@ -445,14 +445,14 @@ def router_user_settings_tasks() -> models.response.ResponseBase: description='Get current user settings.', dependencies=[Depends(AuthRequired)], ) -def router_user_settings() -> models.response.ResponseBase: +def router_user_settings() -> models.ResponseBase: """ Get current user settings. Returns: dict: A dictionary containing the current user settings. """ - return models.response.ResponseBase(data=models.UserSettingResponse().model_dump()) + return models.ResponseBase(data=models.UserSettingResponse().model_dump()) @user_settings_router.post( path='/avatar', @@ -460,7 +460,7 @@ def router_user_settings() -> models.response.ResponseBase: description='Upload user avatar from file.', dependencies=[Depends(AuthRequired)], ) -def router_user_settings_avatar() -> models.response.ResponseBase: +def router_user_settings_avatar() -> models.ResponseBase: """ Upload user avatar from file. @@ -475,7 +475,7 @@ def router_user_settings_avatar() -> models.response.ResponseBase: description='Set user avatar to Gravatar.', dependencies=[Depends(AuthRequired)], ) -def router_user_settings_avatar_gravatar() -> models.response.ResponseBase: +def router_user_settings_avatar_gravatar() -> models.ResponseBase: """ Set user avatar to Gravatar. @@ -490,7 +490,7 @@ def router_user_settings_avatar_gravatar() -> models.response.ResponseBase: description='Update user settings.', dependencies=[Depends(AuthRequired)], ) -def router_user_settings_patch(option: str) -> models.response.ResponseBase: +def router_user_settings_patch(option: str) -> models.ResponseBase: """ Update user settings. @@ -510,7 +510,7 @@ def router_user_settings_patch(option: str) -> models.response.ResponseBase: ) async def router_user_settings_2fa( user: Annotated[models.user.User, Depends(AuthRequired)], -) -> models.response.ResponseBase: +) -> models.ResponseBase: """ Get two-factor authentication initialization information. @@ -518,7 +518,7 @@ async def router_user_settings_2fa( dict: A dictionary containing two-factor authentication setup information. """ - return models.response.ResponseBase( + return models.ResponseBase( data=await Password.generate_totp(user.username) ) @@ -533,7 +533,7 @@ async def router_user_settings_2fa_enable( user: Annotated[models.user.User, Depends(AuthRequired)], setup_token: str, code: str, -) -> models.response.ResponseBase: +) -> models.ResponseBase: """ Enable two-factor authentication for the user. @@ -559,6 +559,6 @@ async def router_user_settings_2fa_enable( user.two_factor = secret user = await user.save(session) - return models.response.ResponseBase( + return models.ResponseBase( data={"message": "Two-factor authentication enabled successfully"} ) \ No newline at end of file diff --git a/routers/api/v1/vas/__init__.py b/routers/api/v1/vas/__init__.py index 15dc498..25e6b99 100644 --- a/routers/api/v1/vas/__init__.py +++ b/routers/api/v1/vas/__init__.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends from middleware.auth import SignRequired -from models.response import ResponseBase +from models import ResponseBase vas_router = APIRouter( prefix="/vas", diff --git a/routers/api/v1/webdav/__init__.py b/routers/api/v1/webdav/__init__.py index f41a596..52f064f 100644 --- a/routers/api/v1/webdav/__init__.py +++ b/routers/api/v1/webdav/__init__.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, Request from middleware.auth import SignRequired -from models.response import ResponseBase +from models import ResponseBase # WebDAV 管理路由 webdav_router = APIRouter( diff --git a/service/storage/__init__.py b/service/storage/__init__.py new file mode 100644 index 0000000..2b8091c --- /dev/null +++ b/service/storage/__init__.py @@ -0,0 +1,20 @@ +""" +存储服务模块 + +提供文件存储相关的服务,包括: +- 本地存储服务 +- 命名规则解析器 +- 存储异常定义 +""" +from .exceptions import ( + DirectoryCreationError, + FileReadError, + FileWriteError, + InvalidPathError, + StorageException, + StorageFileNotFoundError, + UploadSessionExpiredError, + UploadSessionNotFoundError, +) +from .local_storage import LocalStorageService +from .naming_rule import NamingContext, NamingRuleParser diff --git a/service/storage/exceptions.py b/service/storage/exceptions.py new file mode 100644 index 0000000..ae1e4e3 --- /dev/null +++ b/service/storage/exceptions.py @@ -0,0 +1,45 @@ +""" +存储服务异常定义 + +定义存储操作相关的异常类型,用于精确的错误处理和诊断。 +""" + + +class StorageException(Exception): + """存储服务基础异常""" + pass + + +class DirectoryCreationError(StorageException): + """目录创建失败""" + pass + + +class StorageFileNotFoundError(StorageException): + """文件不存在""" + pass + + +class FileWriteError(StorageException): + """文件写入失败""" + pass + + +class FileReadError(StorageException): + """文件读取失败""" + pass + + +class UploadSessionNotFoundError(StorageException): + """上传会话不存在""" + pass + + +class UploadSessionExpiredError(StorageException): + """上传会话已过期""" + pass + + +class InvalidPathError(StorageException): + """无效的路径""" + pass diff --git a/service/storage/local_storage.py b/service/storage/local_storage.py new file mode 100644 index 0000000..7eb16ca --- /dev/null +++ b/service/storage/local_storage.py @@ -0,0 +1,388 @@ +""" +本地存储服务 + +负责本地文件系统的物理操作: +- 目录创建 +- 文件写入/读取/删除 +- 文件移动(软删除到 .trash) + +所有 IO 操作都使用 aiofiles 确保异步执行。 +""" +from pathlib import Path +from uuid import UUID + +import aiofiles +import aiofiles.os +from loguru import logger as l + +from models.policy import Policy +from .exceptions import ( + DirectoryCreationError, + FileReadError, + FileWriteError, + InvalidPathError, + StorageException, + StorageFileNotFoundError, +) +from .naming_rule import NamingContext, NamingRuleParser + + +class LocalStorageService: + """ + 本地存储服务 + + 实现本地文件系统的异步文件操作。 + 所有 IO 操作都使用 aiofiles 确保异步执行。 + + 使用示例:: + + service = LocalStorageService(policy) + await service.ensure_base_directory() + + dir_path, storage_name, full_path = await service.generate_file_path( + user_id=user.id, + original_filename="document.pdf", + ) + await service.write_file(full_path, content) + """ + + def __init__(self, policy: Policy): + """ + 初始化本地存储服务 + + :param policy: 存储策略配置 + :raises StorageException: 本地存储策略未指定 server 路径时抛出 + """ + if not policy.server: + raise StorageException("本地存储策略必须指定 server 路径") + + self._policy = policy + self._base_path = Path(policy.server).resolve() + + @property + def base_path(self) -> Path: + """存储根目录""" + return self._base_path + + # ==================== 目录操作 ==================== + + async def ensure_base_directory(self) -> None: + """ + 确保存储根目录存在 + + 创建策略时调用,确保物理目录已创建。 + + :raises DirectoryCreationError: 目录创建失败时抛出 + """ + try: + await aiofiles.os.makedirs(str(self._base_path), exist_ok=True) + l.info(f"已确保存储目录存在: {self._base_path}") + except OSError as e: + raise DirectoryCreationError(f"无法创建存储目录 {self._base_path}: {e}") + + async def ensure_directory(self, relative_path: str) -> Path: + """ + 确保相对路径的目录存在 + + :param relative_path: 相对于存储根目录的路径 + :return: 完整的目录路径 + :raises DirectoryCreationError: 目录创建失败时抛出 + """ + try: + full_path = self._base_path / relative_path + await aiofiles.os.makedirs(str(full_path), exist_ok=True) + return full_path + except OSError as e: + raise DirectoryCreationError(f"无法创建目录 {relative_path}: {e}") + + async def ensure_trash_directory(self, user_id: UUID) -> Path: + """ + 确保用户的回收站目录存在 + + 回收站路径格式: {storage_root}/{user_id}/.trash + + :param user_id: 用户UUID + :return: 回收站目录路径 + :raises DirectoryCreationError: 目录创建失败时抛出 + """ + trash_path = self._base_path / str(user_id) / ".trash" + try: + await aiofiles.os.makedirs(str(trash_path), exist_ok=True) + return trash_path + except OSError as e: + raise DirectoryCreationError(f"无法创建回收站目录: {e}") + + # ==================== 路径生成 ==================== + + async def generate_file_path( + self, + user_id: UUID, + original_filename: str, + ) -> tuple[str, str, str]: + """ + 根据命名规则生成文件存储路径 + + :param user_id: 用户UUID + :param original_filename: 原始文件名 + :return: (相对目录路径, 存储文件名, 完整物理路径) + """ + context = NamingContext( + user_id=user_id, + original_filename=original_filename, + ) + + # 解析目录规则 + dir_path = "" + if self._policy.dir_name_rule: + dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context) + + # 解析文件名规则 + if self._policy.auto_rename and self._policy.file_name_rule: + storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context) + # 确保有扩展名 + if '.' in original_filename and '.' not in storage_name: + ext = original_filename.rsplit('.', 1)[1] + storage_name = f"{storage_name}.{ext}" + else: + storage_name = original_filename + + # 确保目录存在 + if dir_path: + full_dir = await self.ensure_directory(dir_path) + else: + full_dir = self._base_path + + full_path = str(full_dir / storage_name) + + return dir_path, storage_name, full_path + + # ==================== 文件写入 ==================== + + async def write_file(self, path: str, content: bytes) -> int: + """ + 写入文件内容 + + :param path: 完整文件路径 + :param content: 文件内容 + :return: 写入的字节数 + :raises FileWriteError: 写入失败时抛出 + """ + try: + async with aiofiles.open(path, 'wb') as f: + await f.write(content) + return len(content) + except OSError as e: + raise FileWriteError(f"写入文件失败 {path}: {e}") + + async def write_file_chunk( + self, + path: str, + content: bytes, + offset: int, + ) -> int: + """ + 写入文件分片 + + :param path: 完整文件路径 + :param content: 分片内容 + :param offset: 写入偏移量 + :return: 写入的字节数 + :raises FileWriteError: 写入失败时抛出 + """ + try: + # 检查文件是否存在,决定打开模式 + is_exists = await self.file_exists(path) + mode = 'r+b' if is_exists else 'wb' + + async with aiofiles.open(path, mode) as f: + await f.seek(offset) + await f.write(content) + return len(content) + except OSError as e: + raise FileWriteError(f"写入文件分片失败 {path}: {e}") + + async def create_empty_file(self, path: str) -> None: + """ + 创建空白文件 + + :param path: 完整文件路径 + :raises FileWriteError: 创建失败时抛出 + """ + try: + async with aiofiles.open(path, 'wb'): + pass # 创建空文件 + except OSError as e: + raise FileWriteError(f"创建空文件失败 {path}: {e}") + + # ==================== 文件读取 ==================== + + async def read_file(self, path: str) -> bytes: + """ + 读取完整文件 + + :param path: 完整文件路径 + :return: 文件内容 + :raises StorageFileNotFoundError: 文件不存在时抛出 + :raises FileReadError: 读取失败时抛出 + """ + if not await self.file_exists(path): + raise StorageFileNotFoundError(f"文件不存在: {path}") + + try: + async with aiofiles.open(path, 'rb') as f: + return await f.read() + except OSError as e: + raise FileReadError(f"读取文件失败 {path}: {e}") + + async def get_file_size(self, path: str) -> int: + """ + 获取文件大小 + + :param path: 完整文件路径 + :return: 文件大小(字节) + :raises StorageFileNotFoundError: 文件不存在时抛出 + """ + if not await self.file_exists(path): + raise StorageFileNotFoundError(f"文件不存在: {path}") + + stat = await aiofiles.os.stat(path) + return stat.st_size + + async def file_exists(self, path: str) -> bool: + """ + 检查文件是否存在 + + :param path: 完整文件路径 + :return: 是否存在 + """ + return await aiofiles.os.path.exists(path) + + # ==================== 文件删除和移动 ==================== + + async def delete_file(self, path: str) -> None: + """ + 删除文件(物理删除) + + :param path: 完整文件路径 + """ + if await self.file_exists(path): + try: + await aiofiles.os.remove(path) + l.debug(f"已删除文件: {path}") + except OSError as e: + l.warning(f"删除文件失败 {path}: {e}") + + async def move_to_trash( + self, + source_path: str, + user_id: UUID, + object_id: UUID, + ) -> str: + """ + 将文件移动到回收站(软删除) + + 回收站中的文件名格式: {object_uuid}_{original_filename} + + :param source_path: 源文件完整路径 + :param user_id: 用户UUID + :param object_id: 对象UUID(用于生成唯一的回收站文件名) + :return: 回收站中的文件路径 + :raises StorageFileNotFoundError: 源文件不存在时抛出 + """ + if not await self.file_exists(source_path): + raise StorageFileNotFoundError(f"源文件不存在: {source_path}") + + # 确保回收站目录存在 + trash_dir = await self.ensure_trash_directory(user_id) + + # 使用 object_id 作为回收站文件名前缀,避免冲突 + source_filename = Path(source_path).name + trash_filename = f"{object_id}_{source_filename}" + trash_path = trash_dir / trash_filename + + # 移动文件 + try: + await aiofiles.os.rename(source_path, str(trash_path)) + l.info(f"文件已移动到回收站: {source_path} -> {trash_path}") + return str(trash_path) + except OSError as e: + raise StorageException(f"移动文件到回收站失败: {e}") + + async def restore_from_trash( + self, + trash_path: str, + restore_path: str, + ) -> None: + """ + 从回收站恢复文件 + + :param trash_path: 回收站中的文件路径 + :param restore_path: 恢复目标路径 + :raises StorageFileNotFoundError: 回收站文件不存在时抛出 + """ + if not await self.file_exists(trash_path): + raise StorageFileNotFoundError(f"回收站文件不存在: {trash_path}") + + # 确保目标目录存在 + restore_dir = Path(restore_path).parent + await aiofiles.os.makedirs(str(restore_dir), exist_ok=True) + + try: + await aiofiles.os.rename(trash_path, restore_path) + l.info(f"文件已从回收站恢复: {trash_path} -> {restore_path}") + except OSError as e: + raise StorageException(f"从回收站恢复文件失败: {e}") + + async def empty_trash(self, user_id: UUID) -> int: + """ + 清空用户回收站 + + :param user_id: 用户UUID + :return: 删除的文件数量 + """ + trash_dir = self._base_path / str(user_id) / ".trash" + if not await aiofiles.os.path.exists(str(trash_dir)): + return 0 + + deleted_count = 0 + try: + entries = await aiofiles.os.listdir(str(trash_dir)) + for entry in entries: + file_path = trash_dir / entry + if await aiofiles.os.path.isfile(str(file_path)): + await aiofiles.os.remove(str(file_path)) + deleted_count += 1 + l.info(f"已清空用户 {user_id} 的回收站,删除 {deleted_count} 个文件") + except OSError as e: + l.warning(f"清空回收站时出错: {e}") + + return deleted_count + + # ==================== 路径验证 ==================== + + def validate_path(self, path: str) -> bool: + """ + 验证路径是否在存储根目录下(防止路径遍历攻击) + + :param path: 要验证的路径 + :return: 路径是否有效 + """ + try: + resolved = Path(path).resolve() + return str(resolved).startswith(str(self._base_path)) + except (ValueError, OSError): + return False + + def get_relative_path(self, full_path: str) -> str: + """ + 获取相对于存储根目录的相对路径 + + :param full_path: 完整路径 + :return: 相对路径 + :raises InvalidPathError: 路径不在存储根目录下时抛出 + """ + if not self.validate_path(full_path): + raise InvalidPathError(f"路径不在存储根目录下: {full_path}") + + resolved = Path(full_path).resolve() + return str(resolved.relative_to(self._base_path)) diff --git a/service/storage/naming_rule.py b/service/storage/naming_rule.py new file mode 100644 index 0000000..beb823a --- /dev/null +++ b/service/storage/naming_rule.py @@ -0,0 +1,144 @@ +""" +命名规则解析器 + +将包含占位符的规则模板转换为实际的文件名/目录路径。 + +支持的占位符: +- {date}: 当前日期 YYYY-MM-DD +- {timestamp}: Unix 时间戳 +- {year}: 年份 YYYY +- {month}: 月份 MM +- {day}: 日期 DD +- {hour}: 小时 HH +- {minute}: 分钟 MM +- {randomkey16}: 16位随机字符串 +- {originname}: 原始文件名(不含扩展名) +- {ext}: 文件扩展名(不含点) +- {uid}: 用户UUID +- {uuid}: 新生成的UUID +""" +import re +import secrets +import string +from datetime import datetime +from uuid import UUID, uuid4 + +from models.base import SQLModelBase + + +class NamingContext(SQLModelBase): + """ + 命名上下文 + + 包含生成文件名/目录名所需的所有信息。 + """ + + user_id: UUID + """用户UUID""" + + original_filename: str + """原始文件名(包含扩展名)""" + + timestamp: datetime | None = None + """时间戳,默认为当前时间""" + + +class NamingRuleParser: + """ + 命名规则解析器 + + 将包含占位符的规则模板转换为实际的文件名/目录路径。 + + 使用示例:: + + context = NamingContext( + user_id=UUID("..."), + original_filename="document.pdf", + ) + dir_path = NamingRuleParser.parse("{date}/{randomkey16}", context) + # -> "2025-12-23/a1b2c3d4e5f6g7h8" + + file_name = NamingRuleParser.parse("{randomkey16}_{originname}.{ext}", context) + # -> "x9y8z7w6v5u4t3s2_document.pdf" + """ + + # 支持的占位符正则 + _PLACEHOLDER_PATTERN = re.compile(r'\{(\w+)\}') + + # 随机字符集 + _RANDOM_CHARS = string.ascii_lowercase + string.digits + + @classmethod + def parse(cls, rule: str, context: NamingContext) -> str: + """ + 解析命名规则,替换所有占位符 + + :param rule: 命名规则模板,如 "{date}/{randomkey16}" + :param context: 命名上下文 + :return: 解析后的实际路径/文件名 + """ + timestamp = context.timestamp or datetime.now() + + # 解析原始文件名 + origin_name, ext = cls._split_filename(context.original_filename) + + # 占位符替换映射 + replacements: dict[str, str] = { + 'date': timestamp.strftime('%Y-%m-%d'), + 'timestamp': str(int(timestamp.timestamp())), + 'year': timestamp.strftime('%Y'), + 'month': timestamp.strftime('%m'), + 'day': timestamp.strftime('%d'), + 'hour': timestamp.strftime('%H'), + 'minute': timestamp.strftime('%M'), + 'randomkey16': cls._generate_random_key(16), + 'originname': origin_name, + 'ext': ext, + 'uid': str(context.user_id), + 'uuid': str(uuid4()), + } + + def replace_placeholder(match: re.Match[str]) -> str: + placeholder = match.group(1) + return replacements.get(placeholder, match.group(0)) + + return cls._PLACEHOLDER_PATTERN.sub(replace_placeholder, rule) + + @classmethod + def _split_filename(cls, filename: str) -> tuple[str, str]: + """ + 分离文件名和扩展名 + + :param filename: 完整文件名 + :return: (文件名不含扩展名, 扩展名不含点) + """ + if '.' in filename: + parts = filename.rsplit('.', 1) + return parts[0], parts[1] + return filename, '' + + @classmethod + def _generate_random_key(cls, length: int) -> str: + """ + 生成随机字符串 + + :param length: 字符串长度 + :return: 随机字符串 + """ + return ''.join(secrets.choice(cls._RANDOM_CHARS) for _ in range(length)) + + @classmethod + def validate_rule(cls, rule: str) -> bool: + """ + 验证命名规则是否有效 + + :param rule: 命名规则模板 + :return: 是否有效 + """ + valid_placeholders = { + 'date', 'timestamp', 'year', 'month', 'day', 'hour', 'minute', + 'randomkey16', 'originname', 'ext', 'uid', 'uuid', + } + + placeholders = cls._PLACEHOLDER_PATTERN.findall(rule) + return all(p in valid_placeholders for p in placeholders) diff --git a/uv.lock b/uv.lock index d653b65..3f7634c 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,15 @@ resolution-markers = [ "python_full_version < '3.14'", ] +[[package]] +name = "aiofiles" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -421,6 +430,7 @@ name = "disknext-server" version = "0.0.1" source = { virtual = "." } dependencies = [ + { name = "aiofiles" }, { name = "aiohttp" }, { name = "aiosqlite" }, { name = "argon2-cffi" }, @@ -444,6 +454,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "aiofiles", specifier = ">=25.1.0" }, { name = "aiohttp", specifier = ">=3.13.2" }, { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "argon2-cffi", specifier = ">=25.1.0" },