Refactor import statements for ResponseBase in API routers

- Updated import statements in the following files to import ResponseBase directly from models instead of models.response:
  - routers/api/v1/share/__init__.py
  - routers/api/v1/site/__init__.py
  - routers/api/v1/slave/__init__.py
  - routers/api/v1/tag/__init__.py
  - routers/api/v1/user/__init__.py
  - routers/api/v1/vas/__init__.py
  - routers/api/v1/webdav/__init__.py

Enhance user registration and related endpoints in user router

- Changed return type annotations from models.response.ResponseBase to models.ResponseBase in multiple functions.
- Updated return statements to reflect the new import structure.
- Improved documentation for clarity.

Add PhysicalFile model and storage service implementation

- Introduced PhysicalFile model to represent actual files on disk with reference counting logic.
- Created storage service module with local storage implementation, including file operations and error handling.
- Defined exceptions for storage operations to improve error handling.
- Implemented naming rule parser for generating file and directory names based on templates.

Update dependency management in uv.lock

- Added aiofiles version 25.1.0 to the project dependencies.
This commit is contained in:
2025-12-23 12:20:06 +08:00
parent 96bf447426
commit 446d219aca
26 changed files with 2155 additions and 399 deletions

3
.gitignore vendored
View File

@@ -62,3 +62,6 @@ node_modules/
*.bak *.bak
*.tmp *.tmp
*.temp *.temp
# 文件
data/

View File

@@ -1,5 +1,3 @@
from . import response
from .user import ( from .user import (
LoginRequest, LoginRequest,
RegisterRequest, RegisterRequest,
@@ -31,18 +29,29 @@ from .node import (
) )
from .group import Group, GroupBase, GroupOptions, GroupOptionsBase, GroupResponse from .group import Group, GroupBase, GroupOptions, GroupOptionsBase, GroupResponse
from .object import ( from .object import (
CreateFileRequest,
CreateUploadSessionRequest,
DirectoryCreateRequest, DirectoryCreateRequest,
DirectoryResponse, DirectoryResponse,
FileMetadata, FileMetadata,
FileMetadataBase, FileMetadataBase,
Object, Object,
ObjectBase, ObjectBase,
ObjectCopyRequest,
ObjectDeleteRequest, ObjectDeleteRequest,
ObjectMoveRequest, ObjectMoveRequest,
ObjectPropertyDetailResponse,
ObjectPropertyResponse,
ObjectRenameRequest,
ObjectResponse, ObjectResponse,
ObjectType, ObjectType,
PolicyResponse, PolicyResponse,
UploadChunkResponse,
UploadSession,
UploadSessionBase,
UploadSessionResponse,
) )
from .physical_file import PhysicalFile, PhysicalFileBase
from .order import Order, OrderStatus, OrderType from .order import Order, OrderStatus, OrderType
from .policy import Policy, PolicyOptions, PolicyOptionsBase, PolicyType from .policy import Policy, PolicyOptions, PolicyOptionsBase, PolicyType
from .redeem import Redeem, RedeemType from .redeem import Redeem, RedeemType
@@ -56,3 +65,14 @@ from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType
from .webdav import WebDAV from .webdav import WebDAV
from .database import engine, get_session from .database import engine, get_session
import uuid
from sqlmodel import Field
from .base import SQLModelBase
class ResponseBase(SQLModelBase):
"""通用响应模型"""
instance_id: uuid.UUID = Field(default_factory=uuid.uuid4)
"""实例ID用于标识请求的唯一性"""

View File

@@ -283,6 +283,7 @@ async def init_default_user() -> None:
async def init_default_policy() -> None: async def init_default_policy() -> None:
from .policy import Policy, PolicyType from .policy import Policy, PolicyType
from .database import get_session from .database import get_session
from service.storage import LocalStorageService
log.info('初始化默认存储策略...') log.info('初始化默认存储策略...')
@@ -302,6 +303,10 @@ async def init_default_policy() -> None:
file_name_rule="{randomkey16}_{originname}", file_name_rule="{randomkey16}_{originname}",
) )
await local_policy.save(session) local_policy = await local_policy.save(session)
# 创建物理存储目录
storage_service = LocalStorageService(local_policy)
await storage_service.ensure_base_directory()
log.info('已创建默认本地存储策略,存储目录:./data') log.info('已创建默认本地存储策略,存储目录:./data')

View File

@@ -14,6 +14,7 @@ if TYPE_CHECKING:
from .policy import Policy from .policy import Policy
from .source_link import SourceLink from .source_link import SourceLink
from .share import Share from .share import Share
from .physical_file import PhysicalFile
class ObjectType(StrEnum): class ObjectType(StrEnum):
@@ -112,9 +113,6 @@ class ObjectResponse(ObjectBase):
id: UUID id: UUID
"""对象UUID""" """对象UUID"""
path: str
"""对象路径"""
thumb: bool = False thumb: bool = False
"""是否有缩略图""" """是否有缩略图"""
@@ -222,15 +220,20 @@ class Object(ObjectBase, UUIDTableBaseMixin):
# ==================== 文件专属字段 ==================== # ==================== 文件专属字段 ====================
source_name: str | None = Field(default=None, max_length=255)
"""源文件名(仅文件有效)"""
size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""文件大小(字节),目录为 0""" """文件大小(字节),目录为 0"""
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True) upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
"""分块上传会话ID仅文件有效""" """分块上传会话ID仅文件有效"""
physical_file_id: UUID | None = Field(
default=None,
foreign_key="physicalfile.id",
index=True,
ondelete="SET NULL"
)
"""关联的物理文件UUID仅文件有效目录为NULL"""
# ==================== 外键 ==================== # ==================== 外键 ====================
parent_id: UUID | None = Field( parent_id: UUID | None = Field(
@@ -295,8 +298,22 @@ class Object(ObjectBase, UUIDTableBaseMixin):
) )
"""分享列表""" """分享列表"""
physical_file: "PhysicalFile" = Relationship(back_populates="objects")
"""关联的物理文件(仅文件有效)"""
# ==================== 业务属性 ==================== # ==================== 业务属性 ====================
@property
def source_name(self) -> str | None:
"""
源文件存储路径(向后兼容属性)
:return: 物理文件存储路径,如果没有关联物理文件则返回 None
"""
if self.physical_file:
return self.physical_file.storage_path
return None
@property @property
def is_file(self) -> bool: def is_file(self) -> bool:
"""是否为文件""" """是否为文件"""
@@ -397,3 +414,231 @@ class Object(ObjectBase, UUIDTableBaseMixin):
(cls.owner_id == user_id) & (cls.parent_id == parent_id), (cls.owner_id == user_id) & (cls.parent_id == parent_id),
fetch_mode="all" fetch_mode="all"
) )
# ==================== 上传会话模型 ====================
class UploadSessionBase(SQLModelBase):
"""上传会话基础字段"""
file_name: str = Field(max_length=255)
"""原始文件名"""
file_size: int = Field(ge=0)
"""文件总大小(字节)"""
chunk_size: int = Field(ge=1)
"""分片大小(字节)"""
total_chunks: int = Field(ge=1)
"""总分片数"""
class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
"""
上传会话模型
用于管理分片上传的会话状态。
会话有效期为24小时过期后自动失效。
"""
# 会话状态
uploaded_chunks: int = 0
"""已上传分片数"""
uploaded_size: int = 0
"""已上传大小(字节)"""
storage_path: str | None = Field(default=None, max_length=512)
"""文件存储路径"""
expires_at: datetime
"""会话过期时间"""
# 外键
owner_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
"""上传者用户UUID"""
parent_id: UUID = Field(foreign_key="object.id", index=True, ondelete="CASCADE")
"""目标父目录UUID"""
policy_id: UUID = Field(foreign_key="policy.id", index=True, ondelete="RESTRICT")
"""存储策略UUID"""
# 关系
owner: "User" = Relationship()
"""上传者"""
parent: "Object" = Relationship(
sa_relationship_kwargs={"foreign_keys": "[UploadSession.parent_id]"}
)
"""目标父目录"""
policy: "Policy" = Relationship()
"""存储策略"""
@property
def is_expired(self) -> bool:
"""会话是否已过期"""
return datetime.now() > self.expires_at
@property
def is_complete(self) -> bool:
"""上传是否完成"""
return self.uploaded_chunks >= self.total_chunks
# ==================== 上传会话相关 DTO ====================
class CreateUploadSessionRequest(SQLModelBase):
"""创建上传会话请求 DTO"""
file_name: str = Field(max_length=255)
"""文件名"""
file_size: int = Field(ge=0)
"""文件大小(字节)"""
parent_id: UUID
"""父目录UUID"""
policy_id: UUID | None = None
"""存储策略UUID不指定则使用父目录的策略"""
class UploadSessionResponse(SQLModelBase):
"""上传会话响应 DTO"""
id: UUID
"""会话UUID"""
file_name: str
"""原始文件名"""
file_size: int
"""文件总大小(字节)"""
chunk_size: int
"""分片大小(字节)"""
total_chunks: int
"""总分片数"""
uploaded_chunks: int
"""已上传分片数"""
expires_at: datetime
"""过期时间"""
class UploadChunkResponse(SQLModelBase):
"""上传分片响应 DTO"""
uploaded_chunks: int
"""已上传分片数"""
total_chunks: int
"""总分片数"""
is_complete: bool
"""是否上传完成"""
object_id: UUID | None = None
"""完成后的文件对象UUID未完成时为None"""
class CreateFileRequest(SQLModelBase):
"""创建空白文件请求 DTO"""
name: str = Field(max_length=255)
"""文件名"""
parent_id: UUID
"""父目录UUID"""
policy_id: UUID | None = None
"""存储策略UUID不指定则使用父目录的策略"""
# ==================== 对象操作相关 DTO ====================
class ObjectCopyRequest(SQLModelBase):
"""复制对象请求 DTO"""
src_ids: list[UUID]
"""源对象UUID列表"""
dst_id: UUID
"""目标文件夹UUID"""
class ObjectRenameRequest(SQLModelBase):
"""重命名对象请求 DTO"""
id: UUID
"""对象UUID"""
new_name: str = Field(max_length=255)
"""新名称"""
class ObjectPropertyResponse(SQLModelBase):
"""对象基本属性响应 DTO"""
id: UUID
"""对象UUID"""
name: str
"""对象名称"""
type: ObjectType
"""对象类型"""
size: int
"""文件大小(字节)"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""修改时间"""
parent_id: UUID | None
"""父目录UUID"""
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
"""对象详细属性响应 DTO继承基本属性"""
# 元数据信息
mime_type: str | None = None
"""MIME类型"""
width: int | None = None
"""图片宽度(像素)"""
height: int | None = None
"""图片高度(像素)"""
duration: float | None = None
"""音视频时长(秒)"""
checksum_md5: str | None = None
"""MD5校验和"""
# 分享统计
share_count: int = 0
"""分享次数"""
total_views: int = 0
"""总浏览次数"""
total_downloads: int = 0
"""总下载次数"""
# 存储信息
policy_name: str | None = None
"""存储策略名称"""
reference_count: int = 1
"""物理文件引用计数(仅文件有效)"""

