Compare commits

..

3 Commits

Author SHA1 Message Date
1ecc0fdc1c feat: implement source link endpoints and enforce policy rules
All checks were successful
Test / test (push) Successful in 1m56s
- Add POST/GET source link endpoints for file sharing via permanent URLs
- Enforce max_size check in PATCH /file/content to prevent size limit bypass
- Support is_private (proxy) vs public (302 redirect) storage modes
- Replace all ResponseBase(data=...) with proper DTOs or 204 responses
- Add 18 integration tests for source link and policy rule enforcement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:07:20 +08:00
71883d32c0 feat: add PATCH /user/settings/password endpoint for changing password
All checks were successful
Test / test (push) Successful in 1m43s
Register the fixed /password route before the wildcard /{option} to
prevent FastAPI from matching it as a path parameter.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 15:11:56 +08:00
ccadfe57cd feat: migrate ORM base to sqlmodel-ext, add file viewers and WOPI integration
All checks were successful
Test / test (push) Successful in 1m45s
- Migrate SQLModel base classes, mixins, and database management to
  external sqlmodel-ext package; remove sqlmodels/base/, sqlmodels/mixin/,
  and sqlmodels/database.py
- Add file viewer/editor system with WOPI protocol support for
  collaborative editing (OnlyOffice, Collabora)
- Add enterprise edition license verification module (ee/)
- Add Dockerfile multi-stage build with Cython compilation support
- Add new dependencies: sqlmodel-ext, cryptography, whatthepatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 14:23:17 +08:00
87 changed files with 1676 additions and 6743 deletions

View File

@@ -5,8 +5,7 @@
"Bash(findstr:*)",
"Bash(find:*)",
"Bash(yarn tsc:*)",
"Bash(dir:*)",
"mcp__server-notify__notify"
"Bash(dir:*)"
]
}
}

5
.gitignore vendored
View File

@@ -1,6 +1,8 @@
# Python
__pycache__/
*.py[cod]
*.pyo
*.pyd
*.so
*.egg
*.egg-info/
@@ -77,6 +79,3 @@ statics/
# 许可证密钥(保密)
license_private.pem
license.key
avatar/
.dev/

3
.gitmodules vendored
View File

@@ -1,3 +0,0 @@
[submodule "ee"]
path = ee
url = https://git.yxqi.cn/Yuerchu/disknext-ee.git

1
ee

Submodule ee deleted from cc32d8db91

42
ee/__init__.py Normal file
View File

@@ -0,0 +1,42 @@
"""
DiskNext Enterprise Edition (EE) 模块
通过 ``try: from ee import init_ee`` 检测是否存在。
CE 版本中此目录不存在ImportError 被 main.py 捕获。
"""
from loguru import logger as l
from utils.conf import appmeta
_ee_initialized: bool = False
def is_pro() -> bool:
"""当前实例是否以 Pro 版本运行。"""
return _ee_initialized
async def init_ee() -> None:
"""
初始化企业版功能。
1. 加载并验证许可证
2. 设置 appmeta.IsPro = True
3. 标记 EE 已初始化
许可证无效或缺失时抛出异常,阻止应用启动。
"""
global _ee_initialized
from ee.service.license_service import load_and_validate_license
payload = await load_and_validate_license()
appmeta.IsPro = True
_ee_initialized = True
l.info(
f"Pro 版本已激活 — 域名: {payload.domain}, "
f"过期: {payload.expires_at.isoformat()}, "
f"功能: {payload.features}"
)

86
ee/license.py Normal file
View File

@@ -0,0 +1,86 @@
"""
RSA-PSS 许可证验签核心(编译为 .so 后公钥藏入二进制)
此文件只包含纯函数和常量,不包含 SQLModel 类。
"""
import base64
from datetime import datetime, timezone
import orjson
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
_PUBLIC_KEY_PEM: bytes = b"""-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
-----END PUBLIC KEY-----"""
class LicenseError(Exception):
"""许可证验证基础异常"""
class LicenseExpiredError(LicenseError):
"""许可证已过期"""
def verify_license(raw: str) -> dict:
"""
验证许可证字符串并返回载荷字典。
:param raw: 格式为 ``base64(json_payload).base64(signature)``
:returns: 解析后的载荷字典
:raises LicenseError: 格式无效或签名验证失败
:raises LicenseExpiredError: 许可证已过期
"""
parts = raw.strip().split(".")
if len(parts) != 2:
raise LicenseError("许可证格式无效:需要 payload.signature")
payload_b64, signature_b64 = parts
try:
payload_bytes = base64.urlsafe_b64decode(payload_b64)
signature = base64.urlsafe_b64decode(signature_b64)
except Exception as exc:
raise LicenseError(f"许可证 base64 解码失败: {exc}") from exc
public_key = serialization.load_pem_public_key(_PUBLIC_KEY_PEM)
try:
public_key.verify( # type: ignore[union-attr]
signature,
payload_bytes,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
except Exception as exc:
raise LicenseError(f"许可证签名验证失败: {exc}") from exc
data: dict = orjson.loads(payload_bytes)
expires_at_str: str | None = data.get('expires_at')
if not expires_at_str:
raise LicenseError("许可证缺少 expires_at 字段")
expires_at = datetime.fromisoformat(expires_at_str)
if expires_at < datetime.now(timezone.utc):
raise LicenseExpiredError(
f"许可证已过期: {expires_at.isoformat()}"
)
return data

6
ee/routers/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
from fastapi import APIRouter
from .pro import router as pro_router
ee_router = APIRouter()
ee_router.include_router(pro_router)

View File

@@ -0,0 +1,69 @@
"""
Pro 版本状态端点
提供许可证状态查询,需管理员权限。
"""
from datetime import datetime
from fastapi import APIRouter, Depends
from ee.service import LicensePayload
from ee.service.license_service import get_cached_license
from middleware.auth import admin_required
from sqlmodel_ext import SQLModelBase
router = APIRouter(prefix="/pro")
class ProStatusResponse(SQLModelBase):
"""Pro 版本状态响应"""
is_active: bool
"""许可证是否有效"""
domain: str
"""授权域名"""
expires_at: datetime
"""过期时间"""
max_users: int
"""最大用户数0 = 无限制)"""
features: list[str]
"""已授权的功能列表"""
@router.get(
'/status',
dependencies=[Depends(admin_required)],
)
async def get_pro_status() -> ProStatusResponse:
"""
查询 Pro 版本许可证状态
认证:
- JWT token in Authorization header
- 需要管理员权限
响应:
- ProStatusResponse: 当前许可证信息
错误处理:
- HTTPException 401: 未授权
- HTTPException 403: 非管理员
- HTTPException 500: 许可证缓存异常
"""
payload: LicensePayload | None = get_cached_license()
if not payload:
# init_ee 成功后 payload 一定存在,此处做防御性编程
from utils.http.http_exceptions import raise_internal_error
raise_internal_error()
return ProStatusResponse(
is_active=True,
domain=payload.domain,
expires_at=payload.expires_at,
max_users=payload.max_users,
features=payload.features,
)

30
ee/service/__init__.py Normal file
View File

@@ -0,0 +1,30 @@
"""
EE 许可证服务模块
提供许可证加载、验证和缓存功能。
"""
from datetime import datetime
from sqlmodel_ext import SQLModelBase
class LicensePayload(SQLModelBase):
"""许可证载荷RSA 验签后的明文数据)"""
domain: str
"""授权域名"""
expires_at: datetime
"""过期时间UTC"""
max_users: int
"""最大用户数0 = 无限制)"""
features: list[str]
"""已授权的功能列表"""
issued_at: datetime
"""签发时间UTC"""
from .license_service import get_cached_license, load_and_validate_license

View File

@@ -0,0 +1,49 @@
"""
许可证加载与缓存服务(编译为 .so
从环境变量 LICENSE_KEY 或 license.key 文件加载许可证,
调用 verify_license() 验证后缓存结果。
"""
import os
from pathlib import Path
import aiofiles
from ee.license import LicenseError, verify_license
from ee.service import LicensePayload
_cached_payload: LicensePayload | None = None
async def load_and_validate_license() -> LicensePayload:
"""
加载并验证许可证,成功后缓存。
加载优先级:
1. 环境变量 ``LICENSE_KEY``
2. 项目根目录 ``license.key`` 文件
:returns: 验证通过的 LicensePayload
:raises LicenseError: 未找到许可证 / 验证失败 / 已过期
"""
global _cached_payload
raw: str | None = os.getenv("LICENSE_KEY")
if not raw:
key_path = Path("license.key")
if key_path.is_file():
async with aiofiles.open(key_path, 'r') as f:
raw = (await f.read()).strip()
if not raw:
raise LicenseError("未找到许可证:请设置 LICENSE_KEY 环境变量或提供 license.key 文件")
data = verify_license(raw)
_cached_payload = LicensePayload.model_validate(data)
return _cached_payload
def get_cached_license() -> LicensePayload | None:
"""获取已缓存的许可证载荷(未加载时返回 None"""
return _cached_payload

5
ee/sqlmodels/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""
EE 版本数据库模型
后续 Pro 功能的 SQLModel 定义位置。
"""

31
main.py
View File

@@ -5,10 +5,7 @@ from fastapi import FastAPI, Request
from loguru import logger as l
from routers import router
from routers.dav import dav_app
from routers.dav.provider import EventLoopRef
from service.redis import RedisManager
from service.storage import S3StorageService
from sqlmodels.database_connection import DatabaseManager
from sqlmodels.migration import migration
from utils import JWT
@@ -17,26 +14,24 @@ from utils.http.http_exceptions import raise_internal_error
from utils.lifespan import lifespan
# 尝试加载企业版功能
_has_ee: bool = False
try:
from ee import init_ee
from ee.license import LicenseError
from ee.routers import ee_router
_has_ee = True
async def _init_ee() -> None:
"""启动时验证许可证,路由由 license_valid_required 依赖保护"""
async def _init_ee_and_routes() -> None:
try:
await init_ee()
except LicenseError as exc:
l.critical(f"许可证验证失败: {exc}")
raise SystemExit(1) from exc
lifespan.add_startup(_init_ee)
except ImportError as exc:
ee_router = None
l.info(f"以 Community 版本运行 (原因: {exc})")
from ee.routers import ee_router
from routers.api.v1 import router as v1_router
v1_router.include_router(ee_router)
lifespan.add_startup(_init_ee_and_routes)
except ImportError:
l.info("以 Community 版本运行")
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
"""前端静态文件目录(由 Docker 构建时复制)"""
@@ -45,18 +40,13 @@ async def _init_db() -> None:
"""初始化数据库连接引擎"""
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
# 捕获事件循环引用(供 WSGI 线程桥接使用)
lifespan.add_startup(EventLoopRef.capture)
# 添加初始化数据库启动项
lifespan.add_startup(_init_db)
lifespan.add_startup(migration)
lifespan.add_startup(JWT.load_secret_key)
lifespan.add_startup(RedisManager.connect)
lifespan.add_startup(S3StorageService.initialize_session)
# 添加关闭项
lifespan.add_shutdown(S3StorageService.close_session)
lifespan.add_shutdown(DatabaseManager.close)
lifespan.add_shutdown(RedisManager.disconnect)
@@ -97,11 +87,6 @@ async def handle_unexpected_exceptions(
# 挂载路由
app.include_router(router)
if _has_ee:
app.include_router(ee_router, prefix="/api/v1")
# 挂载 WebDAV 协议端点(优先于 SPA catch-all
app.mount("/dav", dav_app)
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
if STATICS_DIR.is_dir():

View File

@@ -33,8 +33,6 @@ dependencies = [
"uvicorn>=0.38.0",
"webauthn>=2.7.0",
"whatthepatch>=1.0.6",
"wsgidav>=4.3.0",
"a2wsgi>=1.10.0",
]
[project.optional-dependencies]

View File

@@ -5,7 +5,6 @@ from utils.conf import appmeta
from .admin import admin_router
from .callback import callback_router
from .category import category_router
from .directory import directory_router
from .download import download_router
from .file import router as file_router
@@ -15,6 +14,7 @@ from .trash import trash_router
from .site import site_router
from .slave import slave_router
from .user import user_router
from .vas import vas_router
from .webdav import webdav_router
router = APIRouter(prefix="/v1")
@@ -24,7 +24,6 @@ router = APIRouter(prefix="/v1")
if appmeta.mode == "master":
router.include_router(admin_router)
router.include_router(callback_router)
router.include_router(category_router)
router.include_router(directory_router)
router.include_router(download_router)
router.include_router(file_router)
@@ -33,6 +32,7 @@ if appmeta.mode == "master":
router.include_router(site_router)
router.include_router(trash_router)
router.include_router(user_router)
router.include_router(vas_router)
router.include_router(webdav_router)
elif appmeta.mode == "slave":
router.include_router(slave_router)

View File

@@ -16,12 +16,6 @@ from sqlmodels.setting import (
from sqlmodels.setting import SettingsType
from utils import http_exceptions
from utils.conf import appmeta
try:
from ee.service import get_cached_license
except ImportError:
get_cached_license = None
from .file import admin_file_router
from .file_app import admin_file_app_router
from .group import admin_group_router
@@ -30,6 +24,7 @@ from .share import admin_share_router
from .task import admin_task_router
from .user import admin_user_router
from .theme import admin_theme_router
from .vas import admin_vas_router
class Aria2TestRequest(SQLModelBase):
@@ -55,6 +50,7 @@ admin_router.include_router(admin_policy_router)
admin_router.include_router(admin_share_router)
admin_router.include_router(admin_task_router)
admin_router.include_router(admin_theme_router)
admin_router.include_router(admin_vas_router)
# 离线下载 /api/admin/aria2
admin_aria2_router = APIRouter(
@@ -163,24 +159,14 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
if site_url_setting and site_url_setting.value:
site_urls.append(site_url_setting.value)
# 许可证信息(Pro 版本从缓存读取CE 版本永不过期
if appmeta.IsPro and get_cached_license:
payload = get_cached_license()
license_info = LicenseInfo(
expired_at=payload.expires_at,
signed_at=payload.issued_at,
root_domains=[],
domains=[payload.domain],
vol_domains=[],
)
else:
license_info = LicenseInfo(
expired_at=datetime.max,
signed_at=now,
root_domains=[],
domains=[],
vol_domains=[],
)
# 许可证信息(从设置读取或使用默认值
license_info = LicenseInfo(
expired_at=now + timedelta(days=365),
signed_at=now,
root_domains=[],
domains=[],
vol_domains=[],
)
# 版本信息
version_info = VersionInfo(
@@ -239,11 +225,11 @@ async def router_admin_update_settings(
if existing:
existing.value = item.value
existing = await existing.save(session)
await existing.save(session)
updated_count += 1
else:
new_setting = Setting(type=item.type, name=item.name, value=item.value)
new_setting = await new_setting.save(session)
await new_setting.save(session)
created_count += 1
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")

View File

@@ -54,7 +54,7 @@ async def _set_ban_recursive(
obj.banned_by = None
obj.ban_reason = None
obj = await obj.save(session)
await obj.save(session)
count += 1
return count
@@ -131,7 +131,9 @@ async def router_admin_preview_file(
:param file_id: 文件UUID
:return: 文件内容
"""
file_obj = await Object.get_exist_one(session, file_id)
file_obj = await Object.get(session, Object.id == file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件")
@@ -180,7 +182,9 @@ async def router_admin_ban_file(
:param claims: 当前管理员 JWT claims
:return: 封禁结果
"""
file_obj = await Object.get_exist_one(session, file_id)
file_obj = await Object.get(session, Object.id == file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
@@ -208,7 +212,9 @@ async def router_admin_delete_file(
:param delete_physical: 是否同时删除物理文件
:return: 删除结果
"""
file_obj = await Object.get_exist_one(session, file_id)
file_obj = await Object.get(session, Object.id == file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件")

View File

@@ -1,18 +1,16 @@
"""
管理员文件应用管理端点
提供文件查看器应用的 CRUD、扩展名管理用户组权限管理和 WOPI Discovery
提供文件查看器应用的 CRUD、扩展名管理用户组权限管理。
"""
from uuid import UUID
import aiohttp
from fastapi import APIRouter, Depends, status
from loguru import logger as l
from sqlalchemy import select
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from service.wopi import parse_wopi_discovery_xml
from sqlmodels import (
FileApp,
FileAppCreateRequest,
@@ -23,10 +21,7 @@ from sqlmodels import (
FileAppUpdateRequest,
ExtensionUpdateRequest,
GroupAccessUpdateRequest,
WopiDiscoveredExtension,
WopiDiscoveryResponse,
)
from sqlmodels.file_app import FileAppType
from utils import http_exceptions
admin_file_app_router = APIRouter(
@@ -128,7 +123,6 @@ async def create_file_app(
group_links.append(link)
if group_links:
await session.commit()
await session.refresh(app)
l.info(f"创建文件应用: {app.name} ({app.app_key})")
@@ -151,7 +145,9 @@ async def get_file_app(
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
extensions = await FileAppExtension.get(
session,
@@ -184,7 +180,9 @@ async def update_file_app(
- 404: 应用不存在
- 409: 新 app_key 已被其他应用使用
"""
app = await FileApp.get_exist_one(session, app_id)
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
# 检查 app_key 唯一性
if request.app_key is not None and request.app_key != app.app_key:
@@ -231,7 +229,9 @@ async def delete_file_app(
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
app_name = app.app_key
await FileApp.delete(session, app)
@@ -257,24 +257,20 @@ async def update_extensions(
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
# 保留旧扩展名的 wopi_action_urlDiscovery 填充的值)
# 删除旧的扩展名
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
old_url_map: dict[str, str] = {
ext.extension: ext.wopi_action_url
for ext in old_extensions
if ext.wopi_action_url
}
for old_ext in old_extensions:
await FileAppExtension.delete(session, old_ext, commit=False)
await session.flush()
# 创建新的扩展名(保留已有的 wopi_action_url
# 创建新的扩展名
new_extensions: list[FileAppExtension] = []
for i, ext in enumerate(request.extensions):
normalized = ext.lower().strip().lstrip('.')
@@ -282,14 +278,12 @@ async def update_extensions(
app_id=app_id,
extension=normalized,
priority=i,
wopi_action_url=old_url_map.get(normalized),
)
session.add(ext_record)
new_extensions.append(ext_record)
await session.commit()
# refresh commit 后过期的对象
await session.refresh(app)
# refresh 新创建的记录
for ext_record in new_extensions:
await session.refresh(ext_record)
@@ -322,7 +316,9 @@ async def update_group_access(
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
# 删除旧的用户组关联
old_links_result = await session.exec(
@@ -340,7 +336,6 @@ async def update_group_access(
new_links.append(link)
await session.commit()
await session.refresh(app)
extensions = await FileAppExtension.get(
session,
@@ -351,100 +346,3 @@ async def update_group_access(
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
return FileAppResponse.from_app(app, extensions, new_links)
@admin_file_app_router.post(
path='/{app_id}/discover',
summary='执行 WOPI Discovery',
)
async def discover_wopi(
session: SessionDep,
app_id: UUID,
) -> WopiDiscoveryResponse:
"""
从 WOPI 服务端获取 Discovery XML 并自动配置扩展名和 URL 模板。
流程:
1. 验证 FileApp 存在且为 WOPI 类型
2. 使用 FileApp.wopi_discovery_url 获取 Discovery XML
3. 解析 XML提取扩展名和动作 URL
4. 全量替换 FileAppExtension 记录(带 wopi_action_url
认证:管理员权限
错误处理:
- 404: 应用不存在
- 400: 非 WOPI 类型 / discovery URL 未配置 / XML 解析失败
- 502: WOPI 服务端不可达或返回无效响应
"""
app = await FileApp.get_exist_one(session, app_id)
if app.type != FileAppType.WOPI:
http_exceptions.raise_bad_request("仅 WOPI 类型应用支持自动发现")
if not app.wopi_discovery_url:
http_exceptions.raise_bad_request("未配置 WOPI Discovery URL")
# commit 后对象会过期,先保存需要的值
discovery_url = app.wopi_discovery_url
app_key = app.app_key
# 获取 Discovery XML
try:
async with aiohttp.ClientSession() as client:
async with client.get(
discovery_url,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
http_exceptions.raise_bad_gateway(
f"WOPI 服务端返回 HTTP {resp.status}"
)
xml_content = await resp.text()
except aiohttp.ClientError as e:
http_exceptions.raise_bad_gateway(f"无法连接 WOPI 服务端: {e}")
# 解析 XML
try:
action_urls, app_names = parse_wopi_discovery_xml(xml_content)
except ValueError as e:
http_exceptions.raise_bad_request(str(e))
if not action_urls:
return WopiDiscoveryResponse(app_names=app_names)
# 全量替换扩展名
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
for old_ext in old_extensions:
await FileAppExtension.delete(session, old_ext, commit=False)
await session.flush()
new_extensions: list[FileAppExtension] = []
discovered: list[WopiDiscoveredExtension] = []
for i, (ext, action_url) in enumerate(sorted(action_urls.items())):
ext_record = FileAppExtension(
app_id=app_id,
extension=ext,
priority=i,
wopi_action_url=action_url,
)
session.add(ext_record)
new_extensions.append(ext_record)
discovered.append(WopiDiscoveredExtension(extension=ext, action_url=action_url))
await session.commit()
l.info(
f"WOPI Discovery 完成: 应用 {app_key}, "
f"发现 {len(discovered)} 个扩展名"
)
return WopiDiscoveryResponse(
discovered_extensions=discovered,
app_names=app_names,
applied_count=len(discovered),
)

View File

@@ -63,7 +63,10 @@ async def router_admin_get_group(
:param group_id: 用户组UUID
:return: 用户组详情
"""
group = await Group.get_exist_one(session, group_id, load=[Group.options, Group.policies])
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 直接访问已加载的关系,无需额外查询
policies = group.policies
@@ -91,7 +94,9 @@ async def router_admin_get_group_members(
:return: 分页成员列表
"""
# 验证组存在
await Group.get_exist_one(session, group_id)
group = await Group.get(session, Group.id == group_id)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
@@ -133,11 +138,10 @@ async def router_admin_create_group(
speed_limit=request.speed_limit,
)
group = await group.save(session)
group_id_val: UUID = group.id
# 创建选项
options = GroupOptions(
group_id=group_id_val,
group_id=group.id,
share_download=request.share_download,
share_free=request.share_free,
relocate=request.relocate,
@@ -150,11 +154,11 @@ async def router_admin_create_group(
aria2=request.aria2,
redirected_source=request.redirected_source,
)
options = await options.save(session)
await options.save(session)
# 关联存储策略
for policy_id in request.policy_ids:
link = GroupPolicyLink(group_id=group_id_val, policy_id=policy_id)
link = GroupPolicyLink(group_id=group.id, policy_id=policy_id)
session.add(link)
await session.commit()
@@ -181,7 +185,9 @@ async def router_admin_update_group(
:param request: 更新请求
:return: 更新结果
"""
group = await Group.get_exist_one(session, group_id, load=Group.options)
group = await Group.get(session, Group.id == group_id, load=Group.options)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 检查名称唯一性(如果要更新名称)
if request.name and request.name != group.name:
@@ -211,7 +217,7 @@ async def router_admin_update_group(
if options_data:
for key, value in options_data.items():
setattr(group.options, key, value)
group.options = await group.options.save(session)
await group.options.save(session)
# 更新策略关联
if request.policy_ids is not None:
@@ -249,7 +255,9 @@ async def router_admin_delete_group(
:param group_id: 用户组UUID
:return: 删除结果
"""
group = await Group.get_exist_one(session, group_id)
group = await Group.get(session, Group.id == group_id)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 检查是否有用户属于该组
user_count = await User.count(session, User.group_id == group_id)

View File

@@ -8,11 +8,11 @@ from sqlmodel import Field
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
PolicyUpdateRequest, ResponseBase, ListResponse, Object,
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
ListResponse, Object,
)
from sqlmodel_ext import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
from service.storage import DirectoryCreationError, LocalStorageService
admin_policy_router = APIRouter(
prefix='/policy',
@@ -67,12 +67,6 @@ class PolicyDetailResponse(SQLModelBase):
base_url: str | None
"""基础URL"""
access_key: str | None
"""Access Key"""
secret_key: str | None
"""Secret Key"""
max_size: int
"""最大文件尺寸"""
@@ -113,45 +107,9 @@ class PolicyTestSlaveRequest(SQLModelBase):
secret: str
"""从机通信密钥"""
class PolicyTestS3Request(SQLModelBase):
"""测试 S3 连接请求 DTO"""
server: str = Field(max_length=255)
"""S3 端点地址"""
bucket_name: str = Field(max_length=255)
"""存储桶名称"""
access_key: str
"""Access Key"""
secret_key: str
"""Secret Key"""
s3_region: str = Field(default='us-east-1', max_length=64)
"""S3 区域"""
s3_path_style: bool = False
"""是否使用路径风格"""
class PolicyTestS3Response(SQLModelBase):
"""S3 连接测试响应"""
is_connected: bool
"""连接是否成功"""
message: str
"""测试结果消息"""
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
_OPTIONS_FIELDS: set[str] = {
'token', 'file_type', 'mimetype', 'od_redirect',
'chunk_size', 's3_path_style', 's3_region',
}
class PolicyCreateRequest(PolicyBase):
"""创建存储策略请求 DTO继承 PolicyBase 中的所有字段"""
pass
@admin_policy_router.get(
path='/list',
@@ -319,20 +277,7 @@ async def router_policy_add_policy(
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
# 保存到数据库
policy = await policy.save(session)
# 创建策略选项
options = PolicyOptions(
policy_id=policy.id,
token=request.token,
file_type=request.file_type,
mimetype=request.mimetype,
od_redirect=request.od_redirect,
chunk_size=request.chunk_size,
s3_path_style=request.s3_path_style,
s3_region=request.s3_region,
)
options = await options.save(session)
await policy.save(session)
@admin_policy_router.post(
path='/cors',
@@ -383,7 +328,9 @@ async def router_policy_onddrive_oauth(
:param policy_id: 存储策略UUID
:return: OAuth URL
"""
policy = await Policy.get_exist_one(session, policy_id)
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# TODO: 实现OneDrive OAuth
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
@@ -406,7 +353,9 @@ async def router_policy_get_policy(
:param policy_id: 存储策略UUID
:return: 策略详情
"""
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 获取使用此策略的用户组
groups = await policy.awaitable_attrs.groups
@@ -422,8 +371,6 @@ async def router_policy_get_policy(
bucket_name=policy.bucket_name,
is_private=policy.is_private,
base_url=policy.base_url,
access_key=policy.access_key,
secret_key=policy.secret_key,
max_size=policy.max_size,
auto_rename=policy.auto_rename,
dir_name_rule=policy.dir_name_rule,
@@ -455,7 +402,9 @@ async def router_policy_delete_policy(
:param policy_id: 存储策略UUID
:return: 删除结果
"""
policy = await Policy.get_exist_one(session, policy_id)
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 检查是否有文件使用此策略
file_count = await Object.count(session, Object.policy_id == policy_id)
@@ -468,106 +417,4 @@ async def router_policy_delete_policy(
policy_name = policy.name
await Policy.delete(session, policy)
l.info(f"管理员删除了存储策略: {policy_name}")
@admin_policy_router.patch(
path='/{policy_id}',
summary='更新存储策略',
description='更新存储策略配置。策略类型创建后不可更改。',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_policy_update_policy(
session: SessionDep,
policy_id: UUID,
request: PolicyUpdateRequest,
) -> None:
"""
更新存储策略端点
功能:
- 更新策略基础字段和扩展选项
- 策略类型type不可更改
认证:
- 需要管理员权限
:param session: 数据库会话
:param policy_id: 存储策略UUID
:param request: 更新请求
"""
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
# 检查名称唯一性(如果要更新名称)
if request.name and request.name != policy.name:
existing = await Policy.get(session, Policy.name == request.name)
if existing:
raise HTTPException(status_code=409, detail="策略名称已存在")
# 分离 Policy 字段和 Options 字段
all_data = request.model_dump(exclude_unset=True)
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
# 更新 Policy 基础字段
if policy_data:
for key, value in policy_data.items():
setattr(policy, key, value)
policy = await policy.save(session)
# 更新或创建 PolicyOptions
if options_data:
if policy.options:
for key, value in options_data.items():
setattr(policy.options, key, value)
policy.options = await policy.options.save(session)
else:
options = PolicyOptions(policy_id=policy.id, **options_data)
options = await options.save(session)
l.info(f"管理员更新了存储策略: {policy_id}")
@admin_policy_router.post(
path='/test/s3',
summary='测试 S3 连接',
description='测试 S3 存储端点的连通性和凭据有效性。',
dependencies=[Depends(admin_required)],
)
async def router_policy_test_s3(
request: PolicyTestS3Request,
) -> PolicyTestS3Response:
"""
测试 S3 连接端点
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
:param request: 测试请求
:return: 测试结果
"""
from service.storage import S3APIError
# 构造临时 Policy 对象用于创建 S3StorageService
temp_policy = Policy(
name="__test__",
type=PolicyType.S3,
server=request.server,
bucket_name=request.bucket_name,
access_key=request.access_key,
secret_key=request.secret_key,
)
s3_service = S3StorageService(
temp_policy,
region=request.s3_region,
is_path_style=request.s3_path_style,
)
try:
# 使用 file_exists 发送 HEAD 请求来验证连通性
await s3_service.file_exists("__connection_test__")
return PolicyTestS3Response(is_connected=True, message="连接成功")
except S3APIError as e:
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
except Exception as e:
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")
l.info(f"管理员删除了存储策略: {policy_name}")

View File

@@ -155,7 +155,9 @@ async def router_admin_delete_share(
:param share_id: 分享ID
:return: 删除结果
"""
share = await Share.get_exist_one(session, share_id)
share = await Share.get(session, Share.id == share_id)
if not share:
raise HTTPException(status_code=404, detail="分享不存在")
await Share.delete(session, share)

View File

@@ -8,7 +8,7 @@ from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
ListResponse,
Task, TaskSummary, TaskStatus, TaskType,
Task, TaskSummary,
)
from sqlmodel_ext import SQLModelBase
@@ -19,10 +19,10 @@ class TaskDetailResponse(SQLModelBase):
id: int
"""任务ID"""
status: TaskStatus
status: int
"""任务状态"""
type: TaskType
type: int
"""任务类型"""
progress: int
@@ -150,7 +150,9 @@ async def router_admin_delete_task(
:param task_id: 任务ID
:return: 删除结果
"""
task = await Task.get_exist_one(session, task_id)
task = await Task.get(session, Task.id == task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
await Task.delete(session, task)

View File

@@ -71,7 +71,7 @@ async def router_admin_theme_create(
name=request.name,
**request.colors.model_dump(),
)
preset = await preset.save(session)
await preset.save(session)
l.info(f"管理员创建了主题预设: {request.name}")
@@ -101,7 +101,11 @@ async def router_admin_theme_update(
- 404: 预设不存在
- 409: 名称已被其他预设使用
"""
preset = await ThemePreset.get_exist_one(session, preset_id)
preset: ThemePreset | None = await ThemePreset.get(
session, ThemePreset.id == preset_id
)
if not preset:
http_exceptions.raise_not_found("主题预设不存在")
# 检查名称唯一性(排除自身)
if request.name is not None and request.name != preset.name:
@@ -116,7 +120,7 @@ async def router_admin_theme_update(
for key, value in color_data.items():
setattr(preset, key, value)
preset = await preset.save(session)
await preset.save(session)
l.info(f"管理员更新了主题预设: {preset.name}")
@@ -143,7 +147,11 @@ async def router_admin_theme_delete(
副作用:
- 关联用户的 theme_preset_id 会被数据库 SET NULL
"""
preset = await ThemePreset.get_exist_one(session, preset_id)
preset: ThemePreset | None = await ThemePreset.get(
session, ThemePreset.id == preset_id
)
if not preset:
http_exceptions.raise_not_found("主题预设不存在")
await preset.delete(session)
l.info(f"管理员删除了主题预设: {preset.name}")
@@ -172,7 +180,11 @@ async def router_admin_theme_set_default(
逻辑:
- 事务中先清除所有旧默认,再设新默认
"""
preset = await ThemePreset.get_exist_one(session, preset_id)
preset: ThemePreset | None = await ThemePreset.get(
session, ThemePreset.id == preset_id
)
if not preset:
http_exceptions.raise_not_found("主题预设不存在")
# 清除所有旧默认
await session.execute(
@@ -183,5 +195,5 @@ async def router_admin_theme_set_default(
# 设新默认
preset.is_default = True
preset = await preset.save(session)
await preset.save(session)
l.info(f"管理员将主题预设 '{preset.name}' 设为默认")

View File

@@ -128,9 +128,8 @@ async def router_admin_create_user(
is_verified=True,
user_id=user.id,
)
identity = await identity.save(session)
await identity.save(session)
user = await User.get(session, User.id == user.id, load=User.group)
return user.to_public()
@@ -154,7 +153,9 @@ async def router_admin_update_user(
:param request: 更新请求
:return: 更新结果
"""
user = await User.get_exist_one(session, user_id)
user = await User.get(session, User.id == user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
default_admin_setting = await Setting.get(
@@ -251,7 +252,9 @@ async def router_admin_calibrate_storage(
:param user_id: 用户UUID
:return: 校准结果
"""
user = await User.get_exist_one(session, user_id)
user = await User.get(session, User.id == user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
previous_storage = user.storage

View File

@@ -0,0 +1,81 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from sqlmodels import (
ResponseBase,
)
admin_vas_router = APIRouter(
prefix='/vas',
tags=['admin', 'admin_vas']
)
@admin_vas_router.get(
path='/list',
summary='获取增值服务列表',
description='Get VAS list (orders and storage packs)',
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas_list(
session: SessionDep,
user_id: UUID | None = None,
page: int = 1,
page_size: int = 20,
) -> ResponseBase:
"""
获取增值服务列表(订单和存储包)。
:param session: 数据库会话
:param user_id: 按用户筛选
:param page: 页码
:param page_size: 每页数量
:return: 增值服务列表
"""
# TODO: 实现增值服务列表
# 需要查询 Order 和 StoragePack 模型
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
@admin_vas_router.get(
path='/{vas_id}',
summary='获取增值服务详情',
description='Get VAS detail by ID',
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas(
session: SessionDep,
vas_id: UUID,
) -> ResponseBase:
"""
获取增值服务详情。
:param session: 数据库会话
:param vas_id: 增值服务UUID
:return: 增值服务详情
"""
# TODO: 实现增值服务详情
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
@admin_vas_router.delete(
path='/{vas_id}',
summary='删除增值服务',
description='Delete VAS by ID',
dependencies=[Depends(admin_required)]
)
async def router_admin_delete_vas(
session: SessionDep,
vas_id: UUID,
) -> ResponseBase:
"""
删除增值服务。
:param session: 数据库会话
:param vas_id: 增值服务UUID
:return: 删除结果
"""
# TODO: 实现增值服务删除
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")

View File

@@ -1,6 +1,5 @@
from fastapi import APIRouter, Query
from fastapi.responses import PlainTextResponse
from loguru import logger as l
from sqlmodels import ResponseBase
import service.oauth
@@ -16,12 +15,18 @@ oauth_router = APIRouter(
tags=["callback", "oauth"],
)
pay_router = APIRouter(
prefix='/callback/pay',
tags=["callback", "pay"],
)
upload_router = APIRouter(
prefix='/callback/upload',
tags=["callback", "upload"],
)
callback_router.include_router(oauth_router)
callback_router.include_router(pay_router)
callback_router.include_router(upload_router)
@oauth_router.post(
@@ -32,7 +37,7 @@ callback_router.include_router(upload_router)
def router_callback_qq() -> ResponseBase:
"""
Handle QQ OAuth callback and return user information.
Returns:
ResponseBase: A model containing the response data for the QQ OAuth callback.
"""
@@ -49,27 +54,101 @@ async def router_callback_github(
GitHub OAuth 回调处理
- 错误响应示例:
- {
'error': 'bad_verification_code',
'error_description': 'The code passed is incorrect or expired.',
'error': 'bad_verification_code',
'error_description': 'The code passed is incorrect or expired.',
'error_uri': 'https://docs.github.com/apps/managing-oauth-apps/troubleshooting-oauth-app-access-token-request-errors/#bad-verification-code'
}
Returns:
PlainTextResponse: A response containing the user information from GitHub.
"""
try:
access_token = await service.oauth.github.get_access_token(code)
# [TODO] 把access_token写数据库里
if not access_token:
return PlainTextResponse("GitHub 认证失败", status_code=400)
return PlainTextResponse("Failed to retrieve access token from GitHub.", status_code=400)
user_data = await service.oauth.github.get_user_info(access_token.access_token)
# [TODO] 把 access_token 和 user_data 写数据库,生成 JWT重定向到前端
l.info(f"GitHub OAuth 回调成功: user={user_data.user_data.login}")
return PlainTextResponse("认证成功,功能开发中", status_code=200)
# [TODO] 把user_data写数据库
return PlainTextResponse(f"User information processed successfully, code: {code}, user_data: {user_data.json_dump()}", status_code=200)
except Exception as e:
l.error(f"GitHub OAuth 回调异常: {e}")
return PlainTextResponse("认证过程中发生错误,请重试", status_code=500)
return PlainTextResponse(f"An error occurred: {str(e)}", status_code=500)
@pay_router.post(
path='/alipay',
summary='支付宝支付回调',
description='Handle Alipay payment callback and return payment status.',
)
def router_callback_alipay() -> ResponseBase:
"""
Handle Alipay payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the Alipay payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/wechat',
summary='微信支付回调',
description='Handle WeChat Pay payment callback and return payment status.',
)
def router_callback_wechat() -> ResponseBase:
"""
Handle WeChat Pay payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/stripe',
summary='Stripe支付回调',
description='Handle Stripe payment callback and return payment status.',
)
def router_callback_stripe() -> ResponseBase:
"""
Handle Stripe payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the Stripe payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.get(
path='/easypay',
summary='易支付回调',
description='Handle EasyPay payment callback and return payment status.',
)
def router_callback_easypay() -> PlainTextResponse:
"""
Handle EasyPay payment callback and return payment status.
Returns:
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
"""
http_exceptions.raise_not_implemented()
# return PlainTextResponse("success", status_code=200)
@pay_router.get(
path='/custom/{order_no}/{id}',
summary='自定义支付回调',
description='Handle custom payment callback and return payment status.',
)
def router_callback_custom(order_no: str, id: str) -> ResponseBase:
"""
Handle custom payment callback and return payment status.
Args:
order_no (str): The order number for the payment.
id (str): The ID associated with the payment.
Returns:
ResponseBase: A model containing the response data for the custom payment callback.
"""
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/remote/{session_id}/{key}',
@@ -79,11 +158,11 @@ async def router_callback_github(
def router_callback_remote(session_id: str, key: str) -> ResponseBase:
"""
Handle remote upload callback and return upload status.
Args:
session_id (str): The session ID for the upload.
key (str): The key for the uploaded file.
Returns:
ResponseBase: A model containing the response data for the remote upload callback.
"""
@@ -97,15 +176,15 @@ def router_callback_remote(session_id: str, key: str) -> ResponseBase:
def router_callback_qiniu(session_id: str) -> ResponseBase:
"""
Handle Qiniu Cloud upload callback and return upload status.
Args:
session_id (str): The session ID for the upload.
Returns:
ResponseBase: A model containing the response data for the Qiniu Cloud upload callback.
"""
http_exceptions.raise_not_implemented()
@upload_router.post(
path='/tencent/{session_id}',
summary='腾讯云上传回调',
@@ -114,16 +193,16 @@ def router_callback_qiniu(session_id: str) -> ResponseBase:
def router_callback_tencent(session_id: str) -> ResponseBase:
"""
Handle Tencent Cloud upload callback and return upload status.
Args:
session_id (str): The session ID for the upload.
Returns:
ResponseBase: A model containing the response data for the Tencent Cloud upload callback.
"""
http_exceptions.raise_not_implemented()
@upload_router.post(
@upload_router.post(
path='/aliyun/{session_id}',
summary='阿里云上传回调',
description='Handle Aliyun upload callback and return upload status.',
@@ -131,16 +210,16 @@ def router_callback_tencent(session_id: str) -> ResponseBase:
def router_callback_aliyun(session_id: str) -> ResponseBase:
"""
Handle Aliyun upload callback and return upload status.
Args:
session_id (str): The session ID for the upload.
Returns:
ResponseBase: A model containing the response data for the Aliyun upload callback.
"""
http_exceptions.raise_not_implemented()
@upload_router.post(
@upload_router.post(
path='/upyun/{session_id}',
summary='又拍云上传回调',
description='Handle Upyun upload callback and return upload status.',
@@ -148,10 +227,10 @@ def router_callback_aliyun(session_id: str) -> ResponseBase:
def router_callback_upyun(session_id: str) -> ResponseBase:
"""
Handle Upyun upload callback and return upload status.
Args:
session_id (str): The session ID for the upload.
Returns:
ResponseBase: A model containing the response data for the Upyun upload callback.
"""
@@ -165,10 +244,10 @@ def router_callback_upyun(session_id: str) -> ResponseBase:
def router_callback_aws(session_id: str) -> ResponseBase:
"""
Handle AWS S3 upload callback and return upload status.
Args:
session_id (str): The session ID for the upload.
Returns:
ResponseBase: A model containing the response data for the AWS S3 upload callback.
"""
@@ -182,10 +261,10 @@ def router_callback_aws(session_id: str) -> ResponseBase:
def router_callback_onedrive_finish(session_id: str) -> ResponseBase:
"""
Handle OneDrive upload completion callback and return upload status.
Args:
session_id (str): The session ID for the upload.
Returns:
ResponseBase: A model containing the response data for the OneDrive upload completion callback.
"""
@@ -199,7 +278,7 @@ def router_callback_onedrive_finish(session_id: str) -> ResponseBase:
def router_callback_onedrive_auth() -> ResponseBase:
"""
Handle OneDrive authorization callback and return authorization status.
Returns:
ResponseBase: A model containing the response data for the OneDrive authorization callback.
"""
@@ -213,8 +292,8 @@ def router_callback_onedrive_auth() -> ResponseBase:
def router_callback_google_auth() -> ResponseBase:
"""
Handle Google OAuth completion callback and return authorization status.
Returns:
ResponseBase: A model containing the response data for the Google OAuth completion callback.
"""
http_exceptions.raise_not_implemented()
http_exceptions.raise_not_implemented()

View File

@@ -1,100 +0,0 @@
"""
文件分类筛选端点
按文件类型分类(图片/视频/音频/文档)查询用户的所有文件,
跨目录搜索,支持分页。扩展名映射从数据库 Setting 表读取。
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
FileCategory,
ListResponse,
Object,
ObjectResponse,
ObjectType,
Setting,
SettingsType,
User,
)
category_router = APIRouter(
prefix="/category",
tags=["category"],
)
@category_router.get(
path="/{category}",
summary="按分类获取文件列表",
)
async def router_category_list(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
category: FileCategory,
table_view: TableViewRequestDep,
) -> ListResponse[ObjectResponse]:
"""
按文件类型分类查询用户的所有文件
跨所有目录搜索,返回分页结果。
扩展名配置从数据库 Setting 表读取type=file_category
认证:
- JWT token in Authorization header
路径参数:
- category: 文件分类image / video / audio / document
查询参数:
- offset: 分页偏移量默认0
- limit: 每页数量默认20最大100
- desc: 是否降序默认true
- order: 排序字段created_at / updated_at
响应:
- ListResponse[ObjectResponse]: 分页文件列表
错误处理:
- HTTPException 422: category 参数无效
- HTTPException 404: 该分类未配置扩展名
"""
# 从数据库读取该分类的扩展名配置
setting = await Setting.get(
session,
(Setting.type == SettingsType.FILE_CATEGORY) & (Setting.name == category.value),
)
if not setting or not setting.value:
raise HTTPException(status_code=404, detail=f"分类 {category.value} 未配置扩展名")
extensions = [ext.strip() for ext in setting.value.split(",") if ext.strip()]
if not extensions:
raise HTTPException(status_code=404, detail=f"分类 {category.value} 扩展名列表为空")
result = await Object.get_by_category(
session,
user.id,
extensions,
table_view=table_view,
)
items = [
ObjectResponse(
id=obj.id,
name=obj.name,
type=ObjectType.FILE,
size=obj.size,
mime_type=obj.mime_type,
thumb=False,
created_at=obj.created_at,
updated_at=obj.updated_at,
source_enabled=False,
)
for obj in result.items
]
return ListResponse(count=result.count, items=items)

View File

@@ -57,7 +57,7 @@ async def _get_directory_response(
policy_response = PolicyResponse(
id=policy.id,
name=policy.name,
type=policy.type,
type=policy.type.value,
max_size=policy.max_size,
)
@@ -189,14 +189,6 @@ async def router_directory_create(
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
policy_id = request.policy_id if request.policy_id else parent.policy_id
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
parent_id = parent.id # 在 save 前保存
new_folder = Object(
@@ -206,4 +198,4 @@ async def router_directory_create(
parent_id=parent_id,
policy_id=policy_id,
)
new_folder = await new_folder.save(session)
await new_folder.save(session)

View File

@@ -13,11 +13,9 @@ from datetime import datetime, timedelta
from typing import Annotated
from uuid import UUID
import orjson
import whatthepatch
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, RedirectResponse
from starlette.responses import Response
from loguru import logger as l
from sqlmodel_ext import SQLModelBase
from whatthepatch.exceptions import HunkApplyException
@@ -46,9 +44,7 @@ from sqlmodels import (
User,
WopiSessionResponse,
)
import orjson
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
from service.storage import LocalStorageService, adjust_user_storage
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
from utils.JWT.wopi_token import create_wopi_token
from utils import http_exceptions
@@ -184,14 +180,9 @@ async def create_upload_session(
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get_exist_one(session, policy_id)
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 验证文件大小限制
_check_policy_size_limit(policy, request.file_size)
@@ -219,7 +210,6 @@ async def create_upload_session(
# 生成存储路径
storage_path: str | None = None
s3_upload_id: str | None = None
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path(
@@ -227,25 +217,8 @@ async def create_upload_session(
original_filename=request.file_name,
)
storage_path = full_path
elif policy.type == PolicyType.S3:
s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user.id,
original_filename=request.file_name,
)
# 多分片时创建 multipart upload
if total_chunks > 1:
s3_upload_id = await s3_service.create_multipart_upload(
storage_path, content_type='application/octet-stream',
)
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
if request.file_size > 0:
await adjust_user_storage(session, user.id, request.file_size, commit=False)
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 创建上传会话
upload_session = UploadSession(
@@ -254,7 +227,6 @@ async def create_upload_session(
chunk_size=chunk_size,
total_chunks=total_chunks,
storage_path=storage_path,
s3_upload_id=s3_upload_id,
expires_at=datetime.now() + timedelta(hours=24),
owner_id=user.id,
parent_id=request.parent_id,
@@ -330,38 +302,8 @@ async def upload_chunk(
content,
offset,
)
elif policy.type == PolicyType.S3:
if not upload_session.storage_path:
raise HTTPException(status_code=500, detail="存储路径丢失")
s3_service = await S3StorageService.from_policy(policy)
if upload_session.total_chunks == 1:
# 单分片:直接 PUT 上传
await s3_service.upload_file(upload_session.storage_path, content)
else:
# 多分片UploadPart
if not upload_session.s3_upload_id:
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
etag = await s3_service.upload_part(
upload_session.storage_path,
upload_session.s3_upload_id,
chunk_index + 1, # S3 part number 从 1 开始
content,
)
# 追加 ETag 到 s3_part_etags
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
etags.append([chunk_index + 1, etag])
upload_session.s3_part_etags = orjson.dumps(etags).decode()
# 在 savecommit前缓存后续需要的属性commit 后 ORM 对象会过期)
policy_type = policy.type
s3_upload_id = upload_session.s3_upload_id
s3_part_etags = upload_session.s3_part_etags
s3_service_for_complete: S3StorageService | None = None
if policy_type == PolicyType.S3:
s3_service_for_complete = await S3StorageService.from_policy(policy)
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 更新会话进度
upload_session.uploaded_chunks += 1
@@ -377,26 +319,12 @@ async def upload_chunk(
if is_complete:
# 保存 upload_session 属性commit 后会过期)
file_name = upload_session.file_name
file_size = upload_session.file_size
uploaded_size = upload_session.uploaded_size
storage_path = upload_session.storage_path
upload_session_id = upload_session.id
parent_id = upload_session.parent_id
policy_id = upload_session.policy_id
# S3 多分片上传完成:合并分片
if (
policy_type == PolicyType.S3
and s3_upload_id
and s3_part_etags
and s3_service_for_complete
):
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
parts = [(int(pn), str(et)) for pn, et in parts_data]
await s3_service_for_complete.complete_multipart_upload(
storage_path, s3_upload_id, parts,
)
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=storage_path,
@@ -427,10 +355,9 @@ async def upload_chunk(
commit=False
)
# 调整存储配额差值(创建会话时已预扣 file_size这里只补差
size_diff = uploaded_size - file_size
if size_diff != 0:
await adjust_user_storage(session, user_id, size_diff, commit=False)
# 更新用户存储配额
if uploaded_size > 0:
await adjust_user_storage(session, user_id, uploaded_size, commit=False)
# 统一提交所有更改
await session.commit()
@@ -463,25 +390,9 @@ async def delete_upload_session(
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and upload_session.storage_path:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 如果有分片上传,先取消
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
else:
# 单分片上传已完成的话,删除已上传的文件
if upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
# 删除会话记录
await UploadSession.delete(session, upload_session)
@@ -511,22 +422,9 @@ async def clear_upload_sessions(
for upload_session in sessions:
# 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and upload_session.storage_path:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
elif upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
await UploadSession.delete(session, upload_session)
deleted_count += 1
@@ -588,12 +486,11 @@ async def create_download_token_endpoint(
path='/{token}',
summary='下载文件',
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
response_model=None,
)
async def download_file(
session: SessionDep,
token: str,
) -> Response:
) -> FileResponse:
"""
下载文件端点
@@ -643,15 +540,8 @@ async def download_file(
filename=file_obj.name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 302 重定向到预签名 URL
presigned_url = s3_service.generate_presigned_url(
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
raise HTTPException(status_code=500, detail="不支持的存储类型")
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# ==================== 包含子路由 ====================
@@ -709,7 +599,9 @@ async def create_empty_file(
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get_exist_one(session, policy_id)
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 生成存储路径并创建空文件
storage_path: str | None = None
@@ -721,13 +613,8 @@ async def create_empty_file(
)
await storage_service.create_empty_file(full_path)
storage_path = full_path
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user_id,
original_filename=request.name,
)
await s3_service.upload_file(storage_path, b'')
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
@@ -808,7 +695,6 @@ async def create_wopi_session(
)
wopi_app: FileApp | None = None
matched_ext_record: FileAppExtension | None = None
for ext_record in ext_records:
app = ext_record.app
if app.type == FileAppType.WOPI and app.is_enabled:
@@ -824,20 +710,13 @@ async def create_wopi_session(
if not result.first():
continue
wopi_app = app
matched_ext_record = ext_record
break
if not wopi_app:
http_exceptions.raise_not_found("无可用的 WOPI 查看器")
# 优先使用 per-extension URLDiscovery 自动填充),回退到全局模板
editor_url_template: str | None = None
if matched_ext_record and matched_ext_record.wopi_action_url:
editor_url_template = matched_ext_record.wopi_action_url
if not editor_url_template:
editor_url_template = wopi_app.wopi_editor_url_template
if not editor_url_template:
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板,请先执行 Discovery 或手动配置")
if not wopi_app.wopi_editor_url_template:
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板")
# 获取站点 URL
site_url_setting: Setting | None = await Setting.get(
@@ -853,8 +732,12 @@ async def create_wopi_session(
# 构建 wopi_src
wopi_src = f"{site_url}/wopi/files/{file_id}"
# 构建 editor URL(只替换 wopi_srctoken 通过 POST 表单传递)
editor_url = editor_url_template.format(wopi_src=wopi_src)
# 构建 editor URL
editor_url = wopi_app.wopi_editor_url_template.format(
wopi_src=wopi_src,
access_token=token,
access_token_ttl=access_token_ttl,
)
return WopiSessionResponse(
wopi_src=wopi_src,
@@ -915,13 +798,12 @@ async def _validate_source_link(
path='/get/{file_id}/{name}',
summary='文件外链(直接输出文件数据)',
description='通过外链直接获取文件内容,公开访问无需认证。',
response_model=None,
)
async def file_get(
session: SessionDep,
file_id: UUID,
name: str,
) -> Response:
) -> FileResponse:
"""
文件外链端点(直接输出)
@@ -933,32 +815,25 @@ async def file_get(
"""
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存物理路径save 后对象属性会过期)
file_path = physical_file.storage_path
# 递增下载次数
link.downloads += 1
link = await link.save(session)
await link.save(session)
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(file_path):
http_exceptions.raise_not_found("物理文件不存在")
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
# S3 外链直接输出302 重定向到预签名 URL
s3_service = await S3StorageService.from_policy(policy)
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
@router.get(
@@ -971,7 +846,7 @@ async def file_source_redirect(
session: SessionDep,
file_id: UUID,
name: str,
) -> Response:
) -> FileResponse | RedirectResponse:
"""
文件外链端点(重定向/直接输出)
@@ -985,6 +860,13 @@ async def file_source_redirect(
"""
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存所有需要的值save 后对象属性会过期)
file_path = physical_file.storage_path
is_private = policy.is_private
@@ -992,38 +874,20 @@ async def file_source_redirect(
# 递增下载次数
link.downloads += 1
link = await link.save(session)
await link.save(session)
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(file_path):
http_exceptions.raise_not_found("物理文件不存在")
# 公有存储302 重定向到 base_url
if not is_private and base_url:
relative_path = storage_service.get_relative_path(file_path)
redirect_url = f"{base_url}/{relative_path}"
return RedirectResponse(url=redirect_url, status_code=302)
# 有存储302 重定向到 base_url
if not is_private and base_url:
relative_path = storage_service.get_relative_path(file_path)
redirect_url = f"{base_url}/{relative_path}"
return RedirectResponse(url=redirect_url, status_code=302)
# 私有存储或 base_url 为空:通过应用代理文件
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 公有存储且有 base_url直接重定向到公开 URL
if not is_private and base_url:
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
return RedirectResponse(url=redirect_url, status_code=302)
# 私有存储:生成预签名 URL 重定向
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
# 有存储或 base_url 为空:通过应用代理文件
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
@router.put(
@@ -1077,15 +941,11 @@ async def file_content(
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
# 读取文件内容
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(physical_file.storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(physical_file.storage_path)
try:
content = raw_bytes.decode('utf-8')
@@ -1151,15 +1011,11 @@ async def patch_file_content(
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
# 读取文件内容
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
raw_bytes = await storage_service.read_file(storage_path)
# 解码 + 规范化
original_text = raw_bytes.decode('utf-8')
@@ -1193,10 +1049,7 @@ async def patch_file_content(
_check_policy_size_limit(policy, len(new_bytes))
# 写入文件
if policy.type == PolicyType.LOCAL:
await storage_service.write_file(storage_path, new_bytes)
elif policy.type == PolicyType.S3:
await s3_service.upload_file(storage_path, new_bytes)
await storage_service.write_file(storage_path, new_bytes)
# 更新数据库
owner_id = file_obj.owner_id

View File

@@ -8,14 +8,13 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
CreateFileRequest,
Group,
Object,
ObjectCopyRequest,
ObjectDeleteRequest,
@@ -23,42 +22,24 @@ from sqlmodels import (
ObjectPropertyDetailResponse,
ObjectPropertyResponse,
ObjectRenameRequest,
ObjectSwitchPolicyRequest,
ObjectType,
PhysicalFile,
Policy,
PolicyType,
Task,
TaskProps,
TaskStatus,
TaskSummaryBase,
TaskType,
User,
# 元数据相关
ObjectMetadata,
MetadataResponse,
MetadataPatchRequest,
INTERNAL_NAMESPACES,
USER_WRITABLE_NAMESPACES,
)
from service.storage import (
LocalStorageService,
adjust_user_storage,
copy_object_recursive,
migrate_file_with_task,
migrate_directory_files,
)
from service.storage.object import soft_delete_objects
from sqlmodels.database_connection import DatabaseManager
from utils import http_exceptions
from .custom_property import router as custom_property_router
object_router = APIRouter(
prefix="/object",
tags=["object"]
)
object_router.include_router(custom_property_router)
@object_router.post(
path='/',
@@ -112,7 +93,9 @@ async def router_object_create(
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get_exist_one(session, policy_id)
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
parent_id = parent.id
@@ -147,7 +130,7 @@ async def router_object_create(
owner_id=user_id,
policy_id=policy_id,
)
file_object = await file_object.save(session)
await file_object.save(session)
l.info(f"创建空白文件: {request.name}")
@@ -472,7 +455,7 @@ async def router_object_rename(
# 更新名称
obj.name = new_name
obj = await obj.save(session)
await obj.save(session)
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
@@ -510,7 +493,6 @@ async def router_object_property(
name=obj.name,
type=obj.type,
size=obj.size,
mime_type=obj.mime_type,
created_at=obj.created_at,
updated_at=obj.updated_at,
parent_id=obj.parent_id,
@@ -538,7 +520,7 @@ async def router_object_property_detail(
obj = await Object.get(
session,
(Object.id == id) & (Object.deleted_at == None),
load=Object.metadata_entries,
load=Object.file_metadata,
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
@@ -561,301 +543,35 @@ async def router_object_property_detail(
total_views = sum(s.views for s in shares)
total_downloads = sum(s.downloads for s in shares)
# 获取物理文件信息(引用计数、校验和)
# 获取物理文件引用计数
reference_count = 1
checksum_md5: str | None = None
checksum_sha256: str | None = None
if obj.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
if physical_file:
reference_count = physical_file.reference_count
checksum_md5 = physical_file.checksum_md5
checksum_sha256 = physical_file.checksum_sha256
# 构建元数据字典(排除内部命名空间)
metadata: dict[str, str] = {}
for entry in obj.metadata_entries:
ns = entry.name.split(":")[0] if ":" in entry.name else ""
if ns not in INTERNAL_NAMESPACES:
metadata[entry.name] = entry.value
return ObjectPropertyDetailResponse(
# 构建响应
response = ObjectPropertyDetailResponse(
id=obj.id,
name=obj.name,
type=obj.type,
size=obj.size,
mime_type=obj.mime_type,
created_at=obj.created_at,
updated_at=obj.updated_at,
parent_id=obj.parent_id,
checksum_md5=checksum_md5,
checksum_sha256=checksum_sha256,
policy_name=policy_name,
share_count=share_count,
total_views=total_views,
total_downloads=total_downloads,
reference_count=reference_count,
metadatas=metadata,
)
# 添加文件元数据
if obj.file_metadata:
response.mime_type = obj.file_metadata.mime_type
response.width = obj.file_metadata.width
response.height = obj.file_metadata.height
response.duration = obj.file_metadata.duration
response.checksum_md5 = obj.file_metadata.checksum_md5
@object_router.patch(
path='/{object_id}/policy',
summary='切换对象存储策略',
)
async def router_object_switch_policy(
session: SessionDep,
background_tasks: BackgroundTasks,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: ObjectSwitchPolicyRequest,
) -> TaskSummaryBase:
"""
切换对象的存储策略
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
目录:更新目录 policy_id新文件使用新策略
若 is_migrate_existing=True额外创建后台任务迁移所有已有文件。
认证JWT Bearer Token
错误处理:
- 404: 对象不存在
- 403: 无权操作此对象 / 用户组无权使用目标策略
- 400: 目标策略与当前相同 / 不能对根目录操作
"""
user_id = user.id
# 查找对象
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None)
)
if not obj:
http_exceptions.raise_not_found("对象不存在")
if obj.owner_id != user_id:
http_exceptions.raise_forbidden("无权操作此对象")
if obj.is_banned:
http_exceptions.raise_banned()
# 根目录不能直接切换策略(应通过子对象或子目录操作)
if obj.parent_id is None:
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
# 校验目标策略存在
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
if not dest_policy:
http_exceptions.raise_not_found("目标存储策略不存在")
# 校验用户组权限
group: Group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
allowed_ids = {p.id for p in group.policies}
if request.policy_id not in allowed_ids:
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
# 不能切换到相同策略
if obj.policy_id == request.policy_id:
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
# 保存必要的属性,避免 save 后对象过期
src_policy_id = obj.policy_id
obj_id = obj.id
obj_is_file = obj.type == ObjectType.FILE
dest_policy_id = request.policy_id
dest_policy_name = dest_policy.name
# 创建任务记录
task = Task(
type=TaskType.POLICY_MIGRATE,
status=TaskStatus.QUEUED,
user_id=user_id,
)
task = await task.save(session)
task_id = task.id
task_props = TaskProps(
task_id=task_id,
source_policy_id=src_policy_id,
dest_policy_id=dest_policy_id,
object_id=obj_id,
)
task_props = await task_props.save(session)
if obj_is_file:
# 文件:后台迁移
async def _run_file_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_obj = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
background_tasks.add_task(_run_file_migration)
else:
# 目录:先更新目录自身的 policy_id
obj = await Object.get(session, Object.id == obj_id)
obj.policy_id = dest_policy_id
obj = await obj.save(session)
if request.is_migrate_existing:
# 后台迁移所有已有文件
async def _run_dir_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_folder = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
background_tasks.add_task(_run_dir_migration)
else:
# 不迁移已有文件,直接完成任务
task = await Task.get(session, Task.id == task_id)
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
# 重新获取 task 以读取最新状态
task = await Task.get(session, Task.id == task_id)
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
return TaskSummaryBase(
id=task.id,
type=task.type,
status=task.status,
progress=task.progress,
error=task.error,
user_id=task.user_id,
created_at=task.created_at,
updated_at=task.updated_at,
)
# ==================== 元数据端点 ====================
@object_router.get(
path='/{object_id}/metadata',
summary='获取对象元数据',
description='获取对象的元数据键值对,可按命名空间过滤。',
)
async def router_get_object_metadata(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
ns: str | None = None,
) -> MetadataResponse:
"""
获取对象元数据端点
认证JWT token 必填
查询参数:
- ns: 逗号分隔的命名空间列表(如 exif,stream不传返回所有非内部命名空间
错误处理:
- 404: 对象不存在
- 403: 无权查看此对象
"""
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None),
load=Object.metadata_entries,
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权查看此对象")
# 解析命名空间过滤
ns_filter: set[str] | None = None
if ns:
ns_filter = {n.strip() for n in ns.split(",") if n.strip()}
# 不允许查看内部命名空间
ns_filter -= INTERNAL_NAMESPACES
# 构建元数据字典
metadata: dict[str, str] = {}
for entry in obj.metadata_entries:
entry_ns = entry.name.split(":")[0] if ":" in entry.name else ""
if entry_ns in INTERNAL_NAMESPACES:
continue
if ns_filter is not None and entry_ns not in ns_filter:
continue
metadata[entry.name] = entry.value
return MetadataResponse(metadatas=metadata)
@object_router.patch(
path='/{object_id}/metadata',
summary='批量更新对象元数据',
description='批量设置或删除对象的元数据条目。仅允许修改 custom: 命名空间。',
status_code=204,
)
async def router_patch_object_metadata(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: MetadataPatchRequest,
) -> None:
"""
批量更新对象元数据端点
请求体中值为 None 的键将被删除,其余键将被设置/更新。
用户只能修改 custom: 命名空间的条目。
认证JWT token 必填
错误处理:
- 400: 尝试修改非 custom: 命名空间的条目
- 404: 对象不存在
- 403: 无权操作此对象
"""
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None),
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此对象")
for patch in request.patches:
# 验证命名空间
patch_ns = patch.key.split(":")[0] if ":" in patch.key else ""
if patch_ns not in USER_WRITABLE_NAMESPACES:
raise HTTPException(
status_code=400,
detail=f"不允许修改命名空间 '{patch_ns}' 的元数据,仅允许 custom: 命名空间",
)
if patch.value is None:
# 删除元数据条目
existing = await ObjectMetadata.get(
session,
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
)
if existing:
await ObjectMetadata.delete(session, instances=existing)
else:
# 设置/更新元数据条目
existing = await ObjectMetadata.get(
session,
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
)
if existing:
existing.value = patch.value
existing = await existing.save(session)
else:
entry = ObjectMetadata(
object_id=object_id,
name=patch.key,
value=patch.value,
is_public=True,
)
entry = await entry.save(session)
l.info(f"用户 {user.id} 更新了对象 {object_id}{len(request.patches)} 条元数据")
return response

View File

@@ -1,168 +0,0 @@
"""
用户自定义属性定义路由
提供自定义属性模板的增删改查功能。
用户可以定义类型化的属性模板(如标签、评分、分类等),
然后通过元数据 PATCH 端点为对象设置属性值。
路由前缀:/custom_property
"""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
CustomPropertyDefinition,
CustomPropertyCreateRequest,
CustomPropertyUpdateRequest,
CustomPropertyResponse,
User,
)
router = APIRouter(
prefix="/custom_property",
tags=["custom_property"],
)
@router.get(
path='',
summary='获取自定义属性定义列表',
description='获取当前用户的所有自定义属性定义,按 sort_order 排序。',
)
async def router_list_custom_properties(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[CustomPropertyResponse]:
"""
获取自定义属性定义列表端点
认证JWT token 必填
返回当前用户定义的所有自定义属性模板。
"""
definitions = await CustomPropertyDefinition.get(
session,
CustomPropertyDefinition.owner_id == user.id,
fetch_mode="all",
)
return [
CustomPropertyResponse(
id=d.id,
name=d.name,
type=d.type,
icon=d.icon,
options=d.options,
default_value=d.default_value,
sort_order=d.sort_order,
)
for d in sorted(definitions, key=lambda x: x.sort_order)
]
@router.post(
path='',
summary='创建自定义属性定义',
description='创建一个新的自定义属性模板。',
status_code=204,
)
async def router_create_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: CustomPropertyCreateRequest,
) -> None:
"""
创建自定义属性定义端点
认证JWT token 必填
错误处理:
- 400: 请求数据无效
- 409: 同名属性已存在
"""
# 检查同名属性
existing = await CustomPropertyDefinition.get(
session,
(CustomPropertyDefinition.owner_id == user.id) &
(CustomPropertyDefinition.name == request.name),
)
if existing:
raise HTTPException(status_code=409, detail="同名自定义属性已存在")
definition = CustomPropertyDefinition(
owner_id=user.id,
name=request.name,
type=request.type,
icon=request.icon,
options=request.options,
default_value=request.default_value,
)
definition = await definition.save(session)
l.info(f"用户 {user.id} 创建了自定义属性: {request.name}")
@router.patch(
path='/{id}',
summary='更新自定义属性定义',
description='更新自定义属性模板的名称、图标、选项等。',
status_code=204,
)
async def router_update_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
request: CustomPropertyUpdateRequest,
) -> None:
"""
更新自定义属性定义端点
认证JWT token 必填
错误处理:
- 404: 属性定义不存在
- 403: 无权操作此属性
"""
definition = await CustomPropertyDefinition.get_exist_one(session, id)
if definition.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此属性")
definition = await definition.update(session, request)
l.info(f"用户 {user.id} 更新了自定义属性: {id}")
@router.delete(
path='/{id}',
summary='删除自定义属性定义',
description='删除自定义属性模板。注意:不会自动清理已使用该属性的元数据条目。',
status_code=204,
)
async def router_delete_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> None:
"""
删除自定义属性定义端点
认证JWT token 必填
错误处理:
- 404: 属性定义不存在
- 403: 无权操作此属性
"""
definition = await CustomPropertyDefinition.get_exist_one(session, id)
if definition.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此属性")
await CustomPropertyDefinition.delete(session, instances=definition)
l.info(f"用户 {user.id} 删除了自定义属性: {id}")

View File

@@ -45,7 +45,12 @@ async def router_share_get(
4. 返回分享详情(含文件树和分享者信息)
"""
# 1. 查询分享(预加载 user 和 object
share = await Share.get_exist_one(session, id, load=[Share.user, Share.object])
share = await Share.get(
session, Share.id == id,
load=[Share.user, Share.object],
)
if not share:
http_exceptions.raise_not_found(detail="分享不存在或已被取消")
# 2. 检查过期
now = datetime.now()
@@ -469,29 +474,16 @@ def router_share_update(id: str) -> ResponseBase:
path='/{id}',
summary='删除分享',
description='Delete a share by ID.',
status_code=204,
dependencies=[Depends(auth_required)]
)
async def router_share_delete(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> None:
def router_share_delete(id: str) -> ResponseBase:
"""
删除分享
认证:需要 JWT token
流程:
1. 通过分享ID查找分享
2. 验证分享属于当前用户
3. 删除分享记录
Delete a share by ID.
Args:
id (str): The ID of the share to be deleted.
Returns:
ResponseBase: A model containing the response data for the deleted share.
"""
share = await Share.get_exist_one(session, id)
if share.user_id != user.id:
http_exceptions.raise_forbidden(detail="无权删除此分享")
user_id = user.id
share_code = share.code
await Share.delete(session, share)
l.info(f"用户 {user_id} 删除了分享: {share_code}")
http_exceptions.raise_not_implemented()

View File

@@ -82,8 +82,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
(Setting.type == SettingsType.REGISTER) |
(Setting.type == SettingsType.CAPTCHA) |
(Setting.type == SettingsType.AUTH) |
(Setting.type == SettingsType.OAUTH) |
(Setting.type == SettingsType.AVATAR),
(Setting.type == SettingsType.OAUTH),
fetch_mode="all",
)
@@ -123,7 +122,6 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
password_required=s.get("auth_password_required") == "1",
phone_binding_required=s.get("auth_phone_binding_required") == "1",
email_binding_required=s.get("auth_email_binding_required") == "1",
avatar_max_size=int(s["avatar_size"]),
footer_code=s.get("footer_code"),
tos_url=s.get("tos_url"),
privacy_url=s.get("privacy_url"),

View File

@@ -5,7 +5,6 @@ import json
import jwt
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, RedirectResponse
from itsdangerous import URLSafeTimedSerializer
from loguru import logger
from webauthn import (
@@ -234,7 +233,7 @@ async def router_user_register(
group_id=default_group.id,
)
new_user_id = new_user.id
new_user = await new_user.save(session)
await new_user.save(session)
# 7. 创建 AuthIdentity
hashed_password = Password.hash(request.credential) if request.credential else None
@@ -246,14 +245,13 @@ async def router_user_register(
is_verified=False,
user_id=new_user_id,
)
identity = await identity.save(session)
await identity.save(session)
# 8. 创建用户根目录(使用用户组关联的第一个存储策略)
await session.refresh(default_group, ['policies'])
if not default_group.policies:
logger.error("默认用户组未关联任何存储策略")
# 8. 创建用户根目录
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
if not default_policy:
logger.error("默认存储策略不存在")
http_exceptions.raise_internal_error()
default_policy = default_group.policies[0]
await sqlmodels.Object(
name="/",
@@ -320,7 +318,7 @@ async def router_user_magic_link(
site_url = site_url_setting.value if site_url_setting else "http://localhost"
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token}
logger.info(f"Magic Link token 已 {request.email} 生成 (邮件发送待实现)")
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)")
@user_router.post(
@@ -359,78 +357,20 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
@user_router.get(
path='/avatar/{id}/{size}',
summary='获取用户头像',
response_model=None,
description='Get user avatar by ID and size.',
)
async def router_user_avatar(
session: SessionDep,
id: UUID,
size: int = 128,
) -> FileResponse | RedirectResponse:
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
"""
获取指定用户指定尺寸的头像(公开端点,无需认证)
Get user avatar by ID and size.
路径参数:
- id: 用户 UUID
- size: 请求的头像尺寸px默认 128
Args:
id (str): The user ID.
size (int): The size of the avatar image.
行为:
- default: 302 重定向到 Gravatar identicon
- gravatar: 302 重定向到 Gravatar使用用户邮箱 MD5
- file: 返回本地 WebP 文件
响应:
- 200: image/webpfile 模式)
- 302: 重定向到外部 URLdefault/gravatar 模式)
- 404: 用户不存在
缓存Cache-Control: public, max-age=3600
Returns:
str: A Base64 encoded string of the user avatar image.
"""
import aiofiles.os
from service.avatar import (
get_avatar_file_path,
get_avatar_settings,
gravatar_url,
resolve_avatar_size,
)
user = await sqlmodels.User.get(session, sqlmodels.User.id == id)
if not user:
http_exceptions.raise_not_found("用户不存在")
avatar_path, _, size_l, size_m, size_s = await get_avatar_settings(session)
if user.avatar == "file":
size_label = resolve_avatar_size(size, size_l, size_m, size_s)
file_path = get_avatar_file_path(avatar_path, user.id, size_label)
if not await aiofiles.os.path.exists(file_path):
# 文件丢失,降级为 identicon
fallback_url = gravatar_url(str(user.id), size, "https://www.gravatar.com/")
return RedirectResponse(url=fallback_url, status_code=302)
return FileResponse(
path=file_path,
media_type="image/webp",
headers={"Cache-Control": "public, max-age=3600"},
)
elif user.avatar == "gravatar":
gravatar_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AVATAR)
& (sqlmodels.Setting.name == "gravatar_server"),
)
server = gravatar_setting.value if gravatar_setting else "https://www.gravatar.com/"
email = user.email or str(user.id)
url = gravatar_url(email, size, server)
return RedirectResponse(url=url, status_code=302)
else:
# default: identicon
email_or_id = user.email or str(user.id)
url = gravatar_url(email_or_id, size, "https://www.gravatar.com/")
return RedirectResponse(url=url, status_code=302)
http_exceptions.raise_not_implemented()
#####################
# 需要登录的接口
@@ -494,24 +434,9 @@ async def router_user_storage(
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 查询用户所有未过期容量包的 size 总和
from datetime import datetime
from sqlalchemy import func, select, and_, or_
# [TODO] 总空间加上用户购买的额外空间
now = datetime.now()
stmt = select(func.coalesce(func.sum(sqlmodels.StoragePack.size), 0)).where(
and_(
sqlmodels.StoragePack.user_id == user.id,
or_(
sqlmodels.StoragePack.expired_time.is_(None),
sqlmodels.StoragePack.expired_time > now,
),
)
)
result = await session.exec(stmt)
active_packs_total: int = result.scalar_one()
total: int = group.max_storage + active_packs_total
total: int = group.max_storage
used: int = user.storage
free: int = max(0, total - used)
@@ -653,7 +578,7 @@ async def router_user_authn_finish(
is_verified=True,
user_id=user.id,
)
identity = await identity.save(session)
await identity.save(session)
return authn.to_detail_response()

View File

@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi import APIRouter, Depends, HTTPException, status
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
import sqlmodels
@@ -13,7 +13,6 @@ from sqlmodels import (
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
ChangePasswordRequest,
AuthnDetailResponse, AuthnRenameRequest,
PolicySummary,
)
from sqlmodels.color import ThemeColorsBase
from sqlmodels.user_authn import UserAuthn
@@ -32,25 +31,16 @@ user_settings_router.include_router(file_viewers_router)
@user_settings_router.get(
path='/policies',
summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
)
async def router_user_settings_policies(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[PolicySummary]:
def router_user_settings_policies() -> sqlmodels.ResponseBase:
"""
获取当前用户所在组可选的存储策略列表
Get user selectable storage policies.
返回用户组关联的所有存储策略的摘要信息。
Returns:
dict: A dictionary containing available storage policies for the user.
"""
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
return [
PolicySummary(
id=p.id, name=p.name, type=p.type,
server=p.server, max_size=p.max_size, is_private=p.is_private,
)
for p in group.policies
]
http_exceptions.raise_not_implemented()
@user_settings_router.get(
@@ -165,121 +155,34 @@ async def router_user_settings(
@user_settings_router.post(
path='/avatar',
summary='从文件上传头像',
status_code=204,
description='Upload user avatar from file.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_avatar(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
file: UploadFile = File(...),
) -> None:
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
"""
上传头像文件
Upload user avatar from file.
认证JWT token
请求体multipart/form-datafile 字段
流程:
1. 验证文件 MIME 类型JPEG/PNG/GIF/WebP
2. 验证文件大小 <= avatar_size 设置(默认 2MB
3. 调用 Pillow 验证图片有效性并处理(居中裁剪、缩放 L/M/S
4. 保存三种尺寸的 WebP 文件
5. 更新 User.avatar = "file"
错误处理:
- 400: 文件类型不支持 / 图片无法解析
- 413: 文件过大
Returns:
dict: A dictionary containing the result of the avatar upload.
"""
from service.avatar import (
ALLOWED_CONTENT_TYPES,
get_avatar_settings,
process_and_save_avatar,
)
# 验证 MIME 类型
if file.content_type not in ALLOWED_CONTENT_TYPES:
http_exceptions.raise_bad_request(
f"不支持的图片格式,允许: {', '.join(ALLOWED_CONTENT_TYPES)}"
)
# 读取并验证大小
_, max_upload_size, _, _, _ = await get_avatar_settings(session)
raw_bytes = await file.read()
if len(raw_bytes) > max_upload_size:
raise HTTPException(
status_code=413,
detail=f"文件过大,最大允许 {max_upload_size} 字节",
)
# 处理并保存(内部会验证图片有效性,无效抛出 ValueError
try:
await process_and_save_avatar(session, user.id, raw_bytes)
except ValueError as e:
http_exceptions.raise_bad_request(str(e))
# 更新用户头像字段
user.avatar = "file"
user = await user.save(session)
http_exceptions.raise_not_implemented()
@user_settings_router.put(
path='/avatar',
summary='设定为 Gravatar 头像',
summary='设定为Gravatar头像',
description='Set user avatar to Gravatar.',
dependencies=[Depends(auth_required)],
status_code=204,
)
async def router_user_settings_avatar_gravatar(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> None:
def router_user_settings_avatar_gravatar() -> None:
"""
将头像切换为 Gravatar
Set user avatar to Gravatar.
认证JWT token
流程:
1. 验证用户有邮箱Gravatar 基于邮箱 MD5
2. 如果当前是 FILE 头像,删除本地文件
3. 更新 User.avatar = "gravatar"
错误处理:
- 400: 用户没有邮箱
Returns:
dict: A dictionary containing the result of setting the Gravatar avatar.
"""
from service.avatar import delete_avatar_files
if not user.email:
http_exceptions.raise_bad_request("Gravatar 需要邮箱,请先绑定邮箱")
if user.avatar == "file":
await delete_avatar_files(session, user.id)
user.avatar = "gravatar"
user = await user.save(session)
@user_settings_router.delete(
path='/avatar',
summary='重置头像为默认',
status_code=204,
)
async def router_user_settings_avatar_delete(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> None:
"""
重置头像为默认
认证JWT token
流程:
1. 如果当前是 FILE 头像,删除本地文件
2. 更新 User.avatar = "default"
"""
from service.avatar import delete_avatar_files
if user.avatar == "file":
await delete_avatar_files(session, user.id)
user.avatar = "default"
user = await user.save(session)
http_exceptions.raise_not_implemented()
@user_settings_router.patch(
@@ -321,7 +224,7 @@ async def router_user_settings_theme(
user.color_error = request.theme_colors.error
user.color_neutral = request.theme_colors.neutral
user = await user.save(session)
await user.save(session)
@user_settings_router.patch(
@@ -358,7 +261,7 @@ async def router_user_settings_change_password(
http_exceptions.raise_forbidden("当前密码错误")
email_identity.credential = Password.hash(request.new_password)
email_identity = await email_identity.save(session)
await email_identity.save(session)
@user_settings_router.patch(
@@ -392,7 +295,7 @@ async def router_user_settings_patch(
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
setattr(user, option.value, value)
user = await user.save(session)
await user.save(session)
@user_settings_router.get(
@@ -454,7 +357,7 @@ async def router_user_settings_2fa_enable(
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
extra["two_factor"] = secret
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
email_identity = await email_identity.save(session)
await email_identity.save(session)
# ==================== 认证身份管理 ====================

View File

@@ -79,7 +79,9 @@ async def set_default_viewer(
if existing:
existing.app_id = request.app_id
existing = await existing.save(session, load=UserFileAppDefault.app)
existing = await existing.save(session)
# 重新加载 app 关系
await session.refresh(existing, attribute_names=["app"])
return existing.to_response()
else:
new_default = UserFileAppDefault(
@@ -87,7 +89,9 @@ async def set_default_viewer(
extension=normalized_ext,
app_id=request.app_id,
)
new_default = await new_default.save(session, load=UserFileAppDefault.app)
new_default = await new_default.save(session)
# 重新加载 app 关系
await session.refresh(new_default, attribute_names=["app"])
return new_default.to_response()

View File

@@ -0,0 +1,106 @@
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from sqlmodels import ResponseBase
from utils import http_exceptions
vas_router = APIRouter(
prefix="/vas",
tags=["vas"]
)
@vas_router.get(
path='/pack',
summary='获取容量包及配额信息',
description='Get information about storage packs and quotas.',
dependencies=[Depends(auth_required)]
)
def router_vas_pack() -> ResponseBase:
"""
Get information about storage packs and quotas.
Returns:
ResponseBase: A model containing the response data for storage packs and quotas.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/product',
summary='获取商品信息,同时返回支付信息',
description='Get product information along with payment details.',
dependencies=[Depends(auth_required)]
)
def router_vas_product() -> ResponseBase:
"""
Get product information along with payment details.
Returns:
ResponseBase: A model containing the response data for products and payment information.
"""
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/order',
summary='新建支付订单',
description='Create an order for a product.',
dependencies=[Depends(auth_required)]
)
def router_vas_order() -> ResponseBase:
"""
Create an order for a product.
Returns:
ResponseBase: A model containing the response data for the created order.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/order/{id}',
summary='查询订单状态',
description='Get information about a specific payment order by ID.',
dependencies=[Depends(auth_required)]
)
def router_vas_order_get(id: str) -> ResponseBase:
"""
Get information about a specific payment order by ID.
Args:
id (str): The ID of the order to retrieve information for.
Returns:
ResponseBase: A model containing the response data for the specified order.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/redeem',
summary='获取兑换码信息',
description='Get information about a specific redemption code.',
dependencies=[Depends(auth_required)]
)
def router_vas_redeem(code: str) -> ResponseBase:
"""
Get information about a specific redemption code.
Args:
code (str): The redemption code to retrieve information for.
Returns:
ResponseBase: A model containing the response data for the specified redemption code.
"""
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/redeem',
summary='执行兑换',
description='Redeem a redemption code for a product or service.',
dependencies=[Depends(auth_required)]
)
def router_vas_redeem_post() -> ResponseBase:
"""
Redeem a redemption code for a product or service.
Returns:
ResponseBase: A model containing the response data for the redeemed code.
"""
http_exceptions.raise_not_implemented()

View File

@@ -1,207 +1,110 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
Object,
User,
WebDAV,
WebDAVAccountResponse,
WebDAVCreateRequest,
WebDAVUpdateRequest,
)
from service.redis.webdav_auth_cache import WebDAVAuthCache
from sqlmodels import ResponseBase
from utils import http_exceptions
from utils.password.pwd import Password
# WebDAV 管理路由
webdav_router = APIRouter(
prefix='/webdav',
tags=["webdav"],
)
def _check_webdav_enabled(user: User) -> None:
"""检查用户组是否启用了 WebDAV 功能"""
if not user.group.web_dav_enabled:
http_exceptions.raise_forbidden("WebDAV 功能未启用")
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
"""将 WebDAV 数据库模型转换为响应 DTO"""
return WebDAVAccountResponse(
id=account.id,
name=account.name,
root=account.root,
readonly=account.readonly,
use_proxy=account.use_proxy,
created_at=str(account.created_at),
updated_at=str(account.updated_at),
)
@webdav_router.get(
path='/accounts',
summary='获取账号列表',
summary='获取账号信息',
description='Get account information for WebDAV.',
dependencies=[Depends(auth_required)],
)
async def list_accounts(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[WebDAVAccountResponse]:
def router_webdav_accounts() -> ResponseBase:
"""
列出当前用户所有 WebDAV 账户
认证JWT Bearer Token
Get account information for WebDAV.
Returns:
ResponseBase: A model containing the response data for the account information.
"""
_check_webdav_enabled(user)
user_id: UUID = user.id
accounts: list[WebDAV] = await WebDAV.get(
session,
WebDAV.user_id == user_id,
fetch_mode="all",
)
return [_to_response(a) for a in accounts]
http_exceptions.raise_not_implemented()
@webdav_router.post(
path='/accounts',
summary='建账号',
status_code=201,
summary='建账号',
description='Create a new WebDAV account.',
dependencies=[Depends(auth_required)],
)
async def create_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: WebDAVCreateRequest,
) -> WebDAVAccountResponse:
def router_webdav_create_account() -> ResponseBase:
"""
创建 WebDAV 账户
认证JWT Bearer Token
错误处理:
- 403: WebDAV 功能未启用
- 400: 根目录路径不存在或不是目录
- 409: 账户名已存在
Create a new WebDAV account.
Returns:
ResponseBase: A model containing the response data for the created account.
"""
_check_webdav_enabled(user)
user_id: UUID = user.id
# 验证账户名唯一
existing = await WebDAV.get(
session,
(WebDAV.name == request.name) & (WebDAV.user_id == user_id),
)
if existing:
http_exceptions.raise_conflict("账户名已存在")
# 验证 root 路径存在且为目录
root_obj = await Object.get_by_path(session, user_id, request.root)
if not root_obj or not root_obj.is_folder:
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
# 创建账户
account = WebDAV(
name=request.name,
password=Password.hash(request.password),
root=request.root,
readonly=request.readonly,
use_proxy=request.use_proxy,
user_id=user_id,
)
account = await account.save(session)
l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
return _to_response(account)
@webdav_router.patch(
path='/accounts/{account_id}',
summary='更新账号',
)
async def update_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
account_id: int,
request: WebDAVUpdateRequest,
) -> WebDAVAccountResponse:
"""
更新 WebDAV 账户
认证JWT Bearer Token
错误处理:
- 403: WebDAV 功能未启用
- 404: 账户不存在
- 400: 根目录路径不存在或不是目录
"""
_check_webdav_enabled(user)
user_id: UUID = user.id
account = await WebDAV.get(
session,
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
)
if not account:
http_exceptions.raise_not_found("WebDAV 账户不存在")
# 验证 root 路径
if request.root is not None:
root_obj = await Object.get_by_path(session, user_id, request.root)
if not root_obj or not root_obj.is_folder:
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
# 密码哈希后原地替换update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
is_password_changed = request.password is not None
if is_password_changed:
request.password = Password.hash(request.password)
account = await account.update(session, request)
# 密码变更时清除认证缓存
if is_password_changed:
await WebDAVAuthCache.invalidate_account(user_id, account.name)
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
return _to_response(account)
http_exceptions.raise_not_implemented()
@webdav_router.delete(
path='/accounts/{account_id}',
path='/accounts/{id}',
summary='删除账号',
status_code=204,
description='Delete a WebDAV account by its ID.',
dependencies=[Depends(auth_required)],
)
async def delete_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
account_id: int,
) -> None:
def router_webdav_delete_account(id: str) -> ResponseBase:
"""
删除 WebDAV 账户
认证JWT Bearer Token
错误处理:
- 403: WebDAV 功能未启用
- 404: 账户不存在
Delete a WebDAV account by its ID.
Args:
id (str): The ID of the account to be deleted.
Returns:
ResponseBase: A model containing the response data for the deletion operation.
"""
_check_webdav_enabled(user)
user_id: UUID = user.id
http_exceptions.raise_not_implemented()
account = await WebDAV.get(
session,
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
)
if not account:
http_exceptions.raise_not_found("WebDAV 账户不存在")
@webdav_router.post(
path='/mount',
summary='新建目录挂载',
description='Create a new WebDAV mount point.',
dependencies=[Depends(auth_required)],
)
def router_webdav_create_mount() -> ResponseBase:
"""
Create a new WebDAV mount point.
Returns:
ResponseBase: A model containing the response data for the created mount point.
"""
http_exceptions.raise_not_implemented()
account_name = account.name
await WebDAV.delete(session, account)
@webdav_router.delete(
path='/mount/{id}',
summary='删除目录挂载',
description='Delete a WebDAV mount point by its ID.',
dependencies=[Depends(auth_required)],
)
def router_webdav_delete_mount(id: str) -> ResponseBase:
"""
Delete a WebDAV mount point by its ID.
Args:
id (str): The ID of the mount point to be deleted.
Returns:
ResponseBase: A model containing the response data for the deletion operation.
"""
http_exceptions.raise_not_implemented()
# 清除认证缓存
await WebDAVAuthCache.invalidate_account(user_id, account_name)
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")
@webdav_router.patch(
path='accounts/{id}',
summary='更新账号信息',
description='Update WebDAV account information by ID.',
dependencies=[Depends(auth_required)],
)
def router_webdav_update_account(id: str) -> ResponseBase:
"""
Update WebDAV account information by ID.
Args:
id (str): The ID of the account to be updated.
Returns:
ResponseBase: A model containing the response data for the updated account.
"""
http_exceptions.raise_not_implemented()

1
routers/dav/README.md Normal file
View File

@@ -0,0 +1 @@
# WebDAV 操作路由

View File

@@ -1,35 +0,0 @@
"""
WebDAV 协议入口
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
"""
from a2wsgi import WSGIMiddleware
from wsgidav.wsgidav_app import WsgiDAVApp
from .domain_controller import DiskNextDomainController
from .provider import DiskNextDAVProvider
_wsgidav_config: dict[str, object] = {
"provider_mapping": {
"/": DiskNextDAVProvider(),
},
"http_authenticator": {
"domain_controller": DiskNextDomainController,
"accept_basic": True,
"accept_digest": False,
"default_to_digest": False,
},
"verbose": 1,
# 使用 WsgiDAV 内置的内存锁管理器
"lock_storage": True,
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
"dir_browser": {
"enable": False,
},
}
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
"""ASGI 应用,挂载到 /dav 路径"""

View File

@@ -1,148 +0,0 @@
"""
WebDAV 认证控制器
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
通过 DiskNext 的 WebDAV 账户模型进行认证。
用户名格式: {email}/{webdav_account_name}
"""
import asyncio
from uuid import UUID
from loguru import logger as l
from wsgidav.dc.base_dc import BaseDomainController
from routers.dav.provider import EventLoopRef, _get_session
from service.redis.webdav_auth_cache import WebDAVAuthCache
from sqlmodels.user import User, UserStatus
from sqlmodels.webdav import WebDAV
from utils.password.pwd import Password, PasswordStatus
async def _authenticate(
email: str,
account_name: str,
password: str,
) -> tuple[UUID, int] | None:
"""
异步认证 WebDAV 用户。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 明文密码
:return: (user_id, webdav_id) 或 None
"""
# 1. 查缓存
cached = await WebDAVAuthCache.get(email, account_name, password)
if cached is not None:
return cached
# 2. 缓存未命中,查库验证
async with _get_session() as session:
user = await User.get(session, User.email == email, load=User.group)
if not user:
return None
if user.status != UserStatus.ACTIVE:
return None
if not user.group.web_dav_enabled:
return None
account = await WebDAV.get(
session,
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
)
if not account:
return None
status = Password.verify(account.password, password)
if status == PasswordStatus.INVALID:
return None
user_id: UUID = user.id
webdav_id: int = account.id
# 3. 写入缓存
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
return user_id, webdav_id
class DiskNextDomainController(BaseDomainController):
"""
DiskNext WebDAV 认证控制器
用户名格式: {email}/{webdav_account_name}
密码: WebDAV 账户密码(创建账户时设置)
"""
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
super().__init__(wsgidav_app, config)
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
"""返回 realm 名称"""
return "DiskNext WebDAV"
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
"""所有请求都需要认证"""
return True
def is_share_anonymous(self, path_info: str) -> bool:
"""不支持匿名访问"""
return False
def supports_http_digest_auth(self) -> bool:
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
return False
def basic_auth_user(
self,
realm: str,
user_name: str,
password: str,
environ: dict[str, object],
) -> bool:
"""
HTTP Basic Auth 认证。
用户名格式: {email}/{webdav_account_name}
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
"""
# 解析用户名
if "/" not in user_name:
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
return False
email, account_name = user_name.split("/", 1)
if not email or not account_name:
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
return False
# 在 WSGI 线程中调用异步认证
future = asyncio.run_coroutine_threadsafe(
_authenticate(email, account_name, password),
EventLoopRef.get(),
)
result = future.result()
if result is None:
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
return False
user_id, webdav_id = result
# 将认证信息存入 environ供 Provider 使用
environ["disknext.user_id"] = user_id
environ["disknext.webdav_id"] = webdav_id
environ["disknext.email"] = email
environ["disknext.account_name"] = account_name
return True
def digest_auth_user(
self,
realm: str,
user_name: str,
environ: dict[str, object],
) -> bool:
"""不支持 Digest 认证"""
return False

View File

@@ -1,645 +0,0 @@
"""
DiskNext WebDAV 存储 Provider
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
"""
import asyncio
import io
import mimetypes
from pathlib import Path
from typing import ClassVar
from uuid import UUID
from loguru import logger as l
from wsgidav.dav_error import (
DAVError,
HTTP_FORBIDDEN,
HTTP_INSUFFICIENT_STORAGE,
HTTP_NOT_FOUND,
)
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
from service.storage import LocalStorageService, adjust_user_storage
from sqlmodels.database_connection import DatabaseManager
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy
from sqlmodels.user import User
from sqlmodels.webdav import WebDAV
class EventLoopRef:
"""持有主线程事件循环引用,供 WSGI 线程使用"""
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
@classmethod
async def capture(cls) -> None:
"""在 async 上下文中调用,捕获当前事件循环"""
cls._loop = asyncio.get_running_loop()
@classmethod
def get(cls) -> asyncio.AbstractEventLoop:
if cls._loop is None:
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
return cls._loop
def _run_async(coro): # type: ignore[no-untyped-def]
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
return future.result()
def _get_session(): # type: ignore[no-untyped-def]
"""获取数据库会话上下文管理器"""
return DatabaseManager._async_session_factory()
# ==================== 异步辅助函数 ====================
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
"""获取 WebDAV 账户"""
async with _get_session() as session:
return await WebDAV.get(session, WebDAV.id == webdav_id)
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
"""根据路径获取对象"""
async with _get_session() as session:
return await Object.get_by_path(session, user_id, path)
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
"""获取目录子对象"""
async with _get_session() as session:
return await Object.get_children(session, user_id, parent_id)
async def _get_object_by_id(object_id: UUID) -> Object | None:
"""根据ID获取对象"""
async with _get_session() as session:
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
async def _get_user(user_id: UUID) -> User | None:
"""获取用户(含 group 关系)"""
async with _get_session() as session:
return await User.get(session, User.id == user_id, load=User.group)
async def _get_policy(policy_id: UUID) -> Policy | None:
"""获取存储策略"""
async with _get_session() as session:
return await Policy.get(session, Policy.id == policy_id)
async def _create_folder(
name: str,
parent_id: UUID,
owner_id: UUID,
policy_id: UUID,
) -> Object:
"""创建目录对象"""
async with _get_session() as session:
obj = Object(
name=name,
type=ObjectType.FOLDER,
size=0,
parent_id=parent_id,
owner_id=owner_id,
policy_id=policy_id,
)
obj = await obj.save(session)
return obj
async def _create_file(
name: str,
parent_id: UUID,
owner_id: UUID,
policy_id: UUID,
) -> Object:
"""创建空文件对象"""
async with _get_session() as session:
obj = Object(
name=name,
type=ObjectType.FILE,
size=0,
parent_id=parent_id,
owner_id=owner_id,
policy_id=policy_id,
)
obj = await obj.save(session)
return obj
async def _soft_delete_object(object_id: UUID) -> None:
"""软删除对象(移入回收站)"""
from service.storage import soft_delete_objects
async with _get_session() as session:
obj = await Object.get(session, Object.id == object_id)
if obj:
await soft_delete_objects(session, [obj])
async def _finalize_upload(
object_id: UUID,
physical_path: str,
size: int,
owner_id: UUID,
policy_id: UUID,
) -> None:
"""上传完成后更新对象元数据和物理文件记录"""
async with _get_session() as session:
# 获取存储路径(相对路径)
policy = await Policy.get(session, Policy.id == policy_id)
if not policy or not policy.server:
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
base_path = Path(policy.server).resolve()
full_path = Path(physical_path).resolve()
storage_path = str(full_path.relative_to(base_path))
# 创建 PhysicalFile 记录
pf = PhysicalFile(
storage_path=storage_path,
size=size,
policy_id=policy_id,
reference_count=1,
)
pf = await pf.save(session)
# 更新 Object
obj = await Object.get(session, Object.id == object_id)
if obj:
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
obj = await obj.save(session)
# 更新用户存储用量
if size > 0:
await adjust_user_storage(session, owner_id, size)
async def _move_object(
object_id: UUID,
new_parent_id: UUID,
new_name: str,
) -> None:
"""移动/重命名对象"""
async with _get_session() as session:
obj = await Object.get(session, Object.id == object_id)
if obj:
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
obj = await obj.save(session)
async def _copy_object_recursive(
src_id: UUID,
dst_parent_id: UUID,
dst_name: str,
owner_id: UUID,
) -> None:
"""递归复制对象"""
from service.storage import copy_object_recursive
async with _get_session() as session:
src = await Object.get(session, Object.id == src_id)
if not src:
return
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
# ==================== 辅助工具 ====================
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
"""从 environ 中提取认证信息"""
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
return user_id, webdav_id
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
"""
将 DAV 相对路径映射到 DiskNext 绝对路径。
:param account_root: 账户挂载根路径,如 "/""/docs"
:param dav_path: DAV 请求路径,如 "/""/photos/cat.jpg"
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
"""
# 规范化根路径
root = account_root.rstrip("/")
if not root:
root = ""
# 规范化 DAV 路径
if not dav_path or dav_path == "/":
return root + "/" if root else "/"
if not dav_path.startswith("/"):
dav_path = "/" + dav_path
full = root + dav_path
return full if full else "/"
def _check_readonly(environ: dict[str, object]) -> None:
"""检查账户是否只读,只读则抛出 403"""
account = environ.get("disknext.webdav_account")
if account and getattr(account, 'readonly', False):
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
def _check_storage_quota(user: User, additional_bytes: int) -> None:
"""检查存储配额"""
max_storage = user.group.max_storage
if max_storage > 0 and user.storage + additional_bytes > max_storage:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
class QuotaLimitedWriter(io.RawIOBase):
"""带配额限制的写入流包装器"""
def __init__(self, stream: io.BufferedWriter, max_bytes: int) -> None:
self._stream = stream
self._max_bytes = max_bytes
self._bytes_written = 0
def writable(self) -> bool:
return True
def write(self, b: bytes | bytearray) -> int:
if self._bytes_written + len(b) > self._max_bytes:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
written = self._stream.write(b)
self._bytes_written += written
return written
def close(self) -> None:
self._stream.close()
super().close()
@property
def bytes_written(self) -> int:
return self._bytes_written
# ==================== Provider ====================
class DiskNextDAVProvider(DAVProvider):
"""DiskNext WebDAV 存储 Provider"""
def __init__(self) -> None:
super().__init__()
def get_resource_inst(
self,
path: str,
environ: dict[str, object],
) -> 'DiskNextCollection | DiskNextFile | None':
"""
将 WebDAV 路径映射到资源对象。
首次调用时加载 WebDAV 账户信息并缓存到 environ。
"""
user_id, webdav_id = _get_environ_info(environ)
# 首次请求时加载账户信息
if "disknext.webdav_account" not in environ:
account = _run_async(_get_webdav_account(webdav_id))
if not account:
return None
environ["disknext.webdav_account"] = account
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
disknext_path = _resolve_dav_path(account.root, path)
obj = _run_async(_get_object_by_path(user_id, disknext_path))
if not obj:
return None
if obj.is_folder:
return DiskNextCollection(path, environ, obj, user_id, account)
else:
return DiskNextFile(path, environ, obj, user_id, account)
def is_readonly(self) -> bool:
"""只读由账户级别控制,不在 provider 级别限制"""
return False
# ==================== Collection目录 ====================
class DiskNextCollection(DAVCollection):
"""DiskNext 目录资源"""
def __init__(
self,
path: str,
environ: dict[str, object],
obj: Object,
user_id: UUID,
account: WebDAV,
) -> None:
super().__init__(path, environ)
self._obj = obj
self._user_id = user_id
self._account = account
def get_display_info(self) -> dict[str, str]:
return {"type": "Directory"}
def get_member_names(self) -> list[str]:
"""获取子对象名称列表"""
children = _run_async(_get_children(self._user_id, self._obj.id))
return [c.name for c in children]
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
"""获取指定名称的子资源"""
member_path = self.path.rstrip("/") + "/" + name
account_root = self._account.root
disknext_path = _resolve_dav_path(account_root, member_path)
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
if not obj:
return None
if obj.is_folder:
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
else:
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
def get_creation_date(self) -> float | None:
if self._obj.created_at:
return self._obj.created_at.timestamp()
return None
def get_last_modified(self) -> float | None:
if self._obj.updated_at:
return self._obj.updated_at.timestamp()
return None
def create_empty_resource(self, name: str) -> 'DiskNextFile':
"""创建空文件PUT 操作的第一步)"""
_check_readonly(self.environ)
obj = _run_async(_create_file(
name=name,
parent_id=self._obj.id,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
member_path = self.path.rstrip("/") + "/" + name
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
def create_collection(self, name: str) -> 'DiskNextCollection':
"""创建子目录MKCOL"""
_check_readonly(self.environ)
obj = _run_async(_create_folder(
name=name,
parent_id=self._obj.id,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
member_path = self.path.rstrip("/") + "/" + name
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
def delete(self) -> None:
"""软删除目录"""
_check_readonly(self.environ)
_run_async(_soft_delete_object(self._obj.id))
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
"""复制或移动目录"""
_check_readonly(self.environ)
account_root = self._account.root
dest_disknext = _resolve_dav_path(account_root, dest_path)
# 解析目标父路径和新名称
if "/" in dest_disknext.rstrip("/"):
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
new_name = dest_disknext.rsplit("/", 1)[1]
else:
parent_path = "/"
new_name = dest_disknext.lstrip("/")
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
if not dest_parent:
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
if is_move:
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
else:
_run_async(_copy_object_recursive(
self._obj.id, dest_parent.id, new_name, self._user_id,
))
return True
def support_recursive_delete(self) -> bool:
return True
def support_recursive_move(self, dest_path: str) -> bool:
return True
# ==================== NonCollection文件 ====================
class DiskNextFile(DAVNonCollection):
"""DiskNext 文件资源"""
def __init__(
self,
path: str,
environ: dict[str, object],
obj: Object,
user_id: UUID,
account: WebDAV,
) -> None:
super().__init__(path, environ)
self._obj = obj
self._user_id = user_id
self._account = account
self._write_path: str | None = None
self._write_stream: io.BufferedWriter | QuotaLimitedWriter | None = None
def get_content_length(self) -> int | None:
return self._obj.size if self._obj.size else 0
def get_content_type(self) -> str | None:
# 尝试从文件名推断 MIME 类型
mime, _ = mimetypes.guess_type(self._obj.name)
return mime or "application/octet-stream"
def get_creation_date(self) -> float | None:
if self._obj.created_at:
return self._obj.created_at.timestamp()
return None
def get_last_modified(self) -> float | None:
if self._obj.updated_at:
return self._obj.updated_at.timestamp()
return None
def get_display_info(self) -> dict[str, str]:
return {"type": "File"}
def get_content(self) -> io.BufferedReader | None:
"""
返回文件内容的可读流。
WsgiDAV 在线程中运行,可安全使用同步 open()。
"""
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
if not obj_with_file or not obj_with_file.physical_file:
return None
pf = obj_with_file.physical_file
policy = _run_async(_get_policy(obj_with_file.policy_id))
if not policy or not policy.server:
return None
full_path = Path(policy.server).resolve() / pf.storage_path
if not full_path.is_file():
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
return None
return open(full_path, "rb") # noqa: SIM115
def begin_write(
self,
*,
content_type: str | None = None,
) -> io.BufferedWriter | QuotaLimitedWriter:
"""
开始写入文件PUT 操作)。
返回一个可写的文件流WsgiDAV 将向其中写入请求体数据。
当用户有配额限制时,返回 QuotaLimitedWriter 在写入过程中实时检查配额。
"""
_check_readonly(self.environ)
# 检查配额
remaining_quota: int = 0
user = _run_async(_get_user(self._user_id))
if user:
max_storage = user.group.max_storage
if max_storage > 0:
remaining_quota = max_storage - user.storage
if remaining_quota <= 0:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
# Content-Length 预检(如果有的话)
content_length = self.environ.get("CONTENT_LENGTH")
if content_length and int(content_length) > remaining_quota:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
# 获取策略以确定存储路径
policy = _run_async(_get_policy(self._obj.policy_id))
if not policy or not policy.server:
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = _run_async(
storage_service.generate_file_path(
user_id=self._user_id,
original_filename=self._obj.name,
)
)
self._write_path = full_path
raw_stream = open(full_path, "wb") # noqa: SIM115
# 有配额限制时使用包装流,实时检查写入量
if remaining_quota > 0:
self._write_stream = QuotaLimitedWriter(raw_stream, remaining_quota)
else:
self._write_stream = raw_stream
return self._write_stream
def end_write(self, *, with_errors: bool) -> None:
"""写入完成后的收尾工作"""
if self._write_stream:
self._write_stream.close()
self._write_stream = None
if with_errors:
if self._write_path:
file_path = Path(self._write_path)
if file_path.exists():
file_path.unlink()
return
if not self._write_path:
return
# 获取文件大小
file_path = Path(self._write_path)
if not file_path.exists():
return
size = file_path.stat().st_size
# 更新数据库记录
_run_async(_finalize_upload(
object_id=self._obj.id,
physical_path=self._write_path,
size=size,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
def delete(self) -> None:
"""软删除文件"""
_check_readonly(self.environ)
_run_async(_soft_delete_object(self._obj.id))
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
"""复制或移动文件"""
_check_readonly(self.environ)
account_root = self._account.root
dest_disknext = _resolve_dav_path(account_root, dest_path)
# 解析目标父路径和新名称
if "/" in dest_disknext.rstrip("/"):
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
new_name = dest_disknext.rsplit("/", 1)[1]
else:
parent_path = "/"
new_name = dest_disknext.lstrip("/")
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
if not dest_parent:
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
if is_move:
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
else:
_run_async(_copy_object_recursive(
self._obj.id, dest_parent.id, new_name, self._user_id,
))
return True
def support_content_length(self) -> bool:
return True
def get_etag(self) -> str | None:
"""返回 ETag基于ID和更新时间WsgiDAV 会自动加双引号"""
if self._obj.updated_at:
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
return None
def support_etag(self) -> bool:
return True
def support_ranges(self) -> bool:
return True

View File

@@ -1,128 +0,0 @@
"""
WebDAV 认证缓存
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
支持 Redis首选和内存缓存降级两种存储后端。
"""
import hashlib
from typing import ClassVar
from uuid import UUID
from cachetools import TTLCache
from loguru import logger as l
from . import RedisManager
_AUTH_TTL: int = 300
"""认证缓存 TTL5 分钟"""
class WebDAVAuthCache:
"""
WebDAV 认证结果缓存
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
缓存值格式: {user_id}:{webdav_id}
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
"""
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
"""内存缓存降级方案"""
@classmethod
def _build_key(cls, email: str, account_name: str, password: str) -> str:
"""构建缓存键"""
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
@classmethod
async def get(
cls,
email: str,
account_name: str,
password: str,
) -> tuple[UUID, int] | None:
"""
查询缓存中的认证结果。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 用户提供的明文密码
:return: (user_id, webdav_id) 或 None缓存未命中
"""
key = cls._build_key(email, account_name, password)
client = RedisManager.get_client()
if client is not None:
value = await client.get(key)
if value is not None:
raw = value.decode() if isinstance(value, bytes) else value
user_id_str, webdav_id_str = raw.split(":", 1)
return UUID(user_id_str), int(webdav_id_str)
else:
raw = cls._memory_cache.get(key)
if raw is not None:
user_id_str, webdav_id_str = raw.split(":", 1)
return UUID(user_id_str), int(webdav_id_str)
return None
@classmethod
async def set(
cls,
email: str,
account_name: str,
password: str,
user_id: UUID,
webdav_id: int,
) -> None:
"""
写入认证结果到缓存。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 用户提供的明文密码
:param user_id: 用户UUID
:param webdav_id: WebDAV 账户ID
"""
key = cls._build_key(email, account_name, password)
value = f"{user_id}:{webdav_id}"
client = RedisManager.get_client()
if client is not None:
await client.set(key, value, ex=_AUTH_TTL)
else:
cls._memory_cache[key] = value
@classmethod
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
"""
失效指定账户的所有缓存。
由于缓存键包含 password hash无法精确删除
Redis 端使用 pattern scan 删除,内存端清空全部。
:param user_id: 用户UUID
:param account_name: WebDAV 账户名
"""
client = RedisManager.get_client()
if client is not None:
pattern = f"webdav_auth:*/{account_name}:*"
cursor: int = 0
while True:
cursor, keys = await client.scan(cursor, match=pattern, count=100)
if keys:
await client.delete(*keys)
if cursor == 0:
break
else:
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
keys_to_delete = [
k for k in cls._memory_cache
if f"/{account_name}:" in k
]
for k in keys_to_delete:
cls._memory_cache.pop(k, None)
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")

View File

@@ -3,7 +3,6 @@
提供文件存储相关的服务,包括:
- 本地存储服务
- S3 存储服务
- 命名规则解析器
- 存储异常定义
"""
@@ -12,8 +11,6 @@ from .exceptions import (
FileReadError,
FileWriteError,
InvalidPathError,
S3APIError,
S3MultipartUploadError,
StorageException,
StorageFileNotFoundError,
UploadSessionExpiredError,
@@ -28,6 +25,4 @@ from .object import (
permanently_delete_objects,
restore_objects,
soft_delete_objects,
)
from .migrate import migrate_file_with_task, migrate_directory_files
from .s3_storage import S3StorageService
)

View File

@@ -43,13 +43,3 @@ class UploadSessionExpiredError(StorageException):
class InvalidPathError(StorageException):
"""无效的路径"""
pass
class S3APIError(StorageException):
"""S3 API 请求错误"""
pass
class S3MultipartUploadError(S3APIError):
"""S3 分片上传错误"""
pass

View File

@@ -263,49 +263,15 @@ class LocalStorageService:
"""
删除文件(物理删除)
删除文件后会尝试清理因此变空的父目录。
:param path: 完整文件路径
"""
if await self.file_exists(path):
try:
await aiofiles.os.remove(path)
l.debug(f"已删除文件: {path}")
await self._cleanup_empty_parents(path)
except OSError as e:
l.warning(f"删除文件失败 {path}: {e}")
async def _cleanup_empty_parents(self, file_path: str) -> None:
"""
从被删文件的父目录开始,向上逐级删除空目录
在以下情况停止:
- 到达存储根目录_base_path
- 遇到非空目录
- 遇到 .trash 目录
- 删除失败(权限、并发等)
:param file_path: 被删文件的完整路径
"""
current = Path(file_path).parent
while current != self._base_path and str(current).startswith(str(self._base_path)):
if current.name == '.trash':
break
try:
entries = await aiofiles.os.listdir(str(current))
if entries:
break
await aiofiles.os.rmdir(str(current))
l.debug(f"已清理空目录: {current}")
current = current.parent
except OSError as e:
l.debug(f"清理空目录失败(忽略): {current}: {e}")
break
async def move_to_trash(
self,
source_path: str,
@@ -338,7 +304,6 @@ class LocalStorageService:
try:
await aiofiles.os.rename(source_path, str(trash_path))
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
await self._cleanup_empty_parents(source_path)
return str(trash_path)
except OSError as e:
raise StorageException(f"移动文件到回收站失败: {e}")

View File

@@ -1,291 +0,0 @@
"""
存储策略迁移服务
提供跨存储策略的文件迁移功能:
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
"""
from uuid import UUID
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.task import Task, TaskStatus
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
async def _get_storage_service(
policy: Policy,
) -> LocalStorageService | S3StorageService:
"""
根据策略类型创建对应的存储服务实例
:param policy: 存储策略
:return: 存储服务实例
"""
if policy.type == PolicyType.LOCAL:
return LocalStorageService(policy)
elif policy.type == PolicyType.S3:
return await S3StorageService.from_policy(policy)
else:
raise ValueError(f"不支持的存储策略类型: {policy.type}")
async def _read_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> bytes:
"""
从存储服务读取文件内容
:param service: 存储服务实例
:param storage_path: 文件存储路径
:return: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
return await service.read_file(storage_path)
else:
return await service.download_file(storage_path)
async def _write_file_to_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
data: bytes,
) -> None:
"""
将文件内容写入存储服务
:param service: 存储服务实例
:param storage_path: 文件存储路径
:param data: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
await service.write_file(storage_path, data)
else:
await service.upload_file(storage_path, data)
async def _delete_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> None:
"""
从存储服务删除文件
:param service: 存储服务实例
:param storage_path: 文件存储路径
"""
if isinstance(service, LocalStorageService):
await service.delete_file(storage_path)
else:
await service.delete_file(storage_path)
async def migrate_single_file(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
) -> None:
"""
将单个文件对象从当前存储策略迁移到目标策略
流程:
1. 获取源物理文件和存储服务
2. 读取源文件内容
3. 在目标存储中生成新路径并写入
4. 创建新的 PhysicalFile 记录
5. 更新 Object 的 policy_id 和 physical_file_id
6. 旧 PhysicalFile 引用计数 -1如为 0 则删除源物理文件
:param session: 数据库会话
:param obj: 待迁移的文件对象(必须为文件类型)
:param dest_policy: 目标存储策略
"""
if obj.type != ObjectType.FILE:
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
# 获取源策略和物理文件
src_policy: Policy = await obj.awaitable_attrs.policy
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
if not old_physical:
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
return
if src_policy.id == dest_policy.id:
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
return
# 1. 从源存储读取文件
src_service = await _get_storage_service(src_policy)
data = await _read_file_from_storage(src_service, old_physical.storage_path)
# 2. 在目标存储生成新路径并写入
dest_service = await _get_storage_service(dest_policy)
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
user_id=obj.owner_id,
original_filename=obj.name,
)
await _write_file_to_storage(dest_service, new_storage_path, data)
# 3. 创建新的 PhysicalFile
new_physical = PhysicalFile(
storage_path=new_storage_path,
size=old_physical.size,
checksum_md5=old_physical.checksum_md5,
policy_id=dest_policy.id,
reference_count=1,
)
new_physical = await new_physical.save(session)
# 4. 更新 Object
obj.policy_id = dest_policy.id
obj.physical_file_id = new_physical.id
obj = await obj.save(session)
# 5. 旧 PhysicalFile 引用计数 -1
old_physical.decrement_reference()
if old_physical.can_be_deleted:
# 删除源存储中的物理文件
try:
await _delete_file_from_storage(src_service, old_physical.storage_path)
except Exception as e:
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
await PhysicalFile.delete(session, old_physical)
else:
old_physical = await old_physical.save(session)
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name}{dest_policy.name}")
async def migrate_file_with_task(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移单个文件并更新任务状态
:param session: 数据库会话
:param obj: 待迁移的文件对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
await migrate_single_file(session, obj, dest_policy)
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
except Exception as e:
l.error(f"文件迁移任务失败: {obj.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
task = await task.save(session)
async def migrate_directory_files(
session: AsyncSession,
folder: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移目录下所有文件到目标存储策略
递归遍历目录树,将所有文件迁移到目标策略。
子目录的 policy_id 同步更新。
任务进度按文件数比例更新。
:param session: 数据库会话
:param folder: 目录对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
# 收集所有需要迁移的文件
files_to_migrate: list[Object] = []
folders_to_update: list[Object] = []
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
total = len(files_to_migrate)
migrated = 0
errors: list[str] = []
for file_obj in files_to_migrate:
try:
await migrate_single_file(session, file_obj, dest_policy)
migrated += 1
except Exception as e:
error_msg = f"{file_obj.name}: {e}"
l.error(f"迁移文件失败: {error_msg}")
errors.append(error_msg)
# 更新进度
if total > 0:
task.progress = min(99, int(migrated / total * 100))
task = await task.save(session)
# 更新所有子目录的 policy_id
for sub_folder in folders_to_update:
sub_folder.policy_id = dest_policy.id
sub_folder = await sub_folder.save(session)
# 完成任务
if errors:
task.status = TaskStatus.ERROR
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
else:
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
l.info(
f"目录迁移完成: {folder.name} ({folder.id}), "
f"成功 {migrated}/{total}, 错误 {len(errors)}"
)
except Exception as e:
l.error(f"目录迁移任务失败: {folder.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
task = await task.save(session)
async def _collect_objects_recursive(
session: AsyncSession,
folder: Object,
files: list[Object],
folders: list[Object],
) -> None:
"""
递归收集目录下所有文件和子目录
:param session: 数据库会话
:param folder: 当前目录
:param files: 文件列表(输出)
:param folders: 子目录列表(输出)
"""
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
for child in children:
if child.type == ObjectType.FILE:
files.append(child)
elif child.type == ObjectType.FOLDER:
folders.append(child)
await _collect_objects_recursive(session, child, files, folders)

View File

@@ -6,8 +6,7 @@ from sqlalchemy import update as sql_update
from sqlalchemy.sql.functions import func
from middleware.dependencies import SessionDep
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
from service.storage import LocalStorageService
from sqlmodels import (
Object,
PhysicalFile,
@@ -272,14 +271,10 @@ async def permanently_delete_objects(
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy:
if policy and policy.type == PolicyType.LOCAL:
try:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
await s3_service.delete_file(physical_file.storage_path)
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
@@ -287,7 +282,7 @@ async def permanently_delete_objects(
await PhysicalFile.delete(session, physical_file, commit=False)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else:
physical_file = await physical_file.save(session, commit=False)
await physical_file.save(session, commit=False)
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
# 更新用户存储配额
@@ -379,19 +374,10 @@ async def delete_object_recursive(
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy:
if policy and policy.type == PolicyType.LOCAL:
try:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
options = await policy.awaitable_attrs.options
s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
await s3_service.delete_file(physical_file.storage_path)
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
@@ -399,7 +385,7 @@ async def delete_object_recursive(
await PhysicalFile.delete(session, physical_file, commit=False)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else:
physical_file = await physical_file.save(session, commit=False)
await physical_file.save(session, commit=False)
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
# 阶段三:更新用户存储配额(与删除在同一事务中)
@@ -458,7 +444,7 @@ async def _copy_object_recursive(
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
if physical_file:
physical_file.increment_reference()
physical_file = await physical_file.save(session)
await physical_file.save(session)
total_copied_size += src_size
new_obj = await new_obj.save(session)

View File

@@ -1,709 +0,0 @@
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
移植自 foxline-pro-backend-server 项目的 S3APIClient
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
"""
import hashlib
import hmac
import xml.etree.ElementTree as ET
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import ClassVar, Literal
from urllib.parse import quote, urlencode
from uuid import UUID
import aiohttp
from yarl import URL
from loguru import logger as l
from sqlmodels.policy import Policy
from .exceptions import S3APIError, S3MultipartUploadError
from .naming_rule import NamingContext, NamingRuleParser
def _sign(key: bytes, msg: str) -> bytes:
"""HMAC-SHA256 签名"""
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
class S3StorageService:
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息。
使用示例::
service = S3StorageService(policy, region='us-east-1')
await service.upload_file('path/to/file.txt', b'content')
data = await service.download_file('path/to/file.txt')
"""
_http_session: ClassVar[aiohttp.ClientSession | None] = None
def __init__(
self,
policy: Policy,
region: str = 'us-east-1',
is_path_style: bool = False,
):
"""
:param policy: 存储策略server=endpoint_url, bucket_name, access_key, secret_key
:param region: S3 区域
:param is_path_style: 是否使用路径风格 URL
"""
if not policy.server:
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
if not policy.bucket_name:
raise S3APIError("S3 策略必须指定 bucket_name")
if not policy.access_key:
raise S3APIError("S3 策略必须指定 access_key")
if not policy.secret_key:
raise S3APIError("S3 策略必须指定 secret_key")
self._policy = policy
self._endpoint_url = policy.server.rstrip("/")
self._bucket_name = policy.bucket_name
self._access_key = policy.access_key
self._secret_key = policy.secret_key
self._region = region
self._is_path_style = is_path_style
self._base_url = policy.base_url
# 从 endpoint_url 提取 host
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
# ==================== 工厂方法 ====================
@classmethod
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
"""
根据 Policy 异步创建 S3StorageService自动加载 options
:param policy: 存储策略
:return: S3StorageService 实例
"""
options = await policy.awaitable_attrs.options
region = options.s3_region if options else 'us-east-1'
is_path_style = options.s3_path_style if options else False
return cls(policy, region=region, is_path_style=is_path_style)
# ==================== HTTP Session 管理 ====================
@classmethod
async def initialize_session(cls) -> None:
"""初始化全局 aiohttp ClientSession"""
if cls._http_session is None or cls._http_session.closed:
cls._http_session = aiohttp.ClientSession()
l.info("S3StorageService HTTP session 已初始化")
@classmethod
async def close_session(cls) -> None:
"""关闭全局 aiohttp ClientSession"""
if cls._http_session and not cls._http_session.closed:
await cls._http_session.close()
cls._http_session = None
l.info("S3StorageService HTTP session 已关闭")
@classmethod
def _get_session(cls) -> aiohttp.ClientSession:
"""获取 HTTP session"""
if cls._http_session is None or cls._http_session.closed:
# 懒初始化,以防 initialize_session 未被调用
cls._http_session = aiohttp.ClientSession()
return cls._http_session
# ==================== AWS Signature V4 签名 ====================
def _get_signature_key(self, date_stamp: str) -> bytes:
"""生成 AWS Signature V4 签名密钥"""
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
k_region = _sign(k_date, self._region)
k_service = _sign(k_region, "s3")
return _sign(k_service, "aws4_request")
def _create_authorization_header(
self,
method: str,
uri: str,
query_string: str,
headers: dict[str, str],
payload_hash: str,
amz_date: str,
date_stamp: str,
) -> str:
"""创建 AWS Signature V4 授权头"""
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
canonical_headers = "".join(
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
)
canonical_request = (
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
f"{signed_headers}\n{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
string_to_sign = (
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
return (
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
f"SignedHeaders={signed_headers}, Signature={signature}"
)
def _build_headers(
self,
method: str,
uri: str,
query_string: str = "",
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
payload_hash: str | None = None,
host: str | None = None,
) -> dict[str, str]:
"""
构建包含 AWS V4 签名的完整请求头
:param method: HTTP 方法
:param uri: 请求 URI
:param query_string: 查询字符串
:param payload: 请求体字节(用于计算哈希)
:param content_type: Content-Type
:param extra_headers: 额外请求头
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
:param host: Host 头(默认使用 self._host
"""
now_utc = datetime.now(timezone.utc)
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
date_stamp = now_utc.strftime("%Y%m%d")
if payload_hash is None:
payload_hash = hashlib.sha256(payload).hexdigest()
effective_host = host or self._host
headers: dict[str, str] = {
"Host": effective_host,
"X-Amz-Date": amz_date,
"X-Amz-Content-Sha256": payload_hash,
}
if content_type:
headers["Content-Type"] = content_type
if extra_headers:
headers.update(extra_headers)
authorization = self._create_authorization_header(
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
)
headers["Authorization"] = authorization
return headers
# ==================== 内部请求方法 ====================
def _build_uri(self, key: str | None = None) -> str:
"""
构建请求 URI
按 AWS S3 Signature V4 规范对路径进行 URI 编码S3 仅需一次)。
斜杠作为路径分隔符保留不编码。
"""
if self._is_path_style:
if key:
return f"/{self._bucket_name}/{quote(key, safe='/')}"
return f"/{self._bucket_name}"
else:
if key:
return f"/{quote(key, safe='/')}"
return "/"
def _build_url(self, uri: str, query_string: str = "") -> str:
"""构建完整请求 URL"""
if self._is_path_style:
base = self._endpoint_url
else:
# 虚拟主机风格bucket.endpoint
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
base = f"{protocol}{self._bucket_name}.{self._host}"
url = f"{base}{uri}"
if query_string:
url = f"{url}?{query_string}"
return url
def _get_effective_host(self) -> str:
"""获取实际请求的 Host 头"""
if self._is_path_style:
return self._host
return f"{self._bucket_name}.{self._host}"
async def _request(
self,
method: str,
key: str | None = None,
query_params: dict[str, str] | None = None,
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> aiohttp.ClientResponse:
"""发送签名请求"""
uri = self._build_uri(key)
query_string = urlencode(sorted(query_params.items())) if query_params else ""
effective_host = self._get_effective_host()
headers = self._build_headers(
method, uri, query_string, payload, content_type,
extra_headers, host=effective_host,
)
url = self._build_url(uri, query_string)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=payload if payload else None,
)
return response
except Exception as e:
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
async def _request_streaming(
self,
method: str,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> aiohttp.ClientResponse:
"""
发送流式签名请求(大文件上传)
使用 UNSIGNED-PAYLOAD 作为 payload hash。
"""
uri = self._build_uri(key)
effective_host = self._get_effective_host()
headers = self._build_headers(
method,
uri,
query_string="",
content_type=content_type,
extra_headers={"Content-Length": str(content_length)},
payload_hash="UNSIGNED-PAYLOAD",
host=effective_host,
)
url = self._build_url(uri)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=data_stream,
)
return response
except Exception as e:
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
# ==================== 文件操作 ====================
async def upload_file(
self,
key: str,
data: bytes,
content_type: str = 'application/octet-stream',
) -> None:
"""
上传文件
:param key: S3 对象键
:param data: 文件内容
:param content_type: MIME 类型
"""
async with await self._request(
"PUT", key=key, payload=data, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
async def upload_file_streaming(
self,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> None:
"""
流式上传文件(大文件,避免全部加载到内存)
:param key: S3 对象键
:param data_stream: 异步字节流迭代器
:param content_length: 数据总长度(必须准确)
:param content_type: MIME 类型
"""
async with await self._request_streaming(
"PUT", key=key, data_stream=data_stream,
content_length=content_length, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"流式上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
async def download_file(self, key: str) -> bytes:
"""
下载文件
:param key: S3 对象键
:return: 文件内容
"""
async with await self._request("GET", key=key) as response:
if response.status != 200:
body = await response.text()
raise S3APIError(
f"下载失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
data = await response.read()
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
return data
async def delete_file(self, key: str) -> None:
"""
删除文件
:param key: S3 对象键
"""
async with await self._request("DELETE", key=key) as response:
if response.status in (200, 204):
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
else:
body = await response.text()
raise S3APIError(
f"删除失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
async def file_exists(self, key: str) -> bool:
"""
检查文件是否存在
:param key: S3 对象键
:return: 是否存在
"""
async with await self._request("HEAD", key=key) as response:
if response.status == 200:
return True
elif response.status == 404:
return False
else:
raise S3APIError(
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
async def get_file_size(self, key: str) -> int:
"""
获取文件大小
:param key: S3 对象键
:return: 文件大小(字节)
"""
async with await self._request("HEAD", key=key) as response:
if response.status != 200:
raise S3APIError(
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
return int(response.headers.get("Content-Length", 0))
# ==================== Multipart Upload ====================
async def create_multipart_upload(
self,
key: str,
content_type: str = 'application/octet-stream',
) -> str:
"""
创建分片上传任务
:param key: S3 对象键
:param content_type: MIME 类型
:return: Upload ID
"""
async with await self._request(
"POST",
key=key,
query_params={"uploads": ""},
content_type=content_type,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"创建分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
body = await response.text()
root = ET.fromstring(body)
# 查找 UploadId 元素(支持命名空间)
upload_id_elem = root.find("UploadId")
if upload_id_elem is None:
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
if upload_id_elem is None or not upload_id_elem.text:
raise S3MultipartUploadError(
f"创建分片上传响应中未找到 UploadId: {body}"
)
upload_id = upload_id_elem.text
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
return upload_id
async def upload_part(
self,
key: str,
upload_id: str,
part_number: int,
data: bytes,
) -> str:
"""
上传单个分片
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param part_number: 分片编号(从 1 开始)
:param data: 分片数据
:return: ETag
"""
async with await self._request(
"PUT",
key=key,
query_params={
"partNumber": str(part_number),
"uploadId": upload_id,
},
payload=data,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"上传分片失败: {self._bucket_name}/{key}, "
f"part={part_number}, 状态: {response.status}, {body}"
)
etag = response.headers.get("ETag", "").strip('"')
l.debug(
f"S3 分片上传成功: {self._bucket_name}/{key}, "
f"part={part_number}, etag={etag}"
)
return etag
async def complete_multipart_upload(
self,
key: str,
upload_id: str,
parts: list[tuple[int, str]],
) -> None:
"""
完成分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param parts: 分片列表 [(part_number, etag)]
"""
# 按 part_number 排序
parts_sorted = sorted(parts, key=lambda p: p[0])
# 构建 CompleteMultipartUpload XML
xml_parts = ''.join(
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
for pn, etag in parts_sorted
)
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
payload_bytes = payload.encode('utf-8')
async with await self._request(
"POST",
key=key,
query_params={"uploadId": upload_id},
payload=payload_bytes,
content_type="application/xml",
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"完成分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.info(
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
f"{len(parts)} 个分片"
)
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
"""
取消分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
"""
async with await self._request(
"DELETE",
key=key,
query_params={"uploadId": upload_id},
) as response:
if response.status in (200, 204):
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
else:
body = await response.text()
l.warning(
f"取消分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
# ==================== 预签名 URL ====================
def generate_presigned_url(
self,
key: str,
method: Literal['GET', 'PUT'] = 'GET',
expires_in: int = 3600,
filename: str | None = None,
) -> str:
"""
生成 S3 预签名 URLAWS Signature V4 Query String
:param key: S3 对象键
:param method: HTTP 方法GET 下载PUT 上传)
:param expires_in: URL 有效期(秒)
:param filename: 文件名GET 请求时设置 Content-Disposition
:return: 预签名 URL
"""
current_time = datetime.now(timezone.utc)
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
date_stamp = current_time.strftime("%Y%m%d")
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
credential = f"{self._access_key}/{credential_scope}"
uri = self._build_uri(key)
effective_host = self._get_effective_host()
query_params: dict[str, str] = {
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
'X-Amz-Credential': credential,
'X-Amz-Date': amz_date,
'X-Amz-Expires': str(expires_in),
'X-Amz-SignedHeaders': 'host',
}
# GET 请求时添加 Content-Disposition
if method == "GET" and filename:
encoded_filename = quote(filename, safe='')
query_params['response-content-disposition'] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
canonical_query_string = "&".join(
f"{quote(k, safe='')}={quote(v, safe='')}"
for k, v in sorted(query_params.items())
)
canonical_headers = f"host:{effective_host}\n"
signed_headers = "host"
payload_hash = "UNSIGNED-PAYLOAD"
canonical_request = (
f"{method}\n"
f"{uri}\n"
f"{canonical_query_string}\n"
f"{canonical_headers}\n"
f"{signed_headers}\n"
f"{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
string_to_sign = (
f"{algorithm}\n"
f"{amz_date}\n"
f"{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
base_url = self._build_url(uri)
return (
f"{base_url}?"
f"{canonical_query_string}&"
f"X-Amz-Signature={signature}"
)
# ==================== 路径生成 ====================
async def generate_file_path(
self,
user_id: UUID,
original_filename: str,
) -> tuple[str, str, str]:
"""
根据命名规则生成 S3 文件存储路径
与 LocalStorageService.generate_file_path 接口一致。
:param user_id: 用户UUID
:param original_filename: 原始文件名
:return: (相对目录路径, 存储文件名, 完整存储路径)
"""
context = NamingContext(
user_id=user_id,
original_filename=original_filename,
)
# 解析目录规则
dir_path = ""
if self._policy.dir_name_rule:
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
# 解析文件名规则
if self._policy.auto_rename and self._policy.file_name_rule:
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
# 确保有扩展名
if '.' in original_filename and '.' not in storage_name:
ext = original_filename.rsplit('.', 1)[1]
storage_name = f"{storage_name}.{ext}"
else:
storage_name = original_filename
# S3 不需要创建目录,直接拼接路径
if dir_path:
storage_path = f"{dir_path}/{storage_name}"
else:
storage_path = storage_name
return dir_path, storage_name, storage_path

View File

@@ -3,14 +3,12 @@
支持多种认证方式邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信预留
"""
import hashlib
from uuid import UUID, uuid4
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from service.redis.token_store import TokenStore
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.group import GroupClaims, GroupOptions
from sqlmodels.object import Object, ObjectType
@@ -192,7 +190,7 @@ async def _login_oauth(
# 已绑定 → 更新 OAuth 信息并返回关联用户
identity.display_name = nickname
identity.avatar_url = avatar_url
identity = await identity.save(session)
await identity.save(session)
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
@@ -254,7 +252,7 @@ async def _auto_register_oauth_user(
is_verified=True,
user_id=new_user_id,
)
identity = await identity.save(session)
await identity.save(session)
# 创建用户根目录
default_policy = await Policy.get(session, Policy.name == "本地存储")
@@ -335,7 +333,7 @@ async def _login_passkey(
# 更新签名计数
authn.sign_count = verification.new_sign_count
authn = await authn.save(session)
await authn.save(session)
# 加载用户
user: User = await User.get(session, User.id == authn.user_id, load=User.group)
@@ -365,12 +363,6 @@ async def _login_magic_link(
except BadSignature:
http_exceptions.raise_unauthorized("Magic Link 无效")
# 防重放:使用 token 哈希作为标识符
token_hash = hashlib.sha256(request.identifier.encode()).hexdigest()
is_first_use = await TokenStore.mark_used(f"magic_link:{token_hash}", ttl=600)
if not is_first_use:
http_exceptions.raise_unauthorized("Magic Link 已被使用")
# 查找绑定了该邮箱的 AuthIdentityemail_password 或 magic_link
identity: AuthIdentity | None = await AuthIdentity.get(
session,
@@ -392,7 +384,7 @@ async def _login_magic_link(
# 标记邮箱已验证
if not identity.is_verified:
identity.is_verified = True
identity = await identity.save(session)
await identity.save(session)
return user

View File

@@ -1,185 +0,0 @@
"""
WOPI Discovery 服务模块
解析 WOPI 服务端Collabora / OnlyOffice 等)的 Discovery XML
提取支持的文件扩展名及对应的编辑器 URL 模板。
参考Cloudreve pkg/wopi/discovery.go 和 pkg/wopi/wopi.go
"""
import xml.etree.ElementTree as ET
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from loguru import logger as l
# WOPI URL 模板中已知的查询参数占位符及其替换值
# 值为 None 表示删除该参数,非 None 表示替换为该值
# 参考 Cloudreve pkg/wopi/wopi.go queryPlaceholders
_WOPI_QUERY_PLACEHOLDERS: dict[str, str | None] = {
'BUSINESS_USER': None,
'DC_LLCC': 'lng',
'DISABLE_ASYNC': None,
'DISABLE_CHAT': None,
'EMBEDDED': 'true',
'FULLSCREEN': 'true',
'HOST_SESSION_ID': None,
'SESSION_CONTEXT': None,
'RECORDING': None,
'THEME_ID': 'darkmode',
'UI_LLCC': 'lng',
'VALIDATOR_TEST_CATEGORY': None,
}
_WOPI_SRC_PLACEHOLDER = 'WOPI_SOURCE'
def process_wopi_action_url(raw_urlsrc: str) -> str:
"""
将 WOPI Discovery 中的原始 urlsrc 转换为 DiskNext 可用的 URL 模板。
处理流程(参考 Cloudreve generateActionUrl
1. 去除 ``<>`` 占位符标记
2. 解析查询参数,替换/删除已知占位符
3. ``WOPI_SOURCE`` → ``{wopi_src}``
注意access_token 和 access_token_ttl 不放在 URL 中,
根据 WOPI 规范它们通过 POST 表单字段传递给编辑器。
:param raw_urlsrc: WOPI Discovery XML 中的 urlsrc 原始值
:return: 处理后的 URL 模板字符串,包含 {wopi_src} 占位符
"""
# 去除 <> 标记
cleaned = raw_urlsrc.replace('<', '').replace('>', '')
parsed = urlparse(cleaned)
raw_params = parse_qs(parsed.query, keep_blank_values=True)
new_params: list[tuple[str, str]] = []
is_src_replaced = False
for key, values in raw_params.items():
value = values[0] if values else ''
# WOPI_SOURCE 占位符 → {wopi_src}
if value == _WOPI_SRC_PLACEHOLDER:
new_params.append((key, '{wopi_src}'))
is_src_replaced = True
continue
# 已知占位符
if value in _WOPI_QUERY_PLACEHOLDERS:
replacement = _WOPI_QUERY_PLACEHOLDERS[value]
if replacement is not None:
new_params.append((key, replacement))
# replacement 为 None 时删除该参数
continue
# 其他参数保留原值
new_params.append((key, value))
# 如果没有找到 WOPI_SOURCE 占位符,手动添加 WOPISrc
if not is_src_replaced:
new_params.append(('WOPISrc', '{wopi_src}'))
# LibreOffice/Collabora 需要 lang 参数(避免重复添加)
existing_keys = {k for k, _ in new_params}
if 'lang' not in existing_keys:
new_params.append(('lang', 'lng'))
# 注意access_token 和 access_token_ttl 不放在 URL 中
# 根据 WOPI 规范,它们通过 POST 表单字段传递给编辑器
# 重建 URL
new_query = urlencode(new_params, safe='{}')
result = urlunparse((
parsed.scheme,
parsed.netloc,
parsed.path,
parsed.params,
new_query,
'',
))
return result
def parse_wopi_discovery_xml(xml_content: str) -> tuple[dict[str, str], list[str]]:
"""
解析 WOPI Discovery XML提取扩展名到 URL 模板的映射。
XML 结构::
<wopi-discovery>
<net-zone name="external-https">
<app name="Writer" favIconUrl="...">
<action name="edit" ext="docx" urlsrc="https://..."/>
<action name="view" ext="docx" urlsrc="https://..."/>
</app>
</net-zone>
</wopi-discovery>
动作优先级edit > embedview > view参考 Cloudreve discovery.go
:param xml_content: WOPI Discovery 端点返回的 XML 字符串
:return: (action_urls, app_names) 元组
action_urls: {extension: processed_url_template}
app_names: 发现的应用名称列表
:raises ValueError: XML 解析失败或格式无效
"""
try:
root = ET.fromstring(xml_content)
except ET.ParseError as e:
raise ValueError(f"WOPI Discovery XML 解析失败: {e}")
# 查找 net-zone可能有多个取第一个非空的
net_zones = root.findall('net-zone')
if not net_zones:
raise ValueError("WOPI Discovery XML 缺少 net-zone 节点")
# ext_actions: {extension: {action_name: urlsrc}}
ext_actions: dict[str, dict[str, str]] = {}
app_names: list[str] = []
for net_zone in net_zones:
for app_elem in net_zone.findall('app'):
app_name = app_elem.get('name', '')
if app_name:
app_names.append(app_name)
for action_elem in app_elem.findall('action'):
action_name = action_elem.get('name', '')
ext = action_elem.get('ext', '')
urlsrc = action_elem.get('urlsrc', '')
if not ext or not urlsrc:
continue
# 只关注 edit / embedview / view 三种动作
if action_name not in ('edit', 'embedview', 'view'):
continue
if ext not in ext_actions:
ext_actions[ext] = {}
ext_actions[ext][action_name] = urlsrc
# 为每个扩展名选择最佳 URL: edit > embedview > view
action_urls: dict[str, str] = {}
for ext, actions_map in ext_actions.items():
selected_urlsrc: str | None = None
for preferred in ('edit', 'embedview', 'view'):
if preferred in actions_map:
selected_urlsrc = actions_map[preferred]
break
if selected_urlsrc:
action_urls[ext] = process_wopi_action_url(selected_urlsrc)
# 去重 app_names
seen: set[str] = set()
unique_names: list[str] = []
for name in app_names:
if name not in seen:
seen.add(name)
unique_names.append(name)
l.info(f"WOPI Discovery 解析完成: {len(action_urls)} 个扩展名, 应用: {unique_names}")
return action_urls, unique_names

View File

@@ -84,7 +84,6 @@ if __name__ == "__main__":
setup(
name="disknext-ee",
packages=[],
ext_modules=cythonize(
extensions,
compiler_directives={'language_level': "3"},

View File

@@ -954,11 +954,18 @@ class PolicyType(StrEnum):
S3 = "s3" # S3 兼容存储
```
### PolicyType
### StorageType
```python
class PolicyType(StrEnum):
class StorageType(StrEnum):
LOCAL = "local" # 本地存储
S3 = "s3" # S3 兼容存储
QINIU = "qiniu" # 七牛云
TENCENT = "tencent" # 腾讯云
ALIYUN = "aliyun" # 阿里云
ONEDRIVE = "onedrive" # OneDrive
GOOGLE_DRIVE = "google_drive" # Google Drive
DROPBOX = "dropbox" # Dropbox
WEBDAV = "webdav" # WebDAV
REMOTE = "remote" # 远程存储
```
### UserStatus

View File

@@ -69,20 +69,18 @@ from .object import (
CreateUploadSessionRequest,
DirectoryCreateRequest,
DirectoryResponse,
FileMetadata,
FileMetadataBase,
Object,
ObjectBase,
ObjectCopyRequest,
ObjectDeleteRequest,
ObjectFileFinalize,
ObjectMoveRequest,
ObjectMoveUpdate,
ObjectPropertyDetailResponse,
ObjectPropertyResponse,
ObjectRenameRequest,
ObjectResponse,
ObjectSwitchPolicyRequest,
ObjectType,
FileCategory,
PolicyResponse,
UploadChunkResponse,
UploadSession,
@@ -97,42 +95,11 @@ from .object import (
TrashRestoreRequest,
TrashDeleteRequest,
)
from .object_metadata import (
ObjectMetadata,
ObjectMetadataBase,
MetadataNamespace,
MetadataResponse,
MetadataPatchItem,
MetadataPatchRequest,
INTERNAL_NAMESPACES,
USER_WRITABLE_NAMESPACES,
)
from .custom_property import (
CustomPropertyDefinition,
CustomPropertyDefinitionBase,
CustomPropertyType,
CustomPropertyCreateRequest,
CustomPropertyUpdateRequest,
CustomPropertyResponse,
)
from .physical_file import PhysicalFile, PhysicalFileBase
from .uri import DiskNextURI, FileSystemNamespace
from .order import (
Order, OrderStatus, OrderType,
CreateOrderRequest, OrderResponse,
)
from .policy import (
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
PolicyType, PolicySummary, PolicyUpdateRequest,
)
from .product import (
Product, ProductBase, ProductType, PaymentMethod,
ProductCreateRequest, ProductUpdateRequest, ProductResponse,
)
from .redeem import (
Redeem, RedeemType,
RedeemCreateRequest, RedeemUseRequest, RedeemInfoResponse, RedeemAdminResponse,
)
from .order import Order, OrderStatus, OrderType
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
from .redeem import Redeem, RedeemType
from .report import Report, ReportReason
from .setting import (
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
@@ -145,20 +112,16 @@ from .share import (
AdminShareListItem,
)
from .source_link import SourceLink
from .storage_pack import StoragePack, StoragePackResponse
from .storage_pack import StoragePack
from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
from .webdav import (
WebDAV, WebDAVBase,
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
)
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
from .webdav import WebDAV
from .file_app import (
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
# DTO
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
WopiDiscoveredExtension, WopiDiscoveryResponse,
)
from .wopi import WopiFileInfo, WopiAccessTokenPayload

View File

@@ -10,7 +10,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100, Str128, Str255, Text1024
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -87,7 +87,7 @@ class ChangePasswordRequest(SQLModelBase):
old_password: str = Field(min_length=1)
"""当前密码"""
new_password: Str128 = Field(min_length=8)
new_password: str = Field(min_length=8, max_length=128)
"""新密码(至少 8 位)"""
@@ -103,13 +103,13 @@ class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
provider: AuthProviderType = Field(index=True)
"""提供者类型"""
identifier: Str255 = Field(index=True)
identifier: str = Field(max_length=255, index=True)
"""标识符(邮箱/手机号/OAuth openid"""
credential: Text1024 | None = None
credential: str | None = Field(default=None, max_length=1024)
"""凭证Argon2 哈希密码 / null"""
display_name: Str100 | None = None
display_name: str | None = Field(default=None, max_length=100)
"""OAuth 昵称"""
avatar_url: str | None = Field(default=None, max_length=512)

View File

@@ -1,135 +0,0 @@
"""
用户自定义属性定义模型
允许用户定义类型化的自定义属性模板(如标签、评分、分类等),
实际值通过 ObjectMetadata KV 表存储键名格式custom:{property_definition_id}
支持的属性类型text, number, boolean, select, multi_select, rating, link
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import JSON
from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
if TYPE_CHECKING:
from .user import User
# ==================== 枚举 ====================
class CustomPropertyType(StrEnum):
"""自定义属性值类型枚举"""
TEXT = "text"
"""文本"""
NUMBER = "number"
"""数字"""
BOOLEAN = "boolean"
"""布尔值"""
SELECT = "select"
"""单选"""
MULTI_SELECT = "multi_select"
"""多选"""
RATING = "rating"
"""评分1-5"""
LINK = "link"
"""链接"""
# ==================== Base 模型 ====================
class CustomPropertyDefinitionBase(SQLModelBase):
"""自定义属性定义基础模型"""
name: Str100
"""属性显示名称"""
type: CustomPropertyType
"""属性值类型"""
icon: Str100 | None = None
"""图标标识iconify 名称)"""
options: list[str] | None = Field(default=None, sa_type=JSON)
"""可选值列表(仅 select/multi_select 类型)"""
default_value: str | None = Field(default=None, max_length=500)
"""默认值"""
# ==================== 数据库模型 ====================
class CustomPropertyDefinition(CustomPropertyDefinitionBase, UUIDTableBaseMixin):
"""
用户自定义属性定义
每个用户独立管理自己的属性模板。
实际属性值存储在 ObjectMetadata 表中键名格式custom:{id}
"""
owner_id: UUID = Field(
foreign_key="user.id",
ondelete="CASCADE",
index=True,
)
"""所有者用户UUID"""
sort_order: int = 0
"""排序顺序"""
# 关系
owner: "User" = Relationship()
"""所有者"""
# ==================== DTO 模型 ====================
class CustomPropertyCreateRequest(SQLModelBase):
"""创建自定义属性请求 DTO"""
name: Str100
"""属性显示名称"""
type: CustomPropertyType
"""属性值类型"""
icon: str | None = None
"""图标标识"""
options: list[str] | None = None
"""可选值列表(仅 select/multi_select 类型)"""
default_value: str | None = None
"""默认值"""
class CustomPropertyUpdateRequest(SQLModelBase):
"""更新自定义属性请求 DTO"""
name: str | None = None
"""属性显示名称"""
icon: str | None = None
"""图标标识"""
options: list[str] | None = None
"""可选值列表"""
default_value: str | None = None
"""默认值"""
sort_order: int | None = None
"""排序顺序"""
class CustomPropertyResponse(CustomPropertyDefinitionBase):
"""自定义属性响应 DTO"""
id: UUID
"""属性定义UUID"""
sort_order: int
"""排序顺序"""

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, Index
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -141,7 +141,7 @@ class Download(DownloadBase, UUIDTableBaseMixin):
speed: int = Field(default=0)
"""下载速度bytes/s"""
parent: Str255 | None = None
parent: str | None = Field(default=None, max_length=255)
"""父任务标识"""
error: str | None = Field(default=None)

View File

@@ -20,7 +20,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str100, Str255, Text1024
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
if TYPE_CHECKING:
from .group import Group
@@ -119,7 +119,7 @@ class UserFileAppDefaultResponse(SQLModelBase):
class FileAppCreateRequest(SQLModelBase):
"""管理员创建应用请求 DTO"""
name: Str100
name: str = Field(max_length=100)
"""应用名称"""
app_key: str = Field(max_length=50)
@@ -128,7 +128,7 @@ class FileAppCreateRequest(SQLModelBase):
type: FileAppType
"""应用类型"""
icon: Str255 | None = None
icon: str | None = Field(default=None, max_length=255)
"""图标名称/URL"""
description: str | None = Field(default=None, max_length=500)
@@ -140,13 +140,13 @@ class FileAppCreateRequest(SQLModelBase):
is_restricted: bool = False
"""是否限制用户组访问"""
iframe_url_template: Text1024 | None = None
iframe_url_template: str | None = Field(default=None, max_length=1024)
"""iframe URL 模板"""
wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 发现端点 URL"""
wopi_editor_url_template: Text1024 | None = None
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
"""WOPI 编辑器 URL 模板"""
extensions: list[str] = []
@@ -159,7 +159,7 @@ class FileAppCreateRequest(SQLModelBase):
class FileAppUpdateRequest(SQLModelBase):
"""管理员更新应用请求 DTO所有字段可选"""
name: Str100 | None = None
name: str | None = Field(default=None, max_length=100)
"""应用名称"""
app_key: str | None = Field(default=None, max_length=50)
@@ -168,7 +168,7 @@ class FileAppUpdateRequest(SQLModelBase):
type: FileAppType | None = None
"""应用类型"""
icon: Str255 | None = None
icon: str | None = Field(default=None, max_length=255)
"""图标名称/URL"""
description: str | None = Field(default=None, max_length=500)
@@ -180,13 +180,13 @@ class FileAppUpdateRequest(SQLModelBase):
is_restricted: bool | None = None
"""是否限制用户组访问"""
iframe_url_template: Text1024 | None = None
iframe_url_template: str | None = Field(default=None, max_length=1024)
"""iframe URL 模板"""
wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 发现端点 URL"""
wopi_editor_url_template: Text1024 | None = None
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
"""WOPI 编辑器 URL 模板"""
@@ -297,35 +297,12 @@ class WopiSessionResponse(SQLModelBase):
"""完整的编辑器 URL"""
class WopiDiscoveredExtension(SQLModelBase):
"""单个 WOPI Discovery 发现的扩展名"""
extension: str
"""文件扩展名"""
action_url: str
"""处理后的动作 URL 模板"""
class WopiDiscoveryResponse(SQLModelBase):
"""WOPI Discovery 结果响应 DTO"""
discovered_extensions: list[WopiDiscoveredExtension] = []
"""发现的扩展名及其 URL 模板"""
app_names: list[str] = []
"""WOPI 服务端报告的应用名称(如 Writer、Calc、Impress"""
applied_count: int = 0
"""已应用到 FileAppExtension 的数量"""
# ==================== 数据库模型 ====================
class FileApp(SQLModelBase, UUIDTableBaseMixin):
"""文件查看器应用注册表"""
name: Str100
name: str = Field(max_length=100)
"""应用名称"""
app_key: str = Field(max_length=50, unique=True, index=True)
@@ -334,7 +311,7 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
type: FileAppType
"""应用类型"""
icon: Str255 | None = None
icon: str | None = Field(default=None, max_length=255)
"""图标名称/URL"""
description: str | None = Field(default=None, max_length=500)
@@ -346,13 +323,13 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
is_restricted: bool = False
"""是否限制用户组访问"""
iframe_url_template: Text1024 | None = None
iframe_url_template: str | None = Field(default=None, max_length=1024)
"""iframe URL 模板,支持 {file_url} 占位符"""
wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 客户端发现端点 URL"""
wopi_editor_url_template: Text1024 | None = None
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
# 关系
@@ -400,9 +377,6 @@ class FileAppExtension(SQLModelBase, TableBaseMixin):
priority: int = Field(default=0, ge=0)
"""排序优先级(越小越优先)"""
wopi_action_url: str | None = Field(default=None, max_length=2048)
"""WOPI 动作 URL 模板Discovery 自动填充),支持 {wopi_src} {access_token} {access_token_ttl}"""
# 关系
app: FileApp = Relationship(back_populates="extensions")

View File

@@ -2,10 +2,9 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -67,7 +66,7 @@ class GroupAllOptionsBase(GroupOptionsBase):
class GroupCreateRequest(GroupAllOptionsBase):
"""创建用户组请求 DTO"""
name: Str255
name: str = Field(max_length=255)
"""用户组名称"""
max_storage: int = Field(default=0, ge=0)
@@ -92,7 +91,7 @@ class GroupCreateRequest(GroupAllOptionsBase):
class GroupUpdateRequest(SQLModelBase):
"""更新用户组请求 DTO所有字段可选"""
name: Str255 | None = None
name: str | None = Field(default=None, max_length=255)
"""用户组名称"""
max_storage: int | None = Field(default=None, ge=0)
@@ -258,10 +257,10 @@ class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
class Group(GroupBase, UUIDTableBaseMixin):
"""用户组模型"""
name: Str255 = Field(unique=True)
name: str = Field(max_length=255, unique=True)
"""用户组名"""
max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""最大存储空间(字节)"""
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})

View File

@@ -130,11 +130,6 @@ default_settings: list[Setting] = [
Setting(name="sms_provider", value="", type=SettingsType.MOBILE),
Setting(name="sms_access_key", value="", type=SettingsType.MOBILE),
Setting(name="sms_secret_key", value="", type=SettingsType.MOBILE),
# ==================== 文件分类扩展名配置 ====================
Setting(name="image", value="jpg,jpeg,png,gif,bmp,webp,svg,ico,tiff,tif,avif,heic,heif,psd,raw", type=SettingsType.FILE_CATEGORY),
Setting(name="video", value="mp4,mkv,avi,mov,wmv,flv,webm,m4v,ts,3gp,mpg,mpeg", type=SettingsType.FILE_CATEGORY),
Setting(name="audio", value="mp3,wav,flac,aac,ogg,wma,m4a,opus,ape,aiff,mid,midi", type=SettingsType.FILE_CATEGORY),
Setting(name="document", value="pdf,doc,docx,odt,rtf,txt,tex,epub,pages,ppt,pptx,odp,key,xls,xlsx,csv,ods,numbers,tsv,md,markdown,mdx", type=SettingsType.FILE_CATEGORY),
]
async def init_default_settings() -> None:
@@ -178,7 +173,7 @@ async def init_default_group() -> None:
admin=True,
)
admin_group_id = admin_group.id # 在 save 前保存 UUID
admin_group = await admin_group.save(session)
await admin_group.save(session)
await GroupOptions(
group_id=admin_group_id,
@@ -208,7 +203,7 @@ async def init_default_group() -> None:
web_dav_enabled=True,
)
member_group_id = member_group.id # 在 save 前保存 UUID
member_group = await member_group.save(session)
await member_group.save(session)
await GroupOptions(
group_id=member_group_id,
@@ -227,7 +222,7 @@ async def init_default_group() -> None:
default_group_setting = await Setting.get(session, Setting.name == "default_group")
if default_group_setting:
default_group_setting.value = str(member_group_id)
default_group_setting = await default_group_setting.save(session)
await default_group_setting.save(session)
# 未找到初始游客组时,则创建
if not await Group.get(session, Group.name == "游客"):
@@ -237,7 +232,7 @@ async def init_default_group() -> None:
web_dav_enabled=False,
)
guest_group_id = guest_group.id # 在 save 前保存 UUID
guest_group = await guest_group.save(session)
await guest_group.save(session)
await GroupOptions(
group_id=guest_group_id,
@@ -289,7 +284,7 @@ async def init_default_user() -> None:
group_id=admin_group.id,
)
admin_user_id = admin_user.id # 在 save 前保存 UUID
admin_user = await admin_user.save(session)
await admin_user.save(session)
# 创建 AuthIdentity邮箱密码身份
await AuthIdentity(
@@ -378,7 +373,7 @@ async def init_default_theme_presets() -> None:
error=ChromaticColor.RED,
neutral=NeutralColor.ZINC,
)
default_preset = await default_preset.save(session)
await default_preset.save(session)
log.info('已创建默认主题预设')
@@ -451,43 +446,36 @@ _DEFAULT_FILE_APPS: list[dict] = [
"is_enabled": True,
"extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"],
},
{
"name": "EPUB 阅读器",
"app_key": "epub_reader",
"type": "builtin",
"icon": "book-open",
"description": "阅读 EPUB 电子书",
"is_enabled": True,
"extensions": ["epub"],
},
{
"name": "3D 模型预览",
"app_key": "model_viewer",
"type": "builtin",
"icon": "cube",
"description": "预览 3D 模型",
"is_enabled": True,
"extensions": ["gltf", "glb", "stl", "obj", "fbx", "ply", "3mf"],
},
{
"name": "Font Viewer",
"app_key": "font_viewer",
"type": "builtin",
"icon": "type",
"description": "预览字体文件并显示元数据和文本样本",
"is_enabled": True,
"extensions": ["ttf", "otf", "woff", "woff2"],
},
# iframe 应用(默认禁用)
{
"name": "Office 在线预览",
"app_key": "office_viewer",
"type": "iframe",
"icon": "file-word",
"description": "使用 Microsoft Office Online 预览文档",
"is_enabled": True,
"is_enabled": False,
"iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
},
# WOPI 应用(默认禁用)
{
"name": "Collabora Online",
"app_key": "collabora",
"type": "wopi",
"icon": "file-text",
"description": "Collabora Online 文档编辑器(需自行部署)",
"is_enabled": False,
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx", "odt", "ods", "odp"],
},
{
"name": "OnlyOffice",
"app_key": "onlyoffice",
"type": "wopi",
"icon": "file-text",
"description": "OnlyOffice 文档编辑器(需自行部署)",
"is_enabled": False,
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
},
]
@@ -505,7 +493,7 @@ async def init_default_file_apps() -> None:
return
for app_data in _DEFAULT_FILE_APPS:
extensions = app_data["extensions"]
extensions = app_data.pop("extensions")
app = FileApp(
name=app_data["name"],
@@ -527,6 +515,6 @@ async def init_default_file_apps() -> None:
extension=ext.lower(),
priority=i,
)
ext_record = await ext_record.save(session)
await ext_record.save(session)
log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用')

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Index
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .download import Download
@@ -28,13 +28,13 @@ class NodeType(StrEnum):
class Aria2ConfigurationBase(SQLModelBase):
"""Aria2配置基础模型"""
rpc_url: Str255 | None = None
rpc_url: str | None = Field(default=None, max_length=255)
"""RPC地址"""
rpc_secret: str | None = None
"""RPC密钥"""
temp_path: Str255 | None = None
temp_path: str | None = Field(default=None, max_length=255)
"""临时下载路径"""
max_concurrent: int = Field(default=5, ge=1, le=50)
@@ -70,19 +70,19 @@ class Node(SQLModelBase, TableBaseMixin):
status: NodeStatus = Field(default=NodeStatus.ONLINE)
"""节点状态"""
name: Str255 = Field(unique=True)
name: str = Field(max_length=255, unique=True)
"""节点名称"""
type: NodeType
"""节点类型"""
server: Str255
server: str = Field(max_length=255)
"""节点地址IP或域名"""
slave_key: Str255 | None = None
slave_key: str | None = Field(default=None, max_length=255)
"""从机通讯密钥"""
master_key: Str255 | None = None
master_key: str | None = Field(default=None, max_length=255)
"""主机通讯密钥"""
aria2_enabled: bool = False

View File

@@ -7,9 +7,7 @@ from enum import StrEnum
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, CheckConstraint, Index, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255, Str256
from .policy import PolicyType
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -18,21 +16,49 @@ if TYPE_CHECKING:
from .share import Share
from .physical_file import PhysicalFile
from .uri import DiskNextURI
from .object_metadata import ObjectMetadata
class ObjectType(StrEnum):
"""对象类型枚举"""
FILE = "file"
FOLDER = "folder"
class StorageType(StrEnum):
"""存储类型枚举"""
LOCAL = "local"
QINIU = "qiniu"
TENCENT = "tencent"
ALIYUN = "aliyun"
ONEDRIVE = "onedrive"
GOOGLE_DRIVE = "google_drive"
DROPBOX = "dropbox"
WEBDAV = "webdav"
REMOTE = "remote"
class FileCategory(StrEnum):
"""文件类型分类枚举,用于按类别筛选文件"""
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = "document"
class FileMetadataBase(SQLModelBase):
"""文件元数据基础模型"""
width: int | None = Field(default=None)
"""图片宽度(像素)"""
height: int | None = Field(default=None)
"""图片高度(像素)"""
duration: float | None = Field(default=None)
"""音视频时长(秒)"""
bitrate: int | None = Field(default=None)
"""比特率kbps"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型"""
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和"""
checksum_sha256: str | None = Field(default=None, max_length=64)
"""SHA256校验和"""
# ==================== Base 模型 ====================
@@ -49,32 +75,9 @@ class ObjectBase(SQLModelBase):
size: int | None = None
"""文件大小(字节),目录为 None"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型仅文件有效"""
# ==================== DTO 模型 ====================
class ObjectFileFinalize(SQLModelBase):
"""文件上传完成后更新 Object 的 DTO"""
size: int
"""文件大小(字节)"""
physical_file_id: UUID
"""关联的物理文件UUID"""
class ObjectMoveUpdate(SQLModelBase):
"""移动/重命名 Object 的 DTO"""
parent_id: UUID
"""新的父目录UUID"""
name: str
"""新名称"""
class DirectoryCreateRequest(SQLModelBase):
"""创建目录请求 DTO"""
@@ -133,7 +136,7 @@ class PolicyResponse(SQLModelBase):
name: str
"""策略名称"""
type: PolicyType
type: StorageType
"""存储类型"""
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
@@ -161,6 +164,22 @@ class DirectoryResponse(SQLModelBase):
# ==================== 数据库模型 ====================
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
"""文件元数据模型与Object一对一关联"""
object_id: UUID = Field(
foreign_key="object.id",
unique=True,
index=True,
ondelete="CASCADE"
)
"""关联的对象UUID"""
# 反向关系
object: "Object" = Relationship(back_populates="file_metadata")
"""关联的对象"""
class Object(ObjectBase, UUIDTableBaseMixin):
"""
统一对象模型
@@ -198,13 +217,13 @@ class Object(ObjectBase, UUIDTableBaseMixin):
# ==================== 基础字段 ====================
name: Str255
name: str = Field(max_length=255)
"""对象名称(文件名或目录名)"""
type: ObjectType
"""对象类型file 或 folder"""
password: Str255 | None = None
password: str | None = Field(default=None, max_length=255)
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
# ==================== 文件专属字段 ====================
@@ -212,7 +231,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""文件大小(字节),目录为 0"""
upload_session_id: Str255 | None = Field(default=None, unique=True, index=True)
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
"""分块上传会话ID仅文件有效"""
physical_file_id: UUID | None = Field(
@@ -315,11 +334,11 @@ class Object(ObjectBase, UUIDTableBaseMixin):
"""子对象(文件和子目录)"""
# 仅文件有效的关系
metadata_entries: list["ObjectMetadata"] = Relationship(
file_metadata: FileMetadata | None = Relationship(
back_populates="object",
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
)
"""元数据键值对列表"""
"""文件元数据(仅文件有效)"""
source_links: list["SourceLink"] = Relationship(
back_populates="object",
@@ -477,37 +496,6 @@ class Object(ObjectBase, UUIDTableBaseMixin):
fetch_mode="all"
)
@classmethod
async def get_by_category(
cls,
session: 'AsyncSession',
user_id: UUID,
extensions: list[str],
table_view: 'TableViewRequest | None' = None,
) -> 'ListResponse[Object]':
"""
按扩展名列表查询用户的所有文件(跨目录)
只查询未删除、未封禁的文件对象,使用 ILIKE 匹配文件名后缀。
:param session: 数据库会话
:param user_id: 用户UUID
:param extensions: 扩展名列表(不含点号)
:param table_view: 分页排序参数
:return: 分页文件列表
"""
from sqlalchemy import or_
ext_conditions = [cls.name.ilike(f"%.{ext}") for ext in extensions]
condition = (
(cls.owner_id == user_id) &
(cls.type == ObjectType.FILE) &
(cls.deleted_at == None) &
(cls.is_banned == False) &
or_(*ext_conditions)
)
return await cls.get_with_count(session, condition, table_view=table_view)
@classmethod
async def resolve_uri(
cls,
@@ -585,7 +573,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
class UploadSessionBase(SQLModelBase):
"""上传会话基础字段"""
file_name: Str255
file_name: str = Field(max_length=255)
"""原始文件名"""
file_size: int = Field(ge=0, sa_type=BigInteger)
@@ -616,12 +604,6 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
storage_path: str | None = Field(default=None, max_length=512)
"""文件存储路径"""
s3_upload_id: Str256 | None = None
"""S3 Multipart Upload ID仅 S3 策略使用)"""
s3_part_etags: str | None = None
"""S3 已上传分片的 ETag 列表JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
expires_at: datetime
"""会话过期时间"""
@@ -663,7 +645,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
class CreateUploadSessionRequest(SQLModelBase):
"""创建上传会话请求 DTO"""
file_name: Str255
file_name: str = Field(max_length=255)
"""文件名"""
file_size: int = Field(ge=0)
@@ -720,7 +702,7 @@ class UploadChunkResponse(SQLModelBase):
class CreateFileRequest(SQLModelBase):
"""创建空白文件请求 DTO"""
name: Str255
name: str = Field(max_length=255)
"""文件名"""
parent_id: UUID
@@ -730,16 +712,6 @@ class CreateFileRequest(SQLModelBase):
"""存储策略UUID不指定则使用父目录的策略"""
class ObjectSwitchPolicyRequest(SQLModelBase):
"""切换对象存储策略请求"""
policy_id: UUID
"""目标存储策略UUID"""
is_migrate_existing: bool = False
"""(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
# ==================== 对象操作相关 DTO ====================
class ObjectCopyRequest(SQLModelBase):
@@ -758,7 +730,7 @@ class ObjectRenameRequest(SQLModelBase):
id: UUID
"""对象UUID"""
new_name: Str255
new_name: str = Field(max_length=255)
"""新名称"""
@@ -777,9 +749,6 @@ class ObjectPropertyResponse(SQLModelBase):
size: int
"""文件大小(字节)"""
mime_type: str | None = None
"""MIME类型"""
created_at: datetime
"""创建时间"""
@@ -793,13 +762,22 @@ class ObjectPropertyResponse(SQLModelBase):
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
"""对象详细属性响应 DTO继承基本属性"""
# 校验和(从 PhysicalFile 读取)
# 元数据信息
mime_type: str | None = None
"""MIME类型"""
width: int | None = None
"""图片宽度(像素)"""
height: int | None = None
"""图片高度(像素)"""
duration: float | None = None
"""音视频时长(秒)"""
checksum_md5: str | None = None
"""MD5校验和"""
checksum_sha256: str | None = None
"""SHA256校验和"""
# 分享统计
share_count: int = 0
"""分享次数"""
@@ -817,10 +795,6 @@ class ObjectPropertyDetailResponse(ObjectPropertyResponse):
reference_count: int = 1
"""物理文件引用计数(仅文件有效)"""
# 元数据KV 格式)
metadatas: dict[str, str] = {}
"""所有元数据条目(键名 → 值)"""
# ==================== 管理员文件管理 DTO ====================

View File

@@ -1,127 +0,0 @@
"""
对象元数据 KV 模型
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
如 exif:width, stream:duration, music:artist 等。
架构:
ObjectMetadata (KV 表,与 Object 一对多关系)
└── 每个 Object 可以有多条元数据记录
└── (object_id, name) 组合唯一索引
命名空间:
- exif: 图片 EXIF 信息(尺寸、相机参数、拍摄时间等)
- stream: 音视频流信息(时长、比特率、视频尺寸、编解码等)
- music: 音乐标签(标题、艺术家、专辑等)
- geo: 地理位置(经纬度、地址)
- apk: Android 安装包信息
- custom: 用户自定义属性
- sys: 系统内部元数据
- thumb: 缩略图信息
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, UniqueConstraint, Index, Relationship
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
if TYPE_CHECKING:
from .object import Object
# ==================== 枚举 ====================
class MetadataNamespace(StrEnum):
"""元数据命名空间枚举"""
EXIF = "exif"
"""图片 EXIF 信息(含尺寸、相机参数、拍摄时间等)"""
MUSIC = "music"
"""音乐标签title/artist/album/genre 等)"""
STREAM = "stream"
"""音视频流信息codec/duration/bitrate/resolution 等)"""
GEO = "geo"
"""地理位置latitude/longitude/address"""
APK = "apk"
"""Android 安装包信息package_name/version 等)"""
THUMB = "thumb"
"""缩略图信息(内部使用)"""
SYS = "sys"
"""系统元数据(内部使用)"""
CUSTOM = "custom"
"""用户自定义属性"""
# 对外不可见的命名空间API 不返回给普通用户)
INTERNAL_NAMESPACES: set[str] = {MetadataNamespace.SYS, MetadataNamespace.THUMB}
# 用户可写的命名空间
USER_WRITABLE_NAMESPACES: set[str] = {MetadataNamespace.CUSTOM}
# ==================== Base 模型 ====================
class ObjectMetadataBase(SQLModelBase):
"""对象元数据 KV 基础模型"""
name: Str255
"""元数据键名格式namespace:key如 exif:width, stream:duration"""
value: str
"""元数据值(统一为字符串存储)"""
# ==================== 数据库模型 ====================
class ObjectMetadata(ObjectMetadataBase, UUIDTableBaseMixin):
"""
对象元数据 KV 模型
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
每个对象的每个键名唯一(通过唯一索引保证)。
"""
__table_args__ = (
UniqueConstraint("object_id", "name", name="uq_object_metadata_object_name"),
Index("ix_object_metadata_object_id", "object_id"),
)
object_id: UUID = Field(
foreign_key="object.id",
ondelete="CASCADE",
)
"""关联的对象UUID"""
is_public: bool = False
"""是否对分享页面公开"""
# 关系
object: "Object" = Relationship(back_populates="metadata_entries")
"""关联的对象"""
# ==================== DTO 模型 ====================
class MetadataResponse(SQLModelBase):
"""元数据查询响应 DTO"""
metadatas: dict[str, str]
"""元数据字典(键名 → 值)"""
class MetadataPatchItem(SQLModelBase):
"""单条元数据补丁 DTO"""
key: Str255
"""元数据键名"""
value: str | None = None
"""None 表示删除此条目"""
class MetadataPatchRequest(SQLModelBase):
"""元数据批量更新请求 DTO"""
patches: list[MetadataPatchItem]
"""补丁列表"""

View File

@@ -1,122 +1,58 @@
from decimal import Decimal
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import Numeric
from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .product import Product
from .user import User
class OrderStatus(StrEnum):
"""订单状态枚举"""
PENDING = "pending"
"""待支付"""
COMPLETED = "completed"
"""已完成"""
CANCELLED = "cancelled"
"""已取消"""
class OrderType(StrEnum):
"""订单类型枚举"""
# [TODO] 补充具体订单类型
pass
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
# ==================== DTO 模型 ====================
class CreateOrderRequest(SQLModelBase):
"""创建订单请求 DTO"""
product_id: UUID
"""商品UUID"""
num: int = Field(default=1, ge=1)
"""购买数量"""
method: str
"""支付方式"""
class OrderResponse(SQLModelBase):
"""订单响应 DTO"""
id: int
"""订单ID"""
order_no: str
"""订单号"""
type: OrderType
"""订单类型"""
method: str | None = None
"""支付方式"""
product_id: UUID | None = None
"""商品UUID"""
num: int
"""购买数量"""
name: str
"""商品名称"""
price: float
"""订单价格(元)"""
status: OrderStatus
"""订单状态"""
user_id: UUID
"""所属用户UUID"""
# ==================== 数据库模型 ====================
class Order(SQLModelBase, TableBaseMixin):
"""订单模型"""
order_no: Str255 = Field(unique=True, index=True)
order_no: str = Field(max_length=255, unique=True, index=True)
"""订单号,唯一"""
type: OrderType
"""订单类型"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""订单类型 [TODO] 待定义枚举"""
method: Str255 | None = None
method: str | None = Field(default=None, max_length=255)
"""支付方式"""
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
"""关联商品UUID"""
product_id: int | None = Field(default=None)
"""商品ID"""
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""购买数量"""
name: Str255
name: str = Field(max_length=255)
"""商品名称"""
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
"""订单价格("""
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""订单价格("""
status: OrderStatus = Field(default=OrderStatus.PENDING)
"""订单状态"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
@@ -124,22 +60,6 @@ class Order(SQLModelBase, TableBaseMixin):
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="orders")
product: "Product" = Relationship(back_populates="orders")
def to_response(self) -> OrderResponse:
"""转换为响应 DTO"""
return OrderResponse(
id=self.id,
order_no=self.order_no,
type=self.type,
method=self.method,
product_id=self.product_id,
num=self.num,
name=self.name,
price=float(self.price),
status=self.status,
user_id=self.user_id,
)
user: "User" = Relationship(back_populates="orders")

View File

@@ -15,7 +15,7 @@ from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, Index
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str32, Str64
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
@@ -31,12 +31,9 @@ class PhysicalFileBase(SQLModelBase):
size: int = Field(default=0, sa_type=BigInteger)
"""文件大小(字节)"""
checksum_md5: Str32 | None = None
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和用于文件去重和完整性校验"""
checksum_sha256: Str64 | None = None
"""SHA256校验和"""
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
"""

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from enum import StrEnum
from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
@@ -37,22 +37,22 @@ class PolicyType(StrEnum):
class PolicyBase(SQLModelBase):
"""存储策略基础字段,供 DTO 和数据库模型共享"""
name: Str255
name: str = Field(max_length=255)
"""策略名称"""
type: PolicyType
"""存储策略类型"""
server: Str255 | None = None
server: str | None = Field(default=None, max_length=255)
"""服务器地址(本地策略为绝对路径)"""
bucket_name: Str255 | None = None
bucket_name: str | None = Field(default=None, max_length=255)
"""存储桶名称"""
is_private: bool = True
"""是否为私有空间"""
base_url: Str255 | None = None
base_url: str | None = Field(default=None, max_length=255)
"""访问文件的基础URL"""
access_key: str | None = None
@@ -67,10 +67,10 @@ class PolicyBase(SQLModelBase):
auto_rename: bool = False
"""是否自动重命名"""
dir_name_rule: Str255 | None = None
dir_name_rule: str | None = Field(default=None, max_length=255)
"""目录命名规则"""
file_name_rule: Str255 | None = None
file_name_rule: str | None = Field(default=None, max_length=255)
"""文件命名规则"""
is_origin_link_enable: bool = False
@@ -102,94 +102,6 @@ class PolicySummary(SQLModelBase):
"""是否私有"""
class PolicyCreateRequest(PolicyBase):
"""创建存储策略请求 DTO包含 PolicyOptions 扁平字段"""
# PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: Str255 | None = None
"""OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, ge=1)
"""分片上传大小字节默认50MB"""
s3_path_style: bool = False
"""是否使用S3路径风格"""
s3_region: Str64 = 'us-east-1'
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyUpdateRequest(SQLModelBase):
"""更新存储策略请求 DTO所有字段可选"""
name: Str255 | None = None
"""策略名称"""
server: Str255 | None = None
"""服务器地址"""
bucket_name: Str255 | None = None
"""存储桶名称"""
is_private: bool | None = None
"""是否为私有空间"""
base_url: Str255 | None = None
"""访问文件的基础URL"""
access_key: str | None = None
"""Access Key"""
secret_key: str | None = None
"""Secret Key"""
max_size: int | None = Field(default=None, ge=0)
"""允许上传的最大文件尺寸(字节)"""
auto_rename: bool | None = None
"""是否自动重命名"""
dir_name_rule: Str255 | None = None
"""目录命名规则"""
file_name_rule: Str255 | None = None
"""文件命名规则"""
is_origin_link_enable: bool | None = None
"""是否开启源链接访问"""
# PolicyOptions 字段
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: Str255 | None = None
"""OneDrive重定向地址"""
chunk_size: int | None = Field(default=None, ge=1)
"""分片上传大小(字节)"""
s3_path_style: bool | None = None
"""是否使用S3路径风格"""
s3_region: Str64 | None = None
"""S3 区域"""
# ==================== 数据库模型 ====================
@@ -205,7 +117,7 @@ class PolicyOptionsBase(SQLModelBase):
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: Str255 | None = None
od_redirect: str | None = Field(default=None, max_length=255)
"""OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
@@ -214,9 +126,6 @@ class PolicyOptionsBase(SQLModelBase):
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否使用S3路径风格"""
s3_region: Str64 = Field(default='us-east-1', sa_column_kwargs={"server_default": "'us-east-1'"})
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
"""存储策略选项模型与Policy一对一关联"""
@@ -237,7 +146,7 @@ class Policy(PolicyBase, UUIDTableBaseMixin):
"""存储策略模型"""
# 覆盖基类字段以添加数据库专有配置
name: Str255 = Field(unique=True)
name: str = Field(max_length=255, unique=True)
"""策略名称"""
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})

View File

@@ -1,206 +0,0 @@
from decimal import Decimal
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import Numeric, BigInteger
from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
if TYPE_CHECKING:
from .order import Order
from .redeem import Redeem
class ProductType(StrEnum):
"""商品类型枚举"""
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
class PaymentMethod(StrEnum):
"""支付方式枚举"""
ALIPAY = "alipay"
"""支付宝"""
WECHAT = "wechat"
"""微信支付"""
STRIPE = "stripe"
"""Stripe"""
EASYPAY = "easypay"
"""易支付"""
CUSTOM = "custom"
"""自定义支付"""
# ==================== DTO 模型 ====================
class ProductBase(SQLModelBase):
"""商品基础字段"""
name: str
"""商品名称"""
type: ProductType
"""商品类型"""
description: str | None = None
"""商品描述"""
class ProductCreateRequest(ProductBase):
"""创建商品请求 DTO"""
name: Str255
"""商品名称"""
price: Decimal = Field(ge=0, decimal_places=2)
"""商品价格(元)"""
is_active: bool = True
"""是否上架"""
sort_order: int = Field(default=0, ge=0)
"""排序权重(越大越靠前)"""
# storage_pack 专用
size: int | None = Field(default=None, ge=0)
"""容量大小字节type=storage_pack 时必填"""
duration_days: int | None = Field(default=None, ge=1)
"""有效天数type=storage_pack/group_time 时必填"""
# group_time 专用
group_id: UUID | None = None
"""目标用户组UUIDtype=group_time 时必填"""
# score 专用
score_amount: int | None = Field(default=None, ge=1)
"""积分数量type=score 时必填"""
class ProductUpdateRequest(SQLModelBase):
"""更新商品请求 DTO所有字段可选"""
name: Str255 | None = None
"""商品名称"""
description: str | None = None
"""商品描述"""
price: Decimal | None = Field(default=None, ge=0, decimal_places=2)
"""商品价格(元)"""
is_active: bool | None = None
"""是否上架"""
sort_order: int | None = Field(default=None, ge=0)
"""排序权重"""
size: int | None = Field(default=None, ge=0)
"""容量大小(字节)"""
duration_days: int | None = Field(default=None, ge=1)
"""有效天数"""
group_id: UUID | None = None
"""目标用户组UUID"""
score_amount: int | None = Field(default=None, ge=1)
"""积分数量"""
class ProductResponse(ProductBase):
"""商品响应 DTO"""
id: UUID
"""商品UUID"""
price: float
"""商品价格(元)"""
is_active: bool
"""是否上架"""
sort_order: int
"""排序权重"""
size: int | None = None
"""容量大小(字节)"""
duration_days: int | None = None
"""有效天数"""
group_id: UUID | None = None
"""目标用户组UUID"""
score_amount: int | None = None
"""积分数量"""
# ==================== 数据库模型 ====================
class Product(ProductBase, UUIDTableBaseMixin):
"""商品模型"""
name: Str255
"""商品名称"""
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
"""商品价格(元)"""
is_active: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否上架"""
sort_order: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""排序权重(越大越靠前)"""
# storage_pack 专用
size: int | None = Field(default=None, sa_type=BigInteger)
"""容量大小字节type=storage_pack 时必填"""
duration_days: int | None = None
"""有效天数type=storage_pack/group_time 时必填"""
# group_time 专用
group_id: UUID | None = Field(default=None, foreign_key="group.id", ondelete="SET NULL")
"""目标用户组UUIDtype=group_time 时必填"""
# score 专用
score_amount: int | None = None
"""积分数量type=score 时必填"""
# 关系
orders: list["Order"] = Relationship(back_populates="product")
"""关联的订单列表"""
redeems: list["Redeem"] = Relationship(back_populates="product")
"""关联的兑换码列表"""
def to_response(self) -> ProductResponse:
"""转换为响应 DTO"""
return ProductResponse(
id=self.id,
name=self.name,
type=self.type,
description=self.description,
price=float(self.price),
is_active=self.is_active,
sort_order=self.sort_order,
size=self.size,
duration_days=self.duration_days,
group_id=self.group_id,
score_amount=self.score_amount,
)

View File

@@ -1,141 +1,22 @@
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, text
from sqlmodel import Field, text
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .product import Product
from .user import User
class RedeemType(StrEnum):
"""兑换码类型枚举"""
# [TODO] 补充具体兑换码类型
pass
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
# ==================== DTO 模型 ====================
class RedeemCreateRequest(SQLModelBase):
"""批量生成兑换码请求 DTO"""
product_id: UUID
"""关联商品UUID"""
count: int = Field(default=1, ge=1, le=100)
"""生成数量"""
class RedeemUseRequest(SQLModelBase):
"""使用兑换码请求 DTO"""
code: str
"""兑换码"""
class RedeemInfoResponse(SQLModelBase):
"""兑换码信息响应 DTO用户侧"""
type: RedeemType
"""兑换码类型"""
product_name: str | None = None
"""关联商品名称"""
num: int
"""可兑换数量"""
is_used: bool
"""是否已使用"""
class RedeemAdminResponse(SQLModelBase):
"""兑换码管理响应 DTO管理侧"""
id: int
"""兑换码ID"""
type: RedeemType
"""兑换码类型"""
product_id: UUID | None = None
"""关联商品UUID"""
num: int
"""可兑换数量"""
code: str
"""兑换码"""
is_used: bool
"""是否已使用"""
used_at: datetime | None = None
"""使用时间"""
used_by: UUID | None = None
"""使用者UUID"""
# ==================== 数据库模型 ====================
class Redeem(SQLModelBase, TableBaseMixin):
"""兑换码模型"""
type: RedeemType
"""兑换码类型"""
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
"""关联商品UUID"""
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""可兑换数量/时长等"""
code: str = Field(unique=True, index=True)
"""兑换码,唯一"""
is_used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否已使用"""
used_at: datetime | None = None
"""使用时间"""
used_by: UUID | None = Field(default=None, foreign_key="user.id", ondelete="SET NULL")
"""使用者UUID"""
# 关系
product: "Product" = Relationship(back_populates="redeems")
user: "User" = Relationship(back_populates="redeems")
def to_admin_response(self) -> RedeemAdminResponse:
"""转换为管理侧响应 DTO"""
return RedeemAdminResponse(
id=self.id,
type=self.type,
product_id=self.product_id,
num=self.num,
code=self.code,
is_used=self.is_used,
used_at=self.used_at,
used_by=self.used_by,
)
def to_info_response(self, product_name: str | None = None) -> RedeemInfoResponse:
"""转换为用户侧响应 DTO"""
return RedeemInfoResponse(
type=self.type,
product_name=product_name,
num=self.num,
is_used=self.is_used,
)
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""兑换码类型 [TODO] 待定义枚举"""
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
code: str = Field(unique=True, index=True, description="兑换码,唯一")
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .share import Share
@@ -21,7 +21,7 @@ class Report(SQLModelBase, TableBaseMixin):
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""举报原因 [TODO] 待定义枚举"""
description: Str255 | None = Field(default=None, description="补充描述")
description: str | None = Field(default=None, max_length=255, description="补充描述")
# 外键
share_id: UUID = Field(

View File

@@ -76,9 +76,6 @@ class SiteConfigResponse(SQLModelBase):
email_binding_required: bool = True
"""是否强制绑定邮箱"""
avatar_max_size: int = 2097152
"""头像文件最大字节数(默认 2MB"""
footer_code: str | None = None
"""自定义页脚代码"""
@@ -163,7 +160,6 @@ class SettingsType(StrEnum):
VERSION = "version"
VIEW = "view"
WOPI = "wopi"
FILE_CATEGORY = "file_category"
# 数据库模型
class Setting(SettingItem, TableBaseMixin):

View File

@@ -5,7 +5,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
from .model_base import ResponseBase
from .object import ObjectType
@@ -52,10 +52,10 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
Index("ix_share_object", "object_id"),
)
code: Str64 = Field(nullable=False, index=True)
code: str = Field(max_length=64, nullable=False, index=True)
"""分享码"""
password: Str255 | None = None
password: str | None = Field(default=None, max_length=255)
"""分享密码(加密后)"""
object_id: UUID = Field(
@@ -80,7 +80,7 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否允许预览"""
source_name: Str255 | None = None
source_name: str | None = Field(default=None, max_length=255)
"""源名称(冗余字段,便于展示)"""
score: int = Field(default=0, ge=0)

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, Index
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .object import Object
@@ -17,7 +17,7 @@ class SourceLink(SQLModelBase, TableBaseMixin):
Index("ix_sourcelink_object_name", "object_id", "name"),
)
name: Str255
name: str = Field(max_length=255)
"""链接名称"""
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})

View File

@@ -1,60 +1,23 @@
from datetime import datetime
from typing import TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship
from sqlmodel import Field, Relationship, Column, func, DateTime
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
# ==================== DTO 模型 ====================
class StoragePackResponse(SQLModelBase):
"""容量包响应 DTO"""
id: int
"""容量包ID"""
name: str
"""容量包名称"""
size: int
"""容量大小(字节)"""
active_time: datetime | None = None
"""激活时间"""
expired_time: datetime | None = None
"""过期时间"""
product_id: UUID | None = None
"""来源商品UUID"""
# ==================== 数据库模型 ====================
class StoragePack(SQLModelBase, TableBaseMixin):
"""容量包模型"""
name: Str255
"""容量包名称"""
active_time: datetime | None = None
"""激活时间"""
expired_time: datetime | None = Field(default=None, index=True)
"""过期时间"""
size: int = Field(sa_type=BigInteger)
"""容量包大小(字节)"""
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
"""来源商品UUID"""
name: str = Field(max_length=255, description="容量包名称")
active_time: datetime | None = Field(default=None, description="激活时间")
expired_time: datetime | None = Field(default=None, index=True, description="过期时间")
size: int = Field(description="容量包大小(字节)")
# 外键
user_id: UUID = Field(
foreign_key="user.id",
@@ -62,17 +25,6 @@ class StoragePack(SQLModelBase, TableBaseMixin):
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="storage_packs")
def to_response(self) -> StoragePackResponse:
"""转换为响应 DTO"""
return StoragePackResponse(
id=self.id,
name=self.name,
size=self.size,
active_time=self.active_time,
expired_time=self.expired_time,
product_id=self.product_id,
)
user: "User" = Relationship(back_populates="storage_packs")

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -24,13 +24,13 @@ class Tag(SQLModelBase, TableBaseMixin):
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
name: Str255
name: str = Field(max_length=255)
"""标签名称"""
icon: Str255 | None = None
icon: str | None = Field(default=None, max_length=255)
"""标签图标"""
color: Str255 | None = None
color: str | None = Field(default=None, max_length=255)
"""标签颜色"""
type: TagType = Field(default=TagType.MANUAL)

View File

@@ -26,8 +26,8 @@ class TaskStatus(StrEnum):
class TaskType(StrEnum):
"""任务类型枚举"""
POLICY_MIGRATE = "policy_migrate"
"""存储策略迁移"""
# [TODO] 补充具体任务类型
pass
# ==================== DTO 模型 ====================
@@ -39,7 +39,7 @@ class TaskSummaryBase(SQLModelBase):
id: int
"""任务ID"""
type: TaskType
type: int
"""任务类型"""
status: TaskStatus
@@ -91,14 +91,7 @@ class TaskPropsBase(SQLModelBase):
file_ids: str | None = None
"""文件ID列表逗号分隔"""
source_policy_id: UUID | None = None
"""源存储策略UUID"""
dest_policy_id: UUID | None = None
"""目标存储策略UUID"""
object_id: UUID | None = None
"""关联的对象UUID"""
# [TODO] 根据业务需求补充更多字段
class TaskProps(TaskPropsBase, TableBaseMixin):
@@ -106,7 +99,7 @@ class TaskProps(TaskPropsBase, TableBaseMixin):
task_id: int = Field(
foreign_key="task.id",
unique=True,
primary_key=True,
ondelete="CASCADE"
)
"""关联的任务ID"""
@@ -128,8 +121,8 @@ class Task(SQLModelBase, TableBaseMixin):
status: TaskStatus = Field(default=TaskStatus.QUEUED)
"""任务状态"""
type: TaskType
"""任务类型"""
type: int = Field(default=0)
"""任务类型 [TODO] 待定义枚举"""
progress: int = Field(default=0, ge=0, le=100)
"""任务进度0-100"""

View File

@@ -3,7 +3,7 @@ from uuid import UUID
from sqlmodel import Field
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
@@ -11,7 +11,7 @@ from .color import ChromaticColor, NeutralColor, ThemeColorsBase
class ThemePresetBase(SQLModelBase):
"""主题预设基础字段"""
name: Str100
name: str = Field(max_length=100)
"""预设名称"""
is_default: bool = False
@@ -42,7 +42,7 @@ class ThemePresetBase(SQLModelBase):
class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
"""主题预设表"""
name: Str100 = Field(unique=True)
name: str = Field(max_length=100, unique=True)
"""预设名称(唯一约束)"""
@@ -51,7 +51,7 @@ class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
class ThemePresetCreateRequest(SQLModelBase):
"""创建主题预设请求 DTO"""
name: Str100
name: str = Field(max_length=100)
"""预设名称"""
colors: ThemeColorsBase
@@ -61,7 +61,7 @@ class ThemePresetCreateRequest(SQLModelBase):
class ThemePresetUpdateRequest(SQLModelBase):
"""更新主题预设请求 DTO"""
name: Str100 | None = None
name: str | None = Field(default=None, max_length=100)
"""预设名称(可选)"""
colors: ThemeColorsBase | None = None

View File

@@ -4,12 +4,12 @@ from typing import Literal, TYPE_CHECKING, TypeVar
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import BigInteger, BinaryExpression, ClauseElement, and_
from sqlalchemy import BinaryExpression, ClauseElement, and_
from sqlmodel import Field, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse, Str255
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse
from .auth_identity import AuthProviderType
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
@@ -23,7 +23,6 @@ if TYPE_CHECKING:
from .download import Download
from .object import Object
from .order import Order
from .redeem import Redeem
from .share import Share
from .storage_pack import StoragePack
from .tag import Tag
@@ -474,10 +473,10 @@ class User(UserBase, UUIDTableBaseMixin):
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}, ge=0)
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
"""已用存储空间(字节)"""
avatar: Str255 = Field(default="default")
avatar: str = Field(default="default", max_length=255)
"""头像地址"""
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
@@ -571,14 +570,6 @@ class User(UserBase, UUIDTableBaseMixin):
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
redeems: list["Redeem"] = Relationship(
back_populates="user",
sa_relationship_kwargs={
"cascade": "all, delete-orphan",
"foreign_keys": "[Redeem.used_by]"
}
)
"""用户使用过的兑换码列表"""
shares: list["Share"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}

View File

@@ -5,7 +5,7 @@ from uuid import UUID
from sqlalchemy import Column, Text
from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str32, Str100, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -51,7 +51,7 @@ class AuthnDetailResponse(SQLModelBase):
class AuthnRenameRequest(SQLModelBase):
"""WebAuthn 凭证重命名请求 DTO"""
name: Str100
name: str = Field(max_length=100)
"""新的凭证名称"""
@@ -60,7 +60,7 @@ class AuthnRenameRequest(SQLModelBase):
class UserAuthn(SQLModelBase, TableBaseMixin):
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
credential_id: Str255 = Field(unique=True, index=True)
credential_id: str = Field(max_length=255, unique=True, index=True)
"""凭证 IDBase64URL 编码"""
credential_public_key: str = Field(sa_column=Column(Text))
@@ -69,16 +69,16 @@ class UserAuthn(SQLModelBase, TableBaseMixin):
sign_count: int = Field(default=0, ge=0)
"""签名计数器,用于防重放攻击"""
credential_device_type: Str32
credential_device_type: str = Field(max_length=32)
"""凭证设备类型:'single_device''multi_device'"""
credential_backed_up: bool = Field(default=False)
"""凭证是否已备份"""
transports: Str255 | None = None
transports: str | None = Field(default=None, max_length=255)
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
name: Str100 | None = None
name: str | None = Field(default=None, max_length=100)
"""用户自定义的凭证名称,便于识别"""
# 外键

View File

@@ -1,117 +1,32 @@
"""
WebDAV 账户模型
管理用户的 WebDAV 连接账户,每个账户对应一个挂载根路径。
通过 HTTP Basic Auth 认证访问 DAV 协议端点。
"""
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
# ==================== Base 模型 ====================
class WebDAVBase(SQLModelBase):
"""WebDAV 账户基础字段"""
name: Str255
"""账户名称(同一用户下唯一)"""
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"})
"""挂载根目录路径"""
readonly: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
"""是否只读"""
use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
"""是否使用代理下载"""
# ==================== 数据库模型 ====================
class WebDAV(WebDAVBase, TableBaseMixin):
"""WebDAV 账户模型"""
class WebDAV(SQLModelBase, TableBaseMixin):
"""WebDAV账户模型"""
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
password: Str255
"""密码Argon2 哈希)"""
name: str = Field(max_length=255, description="WebDAV账户名")
password: str = Field(max_length=255, description="WebDAV密码")
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
readonly: bool = Field(default=False, description="是否只读")
use_proxy: bool = Field(default=False, description="是否使用代理下载")
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE",
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="webdavs")
# ==================== DTO 模型 ====================
class WebDAVCreateRequest(SQLModelBase):
"""创建 WebDAV 账户请求"""
name: Str255
"""账户名称"""
password: Str255 = Field(min_length=1)
"""账户密码(明文,服务端哈希后存储)"""
root: str = "/"
"""挂载根目录路径"""
readonly: bool = False
"""是否只读"""
use_proxy: bool = False
"""是否使用代理下载"""
class WebDAVUpdateRequest(SQLModelBase):
"""更新 WebDAV 账户请求"""
password: Str255 | None = Field(default=None, min_length=1)
"""新密码(为 None 时不修改)"""
root: str | None = None
"""新挂载根目录路径(为 None 时不修改)"""
readonly: bool | None = None
"""是否只读(为 None 时不修改)"""
use_proxy: bool | None = None
"""是否使用代理下载(为 None 时不修改)"""
class WebDAVAccountResponse(SQLModelBase):
"""WebDAV 账户响应"""
id: int
"""账户ID"""
name: str
"""账户名称"""
root: str
"""挂载根目录路径"""
readonly: bool
"""是否只读"""
use_proxy: bool
"""是否使用代理下载"""
created_at: str
"""创建时间"""
updated_at: str
"""更新时间"""
user: "User" = Relationship(back_populates="webdavs")

View File

@@ -92,9 +92,9 @@ class ObjectFactory:
owner_id=owner_id,
policy_id=policy_id,
size=size,
mime_type=kwargs.get("mime_type"),
source_name=kwargs.get("source_name", name),
upload_session_id=kwargs.get("upload_session_id"),
file_metadata=kwargs.get("file_metadata"),
password=kwargs.get("password"),
)

View File

@@ -71,7 +71,7 @@ class UserFactory:
is_verified=True,
user_id=user.id,
)
identity = await identity.save(session)
await identity.save(session)
return user
@@ -123,7 +123,7 @@ class UserFactory:
is_verified=True,
user_id=admin.id,
)
identity = await identity.save(session)
await identity.save(session)
return admin
@@ -170,7 +170,7 @@ class UserFactory:
is_verified=True,
user_id=banned_user.id,
)
identity = await identity.save(session)
await identity.save(session)
return banned_user
@@ -219,6 +219,6 @@ class UserFactory:
is_verified=True,
user_id=user.id,
)
identity = await identity.save(session)
await identity.save(session)
return user

View File

@@ -1,219 +0,0 @@
"""
自定义属性定义端点集成测试
"""
import pytest
from httpx import AsyncClient
from uuid import UUID, uuid4
# ==================== 获取属性定义列表测试 ====================
@pytest.mark.asyncio
async def test_list_custom_properties_requires_auth(async_client: AsyncClient):
"""测试获取属性定义需要认证"""
response = await async_client.get("/api/v1/object/custom_property")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_list_custom_properties_empty(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试获取空的属性定义列表"""
response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data == []
# ==================== 创建属性定义测试 ====================
@pytest.mark.asyncio
async def test_create_custom_property(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建自定义属性"""
response = await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={
"name": "评分",
"type": "rating",
"icon": "mdi:star",
},
)
assert response.status_code == 204
# 验证已创建
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
data = list_response.json()
assert len(data) == 1
assert data[0]["name"] == "评分"
assert data[0]["type"] == "rating"
assert data[0]["icon"] == "mdi:star"
@pytest.mark.asyncio
async def test_create_custom_property_with_options(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建带选项的自定义属性"""
response = await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={
"name": "分类",
"type": "select",
"options": ["工作", "个人", "归档"],
"default_value": "个人",
},
)
assert response.status_code == 204
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
data = list_response.json()
prop = next(p for p in data if p["name"] == "分类")
assert prop["type"] == "select"
assert prop["options"] == ["工作", "个人", "归档"]
assert prop["default_value"] == "个人"
@pytest.mark.asyncio
async def test_create_custom_property_duplicate_name(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建同名属性返回 409"""
# 先创建
await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "标签", "type": "text"},
)
# 再创建同名
response = await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "标签", "type": "text"},
)
assert response.status_code == 409
# ==================== 更新属性定义测试 ====================
@pytest.mark.asyncio
async def test_update_custom_property(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新自定义属性"""
# 先创建
await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "备注", "type": "text"},
)
# 获取 ID
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "备注")
# 更新
response = await async_client.patch(
f"/api/v1/object/custom_property/{prop_id}",
headers=auth_headers,
json={"name": "详细备注", "icon": "mdi:note"},
)
assert response.status_code == 204
# 验证已更新
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop = next(p for p in list_response.json() if p["id"] == prop_id)
assert prop["name"] == "详细备注"
assert prop["icon"] == "mdi:note"
@pytest.mark.asyncio
async def test_update_custom_property_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新不存在的属性返回 404"""
fake_id = str(uuid4())
response = await async_client.patch(
f"/api/v1/object/custom_property/{fake_id}",
headers=auth_headers,
json={"name": "不存在"},
)
assert response.status_code == 404
# ==================== 删除属性定义测试 ====================
@pytest.mark.asyncio
async def test_delete_custom_property(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试删除自定义属性"""
# 先创建
await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "待删除", "type": "text"},
)
# 获取 ID
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "待删除")
# 删除
response = await async_client.delete(
f"/api/v1/object/custom_property/{prop_id}",
headers=auth_headers,
)
assert response.status_code == 204
# 验证已删除
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop_names = [p["name"] for p in list_response.json()]
assert "待删除" not in prop_names
@pytest.mark.asyncio
async def test_delete_custom_property_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试删除不存在的属性返回 404"""
fake_id = str(uuid4())
response = await async_client.delete(
f"/api/v1/object/custom_property/{fake_id}",
headers=auth_headers,
)
assert response.status_code == 404

View File

@@ -1,239 +0,0 @@
"""
对象元数据端点集成测试
"""
import pytest
from httpx import AsyncClient
from uuid import UUID, uuid4
from sqlmodels import ObjectMetadata
# ==================== 获取元数据测试 ====================
@pytest.mark.asyncio
async def test_get_metadata_requires_auth(async_client: AsyncClient):
"""测试获取元数据需要认证"""
fake_id = str(uuid4())
response = await async_client.get(f"/api/v1/object/{fake_id}/metadata")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_metadata_empty(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试获取无元数据的对象"""
file_id = test_directory_structure["file_id"]
response = await async_client.get(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["metadatas"] == {}
@pytest.mark.asyncio
async def test_get_metadata_with_entries(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
initialized_db,
):
"""测试获取有元数据的对象"""
file_id = test_directory_structure["file_id"]
# 直接写入元数据
entries = [
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
ObjectMetadata(object_id=file_id, name="exif:height", value="1080", is_public=True),
ObjectMetadata(object_id=file_id, name="sys:extract_status", value="done", is_public=False),
]
for entry in entries:
initialized_db.add(entry)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
# sys: 命名空间应被过滤
assert "exif:width" in data["metadatas"]
assert "exif:height" in data["metadatas"]
assert "sys:extract_status" not in data["metadatas"]
assert data["metadatas"]["exif:width"] == "1920"
@pytest.mark.asyncio
async def test_get_metadata_ns_filter(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
initialized_db,
):
"""测试按命名空间过滤元数据"""
file_id = test_directory_structure["file_id"]
entries = [
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
ObjectMetadata(object_id=file_id, name="music:title", value="Test Song", is_public=True),
]
for entry in entries:
initialized_db.add(entry)
await initialized_db.commit()
# 只获取 exif 命名空间
response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=exif",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "exif:width" in data["metadatas"]
assert "music:title" not in data["metadatas"]
@pytest.mark.asyncio
async def test_get_metadata_nonexistent_object(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试获取不存在对象的元数据"""
fake_id = str(uuid4())
response = await async_client.get(
f"/api/v1/object/{fake_id}/metadata",
headers=auth_headers,
)
assert response.status_code == 404
# ==================== 更新元数据测试 ====================
@pytest.mark.asyncio
async def test_patch_metadata_requires_auth(async_client: AsyncClient):
"""测试更新元数据需要认证"""
fake_id = str(uuid4())
response = await async_client.patch(
f"/api/v1/object/{fake_id}/metadata",
json={"patches": [{"key": "custom:tag", "value": "test"}]},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_patch_metadata_set_custom(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试设置自定义元数据"""
file_id = test_directory_structure["file_id"]
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={
"patches": [
{"key": "custom:tag1", "value": "旅游"},
{"key": "custom:tag2", "value": "风景"},
]
},
)
assert response.status_code == 204
# 验证已写入
get_response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=custom",
headers=auth_headers,
)
assert get_response.status_code == 200
data = get_response.json()
assert data["metadatas"]["custom:tag1"] == "旅游"
assert data["metadatas"]["custom:tag2"] == "风景"
@pytest.mark.asyncio
async def test_patch_metadata_update_existing(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试更新已有的元数据"""
file_id = test_directory_structure["file_id"]
# 先创建
await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:note", "value": "旧值"}]},
)
# 再更新
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:note", "value": "新值"}]},
)
assert response.status_code == 204
# 验证已更新
get_response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=custom",
headers=auth_headers,
)
data = get_response.json()
assert data["metadatas"]["custom:note"] == "新值"
@pytest.mark.asyncio
async def test_patch_metadata_delete(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试删除元数据条目"""
file_id = test_directory_structure["file_id"]
# 先创建
await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:to_delete", "value": "temp"}]},
)
# 删除value 为 null
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:to_delete", "value": None}]},
)
assert response.status_code == 204
# 验证已删除
get_response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=custom",
headers=auth_headers,
)
data = get_response.json()
assert "custom:to_delete" not in data["metadatas"]
@pytest.mark.asyncio
async def test_patch_metadata_reject_non_custom_namespace(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试拒绝修改非 custom: 命名空间"""
file_id = test_directory_structure["file_id"]
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "exif:width", "value": "1920"}]},
)
assert response.status_code == 400

View File

@@ -1,591 +0,0 @@
"""
WebDAV 账户管理端点集成测试
"""
from uuid import UUID, uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, User
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import UserStatus
from utils import Password
from utils.JWT import create_access_token
API_PREFIX = "/api/v1/webdav"
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def no_webdav_headers(initialized_db: AsyncSession) -> dict[str, str]:
"""创建一个 WebDAV 被禁用的用户,返回其认证头"""
group = Group(
id=uuid4(),
name="无WebDAV用户组",
max_storage=1024 * 1024 * 1024,
share_enabled=True,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
initialized_db.add(group)
await initialized_db.commit()
await initialized_db.refresh(group)
group_options = GroupOptions(
group_id=group.id,
share_download=True,
share_free=False,
relocate=False,
source_batch=0,
select_node=False,
advance_delete=False,
)
initialized_db.add(group_options)
await initialized_db.commit()
await initialized_db.refresh(group_options)
user = User(
id=uuid4(),
email="nowebdav@test.local",
nickname="无WebDAV用户",
status=UserStatus.ACTIVE,
storage=0,
score=0,
group_id=group.id,
avatar="default",
)
initialized_db.add(user)
await initialized_db.commit()
await initialized_db.refresh(user)
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="nowebdav@test.local",
credential=Password.hash("nowebdav123"),
is_primary=True,
is_verified=True,
user_id=user.id,
)
initialized_db.add(identity)
from sqlmodels import Policy
policy = await Policy.get(initialized_db, Policy.name == "本地存储")
root = Object(
id=uuid4(),
name="/",
type=ObjectType.FOLDER,
owner_id=user.id,
parent_id=None,
policy_id=policy.id,
size=0,
)
initialized_db.add(root)
await initialized_db.commit()
group.options = group_options
group_claims = GroupClaims.from_group(group)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
)
return {"Authorization": f"Bearer {result.access_token}"}
# ==================== 认证测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_requires_auth(async_client: AsyncClient):
"""测试获取账户列表需要认证"""
response = await async_client.get(f"{API_PREFIX}/accounts")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_create_account_requires_auth(async_client: AsyncClient):
"""测试创建账户需要认证"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
json={"name": "test", "password": "testpass"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_update_account_requires_auth(async_client: AsyncClient):
"""测试更新账户需要认证"""
response = await async_client.patch(
f"{API_PREFIX}/accounts/1",
json={"readonly": True},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_delete_account_requires_auth(async_client: AsyncClient):
"""测试删除账户需要认证"""
response = await async_client.delete(f"{API_PREFIX}/accounts/1")
assert response.status_code == 401
# ==================== WebDAV 禁用测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_webdav_disabled(
async_client: AsyncClient,
no_webdav_headers: dict[str, str],
):
"""测试 WebDAV 被禁用时返回 403"""
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=no_webdav_headers,
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_create_account_webdav_disabled(
async_client: AsyncClient,
no_webdav_headers: dict[str, str],
):
"""测试 WebDAV 被禁用时创建账户返回 403"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=no_webdav_headers,
json={"name": "test", "password": "testpass"},
)
assert response.status_code == 403
# ==================== 获取账户列表测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_empty(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试初始状态账户列表为空"""
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
assert response.json() == []
# ==================== 创建账户测试 ====================
@pytest.mark.asyncio
async def test_create_account_success(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试成功创建 WebDAV 账户"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "my-nas", "password": "secretpass"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "my-nas"
assert data["root"] == "/"
assert data["readonly"] is False
assert data["use_proxy"] is False
assert "id" in data
assert "created_at" in data
assert "updated_at" in data
@pytest.mark.asyncio
async def test_create_account_with_options(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建带选项的 WebDAV 账户"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={
"name": "readonly-nas",
"password": "secretpass",
"readonly": True,
"use_proxy": True,
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "readonly-nas"
assert data["readonly"] is True
assert data["use_proxy"] is True
@pytest.mark.asyncio
async def test_create_account_duplicate_name(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试重名账户返回 409"""
# 先创建一个
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "dup-test", "password": "pass1"},
)
assert response.status_code == 201
# 再创建同名的
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "dup-test", "password": "pass2"},
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_create_account_invalid_root(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试无效根目录路径返回 400"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={
"name": "bad-root",
"password": "secretpass",
"root": "/nonexistent/path",
},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_create_account_with_valid_subdir(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试使用有效的子目录作为根路径"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={
"name": "docs-only",
"password": "secretpass",
"root": "/docs",
},
)
assert response.status_code == 201
assert response.json()["root"] == "/docs"
# ==================== 列表包含已创建账户测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_after_create(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建后列表中包含该账户"""
# 创建
await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "list-test", "password": "pass"},
)
# 列表
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
accounts = response.json()
assert len(accounts) == 1
assert accounts[0]["name"] == "list-test"
# ==================== 更新账户测试 ====================
@pytest.mark.asyncio
async def test_update_account_success(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试成功更新 WebDAV 账户"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "update-test", "password": "oldpass"},
)
account_id = create_resp.json()["id"]
# 更新
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"readonly": True},
)
assert response.status_code == 200
data = response.json()
assert data["readonly"] is True
assert data["name"] == "update-test"
@pytest.mark.asyncio
async def test_update_account_password(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新密码"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "pwd-test", "password": "oldpass"},
)
account_id = create_resp.json()["id"]
# 更新密码
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"password": "newpass123"},
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_update_account_root(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试更新根目录路径"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "root-update", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 更新 root 到有效子目录
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"root": "/docs"},
)
assert response.status_code == 200
assert response.json()["root"] == "/docs"
@pytest.mark.asyncio
async def test_update_account_invalid_root(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新为无效根目录返回 400"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "bad-root-update", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 更新到无效路径
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"root": "/nonexistent"},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_update_account_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新不存在的账户返回 404"""
response = await async_client.patch(
f"{API_PREFIX}/accounts/99999",
headers=auth_headers,
json={"readonly": True},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_other_user_account(
async_client: AsyncClient,
auth_headers: dict[str, str],
admin_headers: dict[str, str],
):
"""测试更新其他用户的账户返回 404"""
# 管理员创建账户
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=admin_headers,
json={"name": "admin-account", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 普通用户尝试更新
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"readonly": True},
)
assert response.status_code == 404
# ==================== 删除账户测试 ====================
@pytest.mark.asyncio
async def test_delete_account_success(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试成功删除 WebDAV 账户"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "delete-test", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 删除
response = await async_client.delete(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
)
assert response.status_code == 204
# 确认列表中已不存在
list_resp = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert list_resp.status_code == 200
names = [a["name"] for a in list_resp.json()]
assert "delete-test" not in names
@pytest.mark.asyncio
async def test_delete_account_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试删除不存在的账户返回 404"""
response = await async_client.delete(
f"{API_PREFIX}/accounts/99999",
headers=auth_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_delete_other_user_account(
async_client: AsyncClient,
auth_headers: dict[str, str],
admin_headers: dict[str, str],
):
"""测试删除其他用户的账户返回 404"""
# 管理员创建账户
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=admin_headers,
json={"name": "admin-del-test", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 普通用户尝试删除
response = await async_client.delete(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
)
assert response.status_code == 404
# ==================== 多账户测试 ====================
@pytest.mark.asyncio
async def test_multiple_accounts(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试同一用户可以创建多个账户"""
for name in ["account-1", "account-2", "account-3"]:
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": name, "password": "pass"},
)
assert response.status_code == 201
# 列表应有3个
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
assert len(response.json()) == 3
# ==================== 用户隔离测试 ====================
@pytest.mark.asyncio
async def test_accounts_user_isolation(
async_client: AsyncClient,
auth_headers: dict[str, str],
admin_headers: dict[str, str],
):
"""测试不同用户的账户相互隔离"""
# 普通用户创建
await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "user-account", "password": "pass"},
)
# 管理员创建
await async_client.post(
f"{API_PREFIX}/accounts",
headers=admin_headers,
json={"name": "admin-account", "password": "pass"},
)
# 普通用户只看到自己的
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
accounts = response.json()
assert len(accounts) == 1
assert accounts[0]["name"] == "user-account"
# 管理员只看到自己的
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=admin_headers,
)
assert response.status_code == 200
accounts = response.json()
assert len(accounts) == 1
assert accounts[0]["name"] == "admin-account"

View File

@@ -23,7 +23,6 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
from main import app
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels.policy import GroupPolicyLink
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import UserStatus
from utils import Password
@@ -109,12 +108,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
Setting(type=SettingsType.AVATAR, name="gravatar_server", value="https://www.gravatar.com/"),
Setting(type=SettingsType.AVATAR, name="avatar_size", value="2097152"),
Setting(type=SettingsType.AVATAR, name="avatar_size_l", value="200"),
Setting(type=SettingsType.AVATAR, name="avatar_size_m", value="130"),
Setting(type=SettingsType.AVATAR, name="avatar_size_s", value="50"),
Setting(type=SettingsType.PATH, name="avatar_path", value="avatar"),
]
for setting in settings:
test_session.add(setting)
@@ -163,11 +156,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
await test_session.refresh(admin_group)
await test_session.refresh(default_policy)
# 4. 关联用户组与存储策略
test_session.add(GroupPolicyLink(group_id=default_group.id, policy_id=default_policy.id))
test_session.add(GroupPolicyLink(group_id=admin_group.id, policy_id=default_policy.id))
# 5. 创建用户组选项
# 4. 创建用户组选项
default_group_options = GroupOptions(
group_id=default_group.id,
share_download=True,

View File

@@ -37,12 +37,6 @@ async def load_secret_key() -> None:
if setting:
SECRET_KEY = setting.value
if not SECRET_KEY:
raise RuntimeError(
"JWT SECRET_KEY 未配置,拒绝启动。"
"请在 Setting 表中添加 type='auth', name='secret_key' 的记录。"
)
def build_token_payload(
data: dict,

View File

@@ -62,10 +62,6 @@ def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
"""Raises an HTTP 501 Not Implemented exception."""
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail)
def raise_bad_gateway(detail: str | None = None) -> NoReturn:
"""Raises an HTTP 502 Bad Gateway exception."""
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=detail)
def raise_service_unavailable(detail: str | None = None) -> NoReturn:
"""Raises an HTTP 503 Service Unavailable exception."""
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail)

995
uv.lock generated

File diff suppressed because it is too large Load Diff