90
models/physical_file.py Normal file
View File

@@ -0,0 +1,90 @@
"""
物理文件模型
表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile
实现文件共享而不复制物理文件。
引用计数逻辑:
- 每个引用此文件的 Object 都会增加引用计数
- 当 Object 被删除时,减少引用计数
- 只有当引用计数为 0 时,才物理删除文件
"""
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, Index
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
from .policy import Policy
class PhysicalFileBase(SQLModelBase):
"""物理文件基础模型"""
storage_path: str = Field(max_length=512)
"""物理存储路径(相对于存储策略根目录)"""
size: int = 0
"""文件大小(字节)"""
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和用于文件去重和完整性校验"""
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
"""
物理文件模型
表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile
实现文件共享而不复制物理文件。
"""
__table_args__ = (
Index("ix_physical_file_policy_path", "policy_id", "storage_path"),
Index("ix_physical_file_checksum", "checksum_md5"),
)
policy_id: UUID = Field(
foreign_key="policy.id",
index=True,
ondelete="RESTRICT",
)
"""存储策略UUID"""
reference_count: int = Field(default=1, ge=0)
"""引用计数(有多少个 Object 引用此物理文件)"""
# 关系
policy: "Policy" = Relationship()
"""存储策略"""
objects: list["Object"] = Relationship(back_populates="physical_file")
"""引用此物理文件的所有逻辑对象"""
def increment_reference(self) -> int:
"""
增加引用计数
:return: 更新后的引用计数
"""
self.reference_count += 1
return self.reference_count
def decrement_reference(self) -> int:
"""
减少引用计数
:return: 更新后的引用计数
"""
if self.reference_count > 0:
self.reference_count -= 1
return self.reference_count
@property
def can_be_deleted(self) -> bool:
"""是否可以物理删除引用计数为0"""
return self.reference_count == 0

View File

@@ -1,14 +0,0 @@
"""
通用响应模型定义
"""
import uuid
from sqlmodel import Field
from .base import SQLModelBase
class ResponseBase(SQLModelBase):
"""通用响应模型"""
instance_id: uuid.UUID = Field(default_factory=uuid.uuid4)
"""实例ID用于标识请求的唯一性"""

View File

@@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"aiofiles>=25.1.0",
"aiohttp>=3.13.2", "aiohttp>=3.13.2",
"aiosqlite>=0.21.0", "aiosqlite>=0.21.0",
"argon2-cffi>=25.1.0", "argon2-cffi>=25.1.0",

View File

@@ -12,7 +12,7 @@ from .admin import admin_vas_router
from .callback import callback_router from .callback import callback_router
from .directory import directory_router from .directory import directory_router
from .download import download_router from .download import download_router
from .file import file_router from .file import file_router, file_upload_router
from .object import object_router from .object import object_router
from .share import share_router from .share import share_router
from .site import site_router from .site import site_router
@@ -36,6 +36,7 @@ router.include_router(callback_router)
router.include_router(directory_router) router.include_router(directory_router)
router.include_router(download_router) router.include_router(download_router)
router.include_router(file_router) router.include_router(file_router)
router.include_router(file_upload_router)
router.include_router(object_router) router.include_router(object_router)
router.include_router(share_router) router.include_router(share_router)
router.include_router(site_router) router.include_router(site_router)

View File

@@ -1,11 +1,57 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException
from loguru import logger from loguru import logger as l
from sqlmodel import Field
from middleware.auth import AdminRequired from middleware.auth import AdminRequired
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import User from models import Policy, PolicyOptions, PolicyType, User
from models.base import SQLModelBase
from models import ResponseBase
from models.user import UserPublic from models.user import UserPublic
from models.response import ResponseBase from service.storage import DirectoryCreationError, LocalStorageService
class PolicyCreateRequest(SQLModelBase):
"""创建存储策略请求 DTO"""
name: str = Field(max_length=255)
"""策略名称"""
type: PolicyType
"""策略类型"""
server: str | None = Field(default=None, max_length=255)
"""服务器地址/本地路径(本地存储必填)"""
bucket_name: str | None = Field(default=None, max_length=255)
"""存储桶名称S3必填"""
is_private: bool = True
"""是否为私有空间"""
base_url: str | None = Field(default=None, max_length=255)
"""访问文件的基础URL"""
access_key: str | None = None
"""Access Key"""
secret_key: str | None = None
"""Secret Key"""
max_size: int = Field(default=0, ge=0)
"""允许上传的最大文件尺寸字节0表示不限制"""
auto_rename: bool = False
"""是否自动重命名"""
dir_name_rule: str | None = Field(default=None, max_length=255)
"""目录命名规则"""
file_name_rule: str | None = Field(default=None, max_length=255)
"""文件命名规则"""
is_origin_link_enable: bool = False
"""是否开启源链接访问"""
# 管理员根目录 /api/admin # 管理员根目录 /api/admin
admin_router = APIRouter( admin_router = APIRouter(
@@ -464,11 +510,72 @@ def router_policy_test_slave() -> ResponseBase:
@admin_policy_router.post( @admin_policy_router.post(
path='/', path='/',
summary='创建存储策略', summary='创建存储策略',
description='', description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
dependencies=[Depends(AdminRequired)] dependencies=[Depends(AdminRequired)]
) )
def router_policy_add_policy() -> ResponseBase: async def router_policy_add_policy(
pass session: SessionDep,
request: PolicyCreateRequest,
) -> ResponseBase:
"""
创建存储策略端点
功能:
- 创建新的存储策略配置
- 对于 LOCAL 类型,自动创建物理目录
认证:
- 需要管理员权限
:param session: 数据库会话
:param request: 创建请求
:return: 创建结果
"""
# 验证本地存储策略必须指定 server 路径
if request.type == PolicyType.LOCAL:
if not request.server:
raise HTTPException(status_code=400, detail="本地存储策略必须指定 server 路径")
# 检查策略名称是否已存在
existing = await Policy.get(session, Policy.name == request.name)
if existing:
raise HTTPException(status_code=409, detail="策略名称已存在")
# 创建策略对象
policy = Policy(
name=request.name,
type=request.type,
server=request.server,
bucket_name=request.bucket_name,
is_private=request.is_private,
base_url=request.base_url,
access_key=request.access_key,
secret_key=request.secret_key,
max_size=request.max_size,
auto_rename=request.auto_rename,
dir_name_rule=request.dir_name_rule,
file_name_rule=request.file_name_rule,
is_origin_link_enable=request.is_origin_link_enable,
)
# 对于本地存储策略,创建物理目录
if policy.type == PolicyType.LOCAL:
try:
storage_service = LocalStorageService(policy)
await storage_service.ensure_base_directory()
l.info(f"已为本地存储策略 '{policy.name}' 创建目录: {policy.server}")
except DirectoryCreationError as e:
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
# 保存到数据库
policy = await policy.save(session)
return ResponseBase(data={
"id": str(policy.id),
"name": policy.name,
"type": policy.type.value,
"server": policy.server,
})
@admin_policy_router.post( @admin_policy_router.post(
path='/cors', path='/cors',

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from fastapi.responses import PlainTextResponse, RedirectResponse from fastapi.responses import PlainTextResponse, RedirectResponse
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseBase from models import ResponseBase
import service.oauth import service.oauth
callback_router = APIRouter( callback_router = APIRouter(

View File

@@ -12,7 +12,7 @@ from models import (
ObjectType, ObjectType,
PolicyResponse, PolicyResponse,
User, User,
response, ResponseBase,
) )
directory_router = APIRouter( directory_router = APIRouter(
@@ -63,7 +63,6 @@ async def router_directory_get(
ObjectResponse( ObjectResponse(
id=child.id, id=child.id,
name=child.name, name=child.name,
path=f"/{child.name}", # TODO: 完整路径
thumb=False, thumb=False,
size=child.size, size=child.size,
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE, type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
@@ -97,7 +96,7 @@ async def router_directory_create(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(AuthRequired)], user: Annotated[User, Depends(AuthRequired)],
request: DirectoryCreateRequest request: DirectoryCreateRequest
) -> response.ResponseBase: ) -> ResponseBase:
""" """
创建目录 创建目录
@@ -111,6 +110,7 @@ async def router_directory_create(
if not name: if not name:
raise HTTPException(status_code=400, detail="目录名称不能为空") raise HTTPException(status_code=400, detail="目录名称不能为空")
# [TODO] 进一步验证名称合法性
if "/" in name or "\\" in name: if "/" in name or "\\" in name:
raise HTTPException(status_code=400, detail="目录名称不能包含斜杠") raise HTTPException(status_code=400, detail="目录名称不能包含斜杠")
@@ -146,7 +146,7 @@ async def router_directory_create(
new_folder_name = new_folder.name new_folder_name = new_folder.name
await new_folder.save(session) await new_folder.save(session)
return response.ResponseBase( return ResponseBase(
data={ data={
"id": new_folder_id, "id": new_folder_id,
"name": new_folder_name, "name": new_folder_name,

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseBase from models import ResponseBase
download_router = APIRouter( download_router = APIRouter(
prefix="/download", prefix="/download",

View File

@@ -1,7 +1,37 @@
from fastapi import APIRouter, Depends, UploadFile """
文件操作路由
提供文件上传、下载、创建等核心功能。
路由前缀:
- /file - 文件操作
- /file/upload - 上传相关操作
"""
from datetime import datetime, timedelta
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from middleware.auth import SignRequired from loguru import logger as l
from models.response import ResponseBase
from middleware.auth import AuthRequired, SignRequired
from middleware.dependencies import SessionDep
from models import (
CreateFileRequest,
CreateUploadSessionRequest,
Object,
ObjectType,
PhysicalFile,
Policy,
PolicyType,
UploadChunkResponse,
UploadSession,
UploadSessionResponse,
User,
)
from models import ResponseBase
from service.storage import LocalStorageService, StorageFileNotFoundError
file_router = APIRouter( file_router = APIRouter(
prefix="/file", prefix="/file",
@@ -13,370 +43,614 @@ file_upload_router = APIRouter(
tags=["file"] tags=["file"]
) )
@file_router.get(
path='/get/{id}/{name}',
summary='文件外链(直接输出文件数据)',
description='Get file external link endpoint.',
)
def router_file_get(id: str, name: str) -> FileResponse:
"""
Get file external link endpoint.
Args:
id (str): The ID of the file.
name (str): The name of the file.
Returns:
FileResponse: A response containing the file data.
"""
pass
@file_router.get( # ==================== 上传会话管理 ====================
path='/source/{id}/{name}',
summary='文件外链(301跳转)',
description='Get file external link with 301 redirect endpoint.',
)
def router_file_source(id: str, name: str) -> ResponseBase:
"""
Get file external link with 301 redirect endpoint.
Args:
id (str): The ID of the file.
name (str): The name of the file.
Returns:
ResponseBase: A model containing the response data for the file with a redirect.
"""
pass
@file_upload_router.get(
path='/download/{id}',
summary='下载文件',
description='Download file endpoint.',
)
def router_file_download(id: str) -> ResponseBase:
"""
Download file endpoint.
Args:
id (str): The ID of the file to download.
Returns:
ResponseBase: A model containing the response data for the file download.
"""
pass
@file_upload_router.get(
path='/archive/{sessionID}/archive.zip',
summary='打包并下载文件',
description='Archive and download files endpoint.',
)
def router_file_archive_download(sessionID: str) -> ResponseBase:
"""
Archive and download files endpoint.
Args:
sessionID (str): The session ID for the archive.
Returns:
ResponseBase: A model containing the response data for the archived files download.
"""
pass
@file_upload_router.post(
path='/{sessionID}/{index}',
summary='文件上传',
description='File upload endpoint.',
)
def router_file_upload(sessionID: str, index: int, file: UploadFile) -> ResponseBase:
"""
File upload endpoint.
Args:
sessionID (str): The session ID for the upload.
index (int): The index of the file being uploaded.
Returns:
ResponseBase: A model containing the response data.
"""
pass
@file_upload_router.put( @file_upload_router.put(
path='/', path='/',
summary='创建上传会话', summary='创建上传会话',
description='Create an upload session endpoint.', description='创建文件上传会话返回会话ID用于后续分片上传。',
dependencies=[Depends(SignRequired)],
) )
def router_file_upload_session() -> ResponseBase: async def create_upload_session(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
request: CreateUploadSessionRequest,
) -> UploadSessionResponse:
""" """
Create an upload session endpoint. 创建上传会话端点
Returns: 流程:
ResponseBase: A model containing the response data for the upload session. 1. 验证父目录存在且属于当前用户
2. 确定存储策略(使用指定的或继承父目录的)
3. 验证文件大小限制
4. 创建上传会话并生成存储路径
5. 返回会话信息
:param session: 数据库会话
:param user: 当前登录用户
:param request: 创建请求
:return: 上传会话信息
""" """
pass # 验证文件名
if not request.file_name or '/' in request.file_name or '\\' in request.file_name:
raise HTTPException(status_code=400, detail="无效的文件名")
# 验证父目录
parent = await Object.get(session, Object.id == request.parent_id)
if not parent or parent.owner_id != user.id:
raise HTTPException(status_code=404, detail="父目录不存在")
if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录")
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 验证文件大小限制
if policy.max_size > 0 and request.file_size > policy.max_size:
raise HTTPException(
status_code=400,
detail=f"文件大小超过限制 ({policy.max_size} bytes)"
)
# 检查是否已存在同名文件
existing = await Object.get(
session,
(Object.owner_id == user.id) &
(Object.parent_id == parent.id) &
(Object.name == request.file_name)
)
if existing:
raise HTTPException(status_code=409, detail="同名文件已存在")
# 计算分片信息
options = await policy.awaitable_attrs.options
chunk_size = options.chunk_size if options else 52428800 # 默认 50MB
total_chunks = max(1, (request.file_size + chunk_size - 1) // chunk_size) if request.file_size > 0 else 1
# 生成存储路径
storage_path: str | None = None
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path(
user_id=user.id,
original_filename=request.file_name,
)
storage_path = full_path
else:
# S3 后续实现
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 创建上传会话
upload_session = UploadSession(
file_name=request.file_name,
file_size=request.file_size,
chunk_size=chunk_size,
total_chunks=total_chunks,
storage_path=storage_path,
expires_at=datetime.now() + timedelta(hours=24), # 24小时过期
owner_id=user.id,
parent_id=request.parent_id,
policy_id=policy_id,
)
upload_session = await upload_session.save(session)
l.info(f"创建上传会话: {upload_session.id}, 文件: {request.file_name}, 大小: {request.file_size}")
return UploadSessionResponse(
id=upload_session.id,
file_name=upload_session.file_name,
file_size=upload_session.file_size,
chunk_size=upload_session.chunk_size,
total_chunks=upload_session.total_chunks,
uploaded_chunks=0,
expires_at=upload_session.expires_at,
)
@file_upload_router.post(
path='/{session_id}/{chunk_index}',
summary='上传文件分片',
description='上传指定分片分片索引从0开始。',
)
async def upload_chunk(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
session_id: UUID,
chunk_index: int,
file: UploadFile = File(...),
) -> UploadChunkResponse:
"""
上传文件分片端点
流程:
1. 验证上传会话
2. 写入分片数据
3. 更新会话进度
4. 如果所有分片上传完成,创建 Object 记录
:param session: 数据库会话
:param user: 当前登录用户
:param session_id: 上传会话UUID
:param chunk_index: 分片索引从0开始
:param file: 上传的文件分片
:return: 上传进度信息
"""
# 获取上传会话
upload_session = await UploadSession.get(session, UploadSession.id == session_id)
if not upload_session or upload_session.owner_id != user.id:
raise HTTPException(status_code=404, detail="上传会话不存在")
if upload_session.is_expired:
raise HTTPException(status_code=400, detail="上传会话已过期")
# 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id
if chunk_index < 0 or chunk_index >= upload_session.total_chunks:
raise HTTPException(status_code=400, detail="无效的分片索引")
# 获取策略
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if not policy:
raise HTTPException(status_code=500, detail="存储策略不存在")
# 读取分片内容
content = await file.read()
# 写入分片
if policy.type == PolicyType.LOCAL:
if not upload_session.storage_path:
raise HTTPException(status_code=500, detail="存储路径丢失")
storage_service = LocalStorageService(policy)
offset = chunk_index * upload_session.chunk_size
await storage_service.write_file_chunk(
upload_session.storage_path,
content,
offset,
)
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 更新会话进度
upload_session.uploaded_chunks += 1
upload_session.uploaded_size += len(content)
upload_session = await upload_session.save(session)
# 检查是否完成
is_complete = upload_session.is_complete
file_object_id: UUID | None = None
if is_complete:
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=upload_session.storage_path,
size=upload_session.uploaded_size,
policy_id=upload_session.policy_id,
reference_count=1,
)
physical_file = await physical_file.save(session)
# 创建 Object 记录
file_object = Object(
name=upload_session.file_name,
type=ObjectType.FILE,
size=upload_session.uploaded_size,
physical_file_id=physical_file.id,
upload_session_id=str(upload_session.id),
parent_id=upload_session.parent_id,
owner_id=user_id,
policy_id=upload_session.policy_id,
)
file_object = await file_object.save(session)
file_object_id = file_object.id
# 删除上传会话
await UploadSession.delete(session, upload_session)
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
return UploadChunkResponse(
uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks,
total_chunks=upload_session.total_chunks,
is_complete=is_complete,
object_id=file_object_id,
)
@file_upload_router.delete( @file_upload_router.delete(
path='/{sessionID}', path='/{session_id}',
summary='删除上传会话', summary='删除上传会话',
description='Delete an upload session endpoint.', description='取消上传并删除会话及已上传的临时文件。',
dependencies=[Depends(SignRequired)]
) )
def router_file_upload_session_delete(sessionID: str) -> ResponseBase: async def delete_upload_session(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
session_id: UUID,
) -> ResponseBase:
""" """
Delete an upload session endpoint. 删除上传会话端点
Args: :param session: 数据库会话
sessionID (str): The session ID to delete. :param user: 当前登录用户
:param session_id: 上传会话UUID
Returns: :return: 删除结果
ResponseBase: A model containing the response data for the deletion.
""" """
pass upload_session = await UploadSession.get(session, UploadSession.id == session_id)
if not upload_session or upload_session.owner_id != user.id:
raise HTTPException(status_code=404, detail="上传会话不存在")
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
# 删除会话记录
await UploadSession.delete(session, upload_session)
l.info(f"删除上传会话: {session_id}")
return ResponseBase(data={"deleted": True})
@file_upload_router.delete( @file_upload_router.delete(
path='/', path='/',
summary='清除所有上传会话', summary='清除所有上传会话',
description='Clear all upload sessions endpoint.', description='清除当前用户的所有上传会话。',
dependencies=[Depends(SignRequired)]
) )
def router_file_upload_session_clear() -> ResponseBase: async def clear_upload_sessions(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
) -> ResponseBase:
""" """
Clear all upload sessions endpoint. 清除所有上传会话端点
Returns:
ResponseBase: A model containing the response data for clearing all sessions.
"""
pass
@file_router.put( :param session: 数据库会话
path='/update/{id}', :param user: 当前登录用户
summary='更新文件', :return: 清除结果
description='Update file information endpoint.', """
dependencies=[Depends(SignRequired)] # 获取所有会话
sessions = await UploadSession.get(
session,
UploadSession.owner_id == user.id,
fetch_mode="all"
)
deleted_count = 0
for upload_session in sessions:
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
await UploadSession.delete(session, upload_session)
deleted_count += 1
l.info(f"清除用户 {user.id} 的所有上传会话,共 {deleted_count}")
return ResponseBase(data={"deleted": deleted_count})
# ==================== 文件下载 ====================
@file_upload_router.get(
path='/download/{file_id}',
summary='下载文件',
description='下载指定文件。',
) )
def router_file_update(id: str) -> ResponseBase: async def download_file(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
file_id: UUID,
) -> FileResponse:
""" """
Update file information endpoint. 下载文件端点
Args: :param session: 数据库会话
id (str): The ID of the file to update. :param user: 当前登录用户
:param file_id: 文件UUID
Returns: :return: 文件响应
ResponseBase: A model containing the response data for the file update.
""" """
pass file_obj = await Object.get(session, Object.id == file_id)
if not file_obj or file_obj.owner_id != user.id:
raise HTTPException(status_code=404, detail="文件不存在")
if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件")
if not file_obj.source_name:
raise HTTPException(status_code=500, detail="文件存储路径丢失")
# 获取策略
policy = await Policy.get(session, Policy.id == file_obj.policy_id)
if not policy:
raise HTTPException(status_code=500, detail="存储策略不存在")
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(file_obj.source_name):
raise HTTPException(status_code=404, detail="物理文件不存在")
return FileResponse(
path=file_obj.source_name,
filename=file_obj.name,
media_type="application/octet-stream",
)
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# ==================== 创建空白文件 ====================
@file_router.post( @file_router.post(
path='/create', path='/create',
summary='创建空白文件', summary='创建空白文件',
description='Create a blank file endpoint.', description='在指定目录下创建空白文件。',
dependencies=[Depends(SignRequired)]
) )
def router_file_create() -> ResponseBase: async def create_empty_file(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
request: CreateFileRequest,
) -> ResponseBase:
""" """
Create a blank file endpoint. 创建空白文件端点
Returns: :param session: 数据库会话
ResponseBase: A model containing the response data for the file creation. :param user: 当前登录用户
:param request: 创建请求
:return: 创建结果
""" """
pass # 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id
# 验证文件名
if not request.name or '/' in request.name or '\\' in request.name:
raise HTTPException(status_code=400, detail="无效的文件名")
# 验证父目录
parent = await Object.get(session, Object.id == request.parent_id)
if not parent or parent.owner_id != user_id:
raise HTTPException(status_code=404, detail="父目录不存在")
if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录")
# 检查是否已存在同名文件
existing = await Object.get(
session,
(Object.owner_id == user_id) &
(Object.parent_id == parent.id) &
(Object.name == request.name)
)
if existing:
raise HTTPException(status_code=409, detail="同名文件已存在")
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 生成存储路径并创建空文件
storage_path: str | None = None
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path(
user_id=user_id,
original_filename=request.name,
)
await storage_service.create_empty_file(full_path)
storage_path = full_path
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=storage_path,
size=0,
policy_id=policy_id,
reference_count=1,
)
physical_file = await physical_file.save(session)
# 创建 Object 记录
file_object = Object(
name=request.name,
type=ObjectType.FILE,
size=0,
physical_file_id=physical_file.id,
parent_id=request.parent_id,
owner_id=user_id,
policy_id=policy_id,
)
file_object = await file_object.save(session)
l.info(f"创建空白文件: {file_object.name}, id={file_object.id}")
return ResponseBase(data={
"id": str(file_object.id),
"name": file_object.name,
"size": file_object.size,
})
# ==================== 文件外链(保留原有端点结构) ====================
@file_router.get(
path='/get/{id}/{name}',
summary='文件外链(直接输出文件数据)',
description='通过外链直接获取文件内容。',
)
async def router_file_get(
session: SessionDep,
id: str,
name: str,
) -> FileResponse:
"""
文件外链端点(直接输出)
TODO: 实现签名验证和权限控制
"""
raise HTTPException(status_code=501, detail="外链功能暂未实现")
@file_router.get(
path='/source/{id}/{name}',
summary='文件外链(301跳转)',
description='通过外链获取文件重定向地址。',
)
async def router_file_source_redirect(id: str, name: str) -> ResponseBase:
"""
文件外链端点301跳转
TODO: 实现签名验证和重定向
"""
raise HTTPException(status_code=501, detail="外链功能暂未实现")
@file_router.put( @file_router.put(
path='/download/{id}', path='/update/{id}',
summary='创建文件下载会话', summary='更新文件',
description='Create a file download session endpoint.', description='更新文件内容。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_download(id: str) -> ResponseBase: async def router_file_update(id: str) -> ResponseBase:
""" """更新文件内容"""
Create a file download session endpoint. raise HTTPException(status_code=501, detail="更新文件功能暂未实现")
Args:
id (str): The ID of the file to download.
Returns:
ResponseBase: A model containing the response data for the file download session.
"""
pass
@file_router.get( @file_router.get(
path='/preview/{id}', path='/preview/{id}',
summary='预览文件', summary='预览文件',
description='Preview file endpoint.', description='获取文件预览。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_preview(id: str) -> ResponseBase: async def router_file_preview(id: str) -> ResponseBase:
""" """预览文件"""
Preview file endpoint. raise HTTPException(status_code=501, detail="预览功能暂未实现")
Args:
id (str): The ID of the file to preview.
Returns:
ResponseBase: A model containing the response data for the file preview.
"""
pass
@file_router.get( @file_router.get(
path='/content/{id}', path='/content/{id}',
summary='获取文本文件内容', summary='获取文本文件内容',
description='Get text file content endpoint.', description='获取文本文件内容。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_content(id: str) -> ResponseBase: async def router_file_content(id: str) -> ResponseBase:
""" """获取文本文件内容"""
Get text file content endpoint. raise HTTPException(status_code=501, detail="文本内容功能暂未实现")
Args:
id (str): The ID of the text file.
Returns:
ResponseBase: A model containing the response data for the text file content.
"""
pass
@file_router.get( @file_router.get(
path='/doc/{id}', path='/doc/{id}',
summary='获取Office文档预览地址', summary='获取Office文档预览地址',
description='Get Office document preview URL endpoint.', description='获取Office文档在线预览地址。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_doc(id: str) -> ResponseBase: async def router_file_doc(id: str) -> ResponseBase:
""" """获取Office文档预览地址"""
Get Office document preview URL endpoint. raise HTTPException(status_code=501, detail="Office预览功能暂未实现")
Args:
id (str): The ID of the Office document.
Returns:
ResponseBase: A model containing the response data for the Office document preview URL.
"""
pass
@file_router.get( @file_router.get(
path='/thumb/{id}', path='/thumb/{id}',
summary='获取文件缩略图', summary='获取文件缩略图',
description='Get file thumbnail endpoint.', description='获取文件缩略图。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_thumb(id: str) -> ResponseBase: async def router_file_thumb(id: str) -> ResponseBase:
""" """获取文件缩略图"""
Get file thumbnail endpoint. raise HTTPException(status_code=501, detail="缩略图功能暂未实现")
Args:
id (str): The ID of the file to get the thumbnail for.
Returns:
ResponseBase: A model containing the response data for the file thumbnail.
"""
pass
@file_router.post( @file_router.post(
path='/source/{id}', path='/source/{id}',
summary='取得文件外链', summary='取得文件外链',
description='Get file external link endpoint.', description='获取文件的外链地址。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_source(id: str) -> ResponseBase: async def router_file_source(id: str) -> ResponseBase:
""" """获取文件外链"""
Get file external link endpoint. raise HTTPException(status_code=501, detail="外链功能暂未实现")
Args:
id (str): The ID of the file to get the external link for.
Returns:
ResponseBase: A model containing the response data for the file external link.
"""
pass
@file_router.post( @file_router.post(
path='/archive', path='/archive',
summary='打包要下载的文件', summary='打包要下载的文件',
description='Archive files for download endpoint.', description='将多个文件打包下载。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_archive(id: str) -> ResponseBase: async def router_file_archive() -> ResponseBase:
""" """打包文件"""
Archive files for download endpoint. raise HTTPException(status_code=501, detail="打包功能暂未实现")
Args:
id (str): The ID of the file to archive.
Returns:
ResponseBase: A model containing the response data for the archived files.
"""
pass
@file_router.post( @file_router.post(
path='/compress', path='/compress',
summary='创建文件压缩任务', summary='创建文件压缩任务',
description='Create file compression task endpoint.', description='创建文件压缩任务。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_compress(id: str) -> ResponseBase: async def router_file_compress() -> ResponseBase:
""" """创建压缩任务"""
Create file compression task endpoint. raise HTTPException(status_code=501, detail="压缩功能暂未实现")
Args:
id (str): The ID of the file to compress.
Returns:
ResponseBase: A model containing the response data for the file compression task.
"""
pass
@file_router.post( @file_router.post(
path='/decompress', path='/decompress',
summary='创建文件解压任务', summary='创建文件解压任务',
description='Create file extraction task endpoint.', description='创建文件解压任务。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_decompress(id: str) -> ResponseBase: async def router_file_decompress() -> ResponseBase:
""" """创建解压任务"""
Create file extraction task endpoint. raise HTTPException(status_code=501, detail="解压功能暂未实现")
Args:
id (str): The ID of the file to decompress.
Returns:
ResponseBase: A model containing the response data for the file extraction task.
"""
pass
@file_router.post( @file_router.post(
path='/relocate', path='/relocate',
summary='创建文件转移任务', summary='创建文件转移任务',
description='Create file relocation task endpoint.', description='创建文件转移任务。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_relocate(id: str) -> ResponseBase: async def router_file_relocate() -> ResponseBase:
""" """创建转移任务"""
Create file relocation task endpoint. raise HTTPException(status_code=501, detail="转移功能暂未实现")
Args:
id (str): The ID of the file to relocate.
Returns:
ResponseBase: A model containing the response data for the file relocation task.
"""
pass
@file_router.get( @file_router.get(
path='/search/{type}/{keyword}', path='/search/{type}/{keyword}',
summary='搜索文件', summary='搜索文件',
description='Search files by keyword endpoint.', description='按关键字搜索文件。',
dependencies=[Depends(SignRequired)] dependencies=[Depends(AuthRequired)]
) )
def router_file_search(type: str, keyword: str) -> ResponseBase: async def router_file_search(type: str, keyword: str) -> ResponseBase:
""" """搜索文件"""
Search files by keyword endpoint. raise HTTPException(status_code=501, detail="搜索功能暂未实现")
Args:
type (str): The type of search (e.g., 'name', 'content'). @file_upload_router.get(
keyword (str): The keyword to search for. path='/archive/{sessionID}/archive.zip',
summary='打包并下载文件',
Returns: description='获取打包后的文件。',
ResponseBase: A model containing the response data for the file search. )
""" async def router_file_archive_download(sessionID: str) -> ResponseBase:
pass """打包下载"""
raise HTTPException(status_code=501, detail="打包下载功能暂未实现")
@file_router.put(
path='/download/{id}',
summary='创建文件下载会话',
description='创建文件下载会话。',
dependencies=[Depends(AuthRequired)]
)
async def router_file_download_session(id: str) -> ResponseBase:
"""创建下载会话"""
raise HTTPException(status_code=501, detail="下载会话功能暂未实现")

View File

@@ -1,11 +1,35 @@
"""
对象操作路由
提供文件和目录对象的管理功能:删除、移动、复制、重命名等。
路由前缀:/object
"""
from typing import Annotated from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import AuthRequired from middleware.auth import AuthRequired
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import Object, ObjectDeleteRequest, ObjectMoveRequest, User from models import (
from models.response import ResponseBase Object,
ObjectCopyRequest,
ObjectDeleteRequest,
ObjectMoveRequest,
ObjectPropertyDetailResponse,
ObjectPropertyResponse,
ObjectRenameRequest,
ObjectType,
PhysicalFile,
Policy,
PolicyType,
User,
)
from models import ResponseBase
from service.storage import LocalStorageService
object_router = APIRouter( object_router = APIRouter(
prefix="/object", prefix="/object",
@@ -13,10 +37,137 @@ object_router = APIRouter(
) )
async def _delete_object_recursive(
session: AsyncSession,
obj: Object,
user_id: UUID,
) -> int:
"""
递归删除对象(软删除)
对于文件:
- 减少 PhysicalFile 引用计数
- 只有引用计数为0时才移动物理文件到回收站
对于目录:
- 递归处理所有子对象
:param session: 数据库会话
:param obj: 要删除的对象
:param user_id: 用户UUID
:return: 删除的对象数量
"""
deleted_count = 0
if obj.is_folder:
# 递归删除子对象
children = await Object.get_children(session, user_id, obj.id)
for child in children:
deleted_count += await _delete_object_recursive(session, child, user_id)
# 如果是文件,处理物理文件引用
if obj.is_file and obj.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
if physical_file:
# 减少引用计数
new_count = physical_file.decrement_reference()
if physical_file.can_be_deleted:
# 引用计数为0移动物理文件到回收站
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL:
try:
storage_service = LocalStorageService(policy)
await storage_service.move_to_trash(
source_path=physical_file.storage_path,
user_id=user_id,
object_id=obj.id,
)
l.debug(f"物理文件已移动到回收站: {obj.name}")
except Exception as e:
l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}")
# 删除 PhysicalFile 记录
await PhysicalFile.delete(session, physical_file)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else:
# 还有其他引用,只更新引用计数
await physical_file.save(session)
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
# 删除数据库记录
await Object.delete(session, obj)
deleted_count += 1
return deleted_count
async def _copy_object_recursive(
session: AsyncSession,
src: Object,
dst_parent_id: UUID,
user_id: UUID,
) -> tuple[int, list[UUID]]:
"""
递归复制对象
对于文件:
- 增加 PhysicalFile 引用计数
- 创建新的 Object 记录指向同一 PhysicalFile
对于目录:
- 创建新目录
- 递归复制所有子对象
:param session: 数据库会话
:param src: 源对象
:param dst_parent_id: 目标父目录UUID
:param user_id: 用户UUID
:return: (复制数量, 新对象UUID列表)
"""
copied_count = 0
new_ids: list[UUID] = []
# 创建新的 Object 记录
new_obj = Object(
name=src.name,
type=src.type,
size=src.size,
password=src.password,
parent_id=dst_parent_id,
owner_id=user_id,
policy_id=src.policy_id,
physical_file_id=src.physical_file_id,
)
# 如果是文件,增加物理文件引用计数
if src.is_file and src.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src.physical_file_id)
if physical_file:
physical_file.increment_reference()
await physical_file.save(session)
new_obj = await new_obj.save(session)
copied_count += 1
new_ids.append(new_obj.id)
# 如果是目录,递归复制子对象
if src.is_folder:
children = await Object.get_children(session, user_id, src.id)
for child in children:
child_count, child_ids = await _copy_object_recursive(
session, child, new_obj.id, user_id
)
copied_count += child_count
new_ids.extend(child_ids)
return copied_count, new_ids
@object_router.delete( @object_router.delete(
path='/', path='/',
summary='删除对象', summary='删除对象',
description='删除一个或多个对象(文件或目录)', description='删除一个或多个对象(文件或目录),文件会移动到用户回收站。',
) )
async def router_object_delete( async def router_object_delete(
session: SessionDep, session: SessionDep,
@@ -24,22 +175,39 @@ async def router_object_delete(
request: ObjectDeleteRequest, request: ObjectDeleteRequest,
) -> ResponseBase: ) -> ResponseBase:
""" """
删除对象端点 删除对象端点(软删除)
流程:
1. 验证对象存在且属于当前用户
2. 对于文件,减少物理文件引用计数
3. 如果引用计数为0移动物理文件到 .trash 目录
4. 对于目录,递归处理子对象
5. 从数据库中删除记录
:param session: 数据库会话 :param session: 数据库会话
:param user: 当前登录用户 :param user: 当前登录用户
:param request: 删除请求包含待删除对象的UUID列表 :param request: 删除请求包含待删除对象的UUID列表
:return: 删除结果 :return: 删除结果
""" """
# 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id
deleted_count = 0 deleted_count = 0
for obj_id in request.ids: for obj_id in request.ids:
obj = await Object.get(session, Object.id == obj_id) obj = await Object.get(session, Object.id == obj_id)
if obj and obj.owner_id == user.id: if not obj or obj.owner_id != user_id:
# TODO: 递归删除子对象(如果是目录) continue
# TODO: 更新用户存储空间
await obj.delete(session) # 不能删除根目录
deleted_count += 1 if obj.parent_id is None:
l.warning(f"尝试删除根目录被阻止: {obj.name}")
continue
# 递归删除(包含引用计数逻辑)
count = await _delete_object_recursive(session, obj, user_id)
deleted_count += count
l.info(f"用户 {user_id} 删除了 {deleted_count} 个对象")
return ResponseBase( return ResponseBase(
data={ data={
@@ -67,9 +235,12 @@ async def router_object_move(
:param request: 移动请求包含源对象UUID列表和目标目录UUID :param request: 移动请求包含源对象UUID列表和目标目录UUID
:return: 移动结果 :return: 移动结果
""" """
# 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id
# 验证目标目录 # 验证目标目录
dst = await Object.get(session, Object.id == request.dst_id) dst = await Object.get(session, Object.id == request.dst_id)
if not dst or dst.owner_id != user.id: if not dst or dst.owner_id != user_id:
raise HTTPException(status_code=404, detail="目标目录不存在") raise HTTPException(status_code=404, detail="目标目录不存在")
if not dst.is_folder: if not dst.is_folder:
@@ -79,17 +250,33 @@ async def router_object_move(
for src_id in request.src_ids: for src_id in request.src_ids:
src = await Object.get(session, Object.id == src_id) src = await Object.get(session, Object.id == src_id)
if not src or src.owner_id != user.id: if not src or src.owner_id != user_id:
continue
# 不能移动根目录
if src.parent_id is None:
continue continue
# 检查是否移动到自身或子目录(防止循环引用) # 检查是否移动到自身或子目录(防止循环引用)
if src.id == dst.id: if src.id == dst.id:
continue continue
# 检查是否将目录移动到其子目录中(循环检测)
if src.is_folder:
current = dst
is_cycle = False
while current and current.parent_id:
if current.parent_id == src.id:
is_cycle = True
break
current = await Object.get(session, Object.id == current.parent_id)
if is_cycle:
continue
# 检查目标目录下是否存在同名对象 # 检查目标目录下是否存在同名对象
existing = await Object.get( existing = await Object.get(
session, session,
(Object.owner_id == user.id) & (Object.owner_id == user_id) &
(Object.parent_id == dst.id) & (Object.parent_id == dst.id) &
(Object.name == src.name) (Object.name == src.name)
) )
@@ -107,50 +294,279 @@ async def router_object_move(
} }
) )
@object_router.post( @object_router.post(
path='/copy', path='/copy',
summary='复制对象', summary='复制对象',
description='Copy an object endpoint.', description='复制一个或多个对象到目标目录。文件复制仅增加物理文件引用计数,不复制物理文件。',
dependencies=[Depends(AuthRequired)]
) )
def router_object_copy() -> ResponseBase: async def router_object_copy(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
request: ObjectCopyRequest,
) -> ResponseBase:
""" """
Copy an object endpoint. 复制对象端点
Returns: 流程:
ResponseBase: A model containing the response data for the object copy. 1. 验证目标目录存在且属于当前用户
2. 对于每个源对象:
- 验证源对象存在且属于当前用户
- 检查目标目录下是否存在同名对象
- 如果是文件:增加 PhysicalFile 引用计数,创建新 Object
- 如果是目录:递归复制所有子对象
3. 返回复制结果
:param session: 数据库会话
:param user: 当前登录用户
:param request: 复制请求
:return: 复制结果
""" """
pass # 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id
# 验证目标目录
dst = await Object.get(session, Object.id == request.dst_id)
if not dst or dst.owner_id != user_id:
raise HTTPException(status_code=404, detail="目标目录不存在")
if not dst.is_folder:
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
copied_count = 0
new_ids: list[UUID] = []
for src_id in request.src_ids:
src = await Object.get(session, Object.id == src_id)
if not src or src.owner_id != user_id:
continue
# 不能复制根目录
if src.parent_id is None:
continue
# 不能复制到自身
if src.id == dst.id:
continue
# 不能将目录复制到其子目录中
if src.is_folder:
current = dst
is_cycle = False
while current and current.parent_id:
if current.parent_id == src.id:
is_cycle = True
break
current = await Object.get(session, Object.id == current.parent_id)
if is_cycle:
continue
# 检查目标目录下是否存在同名对象
existing = await Object.get(
session,
(Object.owner_id == user_id) &
(Object.parent_id == dst.id) &
(Object.name == src.name)
)
if existing:
continue # 跳过重名对象
# 递归复制
count, ids = await _copy_object_recursive(session, src, dst.id, user_id)
copied_count += count
new_ids.extend(ids)
l.info(f"用户 {user_id} 复制了 {copied_count} 个对象")
return ResponseBase(
data={
"copied": copied_count,
"total": len(request.src_ids),
"new_ids": new_ids,
}
)
@object_router.post( @object_router.post(
path='/rename', path='/rename',
summary='重命名对象', summary='重命名对象',
description='Rename an object endpoint.', description='重命名对象(文件或目录)。',
dependencies=[Depends(AuthRequired)]
) )
def router_object_rename() -> ResponseBase: async def router_object_rename(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
request: ObjectRenameRequest,
) -> ResponseBase:
""" """
Rename an object endpoint. 重命名对象端点
Returns: 流程:
ResponseBase: A model containing the response data for the object rename. 1. 验证对象存在且属于当前用户
2. 验证新名称格式(不含非法字符)
3. 检查同目录下是否存在同名对象
4. 更新 name 字段
5. 返回更新结果
:param session: 数据库会话
:param user: 当前登录用户
:param request: 重命名请求
:return: 重命名结果
""" """
pass # 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id
# 验证对象存在
obj = await Object.get(session, Object.id == request.id)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user_id:
raise HTTPException(status_code=403, detail="无权操作此对象")
# 不能重命名根目录
if obj.parent_id is None:
raise HTTPException(status_code=400, detail="无法重命名根目录")
# 验证新名称格式
new_name = request.new_name.strip()
if not new_name:
raise HTTPException(status_code=400, detail="名称不能为空")
if '/' in new_name or '\\' in new_name:
raise HTTPException(status_code=400, detail="名称不能包含斜杠")
# 如果名称没有变化,直接返回成功
if obj.name == new_name:
return ResponseBase(data={"success": True})
# 检查同目录下是否存在同名对象
existing = await Object.get(
session,
(Object.owner_id == user_id) &
(Object.parent_id == obj.parent_id) &
(Object.name == new_name)
)
if existing:
raise HTTPException(status_code=409, detail="同名对象已存在")
# 更新名称
obj.name = new_name
await obj.save(session)
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
return ResponseBase(data={"success": True})
@object_router.get( @object_router.get(
path='/property/{id}', path='/property/{id}',
summary='获取对象属性', summary='获取对象基本属性',
description='Get object properties endpoint.', description='获取对象的基本属性信息(名称、类型、大小、创建/修改时间等)。',
dependencies=[Depends(AuthRequired)]
) )
def router_object_property(id: str) -> ResponseBase: async def router_object_property(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
id: UUID,
) -> ObjectPropertyResponse:
""" """
Get object properties endpoint. 获取对象基本属性端点
Args: :param session: 数据库会话
id (str): The ID of the object to retrieve properties for. :param user: 当前登录用户
:param id: 对象UUID
Returns: :return: 对象基本属性
ResponseBase: A model containing the response data for the object properties.
""" """
pass obj = await Object.get(session, Object.id == id)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权查看此对象")
return ObjectPropertyResponse(
id=obj.id,
name=obj.name,
type=obj.type,
size=obj.size,
created_at=obj.created_at,
updated_at=obj.updated_at,
parent_id=obj.parent_id,
)
@object_router.get(
path='/property/{id}/detail',
summary='获取对象详细属性',
description='获取对象的详细属性信息,包括元数据、分享统计、存储信息等。',
)
async def router_object_property_detail(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
id: UUID,
) -> ObjectPropertyDetailResponse:
"""
获取对象详细属性端点
:param session: 数据库会话
:param user: 当前登录用户
:param id: 对象UUID
:return: 对象详细属性
"""
obj = await Object.get(
session,
Object.id == id,
load=Object.file_metadata,
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权查看此对象")
# 获取策略名称
policy = await Policy.get(session, Policy.id == obj.policy_id)
policy_name = policy.name if policy else None
# 获取分享统计
from models import Share
shares = await Share.get(
session,
Share.object_id == obj.id,
fetch_mode="all"
)
share_count = len(shares)
total_views = sum(s.views for s in shares)
total_downloads = sum(s.downloads for s in shares)
# 获取物理文件引用计数
reference_count = 1
if obj.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
if physical_file:
reference_count = physical_file.reference_count
# 构建响应
response = ObjectPropertyDetailResponse(
id=obj.id,
name=obj.name,
type=obj.type,
size=obj.size,
created_at=obj.created_at,
updated_at=obj.updated_at,
parent_id=obj.parent_id,
policy_name=policy_name,
share_count=share_count,
total_views=total_views,
total_downloads=total_downloads,
reference_count=reference_count,
)
# 添加文件元数据
if obj.file_metadata:
response.mime_type = obj.file_metadata.mime_type
response.width = obj.file_metadata.width
response.height = obj.file_metadata.height
response.duration = obj.file_metadata.duration
response.checksum_md5 = obj.file_metadata.checksum_md5
return response

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseBase from models import ResponseBase
share_router = APIRouter( share_router = APIRouter(
prefix='/share', prefix='/share',

View File

@@ -3,7 +3,7 @@ from sqlalchemy import and_
import json import json
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models.response import ResponseBase from models import ResponseBase
from models.setting import Setting from models.setting import Setting
site_router = APIRouter( site_router = APIRouter(

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseBase from models import ResponseBase
slave_router = APIRouter( slave_router = APIRouter(
prefix="/slave", prefix="/slave",

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseBase from models import ResponseBase
tag_router = APIRouter( tag_router = APIRouter(
prefix='/tag', prefix='/tag',

View File

@@ -96,7 +96,7 @@ async def router_user_session(
async def router_user_register( async def router_user_register(
session: SessionDep, session: SessionDep,
request: models.RegisterRequest, request: models.RegisterRequest,
) -> models.response.ResponseBase: ) -> models.ResponseBase:
""" """
用户注册端点 用户注册端点
@@ -157,7 +157,7 @@ async def router_user_register(
policy_id=default_policy.id, policy_id=default_policy.id,
).save(session) ).save(session)
return models.response.ResponseBase( return models.ResponseBase(
data={ data={
"user_id": new_user_id, "user_id": new_user_id,
"username": new_user_username, "username": new_user_username,
@@ -172,7 +172,7 @@ async def router_user_register(
) )
def router_user_email_code( def router_user_email_code(
reason: Literal['register', 'reset'] = 'register', reason: Literal['register', 'reset'] = 'register',
) -> models.response.ResponseBase: ) -> models.ResponseBase:
""" """
Send a verification code email. Send a verification code email.
@@ -186,7 +186,7 @@ def router_user_email_code(
summary='初始化QQ登录', summary='初始化QQ登录',
description='Initialize QQ login for a user.', description='Initialize QQ login for a user.',
) )
def router_user_qq() -> models.response.ResponseBase: def router_user_qq() -> models.ResponseBase:
""" """
Initialize QQ login for a user. Initialize QQ login for a user.
@@ -200,7 +200,7 @@ def router_user_qq() -> models.response.ResponseBase:
summary='WebAuthn登录初始化', summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.', description='Initialize WebAuthn login for a user.',
) )
async def router_user_authn(username: str) -> models.response.ResponseBase: async def router_user_authn(username: str) -> models.ResponseBase:
pass pass
@@ -209,7 +209,7 @@ async def router_user_authn(username: str) -> models.response.ResponseBase:
summary='WebAuthn登录', summary='WebAuthn登录',
description='Finish WebAuthn login for a user.', description='Finish WebAuthn login for a user.',
) )
def router_user_authn_finish(username: str) -> models.response.ResponseBase: def router_user_authn_finish(username: str) -> models.ResponseBase:
""" """
Finish WebAuthn login for a user. Finish WebAuthn login for a user.
@@ -226,7 +226,7 @@ def router_user_authn_finish(username: str) -> models.response.ResponseBase:
summary='获取用户主页展示用分享', summary='获取用户主页展示用分享',
description='Get user profile for display.', description='Get user profile for display.',
) )
def router_user_profile(id: str) -> models.response.ResponseBase: def router_user_profile(id: str) -> models.ResponseBase:
""" """
Get user profile for display. Get user profile for display.
@@ -243,7 +243,7 @@ def router_user_profile(id: str) -> models.response.ResponseBase:
summary='获取用户头像', summary='获取用户头像',
description='Get user avatar by ID and size.', description='Get user avatar by ID and size.',
) )
def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseBase: def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
""" """
Get user avatar by ID and size. Get user avatar by ID and size.
@@ -265,17 +265,17 @@ def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseBase
summary='获取用户信息', summary='获取用户信息',
description='Get user information.', description='Get user information.',
dependencies=[Depends(dependency=AuthRequired)], dependencies=[Depends(dependency=AuthRequired)],
response_model=models.response.ResponseBase, response_model=models.ResponseBase,
) )
async def router_user_me( async def router_user_me(
session: SessionDep, session: SessionDep,
user: Annotated[models.User, Depends(AuthRequired)], user: Annotated[models.User, Depends(AuthRequired)],
) -> models.response.ResponseBase: ) -> models.ResponseBase:
""" """
获取用户信息. 获取用户信息.
:return: response.ResponseBase containing user information. :return: ResponseBase containing user information.
:rtype: response.ResponseBase :rtype: ResponseBase
""" """
# 加载 group 及其 options 关系 # 加载 group 及其 options 关系
group = await models.Group.get( group = await models.Group.get(
@@ -302,7 +302,7 @@ async def router_user_me(
tags=[tag.name for tag in user_tags] if user_tags else [], tags=[tag.name for tag in user_tags] if user_tags else [],
) )
return models.response.ResponseBase(data=user_response.model_dump()) return models.ResponseBase(data=user_response.model_dump())
@user_router.get( @user_router.get(
path='/storage', path='/storage',
@@ -313,7 +313,7 @@ async def router_user_me(
async def router_user_storage( async def router_user_storage(
session: SessionDep, session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)], user: Annotated[models.user.User, Depends(AuthRequired)],
) -> models.response.ResponseBase: ) -> models.ResponseBase:
""" """
获取用户存储空间信息。 获取用户存储空间信息。
@@ -330,7 +330,7 @@ async def router_user_storage(
used: int = user.storage used: int = user.storage
free: int = max(0, total - used) free: int = max(0, total - used)
return models.response.ResponseBase( return models.ResponseBase(
data={ data={
"used": used, "used": used,
"free": free, "free": free,
@@ -347,7 +347,7 @@ async def router_user_storage(
async def router_user_authn_start( async def router_user_authn_start(
session: SessionDep, session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)], user: Annotated[models.user.User, Depends(AuthRequired)],
) -> models.response.ResponseBase: ) -> models.ResponseBase:
""" """
Initialize WebAuthn login for a user. Initialize WebAuthn login for a user.
@@ -378,7 +378,7 @@ async def router_user_authn_start(
user_display_name=user.nick or user.username, user_display_name=user.nick or user.username,
) )
return models.response.ResponseBase(data=options_to_json_dict(options)) return models.ResponseBase(data=options_to_json_dict(options))
@user_router.put( @user_router.put(
path='/authn/finish', path='/authn/finish',
@@ -386,7 +386,7 @@ async def router_user_authn_start(
description='Finish WebAuthn login for a user.', description='Finish WebAuthn login for a user.',
dependencies=[Depends(AuthRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_authn_finish() -> models.response.ResponseBase: def router_user_authn_finish() -> models.ResponseBase:
""" """
Finish WebAuthn login for a user. Finish WebAuthn login for a user.
@@ -400,7 +400,7 @@ def router_user_authn_finish() -> models.response.ResponseBase:
summary='获取用户可选存储策略', summary='获取用户可选存储策略',
description='Get user selectable storage policies.', description='Get user selectable storage policies.',
) )
def router_user_settings_policies() -> models.response.ResponseBase: def router_user_settings_policies() -> models.ResponseBase:
""" """
Get user selectable storage policies. Get user selectable storage policies.
@@ -415,7 +415,7 @@ def router_user_settings_policies() -> models.response.ResponseBase:
description='Get user selectable nodes.', description='Get user selectable nodes.',
dependencies=[Depends(AuthRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_nodes() -> models.response.ResponseBase: def router_user_settings_nodes() -> models.ResponseBase:
""" """
Get user selectable nodes. Get user selectable nodes.
@@ -430,7 +430,7 @@ def router_user_settings_nodes() -> models.response.ResponseBase:
description='Get user task queue.', description='Get user task queue.',
dependencies=[Depends(AuthRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_tasks() -> models.response.ResponseBase: def router_user_settings_tasks() -> models.ResponseBase:
""" """
Get user task queue. Get user task queue.
@@ -445,14 +445,14 @@ def router_user_settings_tasks() -> models.response.ResponseBase:
description='Get current user settings.', description='Get current user settings.',
dependencies=[Depends(AuthRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings() -> models.response.ResponseBase: def router_user_settings() -> models.ResponseBase:
""" """
Get current user settings. Get current user settings.
Returns: Returns:
dict: A dictionary containing the current user settings. dict: A dictionary containing the current user settings.
""" """
return models.response.ResponseBase(data=models.UserSettingResponse().model_dump()) return models.ResponseBase(data=models.UserSettingResponse().model_dump())
@user_settings_router.post( @user_settings_router.post(
path='/avatar', path='/avatar',
@@ -460,7 +460,7 @@ def router_user_settings() -> models.response.ResponseBase:
description='Upload user avatar from file.', description='Upload user avatar from file.',
dependencies=[Depends(AuthRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_avatar() -> models.response.ResponseBase: def router_user_settings_avatar() -> models.ResponseBase:
""" """
Upload user avatar from file. Upload user avatar from file.
@@ -475,7 +475,7 @@ def router_user_settings_avatar() -> models.response.ResponseBase:
description='Set user avatar to Gravatar.', description='Set user avatar to Gravatar.',
dependencies=[Depends(AuthRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_avatar_gravatar() -> models.response.ResponseBase: def router_user_settings_avatar_gravatar() -> models.ResponseBase:
""" """
Set user avatar to Gravatar. Set user avatar to Gravatar.
@@ -490,7 +490,7 @@ def router_user_settings_avatar_gravatar() -> models.response.ResponseBase:
description='Update user settings.', description='Update user settings.',
dependencies=[Depends(AuthRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_patch(option: str) -> models.response.ResponseBase: def router_user_settings_patch(option: str) -> models.ResponseBase:
""" """
Update user settings. Update user settings.
@@ -510,7 +510,7 @@ def router_user_settings_patch(option: str) -> models.response.ResponseBase:
) )
async def router_user_settings_2fa( async def router_user_settings_2fa(
user: Annotated[models.user.User, Depends(AuthRequired)], user: Annotated[models.user.User, Depends(AuthRequired)],
) -> models.response.ResponseBase: ) -> models.ResponseBase:
""" """
Get two-factor authentication initialization information. Get two-factor authentication initialization information.
@@ -518,7 +518,7 @@ async def router_user_settings_2fa(
dict: A dictionary containing two-factor authentication setup information. dict: A dictionary containing two-factor authentication setup information.
""" """
return models.response.ResponseBase( return models.ResponseBase(
data=await Password.generate_totp(user.username) data=await Password.generate_totp(user.username)
) )
@@ -533,7 +533,7 @@ async def router_user_settings_2fa_enable(
user: Annotated[models.user.User, Depends(AuthRequired)], user: Annotated[models.user.User, Depends(AuthRequired)],
setup_token: str, setup_token: str,
code: str, code: str,
) -> models.response.ResponseBase: ) -> models.ResponseBase:
""" """
Enable two-factor authentication for the user. Enable two-factor authentication for the user.
@@ -559,6 +559,6 @@ async def router_user_settings_2fa_enable(
user.two_factor = secret user.two_factor = secret
user = await user.save(session) user = await user.save(session)
return models.response.ResponseBase( return models.ResponseBase(
data={"message": "Two-factor authentication enabled successfully"} data={"message": "Two-factor authentication enabled successfully"}
) )

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseBase from models import ResponseBase
vas_router = APIRouter( vas_router = APIRouter(
prefix="/vas", prefix="/vas",

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseBase from models import ResponseBase
# WebDAV 管理路由 # WebDAV 管理路由
webdav_router = APIRouter( webdav_router = APIRouter(

View File

@@ -0,0 +1,20 @@
"""
存储服务模块
提供文件存储相关的服务,包括:
- 本地存储服务
- 命名规则解析器
- 存储异常定义
"""
from .exceptions import (
DirectoryCreationError,
FileReadError,
FileWriteError,
InvalidPathError,
StorageException,
StorageFileNotFoundError,
UploadSessionExpiredError,
UploadSessionNotFoundError,
)
from .local_storage import LocalStorageService
from .naming_rule import NamingContext, NamingRuleParser

View File

@@ -0,0 +1,45 @@
"""
存储服务异常定义
定义存储操作相关的异常类型,用于精确的错误处理和诊断。
"""
class StorageException(Exception):
"""存储服务基础异常"""
pass
class DirectoryCreationError(StorageException):
"""目录创建失败"""
pass
class StorageFileNotFoundError(StorageException):
"""文件不存在"""
pass
class FileWriteError(StorageException):
"""文件写入失败"""
pass
class FileReadError(StorageException):
"""文件读取失败"""
pass
class UploadSessionNotFoundError(StorageException):
"""上传会话不存在"""
pass
class UploadSessionExpiredError(StorageException):
"""上传会话已过期"""
pass
class InvalidPathError(StorageException):
"""无效的路径"""
pass

View File

@@ -0,0 +1,388 @@
"""
本地存储服务
负责本地文件系统的物理操作:
- 目录创建
- 文件写入/读取/删除
- 文件移动(软删除到 .trash
所有 IO 操作都使用 aiofiles 确保异步执行。
"""
from pathlib import Path
from uuid import UUID
import aiofiles
import aiofiles.os
from loguru import logger as l
from models.policy import Policy
from .exceptions import (
DirectoryCreationError,
FileReadError,
FileWriteError,
InvalidPathError,
StorageException,
StorageFileNotFoundError,
)
from .naming_rule import NamingContext, NamingRuleParser
class LocalStorageService:
"""
本地存储服务
实现本地文件系统的异步文件操作。
所有 IO 操作都使用 aiofiles 确保异步执行。
使用示例::
service = LocalStorageService(policy)
await service.ensure_base_directory()
dir_path, storage_name, full_path = await service.generate_file_path(
user_id=user.id,
original_filename="document.pdf",
)
await service.write_file(full_path, content)
"""
def __init__(self, policy: Policy):
"""
初始化本地存储服务
:param policy: 存储策略配置
:raises StorageException: 本地存储策略未指定 server 路径时抛出
"""
if not policy.server:
raise StorageException("本地存储策略必须指定 server 路径")
self._policy = policy
self._base_path = Path(policy.server).resolve()
@property
def base_path(self) -> Path:
"""存储根目录"""
return self._base_path
# ==================== 目录操作 ====================
async def ensure_base_directory(self) -> None:
"""
确保存储根目录存在
创建策略时调用,确保物理目录已创建。
:raises DirectoryCreationError: 目录创建失败时抛出
"""
try:
await aiofiles.os.makedirs(str(self._base_path), exist_ok=True)
l.info(f"已确保存储目录存在: {self._base_path}")
except OSError as e:
raise DirectoryCreationError(f"无法创建存储目录 {self._base_path}: {e}")
async def ensure_directory(self, relative_path: str) -> Path:
"""
确保相对路径的目录存在
:param relative_path: 相对于存储根目录的路径
:return: 完整的目录路径
:raises DirectoryCreationError: 目录创建失败时抛出
"""
try:
full_path = self._base_path / relative_path
await aiofiles.os.makedirs(str(full_path), exist_ok=True)
return full_path
except OSError as e:
raise DirectoryCreationError(f"无法创建目录 {relative_path}: {e}")
async def ensure_trash_directory(self, user_id: UUID) -> Path:
"""
确保用户的回收站目录存在
回收站路径格式: {storage_root}/{user_id}/.trash
:param user_id: 用户UUID
:return: 回收站目录路径
:raises DirectoryCreationError: 目录创建失败时抛出
"""
trash_path = self._base_path / str(user_id) / ".trash"
try:
await aiofiles.os.makedirs(str(trash_path), exist_ok=True)
return trash_path
except OSError as e:
raise DirectoryCreationError(f"无法创建回收站目录: {e}")
# ==================== 路径生成 ====================
async def generate_file_path(
self,
user_id: UUID,
original_filename: str,
) -> tuple[str, str, str]:
"""
根据命名规则生成文件存储路径
:param user_id: 用户UUID
:param original_filename: 原始文件名
:return: (相对目录路径, 存储文件名, 完整物理路径)
"""
context = NamingContext(
user_id=user_id,
original_filename=original_filename,
)
# 解析目录规则
dir_path = ""
if self._policy.dir_name_rule:
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
# 解析文件名规则
if self._policy.auto_rename and self._policy.file_name_rule:
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
# 确保有扩展名
if '.' in original_filename and '.' not in storage_name:
ext = original_filename.rsplit('.', 1)[1]
storage_name = f"{storage_name}.{ext}"
else:
storage_name = original_filename
# 确保目录存在
if dir_path:
full_dir = await self.ensure_directory(dir_path)
else:
full_dir = self._base_path
full_path = str(full_dir / storage_name)
return dir_path, storage_name, full_path
# ==================== 文件写入 ====================
async def write_file(self, path: str, content: bytes) -> int:
"""
写入文件内容
:param path: 完整文件路径
:param content: 文件内容
:return: 写入的字节数
:raises FileWriteError: 写入失败时抛出
"""
try:
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
return len(content)
except OSError as e:
raise FileWriteError(f"写入文件失败 {path}: {e}")
async def write_file_chunk(
self,
path: str,
content: bytes,
offset: int,
) -> int:
"""
写入文件分片
:param path: 完整文件路径
:param content: 分片内容
:param offset: 写入偏移量
:return: 写入的字节数
:raises FileWriteError: 写入失败时抛出
"""
try:
# 检查文件是否存在,决定打开模式
is_exists = await self.file_exists(path)
mode = 'r+b' if is_exists else 'wb'
async with aiofiles.open(path, mode) as f:
await f.seek(offset)
await f.write(content)
return len(content)
except OSError as e:
raise FileWriteError(f"写入文件分片失败 {path}: {e}")
async def create_empty_file(self, path: str) -> None:
"""
创建空白文件
:param path: 完整文件路径
:raises FileWriteError: 创建失败时抛出
"""
try:
async with aiofiles.open(path, 'wb'):
pass # 创建空文件
except OSError as e:
raise FileWriteError(f"创建空文件失败 {path}: {e}")
# ==================== 文件读取 ====================
async def read_file(self, path: str) -> bytes:
"""
读取完整文件
:param path: 完整文件路径
:return: 文件内容
:raises StorageFileNotFoundError: 文件不存在时抛出
:raises FileReadError: 读取失败时抛出
"""
if not await self.file_exists(path):
raise StorageFileNotFoundError(f"文件不存在: {path}")
try:
async with aiofiles.open(path, 'rb') as f:
return await f.read()
except OSError as e:
raise FileReadError(f"读取文件失败 {path}: {e}")
async def get_file_size(self, path: str) -> int:
"""
获取文件大小
:param path: 完整文件路径
:return: 文件大小(字节)
:raises StorageFileNotFoundError: 文件不存在时抛出
"""
if not await self.file_exists(path):
raise StorageFileNotFoundError(f"文件不存在: {path}")
stat = await aiofiles.os.stat(path)
return stat.st_size
async def file_exists(self, path: str) -> bool:
"""
检查文件是否存在
:param path: 完整文件路径
:return: 是否存在
"""
return await aiofiles.os.path.exists(path)
# ==================== 文件删除和移动 ====================
async def delete_file(self, path: str) -> None:
"""
删除文件(物理删除)
:param path: 完整文件路径
"""
if await self.file_exists(path):
try:
await aiofiles.os.remove(path)
l.debug(f"已删除文件: {path}")
except OSError as e:
l.warning(f"删除文件失败 {path}: {e}")
async def move_to_trash(
self,
source_path: str,
user_id: UUID,
object_id: UUID,
) -> str:
"""
将文件移动到回收站(软删除)
回收站中的文件名格式: {object_uuid}_{original_filename}
:param source_path: 源文件完整路径
:param user_id: 用户UUID
:param object_id: 对象UUID用于生成唯一的回收站文件名
:return: 回收站中的文件路径
:raises StorageFileNotFoundError: 源文件不存在时抛出
"""
if not await self.file_exists(source_path):
raise StorageFileNotFoundError(f"源文件不存在: {source_path}")
# 确保回收站目录存在
trash_dir = await self.ensure_trash_directory(user_id)
# 使用 object_id 作为回收站文件名前缀,避免冲突
source_filename = Path(source_path).name
trash_filename = f"{object_id}_{source_filename}"
trash_path = trash_dir / trash_filename
# 移动文件
try:
await aiofiles.os.rename(source_path, str(trash_path))
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
return str(trash_path)
except OSError as e:
raise StorageException(f"移动文件到回收站失败: {e}")
async def restore_from_trash(
self,
trash_path: str,
restore_path: str,
) -> None:
"""
从回收站恢复文件
:param trash_path: 回收站中的文件路径
:param restore_path: 恢复目标路径
:raises StorageFileNotFoundError: 回收站文件不存在时抛出
"""
if not await self.file_exists(trash_path):
raise StorageFileNotFoundError(f"回收站文件不存在: {trash_path}")
# 确保目标目录存在
restore_dir = Path(restore_path).parent
await aiofiles.os.makedirs(str(restore_dir), exist_ok=True)
try:
await aiofiles.os.rename(trash_path, restore_path)
l.info(f"文件已从回收站恢复: {trash_path} -> {restore_path}")
except OSError as e:
raise StorageException(f"从回收站恢复文件失败: {e}")
async def empty_trash(self, user_id: UUID) -> int:
"""
清空用户回收站
:param user_id: 用户UUID
:return: 删除的文件数量
"""
trash_dir = self._base_path / str(user_id) / ".trash"
if not await aiofiles.os.path.exists(str(trash_dir)):
return 0
deleted_count = 0
try:
entries = await aiofiles.os.listdir(str(trash_dir))
for entry in entries:
file_path = trash_dir / entry
if await aiofiles.os.path.isfile(str(file_path)):
await aiofiles.os.remove(str(file_path))
deleted_count += 1
l.info(f"已清空用户 {user_id} 的回收站,删除 {deleted_count} 个文件")
except OSError as e:
l.warning(f"清空回收站时出错: {e}")
return deleted_count
# ==================== 路径验证 ====================
def validate_path(self, path: str) -> bool:
"""
验证路径是否在存储根目录下(防止路径遍历攻击)
:param path: 要验证的路径
:return: 路径是否有效
"""
try:
resolved = Path(path).resolve()
return str(resolved).startswith(str(self._base_path))
except (ValueError, OSError):
return False
def get_relative_path(self, full_path: str) -> str:
"""
获取相对于存储根目录的相对路径
:param full_path: 完整路径
:return: 相对路径
:raises InvalidPathError: 路径不在存储根目录下时抛出
"""
if not self.validate_path(full_path):
raise InvalidPathError(f"路径不在存储根目录下: {full_path}")
resolved = Path(full_path).resolve()
return str(resolved.relative_to(self._base_path))

View File

@@ -0,0 +1,144 @@
"""
命名规则解析器
将包含占位符的规则模板转换为实际的文件名/目录路径。
支持的占位符:
- {date}: 当前日期 YYYY-MM-DD
- {timestamp}: Unix 时间戳
- {year}: 年份 YYYY
- {month}: 月份 MM
- {day}: 日期 DD
- {hour}: 小时 HH
- {minute}: 分钟 MM
- {randomkey16}: 16位随机字符串
- {originname}: 原始文件名(不含扩展名)
- {ext}: 文件扩展名(不含点)
- {uid}: 用户UUID
- {uuid}: 新生成的UUID
"""
import re
import secrets
import string
from datetime import datetime
from uuid import UUID, uuid4
from models.base import SQLModelBase
class NamingContext(SQLModelBase):
"""
命名上下文
包含生成文件名/目录名所需的所有信息。
"""
user_id: UUID
"""用户UUID"""
original_filename: str
"""原始文件名(包含扩展名)"""
timestamp: datetime | None = None
"""时间戳,默认为当前时间"""
class NamingRuleParser:
"""
命名规则解析器
将包含占位符的规则模板转换为实际的文件名/目录路径。
使用示例::
context = NamingContext(
user_id=UUID("..."),
original_filename="document.pdf",
)
dir_path = NamingRuleParser.parse("{date}/{randomkey16}", context)
# -> "2025-12-23/a1b2c3d4e5f6g7h8"
file_name = NamingRuleParser.parse("{randomkey16}_{originname}.{ext}", context)
# -> "x9y8z7w6v5u4t3s2_document.pdf"
"""
# 支持的占位符正则
_PLACEHOLDER_PATTERN = re.compile(r'\{(\w+)\}')
# 随机字符集
_RANDOM_CHARS = string.ascii_lowercase + string.digits
@classmethod
def parse(cls, rule: str, context: NamingContext) -> str:
"""
解析命名规则,替换所有占位符
:param rule: 命名规则模板,如 "{date}/{randomkey16}"
:param context: 命名上下文
:return: 解析后的实际路径/文件名
"""
timestamp = context.timestamp or datetime.now()
# 解析原始文件名
origin_name, ext = cls._split_filename(context.original_filename)
# 占位符替换映射
replacements: dict[str, str] = {
'date': timestamp.strftime('%Y-%m-%d'),
'timestamp': str(int(timestamp.timestamp())),
'year': timestamp.strftime('%Y'),
'month': timestamp.strftime('%m'),
'day': timestamp.strftime('%d'),
'hour': timestamp.strftime('%H'),
'minute': timestamp.strftime('%M'),
'randomkey16': cls._generate_random_key(16),
'originname': origin_name,
'ext': ext,
'uid': str(context.user_id),
'uuid': str(uuid4()),
}
def replace_placeholder(match: re.Match[str]) -> str:
placeholder = match.group(1)
return replacements.get(placeholder, match.group(0))
return cls._PLACEHOLDER_PATTERN.sub(replace_placeholder, rule)
@classmethod
def _split_filename(cls, filename: str) -> tuple[str, str]:
"""
分离文件名和扩展名
:param filename: 完整文件名
:return: (文件名不含扩展名, 扩展名不含点)
"""
if '.' in filename:
parts = filename.rsplit('.', 1)
return parts[0], parts[1]
return filename, ''
@classmethod
def _generate_random_key(cls, length: int) -> str:
"""
生成随机字符串
:param length: 字符串长度
:return: 随机字符串
"""
return ''.join(secrets.choice(cls._RANDOM_CHARS) for _ in range(length))
@classmethod
def validate_rule(cls, rule: str) -> bool:
"""
验证命名规则是否有效
:param rule: 命名规则模板
:return: 是否有效
"""
valid_placeholders = {
'date', 'timestamp', 'year', 'month', 'day', 'hour', 'minute',
'randomkey16', 'originname', 'ext', 'uid', 'uuid',
}
placeholders = cls._PLACEHOLDER_PATTERN.findall(rule)
return all(p in valid_placeholders for p in placeholders)

11
uv.lock generated
View File

@@ -6,6 +6,15 @@ resolution-markers = [
"python_full_version < '3.14'", "python_full_version < '3.14'",
] ]
[[package]]
name = "aiofiles"
version = "25.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" },
]
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
version = "2.6.1" version = "2.6.1"
@@ -421,6 +430,7 @@ name = "disknext-server"
version = "0.0.1" version = "0.0.1"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "aiofiles" },
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "aiosqlite" }, { name = "aiosqlite" },
{ name = "argon2-cffi" }, { name = "argon2-cffi" },
@@ -444,6 +454,7 @@ dependencies = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiofiles", specifier = ">=25.1.0" },
{ name = "aiohttp", specifier = ">=3.13.2" }, { name = "aiohttp", specifier = ">=3.13.2" },
{ name = "aiosqlite", specifier = ">=0.21.0" }, { name = "aiosqlite", specifier = ">=0.21.0" },
{ name = "argon2-cffi", specifier = ">=25.1.0" }, { name = "argon2-cffi", specifier = ">=25.1.0" },