Compare commits

..

12 Commits

Author SHA1 Message Date
15b2efe52a fix: 修复 update_group_access 中 app 变量未赋值的问题
All checks were successful
Test / test (push) Successful in 2m34s
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 11:31:02 +08:00
6c96c43bea refactor: 统一 sqlmodel_ext 用法至官方推荐模式
Some checks failed
Test / test (push) Failing after 3m47s
- 替换 Field(max_length=X) 为 StrX/TextX 类型别名(21 个 sqlmodels 文件)
- 替换 get + 404 检查为 get_exist_one()(17 个路由文件,约 50 处)
- 替换 save + session.refresh 为 save(load=...)
- 替换 session.add + commit 为 save()(dav/provider.py)
- 更新所有依赖至最新版本

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 11:13:16 +08:00
9185f26b83 feat: 添加 EPUB 阅读器、3D 模型预览和字体查看器应用,启用 Office 在线预览
All checks were successful
Test / test (push) Successful in 2m31s
2026-02-26 12:50:24 +08:00
f4052d229a fix: clean up empty parent directories after file deletion
All checks were successful
Test / test (push) Successful in 2m32s
Prevent local storage fragmentation by removing empty directories
left behind when files are permanently deleted or moved to trash.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:44 +08:00
bc2182720d feat: implement avatar upload, Gravatar support, and avatar settings
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:24 +08:00
eddf38d316 chore: remove applied migration script
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:07 +08:00
03e768d232 chore: update .gitignore for avatar and dev directories
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:00 +08:00
bcb0a9b322 feat: redesign metadata as KV store, add custom properties and WOPI Discovery
Some checks failed
Test / test (push) Failing after 2m32s
Replace one-to-one FileMetadata table with flexible ObjectMetadata KV pairs,
add custom property definitions, WOPI Discovery auto-configuration, and
per-extension action URL support.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 17:21:22 +08:00
743a2c9d65 fix: use TaskStatus/TaskType enums in TaskDetailResponse
Some checks failed
Test / test (push) Failing after 2m17s
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 16:44:01 +08:00
3639a31163 feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s
- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 13:38:20 +08:00
7200df6d87 fix: patch storage quota bypass and harden auth security
All checks were successful
Test / test (push) Successful in 2m11s
- Fix WebDAV chunked PUT bypassing storage quota when remaining_quota <= 0
- Add QuotaLimitedWriter to enforce quota during streaming writes
- Clean up residual files on write failure in end_write()
- Add Magic Link replay attack prevention via TokenStore
- Reject startup when JWT SECRET_KEY is not configured
- Sanitize OAuth callback and Magic Link log output

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:20:43 +08:00
40b6a31c98 feat: implement WebDAV protocol support with WsgiDAV + account management API
All checks were successful
Test / test (push) Successful in 2m14s
Add complete WebDAV support: management REST API (CRUD accounts at /api/v1/webdav/accounts)
and DAV protocol endpoint (/dav) using WsgiDAV + a2wsgi bridge for client access via
HTTP Basic Auth. Includes Redis+TTLCache auth caching and integration tests (24 cases).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:19:29 +08:00
79 changed files with 6742 additions and 1392 deletions

View File

@@ -5,7 +5,8 @@
"Bash(findstr:*)", "Bash(findstr:*)",
"Bash(find:*)", "Bash(find:*)",
"Bash(yarn tsc:*)", "Bash(yarn tsc:*)",
"Bash(dir:*)" "Bash(dir:*)",
"mcp__server-notify__notify"
] ]
} }
} }

5
.gitignore vendored
View File

@@ -1,8 +1,6 @@
# Python # Python
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*.pyo
*.pyd
*.so *.so
*.egg *.egg
*.egg-info/ *.egg-info/
@@ -79,3 +77,6 @@ statics/
# 许可证密钥(保密) # 许可证密钥(保密)
license_private.pem license_private.pem
license.key license.key
avatar/
.dev/

2
ee

Submodule ee updated: 52921f9ffe...cc32d8db91

31
main.py
View File

@@ -5,7 +5,10 @@ from fastapi import FastAPI, Request
from loguru import logger as l from loguru import logger as l
from routers import router from routers import router
from routers.dav import dav_app
from routers.dav.provider import EventLoopRef
from service.redis import RedisManager from service.redis import RedisManager
from service.storage import S3StorageService
from sqlmodels.database_connection import DatabaseManager from sqlmodels.database_connection import DatabaseManager
from sqlmodels.migration import migration from sqlmodels.migration import migration
from utils import JWT from utils import JWT
@@ -14,24 +17,26 @@ from utils.http.http_exceptions import raise_internal_error
from utils.lifespan import lifespan from utils.lifespan import lifespan
# 尝试加载企业版功能 # 尝试加载企业版功能
_has_ee: bool = False
try: try:
from ee import init_ee from ee import init_ee
from ee.license import LicenseError from ee.license import LicenseError
from ee.routers import ee_router
async def _init_ee_and_routes() -> None: _has_ee = True
async def _init_ee() -> None:
"""启动时验证许可证,路由由 license_valid_required 依赖保护"""
try: try:
await init_ee() await init_ee()
except LicenseError as exc: except LicenseError as exc:
l.critical(f"许可证验证失败: {exc}") l.critical(f"许可证验证失败: {exc}")
raise SystemExit(1) from exc raise SystemExit(1) from exc
from ee.routers import ee_router lifespan.add_startup(_init_ee)
from routers.api.v1 import router as v1_router except ImportError as exc:
v1_router.include_router(ee_router) ee_router = None
l.info(f"以 Community 版本运行 (原因: {exc})")
lifespan.add_startup(_init_ee_and_routes)
except ImportError:
l.info("以 Community 版本运行")
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve() STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
"""前端静态文件目录(由 Docker 构建时复制)""" """前端静态文件目录(由 Docker 构建时复制)"""
@@ -40,13 +45,18 @@ async def _init_db() -> None:
"""初始化数据库连接引擎""" """初始化数据库连接引擎"""
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug) await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
# 捕获事件循环引用(供 WSGI 线程桥接使用)
lifespan.add_startup(EventLoopRef.capture)
# 添加初始化数据库启动项 # 添加初始化数据库启动项
lifespan.add_startup(_init_db) lifespan.add_startup(_init_db)
lifespan.add_startup(migration) lifespan.add_startup(migration)
lifespan.add_startup(JWT.load_secret_key) lifespan.add_startup(JWT.load_secret_key)
lifespan.add_startup(RedisManager.connect) lifespan.add_startup(RedisManager.connect)
lifespan.add_startup(S3StorageService.initialize_session)
# 添加关闭项 # 添加关闭项
lifespan.add_shutdown(S3StorageService.close_session)
lifespan.add_shutdown(DatabaseManager.close) lifespan.add_shutdown(DatabaseManager.close)
lifespan.add_shutdown(RedisManager.disconnect) lifespan.add_shutdown(RedisManager.disconnect)
@@ -87,6 +97,11 @@ async def handle_unexpected_exceptions(
# 挂载路由 # 挂载路由
app.include_router(router) app.include_router(router)
if _has_ee:
app.include_router(ee_router, prefix="/api/v1")
# 挂载 WebDAV 协议端点(优先于 SPA catch-all
app.mount("/dav", dav_app)
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境) # 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
if STATICS_DIR.is_dir(): if STATICS_DIR.is_dir():

View File

@@ -33,6 +33,8 @@ dependencies = [
"uvicorn>=0.38.0", "uvicorn>=0.38.0",
"webauthn>=2.7.0", "webauthn>=2.7.0",
"whatthepatch>=1.0.6", "whatthepatch>=1.0.6",
"wsgidav>=4.3.0",
"a2wsgi>=1.10.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -5,6 +5,7 @@ from utils.conf import appmeta
from .admin import admin_router from .admin import admin_router
from .callback import callback_router from .callback import callback_router
from .category import category_router
from .directory import directory_router from .directory import directory_router
from .download import download_router from .download import download_router
from .file import router as file_router from .file import router as file_router
@@ -14,7 +15,6 @@ from .trash import trash_router
from .site import site_router from .site import site_router
from .slave import slave_router from .slave import slave_router
from .user import user_router from .user import user_router
from .vas import vas_router
from .webdav import webdav_router from .webdav import webdav_router
router = APIRouter(prefix="/v1") router = APIRouter(prefix="/v1")
@@ -24,6 +24,7 @@ router = APIRouter(prefix="/v1")
if appmeta.mode == "master": if appmeta.mode == "master":
router.include_router(admin_router) router.include_router(admin_router)
router.include_router(callback_router) router.include_router(callback_router)
router.include_router(category_router)
router.include_router(directory_router) router.include_router(directory_router)
router.include_router(download_router) router.include_router(download_router)
router.include_router(file_router) router.include_router(file_router)
@@ -32,7 +33,6 @@ if appmeta.mode == "master":
router.include_router(site_router) router.include_router(site_router)
router.include_router(trash_router) router.include_router(trash_router)
router.include_router(user_router) router.include_router(user_router)
router.include_router(vas_router)
router.include_router(webdav_router) router.include_router(webdav_router)
elif appmeta.mode == "slave": elif appmeta.mode == "slave":
router.include_router(slave_router) router.include_router(slave_router)

View File

@@ -16,6 +16,12 @@ from sqlmodels.setting import (
from sqlmodels.setting import SettingsType from sqlmodels.setting import SettingsType
from utils import http_exceptions from utils import http_exceptions
from utils.conf import appmeta from utils.conf import appmeta
try:
from ee.service import get_cached_license
except ImportError:
get_cached_license = None
from .file import admin_file_router from .file import admin_file_router
from .file_app import admin_file_app_router from .file_app import admin_file_app_router
from .group import admin_group_router from .group import admin_group_router
@@ -24,7 +30,6 @@ from .share import admin_share_router
from .task import admin_task_router from .task import admin_task_router
from .user import admin_user_router from .user import admin_user_router
from .theme import admin_theme_router from .theme import admin_theme_router
from .vas import admin_vas_router
class Aria2TestRequest(SQLModelBase): class Aria2TestRequest(SQLModelBase):
@@ -50,7 +55,6 @@ admin_router.include_router(admin_policy_router)
admin_router.include_router(admin_share_router) admin_router.include_router(admin_share_router)
admin_router.include_router(admin_task_router) admin_router.include_router(admin_task_router)
admin_router.include_router(admin_theme_router) admin_router.include_router(admin_theme_router)
admin_router.include_router(admin_vas_router)
# 离线下载 /api/admin/aria2 # 离线下载 /api/admin/aria2
admin_aria2_router = APIRouter( admin_aria2_router = APIRouter(
@@ -159,14 +163,24 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
if site_url_setting and site_url_setting.value: if site_url_setting and site_url_setting.value:
site_urls.append(site_url_setting.value) site_urls.append(site_url_setting.value)
# 许可证信息(从设置读取或使用默认值 # 许可证信息(Pro 版本从缓存读取CE 版本永不过期
license_info = LicenseInfo( if appmeta.IsPro and get_cached_license:
expired_at=now + timedelta(days=365), payload = get_cached_license()
signed_at=now, license_info = LicenseInfo(
root_domains=[], expired_at=payload.expires_at,
domains=[], signed_at=payload.issued_at,
vol_domains=[], root_domains=[],
) domains=[payload.domain],
vol_domains=[],
)
else:
license_info = LicenseInfo(
expired_at=datetime.max,
signed_at=now,
root_domains=[],
domains=[],
vol_domains=[],
)
# 版本信息 # 版本信息
version_info = VersionInfo( version_info = VersionInfo(
@@ -225,11 +239,11 @@ async def router_admin_update_settings(
if existing: if existing:
existing.value = item.value existing.value = item.value
await existing.save(session) existing = await existing.save(session)
updated_count += 1 updated_count += 1
else: else:
new_setting = Setting(type=item.type, name=item.name, value=item.value) new_setting = Setting(type=item.type, name=item.name, value=item.value)
await new_setting.save(session) new_setting = await new_setting.save(session)
created_count += 1 created_count += 1
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项") l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")

View File

@@ -54,7 +54,7 @@ async def _set_ban_recursive(
obj.banned_by = None obj.banned_by = None
obj.ban_reason = None obj.ban_reason = None
await obj.save(session) obj = await obj.save(session)
count += 1 count += 1
return count return count
@@ -131,9 +131,7 @@ async def router_admin_preview_file(
:param file_id: 文件UUID :param file_id: 文件UUID
:return: 文件内容 :return: 文件内容
""" """
file_obj = await Object.get(session, Object.id == file_id) file_obj = await Object.get_exist_one(session, file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")
@@ -182,9 +180,7 @@ async def router_admin_ban_file(
:param claims: 当前管理员 JWT claims :param claims: 当前管理员 JWT claims
:return: 封禁结果 :return: 封禁结果
""" """
file_obj = await Object.get(session, Object.id == file_id) file_obj = await Object.get_exist_one(session, file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason) count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
@@ -212,9 +208,7 @@ async def router_admin_delete_file(
:param delete_physical: 是否同时删除物理文件 :param delete_physical: 是否同时删除物理文件
:return: 删除结果 :return: 删除结果
""" """
file_obj = await Object.get(session, Object.id == file_id) file_obj = await Object.get_exist_one(session, file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")

View File

@@ -1,16 +1,18 @@
""" """
管理员文件应用管理端点 管理员文件应用管理端点
提供文件查看器应用的 CRUD、扩展名管理用户组权限管理。 提供文件查看器应用的 CRUD、扩展名管理用户组权限管理和 WOPI Discovery
""" """
from uuid import UUID from uuid import UUID
import aiohttp
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from loguru import logger as l from loguru import logger as l
from sqlalchemy import select from sqlalchemy import select
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from service.wopi import parse_wopi_discovery_xml
from sqlmodels import ( from sqlmodels import (
FileApp, FileApp,
FileAppCreateRequest, FileAppCreateRequest,
@@ -21,7 +23,10 @@ from sqlmodels import (
FileAppUpdateRequest, FileAppUpdateRequest,
ExtensionUpdateRequest, ExtensionUpdateRequest,
GroupAccessUpdateRequest, GroupAccessUpdateRequest,
WopiDiscoveredExtension,
WopiDiscoveryResponse,
) )
from sqlmodels.file_app import FileAppType
from utils import http_exceptions from utils import http_exceptions
admin_file_app_router = APIRouter( admin_file_app_router = APIRouter(
@@ -123,6 +128,7 @@ async def create_file_app(
group_links.append(link) group_links.append(link)
if group_links: if group_links:
await session.commit() await session.commit()
await session.refresh(app)
l.info(f"创建文件应用: {app.name} ({app.app_key})") l.info(f"创建文件应用: {app.name} ({app.app_key})")
@@ -145,9 +151,7 @@ async def get_file_app(
错误处理: 错误处理:
- 404: 应用不存在 - 404: 应用不存在
""" """
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id) app = await FileApp.get_exist_one(session, app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
extensions = await FileAppExtension.get( extensions = await FileAppExtension.get(
session, session,
@@ -180,9 +184,7 @@ async def update_file_app(
- 404: 应用不存在 - 404: 应用不存在
- 409: 新 app_key 已被其他应用使用 - 409: 新 app_key 已被其他应用使用
""" """
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id) app = await FileApp.get_exist_one(session, app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
# 检查 app_key 唯一性 # 检查 app_key 唯一性
if request.app_key is not None and request.app_key != app.app_key: if request.app_key is not None and request.app_key != app.app_key:
@@ -229,9 +231,7 @@ async def delete_file_app(
错误处理: 错误处理:
- 404: 应用不存在 - 404: 应用不存在
""" """
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id) app = await FileApp.get_exist_one(session, app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
app_name = app.app_key app_name = app.app_key
await FileApp.delete(session, app) await FileApp.delete(session, app)
@@ -257,20 +257,24 @@ async def update_extensions(
错误处理: 错误处理:
- 404: 应用不存在 - 404: 应用不存在
""" """
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id) app = await FileApp.get_exist_one(session, app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
# 删除旧的扩展名 # 保留旧扩展名的 wopi_action_urlDiscovery 填充的值)
old_extensions: list[FileAppExtension] = await FileAppExtension.get( old_extensions: list[FileAppExtension] = await FileAppExtension.get(
session, session,
FileAppExtension.app_id == app_id, FileAppExtension.app_id == app_id,
fetch_mode="all", fetch_mode="all",
) )
old_url_map: dict[str, str] = {
ext.extension: ext.wopi_action_url
for ext in old_extensions
if ext.wopi_action_url
}
for old_ext in old_extensions: for old_ext in old_extensions:
await FileAppExtension.delete(session, old_ext, commit=False) await FileAppExtension.delete(session, old_ext, commit=False)
await session.flush()
# 创建新的扩展名 # 创建新的扩展名(保留已有的 wopi_action_url
new_extensions: list[FileAppExtension] = [] new_extensions: list[FileAppExtension] = []
for i, ext in enumerate(request.extensions): for i, ext in enumerate(request.extensions):
normalized = ext.lower().strip().lstrip('.') normalized = ext.lower().strip().lstrip('.')
@@ -278,12 +282,14 @@ async def update_extensions(
app_id=app_id, app_id=app_id,
extension=normalized, extension=normalized,
priority=i, priority=i,
wopi_action_url=old_url_map.get(normalized),
) )
session.add(ext_record) session.add(ext_record)
new_extensions.append(ext_record) new_extensions.append(ext_record)
await session.commit() await session.commit()
# refresh 新创建的记录 # refresh commit 后过期的对象
await session.refresh(app)
for ext_record in new_extensions: for ext_record in new_extensions:
await session.refresh(ext_record) await session.refresh(ext_record)
@@ -316,9 +322,7 @@ async def update_group_access(
错误处理: 错误处理:
- 404: 应用不存在 - 404: 应用不存在
""" """
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id) app = await FileApp.get_exist_one(session, app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
# 删除旧的用户组关联 # 删除旧的用户组关联
old_links_result = await session.exec( old_links_result = await session.exec(
@@ -336,6 +340,7 @@ async def update_group_access(
new_links.append(link) new_links.append(link)
await session.commit() await session.commit()
await session.refresh(app)
extensions = await FileAppExtension.get( extensions = await FileAppExtension.get(
session, session,
@@ -346,3 +351,100 @@ async def update_group_access(
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}") l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
return FileAppResponse.from_app(app, extensions, new_links) return FileAppResponse.from_app(app, extensions, new_links)
@admin_file_app_router.post(
path='/{app_id}/discover',
summary='执行 WOPI Discovery',
)
async def discover_wopi(
session: SessionDep,
app_id: UUID,
) -> WopiDiscoveryResponse:
"""
从 WOPI 服务端获取 Discovery XML 并自动配置扩展名和 URL 模板。
流程:
1. 验证 FileApp 存在且为 WOPI 类型
2. 使用 FileApp.wopi_discovery_url 获取 Discovery XML
3. 解析 XML提取扩展名和动作 URL
4. 全量替换 FileAppExtension 记录(带 wopi_action_url
认证:管理员权限
错误处理:
- 404: 应用不存在
- 400: 非 WOPI 类型 / discovery URL 未配置 / XML 解析失败
- 502: WOPI 服务端不可达或返回无效响应
"""
app = await FileApp.get_exist_one(session, app_id)
if app.type != FileAppType.WOPI:
http_exceptions.raise_bad_request("仅 WOPI 类型应用支持自动发现")
if not app.wopi_discovery_url:
http_exceptions.raise_bad_request("未配置 WOPI Discovery URL")
# commit 后对象会过期,先保存需要的值
discovery_url = app.wopi_discovery_url
app_key = app.app_key
# 获取 Discovery XML
try:
async with aiohttp.ClientSession() as client:
async with client.get(
discovery_url,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
http_exceptions.raise_bad_gateway(
f"WOPI 服务端返回 HTTP {resp.status}"
)
xml_content = await resp.text()
except aiohttp.ClientError as e:
http_exceptions.raise_bad_gateway(f"无法连接 WOPI 服务端: {e}")
# 解析 XML
try:
action_urls, app_names = parse_wopi_discovery_xml(xml_content)
except ValueError as e:
http_exceptions.raise_bad_request(str(e))
if not action_urls:
return WopiDiscoveryResponse(app_names=app_names)
# 全量替换扩展名
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
for old_ext in old_extensions:
await FileAppExtension.delete(session, old_ext, commit=False)
await session.flush()
new_extensions: list[FileAppExtension] = []
discovered: list[WopiDiscoveredExtension] = []
for i, (ext, action_url) in enumerate(sorted(action_urls.items())):
ext_record = FileAppExtension(
app_id=app_id,
extension=ext,
priority=i,
wopi_action_url=action_url,
)
session.add(ext_record)
new_extensions.append(ext_record)
discovered.append(WopiDiscoveredExtension(extension=ext, action_url=action_url))
await session.commit()
l.info(
f"WOPI Discovery 完成: 应用 {app_key}, "
f"发现 {len(discovered)} 个扩展名"
)
return WopiDiscoveryResponse(
discovered_extensions=discovered,
app_names=app_names,
applied_count=len(discovered),
)

View File

@@ -63,10 +63,7 @@ async def router_admin_get_group(
:param group_id: 用户组UUID :param group_id: 用户组UUID
:return: 用户组详情 :return: 用户组详情
""" """
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies]) group = await Group.get_exist_one(session, group_id, load=[Group.options, Group.policies])
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 直接访问已加载的关系,无需额外查询 # 直接访问已加载的关系,无需额外查询
policies = group.policies policies = group.policies
@@ -94,9 +91,7 @@ async def router_admin_get_group_members(
:return: 分页成员列表 :return: 分页成员列表
""" """
# 验证组存在 # 验证组存在
group = await Group.get(session, Group.id == group_id) await Group.get_exist_one(session, group_id)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view) result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
@@ -138,10 +133,11 @@ async def router_admin_create_group(
speed_limit=request.speed_limit, speed_limit=request.speed_limit,
) )
group = await group.save(session) group = await group.save(session)
group_id_val: UUID = group.id
# 创建选项 # 创建选项
options = GroupOptions( options = GroupOptions(
group_id=group.id, group_id=group_id_val,
share_download=request.share_download, share_download=request.share_download,
share_free=request.share_free, share_free=request.share_free,
relocate=request.relocate, relocate=request.relocate,
@@ -154,11 +150,11 @@ async def router_admin_create_group(
aria2=request.aria2, aria2=request.aria2,
redirected_source=request.redirected_source, redirected_source=request.redirected_source,
) )
await options.save(session) options = await options.save(session)
# 关联存储策略 # 关联存储策略
for policy_id in request.policy_ids: for policy_id in request.policy_ids:
link = GroupPolicyLink(group_id=group.id, policy_id=policy_id) link = GroupPolicyLink(group_id=group_id_val, policy_id=policy_id)
session.add(link) session.add(link)
await session.commit() await session.commit()
@@ -185,9 +181,7 @@ async def router_admin_update_group(
:param request: 更新请求 :param request: 更新请求
:return: 更新结果 :return: 更新结果
""" """
group = await Group.get(session, Group.id == group_id, load=Group.options) group = await Group.get_exist_one(session, group_id, load=Group.options)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 检查名称唯一性(如果要更新名称) # 检查名称唯一性(如果要更新名称)
if request.name and request.name != group.name: if request.name and request.name != group.name:
@@ -217,7 +211,7 @@ async def router_admin_update_group(
if options_data: if options_data:
for key, value in options_data.items(): for key, value in options_data.items():
setattr(group.options, key, value) setattr(group.options, key, value)
await group.options.save(session) group.options = await group.options.save(session)
# 更新策略关联 # 更新策略关联
if request.policy_ids is not None: if request.policy_ids is not None:
@@ -255,9 +249,7 @@ async def router_admin_delete_group(
:param group_id: 用户组UUID :param group_id: 用户组UUID
:return: 删除结果 :return: 删除结果
""" """
group = await Group.get(session, Group.id == group_id) group = await Group.get_exist_one(session, group_id)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 检查是否有用户属于该组 # 检查是否有用户属于该组
user_count = await User.count(session, User.group_id == group_id) user_count = await User.count(session, User.group_id == group_id)

View File

@@ -8,11 +8,11 @@ from sqlmodel import Field
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import ( from sqlmodels import (
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase, Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
ListResponse, Object, PolicyUpdateRequest, ResponseBase, ListResponse, Object,
) )
from sqlmodel_ext import SQLModelBase from sqlmodel_ext import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
admin_policy_router = APIRouter( admin_policy_router = APIRouter(
prefix='/policy', prefix='/policy',
@@ -67,6 +67,12 @@ class PolicyDetailResponse(SQLModelBase):
base_url: str | None base_url: str | None
"""基础URL""" """基础URL"""
access_key: str | None
"""Access Key"""
secret_key: str | None
"""Secret Key"""
max_size: int max_size: int
"""最大文件尺寸""" """最大文件尺寸"""
@@ -107,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
secret: str secret: str
"""从机通信密钥""" """从机通信密钥"""
class PolicyCreateRequest(PolicyBase): class PolicyTestS3Request(SQLModelBase):
"""创建存储策略请求 DTO继承 PolicyBase 中的所有字段""" """测试 S3 连接请求 DTO"""
pass
server: str = Field(max_length=255)
"""S3 端点地址"""
bucket_name: str = Field(max_length=255)
"""存储桶名称"""
access_key: str
"""Access Key"""
secret_key: str
"""Secret Key"""
s3_region: str = Field(default='us-east-1', max_length=64)
"""S3 区域"""
s3_path_style: bool = False
"""是否使用路径风格"""
class PolicyTestS3Response(SQLModelBase):
"""S3 连接测试响应"""
is_connected: bool
"""连接是否成功"""
message: str
"""测试结果消息"""
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
_OPTIONS_FIELDS: set[str] = {
'token', 'file_type', 'mimetype', 'od_redirect',
'chunk_size', 's3_path_style', 's3_region',
}
@admin_policy_router.get( @admin_policy_router.get(
path='/list', path='/list',
@@ -277,7 +319,20 @@ async def router_policy_add_policy(
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}") raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
# 保存到数据库 # 保存到数据库
await policy.save(session) policy = await policy.save(session)
# 创建策略选项
options = PolicyOptions(
policy_id=policy.id,
token=request.token,
file_type=request.file_type,
mimetype=request.mimetype,
od_redirect=request.od_redirect,
chunk_size=request.chunk_size,
s3_path_style=request.s3_path_style,
s3_region=request.s3_region,
)
options = await options.save(session)
@admin_policy_router.post( @admin_policy_router.post(
path='/cors', path='/cors',
@@ -328,9 +383,7 @@ async def router_policy_onddrive_oauth(
:param policy_id: 存储策略UUID :param policy_id: 存储策略UUID
:return: OAuth URL :return: OAuth URL
""" """
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# TODO: 实现OneDrive OAuth # TODO: 实现OneDrive OAuth
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现") raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
@@ -353,9 +406,7 @@ async def router_policy_get_policy(
:param policy_id: 存储策略UUID :param policy_id: 存储策略UUID
:return: 策略详情 :return: 策略详情
""" """
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options) policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 获取使用此策略的用户组 # 获取使用此策略的用户组
groups = await policy.awaitable_attrs.groups groups = await policy.awaitable_attrs.groups
@@ -371,6 +422,8 @@ async def router_policy_get_policy(
bucket_name=policy.bucket_name, bucket_name=policy.bucket_name,
is_private=policy.is_private, is_private=policy.is_private,
base_url=policy.base_url, base_url=policy.base_url,
access_key=policy.access_key,
secret_key=policy.secret_key,
max_size=policy.max_size, max_size=policy.max_size,
auto_rename=policy.auto_rename, auto_rename=policy.auto_rename,
dir_name_rule=policy.dir_name_rule, dir_name_rule=policy.dir_name_rule,
@@ -402,9 +455,7 @@ async def router_policy_delete_policy(
:param policy_id: 存储策略UUID :param policy_id: 存储策略UUID
:return: 删除结果 :return: 删除结果
""" """
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 检查是否有文件使用此策略 # 检查是否有文件使用此策略
file_count = await Object.count(session, Object.policy_id == policy_id) file_count = await Object.count(session, Object.policy_id == policy_id)
@@ -418,3 +469,105 @@ async def router_policy_delete_policy(
await Policy.delete(session, policy) await Policy.delete(session, policy)
l.info(f"管理员删除了存储策略: {policy_name}") l.info(f"管理员删除了存储策略: {policy_name}")
@admin_policy_router.patch(
path='/{policy_id}',
summary='更新存储策略',
description='更新存储策略配置。策略类型创建后不可更改。',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_policy_update_policy(
session: SessionDep,
policy_id: UUID,
request: PolicyUpdateRequest,
) -> None:
"""
更新存储策略端点
功能:
- 更新策略基础字段和扩展选项
- 策略类型type不可更改
认证:
- 需要管理员权限
:param session: 数据库会话
:param policy_id: 存储策略UUID
:param request: 更新请求
"""
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
# 检查名称唯一性(如果要更新名称)
if request.name and request.name != policy.name:
existing = await Policy.get(session, Policy.name == request.name)
if existing:
raise HTTPException(status_code=409, detail="策略名称已存在")
# 分离 Policy 字段和 Options 字段
all_data = request.model_dump(exclude_unset=True)
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
# 更新 Policy 基础字段
if policy_data:
for key, value in policy_data.items():
setattr(policy, key, value)
policy = await policy.save(session)
# 更新或创建 PolicyOptions
if options_data:
if policy.options:
for key, value in options_data.items():
setattr(policy.options, key, value)
policy.options = await policy.options.save(session)
else:
options = PolicyOptions(policy_id=policy.id, **options_data)
options = await options.save(session)
l.info(f"管理员更新了存储策略: {policy_id}")
@admin_policy_router.post(
path='/test/s3',
summary='测试 S3 连接',
description='测试 S3 存储端点的连通性和凭据有效性。',
dependencies=[Depends(admin_required)],
)
async def router_policy_test_s3(
request: PolicyTestS3Request,
) -> PolicyTestS3Response:
"""
测试 S3 连接端点
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
:param request: 测试请求
:return: 测试结果
"""
from service.storage import S3APIError
# 构造临时 Policy 对象用于创建 S3StorageService
temp_policy = Policy(
name="__test__",
type=PolicyType.S3,
server=request.server,
bucket_name=request.bucket_name,
access_key=request.access_key,
secret_key=request.secret_key,
)
s3_service = S3StorageService(
temp_policy,
region=request.s3_region,
is_path_style=request.s3_path_style,
)
try:
# 使用 file_exists 发送 HEAD 请求来验证连通性
await s3_service.file_exists("__connection_test__")
return PolicyTestS3Response(is_connected=True, message="连接成功")
except S3APIError as e:
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
except Exception as e:
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")

View File

@@ -155,9 +155,7 @@ async def router_admin_delete_share(
:param share_id: 分享ID :param share_id: 分享ID
:return: 删除结果 :return: 删除结果
""" """
share = await Share.get(session, Share.id == share_id) share = await Share.get_exist_one(session, share_id)
if not share:
raise HTTPException(status_code=404, detail="分享不存在")
await Share.delete(session, share) await Share.delete(session, share)

View File

@@ -8,7 +8,7 @@ from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import ( from sqlmodels import (
ListResponse, ListResponse,
Task, TaskSummary, Task, TaskSummary, TaskStatus, TaskType,
) )
from sqlmodel_ext import SQLModelBase from sqlmodel_ext import SQLModelBase
@@ -19,10 +19,10 @@ class TaskDetailResponse(SQLModelBase):
id: int id: int
"""任务ID""" """任务ID"""
status: int status: TaskStatus
"""任务状态""" """任务状态"""
type: int type: TaskType
"""任务类型""" """任务类型"""
progress: int progress: int
@@ -150,9 +150,7 @@ async def router_admin_delete_task(
:param task_id: 任务ID :param task_id: 任务ID
:return: 删除结果 :return: 删除结果
""" """
task = await Task.get(session, Task.id == task_id) task = await Task.get_exist_one(session, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
await Task.delete(session, task) await Task.delete(session, task)

View File

@@ -71,7 +71,7 @@ async def router_admin_theme_create(
name=request.name, name=request.name,
**request.colors.model_dump(), **request.colors.model_dump(),
) )
await preset.save(session) preset = await preset.save(session)
l.info(f"管理员创建了主题预设: {request.name}") l.info(f"管理员创建了主题预设: {request.name}")
@@ -101,11 +101,7 @@ async def router_admin_theme_update(
- 404: 预设不存在 - 404: 预设不存在
- 409: 名称已被其他预设使用 - 409: 名称已被其他预设使用
""" """
preset: ThemePreset | None = await ThemePreset.get( preset = await ThemePreset.get_exist_one(session, preset_id)
session, ThemePreset.id == preset_id
)
if not preset:
http_exceptions.raise_not_found("主题预设不存在")
# 检查名称唯一性(排除自身) # 检查名称唯一性(排除自身)
if request.name is not None and request.name != preset.name: if request.name is not None and request.name != preset.name:
@@ -120,7 +116,7 @@ async def router_admin_theme_update(
for key, value in color_data.items(): for key, value in color_data.items():
setattr(preset, key, value) setattr(preset, key, value)
await preset.save(session) preset = await preset.save(session)
l.info(f"管理员更新了主题预设: {preset.name}") l.info(f"管理员更新了主题预设: {preset.name}")
@@ -147,11 +143,7 @@ async def router_admin_theme_delete(
副作用: 副作用:
- 关联用户的 theme_preset_id 会被数据库 SET NULL - 关联用户的 theme_preset_id 会被数据库 SET NULL
""" """
preset: ThemePreset | None = await ThemePreset.get( preset = await ThemePreset.get_exist_one(session, preset_id)
session, ThemePreset.id == preset_id
)
if not preset:
http_exceptions.raise_not_found("主题预设不存在")
await preset.delete(session) await preset.delete(session)
l.info(f"管理员删除了主题预设: {preset.name}") l.info(f"管理员删除了主题预设: {preset.name}")
@@ -180,11 +172,7 @@ async def router_admin_theme_set_default(
逻辑: 逻辑:
- 事务中先清除所有旧默认,再设新默认 - 事务中先清除所有旧默认,再设新默认
""" """
preset: ThemePreset | None = await ThemePreset.get( preset = await ThemePreset.get_exist_one(session, preset_id)
session, ThemePreset.id == preset_id
)
if not preset:
http_exceptions.raise_not_found("主题预设不存在")
# 清除所有旧默认 # 清除所有旧默认
await session.execute( await session.execute(
@@ -195,5 +183,5 @@ async def router_admin_theme_set_default(
# 设新默认 # 设新默认
preset.is_default = True preset.is_default = True
await preset.save(session) preset = await preset.save(session)
l.info(f"管理员将主题预设 '{preset.name}' 设为默认") l.info(f"管理员将主题预设 '{preset.name}' 设为默认")

View File

@@ -128,8 +128,9 @@ async def router_admin_create_user(
is_verified=True, is_verified=True,
user_id=user.id, user_id=user.id,
) )
await identity.save(session) identity = await identity.save(session)
user = await User.get(session, User.id == user.id, load=User.group)
return user.to_public() return user.to_public()
@@ -153,9 +154,7 @@ async def router_admin_update_user(
:param request: 更新请求 :param request: 更新请求
:return: 更新结果 :return: 更新结果
""" """
user = await User.get(session, User.id == user_id) user = await User.get_exist_one(session, user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别) # 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
default_admin_setting = await Setting.get( default_admin_setting = await Setting.get(
@@ -252,9 +251,7 @@ async def router_admin_calibrate_storage(
:param user_id: 用户UUID :param user_id: 用户UUID
:return: 校准结果 :return: 校准结果
""" """
user = await User.get(session, User.id == user_id) user = await User.get_exist_one(session, user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
previous_storage = user.storage previous_storage = user.storage

View File

@@ -1,81 +0,0 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from sqlmodels import (
ResponseBase,
)
admin_vas_router = APIRouter(
prefix='/vas',
tags=['admin', 'admin_vas']
)
@admin_vas_router.get(
path='/list',
summary='获取增值服务列表',
description='Get VAS list (orders and storage packs)',
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas_list(
session: SessionDep,
user_id: UUID | None = None,
page: int = 1,
page_size: int = 20,
) -> ResponseBase:
"""
获取增值服务列表(订单和存储包)。
:param session: 数据库会话
:param user_id: 按用户筛选
:param page: 页码
:param page_size: 每页数量
:return: 增值服务列表
"""
# TODO: 实现增值服务列表
# 需要查询 Order 和 StoragePack 模型
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
@admin_vas_router.get(
path='/{vas_id}',
summary='获取增值服务详情',
description='Get VAS detail by ID',
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas(
session: SessionDep,
vas_id: UUID,
) -> ResponseBase:
"""
获取增值服务详情。
:param session: 数据库会话
:param vas_id: 增值服务UUID
:return: 增值服务详情
"""
# TODO: 实现增值服务详情
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
@admin_vas_router.delete(
path='/{vas_id}',
summary='删除增值服务',
description='Delete VAS by ID',
dependencies=[Depends(admin_required)]
)
async def router_admin_delete_vas(
session: SessionDep,
vas_id: UUID,
) -> ResponseBase:
"""
删除增值服务。
:param session: 数据库会话
:param vas_id: 增值服务UUID
:return: 删除结果
"""
# TODO: 实现增值服务删除
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from loguru import logger as l
from sqlmodels import ResponseBase from sqlmodels import ResponseBase
import service.oauth import service.oauth
@@ -15,18 +16,12 @@ oauth_router = APIRouter(
tags=["callback", "oauth"], tags=["callback", "oauth"],
) )
pay_router = APIRouter(
prefix='/callback/pay',
tags=["callback", "pay"],
)
upload_router = APIRouter( upload_router = APIRouter(
prefix='/callback/upload', prefix='/callback/upload',
tags=["callback", "upload"], tags=["callback", "upload"],
) )
callback_router.include_router(oauth_router) callback_router.include_router(oauth_router)
callback_router.include_router(pay_router)
callback_router.include_router(upload_router) callback_router.include_router(upload_router)
@oauth_router.post( @oauth_router.post(
@@ -64,91 +59,17 @@ async def router_callback_github(
""" """
try: try:
access_token = await service.oauth.github.get_access_token(code) access_token = await service.oauth.github.get_access_token(code)
# [TODO] 把access_token写数据库里
if not access_token: if not access_token:
return PlainTextResponse("Failed to retrieve access token from GitHub.", status_code=400) return PlainTextResponse("GitHub 认证失败", status_code=400)
user_data = await service.oauth.github.get_user_info(access_token.access_token) user_data = await service.oauth.github.get_user_info(access_token.access_token)
# [TODO] 把user_data写数据库 # [TODO] 把 access_token 和 user_data 写数据库,生成 JWT重定向到前端
l.info(f"GitHub OAuth 回调成功: user={user_data.user_data.login}")
return PlainTextResponse(f"User information processed successfully, code: {code}, user_data: {user_data.json_dump()}", status_code=200) return PlainTextResponse("认证成功,功能开发中", status_code=200)
except Exception as e: except Exception as e:
return PlainTextResponse(f"An error occurred: {str(e)}", status_code=500) l.error(f"GitHub OAuth 回调异常: {e}")
return PlainTextResponse("认证过程中发生错误,请重试", status_code=500)
@pay_router.post(
path='/alipay',
summary='支付宝支付回调',
description='Handle Alipay payment callback and return payment status.',
)
def router_callback_alipay() -> ResponseBase:
"""
Handle Alipay payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the Alipay payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/wechat',
summary='微信支付回调',
description='Handle WeChat Pay payment callback and return payment status.',
)
def router_callback_wechat() -> ResponseBase:
"""
Handle WeChat Pay payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/stripe',
summary='Stripe支付回调',
description='Handle Stripe payment callback and return payment status.',
)
def router_callback_stripe() -> ResponseBase:
"""
Handle Stripe payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the Stripe payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.get(
path='/easypay',
summary='易支付回调',
description='Handle EasyPay payment callback and return payment status.',
)
def router_callback_easypay() -> PlainTextResponse:
"""
Handle EasyPay payment callback and return payment status.
Returns:
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
"""
http_exceptions.raise_not_implemented()
# return PlainTextResponse("success", status_code=200)
@pay_router.get(
path='/custom/{order_no}/{id}',
summary='自定义支付回调',
description='Handle custom payment callback and return payment status.',
)
def router_callback_custom(order_no: str, id: str) -> ResponseBase:
"""
Handle custom payment callback and return payment status.
Args:
order_no (str): The order number for the payment.
id (str): The ID associated with the payment.
Returns:
ResponseBase: A model containing the response data for the custom payment callback.
"""
http_exceptions.raise_not_implemented()
@upload_router.post( @upload_router.post(
path='/remote/{session_id}/{key}', path='/remote/{session_id}/{key}',

View File

@@ -0,0 +1,100 @@
"""
文件分类筛选端点
按文件类型分类(图片/视频/音频/文档)查询用户的所有文件,
跨目录搜索,支持分页。扩展名映射从数据库 Setting 表读取。
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
FileCategory,
ListResponse,
Object,
ObjectResponse,
ObjectType,
Setting,
SettingsType,
User,
)
category_router = APIRouter(
prefix="/category",
tags=["category"],
)
@category_router.get(
path="/{category}",
summary="按分类获取文件列表",
)
async def router_category_list(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
category: FileCategory,
table_view: TableViewRequestDep,
) -> ListResponse[ObjectResponse]:
"""
按文件类型分类查询用户的所有文件
跨所有目录搜索,返回分页结果。
扩展名配置从数据库 Setting 表读取type=file_category
认证:
- JWT token in Authorization header
路径参数:
- category: 文件分类image / video / audio / document
查询参数:
- offset: 分页偏移量默认0
- limit: 每页数量默认20最大100
- desc: 是否降序默认true
- order: 排序字段created_at / updated_at
响应:
- ListResponse[ObjectResponse]: 分页文件列表
错误处理:
- HTTPException 422: category 参数无效
- HTTPException 404: 该分类未配置扩展名
"""
# 从数据库读取该分类的扩展名配置
setting = await Setting.get(
session,
(Setting.type == SettingsType.FILE_CATEGORY) & (Setting.name == category.value),
)
if not setting or not setting.value:
raise HTTPException(status_code=404, detail=f"分类 {category.value} 未配置扩展名")
extensions = [ext.strip() for ext in setting.value.split(",") if ext.strip()]
if not extensions:
raise HTTPException(status_code=404, detail=f"分类 {category.value} 扩展名列表为空")
result = await Object.get_by_category(
session,
user.id,
extensions,
table_view=table_view,
)
items = [
ObjectResponse(
id=obj.id,
name=obj.name,
type=ObjectType.FILE,
size=obj.size,
mime_type=obj.mime_type,
thumb=False,
created_at=obj.created_at,
updated_at=obj.updated_at,
source_enabled=False,
)
for obj in result.items
]
return ListResponse(count=result.count, items=items)

View File

@@ -57,7 +57,7 @@ async def _get_directory_response(
policy_response = PolicyResponse( policy_response = PolicyResponse(
id=policy.id, id=policy.id,
name=policy.name, name=policy.name,
type=policy.type.value, type=policy.type,
max_size=policy.max_size, max_size=policy.max_size,
) )
@@ -189,6 +189,14 @@ async def router_directory_create(
raise HTTPException(status_code=409, detail="同名文件或目录已存在") raise HTTPException(status_code=409, detail="同名文件或目录已存在")
policy_id = request.policy_id if request.policy_id else parent.policy_id policy_id = request.policy_id if request.policy_id else parent.policy_id
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
parent_id = parent.id # 在 save 前保存 parent_id = parent.id # 在 save 前保存
new_folder = Object( new_folder = Object(
@@ -198,4 +206,4 @@ async def router_directory_create(
parent_id=parent_id, parent_id=parent_id,
policy_id=policy_id, policy_id=policy_id,
) )
await new_folder.save(session) new_folder = await new_folder.save(session)

View File

@@ -13,9 +13,11 @@ from datetime import datetime, timedelta
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
import orjson
import whatthepatch import whatthepatch
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, RedirectResponse from fastapi.responses import FileResponse, RedirectResponse
from starlette.responses import Response
from loguru import logger as l from loguru import logger as l
from sqlmodel_ext import SQLModelBase from sqlmodel_ext import SQLModelBase
from whatthepatch.exceptions import HunkApplyException from whatthepatch.exceptions import HunkApplyException
@@ -44,7 +46,9 @@ from sqlmodels import (
User, User,
WopiSessionResponse, WopiSessionResponse,
) )
from service.storage import LocalStorageService, adjust_user_storage import orjson
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
from utils.JWT.wopi_token import create_wopi_token from utils.JWT.wopi_token import create_wopi_token
from utils import http_exceptions from utils import http_exceptions
@@ -180,9 +184,14 @@ async def create_upload_session(
# 确定存储策略 # 确定存储策略
policy_id = request.policy_id or parent.policy_id policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在") # 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
# 验证文件大小限制 # 验证文件大小限制
_check_policy_size_limit(policy, request.file_size) _check_policy_size_limit(policy, request.file_size)
@@ -210,6 +219,7 @@ async def create_upload_session(
# 生成存储路径 # 生成存储路径
storage_path: str | None = None storage_path: str | None = None
s3_upload_id: str | None = None
if policy.type == PolicyType.LOCAL: if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy) storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path( dir_path, storage_name, full_path = await storage_service.generate_file_path(
@@ -217,8 +227,25 @@ async def create_upload_session(
original_filename=request.file_name, original_filename=request.file_name,
) )
storage_path = full_path storage_path = full_path
else: elif policy.type == PolicyType.S3:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user.id,
original_filename=request.file_name,
)
# 多分片时创建 multipart upload
if total_chunks > 1:
s3_upload_id = await s3_service.create_multipart_upload(
storage_path, content_type='application/octet-stream',
)
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
if request.file_size > 0:
await adjust_user_storage(session, user.id, request.file_size, commit=False)
# 创建上传会话 # 创建上传会话
upload_session = UploadSession( upload_session = UploadSession(
@@ -227,6 +254,7 @@ async def create_upload_session(
chunk_size=chunk_size, chunk_size=chunk_size,
total_chunks=total_chunks, total_chunks=total_chunks,
storage_path=storage_path, storage_path=storage_path,
s3_upload_id=s3_upload_id,
expires_at=datetime.now() + timedelta(hours=24), expires_at=datetime.now() + timedelta(hours=24),
owner_id=user.id, owner_id=user.id,
parent_id=request.parent_id, parent_id=request.parent_id,
@@ -302,8 +330,38 @@ async def upload_chunk(
content, content,
offset, offset,
) )
else: elif policy.type == PolicyType.S3:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") if not upload_session.storage_path:
raise HTTPException(status_code=500, detail="存储路径丢失")
s3_service = await S3StorageService.from_policy(policy)
if upload_session.total_chunks == 1:
# 单分片:直接 PUT 上传
await s3_service.upload_file(upload_session.storage_path, content)
else:
# 多分片UploadPart
if not upload_session.s3_upload_id:
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
etag = await s3_service.upload_part(
upload_session.storage_path,
upload_session.s3_upload_id,
chunk_index + 1, # S3 part number 从 1 开始
content,
)
# 追加 ETag 到 s3_part_etags
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
etags.append([chunk_index + 1, etag])
upload_session.s3_part_etags = orjson.dumps(etags).decode()
# 在 savecommit前缓存后续需要的属性commit 后 ORM 对象会过期)
policy_type = policy.type
s3_upload_id = upload_session.s3_upload_id
s3_part_etags = upload_session.s3_part_etags
s3_service_for_complete: S3StorageService | None = None
if policy_type == PolicyType.S3:
s3_service_for_complete = await S3StorageService.from_policy(policy)
# 更新会话进度 # 更新会话进度
upload_session.uploaded_chunks += 1 upload_session.uploaded_chunks += 1
@@ -319,12 +377,26 @@ async def upload_chunk(
if is_complete: if is_complete:
# 保存 upload_session 属性commit 后会过期) # 保存 upload_session 属性commit 后会过期)
file_name = upload_session.file_name file_name = upload_session.file_name
file_size = upload_session.file_size
uploaded_size = upload_session.uploaded_size uploaded_size = upload_session.uploaded_size
storage_path = upload_session.storage_path storage_path = upload_session.storage_path
upload_session_id = upload_session.id upload_session_id = upload_session.id
parent_id = upload_session.parent_id parent_id = upload_session.parent_id
policy_id = upload_session.policy_id policy_id = upload_session.policy_id
# S3 多分片上传完成:合并分片
if (
policy_type == PolicyType.S3
and s3_upload_id
and s3_part_etags
and s3_service_for_complete
):
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
parts = [(int(pn), str(et)) for pn, et in parts_data]
await s3_service_for_complete.complete_multipart_upload(
storage_path, s3_upload_id, parts,
)
# 创建 PhysicalFile 记录 # 创建 PhysicalFile 记录
physical_file = PhysicalFile( physical_file = PhysicalFile(
storage_path=storage_path, storage_path=storage_path,
@@ -355,9 +427,10 @@ async def upload_chunk(
commit=False commit=False
) )
# 更新用户存储配额 # 调整存储配额差值(创建会话时已预扣 file_size这里只补差
if uploaded_size > 0: size_diff = uploaded_size - file_size
await adjust_user_storage(session, user_id, uploaded_size, commit=False) if size_diff != 0:
await adjust_user_storage(session, user_id, size_diff, commit=False)
# 统一提交所有更改 # 统一提交所有更改
await session.commit() await session.commit()
@@ -390,9 +463,25 @@ async def delete_upload_session(
# 删除临时文件 # 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id) policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: if policy and upload_session.storage_path:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(upload_session.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 如果有分片上传,先取消
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
else:
# 单分片上传已完成的话,删除已上传的文件
if upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
# 删除会话记录 # 删除会话记录
await UploadSession.delete(session, upload_session) await UploadSession.delete(session, upload_session)
@@ -422,9 +511,22 @@ async def clear_upload_sessions(
for upload_session in sessions: for upload_session in sessions:
# 删除临时文件 # 删除临时文件
policy = await Policy.get(session, Policy.id == upload_session.policy_id) policy = await Policy.get(session, Policy.id == upload_session.policy_id)
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path: if policy and upload_session.storage_path:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(upload_session.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(upload_session.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
if upload_session.s3_upload_id:
await s3_service.abort_multipart_upload(
upload_session.storage_path, upload_session.s3_upload_id,
)
elif upload_session.uploaded_chunks > 0:
await s3_service.delete_file(upload_session.storage_path)
# 释放预扣的存储空间
if upload_session.file_size > 0:
await adjust_user_storage(session, user.id, -upload_session.file_size)
await UploadSession.delete(session, upload_session) await UploadSession.delete(session, upload_session)
deleted_count += 1 deleted_count += 1
@@ -486,11 +588,12 @@ async def create_download_token_endpoint(
path='/{token}', path='/{token}',
summary='下载文件', summary='下载文件',
description='使用下载令牌下载文件,令牌在有效期内可重复使用。', description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
response_model=None,
) )
async def download_file( async def download_file(
session: SessionDep, session: SessionDep,
token: str, token: str,
) -> FileResponse: ) -> Response:
""" """
下载文件端点 下载文件端点
@@ -540,8 +643,15 @@ async def download_file(
filename=file_obj.name, filename=file_obj.name,
media_type="application/octet-stream", media_type="application/octet-stream",
) )
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 302 重定向到预签名 URL
presigned_url = s3_service.generate_presigned_url(
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else: else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") raise HTTPException(status_code=500, detail="不支持的存储类型")
# ==================== 包含子路由 ==================== # ==================== 包含子路由 ====================
@@ -599,9 +709,7 @@ async def create_empty_file(
# 确定存储策略 # 确定存储策略
policy_id = request.policy_id or parent.policy_id policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 生成存储路径并创建空文件 # 生成存储路径并创建空文件
storage_path: str | None = None storage_path: str | None = None
@@ -613,8 +721,13 @@ async def create_empty_file(
) )
await storage_service.create_empty_file(full_path) await storage_service.create_empty_file(full_path)
storage_path = full_path storage_path = full_path
else: elif policy.type == PolicyType.S3:
raise HTTPException(status_code=501, detail="S3 存储暂未实现") s3_service = await S3StorageService.from_policy(policy)
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
user_id=user_id,
original_filename=request.name,
)
await s3_service.upload_file(storage_path, b'')
# 创建 PhysicalFile 记录 # 创建 PhysicalFile 记录
physical_file = PhysicalFile( physical_file = PhysicalFile(
@@ -695,6 +808,7 @@ async def create_wopi_session(
) )
wopi_app: FileApp | None = None wopi_app: FileApp | None = None
matched_ext_record: FileAppExtension | None = None
for ext_record in ext_records: for ext_record in ext_records:
app = ext_record.app app = ext_record.app
if app.type == FileAppType.WOPI and app.is_enabled: if app.type == FileAppType.WOPI and app.is_enabled:
@@ -710,13 +824,20 @@ async def create_wopi_session(
if not result.first(): if not result.first():
continue continue
wopi_app = app wopi_app = app
matched_ext_record = ext_record
break break
if not wopi_app: if not wopi_app:
http_exceptions.raise_not_found("无可用的 WOPI 查看器") http_exceptions.raise_not_found("无可用的 WOPI 查看器")
if not wopi_app.wopi_editor_url_template: # 优先使用 per-extension URLDiscovery 自动填充),回退到全局模板
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板") editor_url_template: str | None = None
if matched_ext_record and matched_ext_record.wopi_action_url:
editor_url_template = matched_ext_record.wopi_action_url
if not editor_url_template:
editor_url_template = wopi_app.wopi_editor_url_template
if not editor_url_template:
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板,请先执行 Discovery 或手动配置")
# 获取站点 URL # 获取站点 URL
site_url_setting: Setting | None = await Setting.get( site_url_setting: Setting | None = await Setting.get(
@@ -732,12 +853,8 @@ async def create_wopi_session(
# 构建 wopi_src # 构建 wopi_src
wopi_src = f"{site_url}/wopi/files/{file_id}" wopi_src = f"{site_url}/wopi/files/{file_id}"
# 构建 editor URL # 构建 editor URL(只替换 wopi_srctoken 通过 POST 表单传递)
editor_url = wopi_app.wopi_editor_url_template.format( editor_url = editor_url_template.format(wopi_src=wopi_src)
wopi_src=wopi_src,
access_token=token,
access_token_ttl=access_token_ttl,
)
return WopiSessionResponse( return WopiSessionResponse(
wopi_src=wopi_src, wopi_src=wopi_src,
@@ -798,12 +915,13 @@ async def _validate_source_link(
path='/get/{file_id}/{name}', path='/get/{file_id}/{name}',
summary='文件外链(直接输出文件数据)', summary='文件外链(直接输出文件数据)',
description='通过外链直接获取文件内容,公开访问无需认证。', description='通过外链直接获取文件内容,公开访问无需认证。',
response_model=None,
) )
async def file_get( async def file_get(
session: SessionDep, session: SessionDep,
file_id: UUID, file_id: UUID,
name: str, name: str,
) -> FileResponse: ) -> Response:
""" """
文件外链端点(直接输出) 文件外链端点(直接输出)
@@ -815,25 +933,32 @@ async def file_get(
""" """
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存物理路径save 后对象属性会过期) # 缓存物理路径save 后对象属性会过期)
file_path = physical_file.storage_path file_path = physical_file.storage_path
# 递增下载次数 # 递增下载次数
link.downloads += 1 link.downloads += 1
await link.save(session) link = await link.save(session)
return FileResponse( if policy.type == PolicyType.LOCAL:
path=file_path, storage_service = LocalStorageService(policy)
filename=name, if not await storage_service.file_exists(file_path):
media_type="application/octet-stream", http_exceptions.raise_not_found("物理文件不存在")
)
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
# S3 外链直接输出302 重定向到预签名 URL
s3_service = await S3StorageService.from_policy(policy)
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
@router.get( @router.get(
@@ -846,7 +971,7 @@ async def file_source_redirect(
session: SessionDep, session: SessionDep,
file_id: UUID, file_id: UUID,
name: str, name: str,
) -> FileResponse | RedirectResponse: ) -> Response:
""" """
文件外链端点(重定向/直接输出) 文件外链端点(重定向/直接输出)
@@ -860,13 +985,6 @@ async def file_source_redirect(
""" """
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id) file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
if policy.type != PolicyType.LOCAL:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
# 缓存所有需要的值save 后对象属性会过期) # 缓存所有需要的值save 后对象属性会过期)
file_path = physical_file.storage_path file_path = physical_file.storage_path
is_private = policy.is_private is_private = policy.is_private
@@ -874,20 +992,38 @@ async def file_source_redirect(
# 递增下载次数 # 递增下载次数
link.downloads += 1 link.downloads += 1
await link.save(session) link = await link.save(session)
# 公有存储302 重定向到 base_url if policy.type == PolicyType.LOCAL:
if not is_private and base_url: storage_service = LocalStorageService(policy)
relative_path = storage_service.get_relative_path(file_path) if not await storage_service.file_exists(file_path):
redirect_url = f"{base_url}/{relative_path}" http_exceptions.raise_not_found("物理文件不存在")
return RedirectResponse(url=redirect_url, status_code=302)
# 有存储或 base_url 为空:通过应用代理文件 # 有存储302 重定向到 base_url
return FileResponse( if not is_private and base_url:
path=file_path, relative_path = storage_service.get_relative_path(file_path)
filename=name, redirect_url = f"{base_url}/{relative_path}"
media_type="application/octet-stream", return RedirectResponse(url=redirect_url, status_code=302)
)
# 私有存储或 base_url 为空:通过应用代理文件
return FileResponse(
path=file_path,
filename=name,
media_type="application/octet-stream",
)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
# 公有存储且有 base_url直接重定向到公开 URL
if not is_private and base_url:
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
return RedirectResponse(url=redirect_url, status_code=302)
# 私有存储:生成预签名 URL 重定向
presigned_url = s3_service.generate_presigned_url(
file_path, method='GET', expires_in=3600, filename=name,
)
return RedirectResponse(url=presigned_url, status_code=302)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
@router.put( @router.put(
@@ -941,11 +1077,15 @@ async def file_content(
if not policy: if not policy:
http_exceptions.raise_internal_error("存储策略不存在") http_exceptions.raise_internal_error("存储策略不存在")
if policy.type != PolicyType.LOCAL: # 读取文件内容
http_exceptions.raise_not_implemented("S3 存储暂未实现") if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
storage_service = LocalStorageService(policy) raw_bytes = await storage_service.read_file(physical_file.storage_path)
raw_bytes = await storage_service.read_file(physical_file.storage_path) elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(physical_file.storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
try: try:
content = raw_bytes.decode('utf-8') content = raw_bytes.decode('utf-8')
@@ -1011,11 +1151,15 @@ async def patch_file_content(
if not policy: if not policy:
http_exceptions.raise_internal_error("存储策略不存在") http_exceptions.raise_internal_error("存储策略不存在")
if policy.type != PolicyType.LOCAL: # 读取文件内容
http_exceptions.raise_not_implemented("S3 存储暂未实现") if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
storage_service = LocalStorageService(policy) raw_bytes = await storage_service.read_file(storage_path)
raw_bytes = await storage_service.read_file(storage_path) elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
raw_bytes = await s3_service.download_file(storage_path)
else:
http_exceptions.raise_internal_error("不支持的存储类型")
# 解码 + 规范化 # 解码 + 规范化
original_text = raw_bytes.decode('utf-8') original_text = raw_bytes.decode('utf-8')
@@ -1049,7 +1193,10 @@ async def patch_file_content(
_check_policy_size_limit(policy, len(new_bytes)) _check_policy_size_limit(policy, len(new_bytes))
# 写入文件 # 写入文件
await storage_service.write_file(storage_path, new_bytes) if policy.type == PolicyType.LOCAL:
await storage_service.write_file(storage_path, new_bytes)
elif policy.type == PolicyType.S3:
await s3_service.upload_file(storage_path, new_bytes)
# 更新数据库 # 更新数据库
owner_id = file_obj.owner_id owner_id = file_obj.owner_id

View File

@@ -8,13 +8,14 @@
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from loguru import logger as l from loguru import logger as l
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from sqlmodels import ( from sqlmodels import (
CreateFileRequest, CreateFileRequest,
Group,
Object, Object,
ObjectCopyRequest, ObjectCopyRequest,
ObjectDeleteRequest, ObjectDeleteRequest,
@@ -22,24 +23,42 @@ from sqlmodels import (
ObjectPropertyDetailResponse, ObjectPropertyDetailResponse,
ObjectPropertyResponse, ObjectPropertyResponse,
ObjectRenameRequest, ObjectRenameRequest,
ObjectSwitchPolicyRequest,
ObjectType, ObjectType,
PhysicalFile, PhysicalFile,
Policy, Policy,
PolicyType, PolicyType,
Task,
TaskProps,
TaskStatus,
TaskSummaryBase,
TaskType,
User, User,
# 元数据相关
ObjectMetadata,
MetadataResponse,
MetadataPatchRequest,
INTERNAL_NAMESPACES,
USER_WRITABLE_NAMESPACES,
) )
from service.storage import ( from service.storage import (
LocalStorageService, LocalStorageService,
adjust_user_storage, adjust_user_storage,
copy_object_recursive, copy_object_recursive,
migrate_file_with_task,
migrate_directory_files,
) )
from service.storage.object import soft_delete_objects from service.storage.object import soft_delete_objects
from sqlmodels.database_connection import DatabaseManager
from utils import http_exceptions from utils import http_exceptions
from .custom_property import router as custom_property_router
object_router = APIRouter( object_router = APIRouter(
prefix="/object", prefix="/object",
tags=["object"] tags=["object"]
) )
object_router.include_router(custom_property_router)
@object_router.post( @object_router.post(
path='/', path='/',
@@ -93,9 +112,7 @@ async def router_object_create(
# 确定存储策略 # 确定存储策略
policy_id = request.policy_id or parent.policy_id policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
parent_id = parent.id parent_id = parent.id
@@ -130,7 +147,7 @@ async def router_object_create(
owner_id=user_id, owner_id=user_id,
policy_id=policy_id, policy_id=policy_id,
) )
await file_object.save(session) file_object = await file_object.save(session)
l.info(f"创建空白文件: {request.name}") l.info(f"创建空白文件: {request.name}")
@@ -455,7 +472,7 @@ async def router_object_rename(
# 更新名称 # 更新名称
obj.name = new_name obj.name = new_name
await obj.save(session) obj = await obj.save(session)
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}") l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
@@ -493,6 +510,7 @@ async def router_object_property(
name=obj.name, name=obj.name,
type=obj.type, type=obj.type,
size=obj.size, size=obj.size,
mime_type=obj.mime_type,
created_at=obj.created_at, created_at=obj.created_at,
updated_at=obj.updated_at, updated_at=obj.updated_at,
parent_id=obj.parent_id, parent_id=obj.parent_id,
@@ -520,7 +538,7 @@ async def router_object_property_detail(
obj = await Object.get( obj = await Object.get(
session, session,
(Object.id == id) & (Object.deleted_at == None), (Object.id == id) & (Object.deleted_at == None),
load=Object.file_metadata, load=Object.metadata_entries,
) )
if not obj: if not obj:
raise HTTPException(status_code=404, detail="对象不存在") raise HTTPException(status_code=404, detail="对象不存在")
@@ -543,35 +561,301 @@ async def router_object_property_detail(
total_views = sum(s.views for s in shares) total_views = sum(s.views for s in shares)
total_downloads = sum(s.downloads for s in shares) total_downloads = sum(s.downloads for s in shares)
# 获取物理文件引用计数 # 获取物理文件信息(引用计数、校验和)
reference_count = 1 reference_count = 1
checksum_md5: str | None = None
checksum_sha256: str | None = None
if obj.physical_file_id: if obj.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id) physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
if physical_file: if physical_file:
reference_count = physical_file.reference_count reference_count = physical_file.reference_count
checksum_md5 = physical_file.checksum_md5
checksum_sha256 = physical_file.checksum_sha256
# 构建响应 # 构建元数据字典(排除内部命名空间)
response = ObjectPropertyDetailResponse( metadata: dict[str, str] = {}
for entry in obj.metadata_entries:
ns = entry.name.split(":")[0] if ":" in entry.name else ""
if ns not in INTERNAL_NAMESPACES:
metadata[entry.name] = entry.value
return ObjectPropertyDetailResponse(
id=obj.id, id=obj.id,
name=obj.name, name=obj.name,
type=obj.type, type=obj.type,
size=obj.size, size=obj.size,
mime_type=obj.mime_type,
created_at=obj.created_at, created_at=obj.created_at,
updated_at=obj.updated_at, updated_at=obj.updated_at,
parent_id=obj.parent_id, parent_id=obj.parent_id,
checksum_md5=checksum_md5,
checksum_sha256=checksum_sha256,
policy_name=policy_name, policy_name=policy_name,
share_count=share_count, share_count=share_count,
total_views=total_views, total_views=total_views,
total_downloads=total_downloads, total_downloads=total_downloads,
reference_count=reference_count, reference_count=reference_count,
metadatas=metadata,
) )
# 添加文件元数据
if obj.file_metadata:
response.mime_type = obj.file_metadata.mime_type
response.width = obj.file_metadata.width
response.height = obj.file_metadata.height
response.duration = obj.file_metadata.duration
response.checksum_md5 = obj.file_metadata.checksum_md5
return response @object_router.patch(
path='/{object_id}/policy',
summary='切换对象存储策略',
)
async def router_object_switch_policy(
session: SessionDep,
background_tasks: BackgroundTasks,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: ObjectSwitchPolicyRequest,
) -> TaskSummaryBase:
"""
切换对象的存储策略
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
目录:更新目录 policy_id新文件使用新策略
若 is_migrate_existing=True额外创建后台任务迁移所有已有文件。
认证JWT Bearer Token
错误处理:
- 404: 对象不存在
- 403: 无权操作此对象 / 用户组无权使用目标策略
- 400: 目标策略与当前相同 / 不能对根目录操作
"""
user_id = user.id
# 查找对象
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None)
)
if not obj:
http_exceptions.raise_not_found("对象不存在")
if obj.owner_id != user_id:
http_exceptions.raise_forbidden("无权操作此对象")
if obj.is_banned:
http_exceptions.raise_banned()
# 根目录不能直接切换策略(应通过子对象或子目录操作)
if obj.parent_id is None:
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
# 校验目标策略存在
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
if not dest_policy:
http_exceptions.raise_not_found("目标存储策略不存在")
# 校验用户组权限
group: Group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
allowed_ids = {p.id for p in group.policies}
if request.policy_id not in allowed_ids:
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
# 不能切换到相同策略
if obj.policy_id == request.policy_id:
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
# 保存必要的属性,避免 save 后对象过期
src_policy_id = obj.policy_id
obj_id = obj.id
obj_is_file = obj.type == ObjectType.FILE
dest_policy_id = request.policy_id
dest_policy_name = dest_policy.name
# 创建任务记录
task = Task(
type=TaskType.POLICY_MIGRATE,
status=TaskStatus.QUEUED,
user_id=user_id,
)
task = await task.save(session)
task_id = task.id
task_props = TaskProps(
task_id=task_id,
source_policy_id=src_policy_id,
dest_policy_id=dest_policy_id,
object_id=obj_id,
)
task_props = await task_props.save(session)
if obj_is_file:
# 文件:后台迁移
async def _run_file_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_obj = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
background_tasks.add_task(_run_file_migration)
else:
# 目录:先更新目录自身的 policy_id
obj = await Object.get(session, Object.id == obj_id)
obj.policy_id = dest_policy_id
obj = await obj.save(session)
if request.is_migrate_existing:
# 后台迁移所有已有文件
async def _run_dir_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_folder = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
background_tasks.add_task(_run_dir_migration)
else:
# 不迁移已有文件,直接完成任务
task = await Task.get(session, Task.id == task_id)
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
# 重新获取 task 以读取最新状态
task = await Task.get(session, Task.id == task_id)
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
return TaskSummaryBase(
id=task.id,
type=task.type,
status=task.status,
progress=task.progress,
error=task.error,
user_id=task.user_id,
created_at=task.created_at,
updated_at=task.updated_at,
)
# ==================== 元数据端点 ====================
@object_router.get(
path='/{object_id}/metadata',
summary='获取对象元数据',
description='获取对象的元数据键值对,可按命名空间过滤。',
)
async def router_get_object_metadata(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
ns: str | None = None,
) -> MetadataResponse:
"""
获取对象元数据端点
认证JWT token 必填
查询参数:
- ns: 逗号分隔的命名空间列表(如 exif,stream不传返回所有非内部命名空间
错误处理:
- 404: 对象不存在
- 403: 无权查看此对象
"""
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None),
load=Object.metadata_entries,
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权查看此对象")
# 解析命名空间过滤
ns_filter: set[str] | None = None
if ns:
ns_filter = {n.strip() for n in ns.split(",") if n.strip()}
# 不允许查看内部命名空间
ns_filter -= INTERNAL_NAMESPACES
# 构建元数据字典
metadata: dict[str, str] = {}
for entry in obj.metadata_entries:
entry_ns = entry.name.split(":")[0] if ":" in entry.name else ""
if entry_ns in INTERNAL_NAMESPACES:
continue
if ns_filter is not None and entry_ns not in ns_filter:
continue
metadata[entry.name] = entry.value
return MetadataResponse(metadatas=metadata)
@object_router.patch(
path='/{object_id}/metadata',
summary='批量更新对象元数据',
description='批量设置或删除对象的元数据条目。仅允许修改 custom: 命名空间。',
status_code=204,
)
async def router_patch_object_metadata(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: MetadataPatchRequest,
) -> None:
"""
批量更新对象元数据端点
请求体中值为 None 的键将被删除,其余键将被设置/更新。
用户只能修改 custom: 命名空间的条目。
认证JWT token 必填
错误处理:
- 400: 尝试修改非 custom: 命名空间的条目
- 404: 对象不存在
- 403: 无权操作此对象
"""
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None),
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此对象")
for patch in request.patches:
# 验证命名空间
patch_ns = patch.key.split(":")[0] if ":" in patch.key else ""
if patch_ns not in USER_WRITABLE_NAMESPACES:
raise HTTPException(
status_code=400,
detail=f"不允许修改命名空间 '{patch_ns}' 的元数据,仅允许 custom: 命名空间",
)
if patch.value is None:
# 删除元数据条目
existing = await ObjectMetadata.get(
session,
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
)
if existing:
await ObjectMetadata.delete(session, instances=existing)
else:
# 设置/更新元数据条目
existing = await ObjectMetadata.get(
session,
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
)
if existing:
existing.value = patch.value
existing = await existing.save(session)
else:
entry = ObjectMetadata(
object_id=object_id,
name=patch.key,
value=patch.value,
is_public=True,
)
entry = await entry.save(session)
l.info(f"用户 {user.id} 更新了对象 {object_id}{len(request.patches)} 条元数据")

View File

@@ -0,0 +1,168 @@
"""
用户自定义属性定义路由
提供自定义属性模板的增删改查功能。
用户可以定义类型化的属性模板(如标签、评分、分类等),
然后通过元数据 PATCH 端点为对象设置属性值。
路由前缀:/custom_property
"""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
CustomPropertyDefinition,
CustomPropertyCreateRequest,
CustomPropertyUpdateRequest,
CustomPropertyResponse,
User,
)
router = APIRouter(
prefix="/custom_property",
tags=["custom_property"],
)
@router.get(
path='',
summary='获取自定义属性定义列表',
description='获取当前用户的所有自定义属性定义,按 sort_order 排序。',
)
async def router_list_custom_properties(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[CustomPropertyResponse]:
"""
获取自定义属性定义列表端点
认证JWT token 必填
返回当前用户定义的所有自定义属性模板。
"""
definitions = await CustomPropertyDefinition.get(
session,
CustomPropertyDefinition.owner_id == user.id,
fetch_mode="all",
)
return [
CustomPropertyResponse(
id=d.id,
name=d.name,
type=d.type,
icon=d.icon,
options=d.options,
default_value=d.default_value,
sort_order=d.sort_order,
)
for d in sorted(definitions, key=lambda x: x.sort_order)
]
@router.post(
path='',
summary='创建自定义属性定义',
description='创建一个新的自定义属性模板。',
status_code=204,
)
async def router_create_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: CustomPropertyCreateRequest,
) -> None:
"""
创建自定义属性定义端点
认证JWT token 必填
错误处理:
- 400: 请求数据无效
- 409: 同名属性已存在
"""
# 检查同名属性
existing = await CustomPropertyDefinition.get(
session,
(CustomPropertyDefinition.owner_id == user.id) &
(CustomPropertyDefinition.name == request.name),
)
if existing:
raise HTTPException(status_code=409, detail="同名自定义属性已存在")
definition = CustomPropertyDefinition(
owner_id=user.id,
name=request.name,
type=request.type,
icon=request.icon,
options=request.options,
default_value=request.default_value,
)
definition = await definition.save(session)
l.info(f"用户 {user.id} 创建了自定义属性: {request.name}")
@router.patch(
path='/{id}',
summary='更新自定义属性定义',
description='更新自定义属性模板的名称、图标、选项等。',
status_code=204,
)
async def router_update_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
request: CustomPropertyUpdateRequest,
) -> None:
"""
更新自定义属性定义端点
认证JWT token 必填
错误处理:
- 404: 属性定义不存在
- 403: 无权操作此属性
"""
definition = await CustomPropertyDefinition.get_exist_one(session, id)
if definition.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此属性")
definition = await definition.update(session, request)
l.info(f"用户 {user.id} 更新了自定义属性: {id}")
@router.delete(
path='/{id}',
summary='删除自定义属性定义',
description='删除自定义属性模板。注意:不会自动清理已使用该属性的元数据条目。',
status_code=204,
)
async def router_delete_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> None:
"""
删除自定义属性定义端点
认证JWT token 必填
错误处理:
- 404: 属性定义不存在
- 403: 无权操作此属性
"""
definition = await CustomPropertyDefinition.get_exist_one(session, id)
if definition.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此属性")
await CustomPropertyDefinition.delete(session, instances=definition)
l.info(f"用户 {user.id} 删除了自定义属性: {id}")

View File

@@ -45,12 +45,7 @@ async def router_share_get(
4. 返回分享详情(含文件树和分享者信息) 4. 返回分享详情(含文件树和分享者信息)
""" """
# 1. 查询分享(预加载 user 和 object # 1. 查询分享(预加载 user 和 object
share = await Share.get( share = await Share.get_exist_one(session, id, load=[Share.user, Share.object])
session, Share.id == id,
load=[Share.user, Share.object],
)
if not share:
http_exceptions.raise_not_found(detail="分享不存在或已被取消")
# 2. 检查过期 # 2. 检查过期
now = datetime.now() now = datetime.now()
@@ -474,16 +469,29 @@ def router_share_update(id: str) -> ResponseBase:
path='/{id}', path='/{id}',
summary='删除分享', summary='删除分享',
description='Delete a share by ID.', description='Delete a share by ID.',
dependencies=[Depends(auth_required)] status_code=204,
) )
def router_share_delete(id: str) -> ResponseBase: async def router_share_delete(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> None:
""" """
Delete a share by ID. 删除分享
Args: 认证:需要 JWT token
id (str): The ID of the share to be deleted.
Returns: 流程:
ResponseBase: A model containing the response data for the deleted share. 1. 通过分享ID查找分享
2. 验证分享属于当前用户
3. 删除分享记录
""" """
http_exceptions.raise_not_implemented() share = await Share.get_exist_one(session, id)
if share.user_id != user.id:
http_exceptions.raise_forbidden(detail="无权删除此分享")
user_id = user.id
share_code = share.code
await Share.delete(session, share)
l.info(f"用户 {user_id} 删除了分享: {share_code}")

View File

@@ -82,7 +82,8 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
(Setting.type == SettingsType.REGISTER) | (Setting.type == SettingsType.REGISTER) |
(Setting.type == SettingsType.CAPTCHA) | (Setting.type == SettingsType.CAPTCHA) |
(Setting.type == SettingsType.AUTH) | (Setting.type == SettingsType.AUTH) |
(Setting.type == SettingsType.OAUTH), (Setting.type == SettingsType.OAUTH) |
(Setting.type == SettingsType.AVATAR),
fetch_mode="all", fetch_mode="all",
) )
@@ -122,6 +123,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
password_required=s.get("auth_password_required") == "1", password_required=s.get("auth_password_required") == "1",
phone_binding_required=s.get("auth_phone_binding_required") == "1", phone_binding_required=s.get("auth_phone_binding_required") == "1",
email_binding_required=s.get("auth_email_binding_required") == "1", email_binding_required=s.get("auth_email_binding_required") == "1",
avatar_max_size=int(s["avatar_size"]),
footer_code=s.get("footer_code"), footer_code=s.get("footer_code"),
tos_url=s.get("tos_url"), tos_url=s.get("tos_url"),
privacy_url=s.get("privacy_url"), privacy_url=s.get("privacy_url"),

View File

@@ -5,6 +5,7 @@ import json
import jwt import jwt
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, RedirectResponse
from itsdangerous import URLSafeTimedSerializer from itsdangerous import URLSafeTimedSerializer
from loguru import logger from loguru import logger
from webauthn import ( from webauthn import (
@@ -233,7 +234,7 @@ async def router_user_register(
group_id=default_group.id, group_id=default_group.id,
) )
new_user_id = new_user.id new_user_id = new_user.id
await new_user.save(session) new_user = await new_user.save(session)
# 7. 创建 AuthIdentity # 7. 创建 AuthIdentity
hashed_password = Password.hash(request.credential) if request.credential else None hashed_password = Password.hash(request.credential) if request.credential else None
@@ -245,13 +246,14 @@ async def router_user_register(
is_verified=False, is_verified=False,
user_id=new_user_id, user_id=new_user_id,
) )
await identity.save(session) identity = await identity.save(session)
# 8. 创建用户根目录 # 8. 创建用户根目录(使用用户组关联的第一个存储策略)
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储") await session.refresh(default_group, ['policies'])
if not default_policy: if not default_group.policies:
logger.error("默认存储策略不存在") logger.error("默认用户组未关联任何存储策略")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
default_policy = default_group.policies[0]
await sqlmodels.Object( await sqlmodels.Object(
name="/", name="/",
@@ -318,7 +320,7 @@ async def router_user_magic_link(
site_url = site_url_setting.value if site_url_setting else "http://localhost" site_url = site_url_setting.value if site_url_setting else "http://localhost"
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token} # TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token}
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)") logger.info(f"Magic Link token 已 {request.email} 生成 (邮件发送待实现)")
@user_router.post( @user_router.post(
@@ -357,20 +359,78 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
@user_router.get( @user_router.get(
path='/avatar/{id}/{size}', path='/avatar/{id}/{size}',
summary='获取用户头像', summary='获取用户头像',
description='Get user avatar by ID and size.', response_model=None,
) )
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase: async def router_user_avatar(
session: SessionDep,
id: UUID,
size: int = 128,
) -> FileResponse | RedirectResponse:
""" """
Get user avatar by ID and size. 获取指定用户指定尺寸的头像(公开端点,无需认证)
Args: 路径参数:
id (str): The user ID. - id: 用户 UUID
size (int): The size of the avatar image. - size: 请求的头像尺寸px默认 128
Returns: 行为:
str: A Base64 encoded string of the user avatar image. - default: 302 重定向到 Gravatar identicon
- gravatar: 302 重定向到 Gravatar使用用户邮箱 MD5
- file: 返回本地 WebP 文件
响应:
- 200: image/webpfile 模式)
- 302: 重定向到外部 URLdefault/gravatar 模式)
- 404: 用户不存在
缓存Cache-Control: public, max-age=3600
""" """
http_exceptions.raise_not_implemented() import aiofiles.os
from service.avatar import (
get_avatar_file_path,
get_avatar_settings,
gravatar_url,
resolve_avatar_size,
)
user = await sqlmodels.User.get(session, sqlmodels.User.id == id)
if not user:
http_exceptions.raise_not_found("用户不存在")
avatar_path, _, size_l, size_m, size_s = await get_avatar_settings(session)
if user.avatar == "file":
size_label = resolve_avatar_size(size, size_l, size_m, size_s)
file_path = get_avatar_file_path(avatar_path, user.id, size_label)
if not await aiofiles.os.path.exists(file_path):
# 文件丢失,降级为 identicon
fallback_url = gravatar_url(str(user.id), size, "https://www.gravatar.com/")
return RedirectResponse(url=fallback_url, status_code=302)
return FileResponse(
path=file_path,
media_type="image/webp",
headers={"Cache-Control": "public, max-age=3600"},
)
elif user.avatar == "gravatar":
gravatar_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AVATAR)
& (sqlmodels.Setting.name == "gravatar_server"),
)
server = gravatar_setting.value if gravatar_setting else "https://www.gravatar.com/"
email = user.email or str(user.id)
url = gravatar_url(email, size, server)
return RedirectResponse(url=url, status_code=302)
else:
# default: identicon
email_or_id = user.email or str(user.id)
url = gravatar_url(email_or_id, size, "https://www.gravatar.com/")
return RedirectResponse(url=url, status_code=302)
##################### #####################
# 需要登录的接口 # 需要登录的接口
@@ -434,9 +494,24 @@ async def router_user_storage(
if not group: if not group:
raise HTTPException(status_code=404, detail="用户组不存在") raise HTTPException(status_code=404, detail="用户组不存在")
# [TODO] 总空间加上用户购买的额外空间 # 查询用户所有未过期容量包的 size 总和
from datetime import datetime
from sqlalchemy import func, select, and_, or_
total: int = group.max_storage now = datetime.now()
stmt = select(func.coalesce(func.sum(sqlmodels.StoragePack.size), 0)).where(
and_(
sqlmodels.StoragePack.user_id == user.id,
or_(
sqlmodels.StoragePack.expired_time.is_(None),
sqlmodels.StoragePack.expired_time > now,
),
)
)
result = await session.exec(stmt)
active_packs_total: int = result.scalar_one()
total: int = group.max_storage + active_packs_total
used: int = user.storage used: int = user.storage
free: int = max(0, total - used) free: int = max(0, total - used)
@@ -578,7 +653,7 @@ async def router_user_authn_finish(
is_verified=True, is_verified=True,
user_id=user.id, user_id=user.id,
) )
await identity.save(session) identity = await identity.save(session)
return authn.to_detail_response() return authn.to_detail_response()

View File

@@ -1,7 +1,7 @@
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
import sqlmodels import sqlmodels
@@ -13,6 +13,7 @@ from sqlmodels import (
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest, AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
ChangePasswordRequest, ChangePasswordRequest,
AuthnDetailResponse, AuthnRenameRequest, AuthnDetailResponse, AuthnRenameRequest,
PolicySummary,
) )
from sqlmodels.color import ThemeColorsBase from sqlmodels.color import ThemeColorsBase
from sqlmodels.user_authn import UserAuthn from sqlmodels.user_authn import UserAuthn
@@ -31,16 +32,25 @@ user_settings_router.include_router(file_viewers_router)
@user_settings_router.get( @user_settings_router.get(
path='/policies', path='/policies',
summary='获取用户可选存储策略', summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
) )
def router_user_settings_policies() -> sqlmodels.ResponseBase: async def router_user_settings_policies(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[PolicySummary]:
""" """
Get user selectable storage policies. 获取当前用户所在组可选的存储策略列表
Returns: 返回用户组关联的所有存储策略的摘要信息。
dict: A dictionary containing available storage policies for the user.
""" """
http_exceptions.raise_not_implemented() group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
return [
PolicySummary(
id=p.id, name=p.name, type=p.type,
server=p.server, max_size=p.max_size, is_private=p.is_private,
)
for p in group.policies
]
@user_settings_router.get( @user_settings_router.get(
@@ -155,34 +165,121 @@ async def router_user_settings(
@user_settings_router.post( @user_settings_router.post(
path='/avatar', path='/avatar',
summary='从文件上传头像', summary='从文件上传头像',
description='Upload user avatar from file.', status_code=204,
dependencies=[Depends(auth_required)],
) )
def router_user_settings_avatar() -> sqlmodels.ResponseBase: async def router_user_settings_avatar(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
file: UploadFile = File(...),
) -> None:
""" """
Upload user avatar from file. 上传头像文件
Returns: 认证JWT token
dict: A dictionary containing the result of the avatar upload. 请求体multipart/form-datafile 字段
流程:
1. 验证文件 MIME 类型JPEG/PNG/GIF/WebP
2. 验证文件大小 <= avatar_size 设置(默认 2MB
3. 调用 Pillow 验证图片有效性并处理(居中裁剪、缩放 L/M/S
4. 保存三种尺寸的 WebP 文件
5. 更新 User.avatar = "file"
错误处理:
- 400: 文件类型不支持 / 图片无法解析
- 413: 文件过大
""" """
http_exceptions.raise_not_implemented() from service.avatar import (
ALLOWED_CONTENT_TYPES,
get_avatar_settings,
process_and_save_avatar,
)
# 验证 MIME 类型
if file.content_type not in ALLOWED_CONTENT_TYPES:
http_exceptions.raise_bad_request(
f"不支持的图片格式,允许: {', '.join(ALLOWED_CONTENT_TYPES)}"
)
# 读取并验证大小
_, max_upload_size, _, _, _ = await get_avatar_settings(session)
raw_bytes = await file.read()
if len(raw_bytes) > max_upload_size:
raise HTTPException(
status_code=413,
detail=f"文件过大,最大允许 {max_upload_size} 字节",
)
# 处理并保存(内部会验证图片有效性,无效抛出 ValueError
try:
await process_and_save_avatar(session, user.id, raw_bytes)
except ValueError as e:
http_exceptions.raise_bad_request(str(e))
# 更新用户头像字段
user.avatar = "file"
user = await user.save(session)
@user_settings_router.put( @user_settings_router.put(
path='/avatar', path='/avatar',
summary='设定为Gravatar头像', summary='设定为 Gravatar 头像',
description='Set user avatar to Gravatar.',
dependencies=[Depends(auth_required)],
status_code=204, status_code=204,
) )
def router_user_settings_avatar_gravatar() -> None: async def router_user_settings_avatar_gravatar(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> None:
""" """
Set user avatar to Gravatar. 将头像切换为 Gravatar
Returns: 认证JWT token
dict: A dictionary containing the result of setting the Gravatar avatar.
流程:
1. 验证用户有邮箱Gravatar 基于邮箱 MD5
2. 如果当前是 FILE 头像,删除本地文件
3. 更新 User.avatar = "gravatar"
错误处理:
- 400: 用户没有邮箱
""" """
http_exceptions.raise_not_implemented() from service.avatar import delete_avatar_files
if not user.email:
http_exceptions.raise_bad_request("Gravatar 需要邮箱,请先绑定邮箱")
if user.avatar == "file":
await delete_avatar_files(session, user.id)
user.avatar = "gravatar"
user = await user.save(session)
@user_settings_router.delete(
path='/avatar',
summary='重置头像为默认',
status_code=204,
)
async def router_user_settings_avatar_delete(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> None:
"""
重置头像为默认
认证JWT token
流程:
1. 如果当前是 FILE 头像,删除本地文件
2. 更新 User.avatar = "default"
"""
from service.avatar import delete_avatar_files
if user.avatar == "file":
await delete_avatar_files(session, user.id)
user.avatar = "default"
user = await user.save(session)
@user_settings_router.patch( @user_settings_router.patch(
@@ -224,7 +321,7 @@ async def router_user_settings_theme(
user.color_error = request.theme_colors.error user.color_error = request.theme_colors.error
user.color_neutral = request.theme_colors.neutral user.color_neutral = request.theme_colors.neutral
await user.save(session) user = await user.save(session)
@user_settings_router.patch( @user_settings_router.patch(
@@ -261,7 +358,7 @@ async def router_user_settings_change_password(
http_exceptions.raise_forbidden("当前密码错误") http_exceptions.raise_forbidden("当前密码错误")
email_identity.credential = Password.hash(request.new_password) email_identity.credential = Password.hash(request.new_password)
await email_identity.save(session) email_identity = await email_identity.save(session)
@user_settings_router.patch( @user_settings_router.patch(
@@ -295,7 +392,7 @@ async def router_user_settings_patch(
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空") http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
setattr(user, option.value, value) setattr(user, option.value, value)
await user.save(session) user = await user.save(session)
@user_settings_router.get( @user_settings_router.get(
@@ -357,7 +454,7 @@ async def router_user_settings_2fa_enable(
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {} extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
extra["two_factor"] = secret extra["two_factor"] = secret
email_identity.extra_data = orjson.dumps(extra).decode('utf-8') email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
await email_identity.save(session) email_identity = await email_identity.save(session)
# ==================== 认证身份管理 ==================== # ==================== 认证身份管理 ====================

View File

@@ -79,9 +79,7 @@ async def set_default_viewer(
if existing: if existing:
existing.app_id = request.app_id existing.app_id = request.app_id
existing = await existing.save(session) existing = await existing.save(session, load=UserFileAppDefault.app)
# 重新加载 app 关系
await session.refresh(existing, attribute_names=["app"])
return existing.to_response() return existing.to_response()
else: else:
new_default = UserFileAppDefault( new_default = UserFileAppDefault(
@@ -89,9 +87,7 @@ async def set_default_viewer(
extension=normalized_ext, extension=normalized_ext,
app_id=request.app_id, app_id=request.app_id,
) )
new_default = await new_default.save(session) new_default = await new_default.save(session, load=UserFileAppDefault.app)
# 重新加载 app 关系
await session.refresh(new_default, attribute_names=["app"])
return new_default.to_response() return new_default.to_response()

View File

@@ -1,106 +0,0 @@
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from sqlmodels import ResponseBase
from utils import http_exceptions
vas_router = APIRouter(
prefix="/vas",
tags=["vas"]
)
@vas_router.get(
path='/pack',
summary='获取容量包及配额信息',
description='Get information about storage packs and quotas.',
dependencies=[Depends(auth_required)]
)
def router_vas_pack() -> ResponseBase:
"""
Get information about storage packs and quotas.
Returns:
ResponseBase: A model containing the response data for storage packs and quotas.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/product',
summary='获取商品信息,同时返回支付信息',
description='Get product information along with payment details.',
dependencies=[Depends(auth_required)]
)
def router_vas_product() -> ResponseBase:
"""
Get product information along with payment details.
Returns:
ResponseBase: A model containing the response data for products and payment information.
"""
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/order',
summary='新建支付订单',
description='Create an order for a product.',
dependencies=[Depends(auth_required)]
)
def router_vas_order() -> ResponseBase:
"""
Create an order for a product.
Returns:
ResponseBase: A model containing the response data for the created order.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/order/{id}',
summary='查询订单状态',
description='Get information about a specific payment order by ID.',
dependencies=[Depends(auth_required)]
)
def router_vas_order_get(id: str) -> ResponseBase:
"""
Get information about a specific payment order by ID.
Args:
id (str): The ID of the order to retrieve information for.
Returns:
ResponseBase: A model containing the response data for the specified order.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/redeem',
summary='获取兑换码信息',
description='Get information about a specific redemption code.',
dependencies=[Depends(auth_required)]
)
def router_vas_redeem(code: str) -> ResponseBase:
"""
Get information about a specific redemption code.
Args:
code (str): The redemption code to retrieve information for.
Returns:
ResponseBase: A model containing the response data for the specified redemption code.
"""
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/redeem',
summary='执行兑换',
description='Redeem a redemption code for a product or service.',
dependencies=[Depends(auth_required)]
)
def router_vas_redeem_post() -> ResponseBase:
"""
Redeem a redemption code for a product or service.
Returns:
ResponseBase: A model containing the response data for the redeemed code.
"""
http_exceptions.raise_not_implemented()

View File

@@ -1,110 +1,207 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from loguru import logger as l
from middleware.auth import auth_required from middleware.auth import auth_required
from sqlmodels import ResponseBase from middleware.dependencies import SessionDep
from sqlmodels import (
Object,
User,
WebDAV,
WebDAVAccountResponse,
WebDAVCreateRequest,
WebDAVUpdateRequest,
)
from service.redis.webdav_auth_cache import WebDAVAuthCache
from utils import http_exceptions from utils import http_exceptions
from utils.password.pwd import Password
# WebDAV 管理路由
webdav_router = APIRouter( webdav_router = APIRouter(
prefix='/webdav', prefix='/webdav',
tags=["webdav"], tags=["webdav"],
) )
def _check_webdav_enabled(user: User) -> None:
"""检查用户组是否启用了 WebDAV 功能"""
if not user.group.web_dav_enabled:
http_exceptions.raise_forbidden("WebDAV 功能未启用")
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
"""将 WebDAV 数据库模型转换为响应 DTO"""
return WebDAVAccountResponse(
id=account.id,
name=account.name,
root=account.root,
readonly=account.readonly,
use_proxy=account.use_proxy,
created_at=str(account.created_at),
updated_at=str(account.updated_at),
)
@webdav_router.get( @webdav_router.get(
path='/accounts', path='/accounts',
summary='获取账号信息', summary='获取账号列表',
description='Get account information for WebDAV.',
dependencies=[Depends(auth_required)],
) )
def router_webdav_accounts() -> ResponseBase: async def list_accounts(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[WebDAVAccountResponse]:
""" """
Get account information for WebDAV. 列出当前用户所有 WebDAV 账户
Returns: 认证JWT Bearer Token
ResponseBase: A model containing the response data for the account information.
""" """
http_exceptions.raise_not_implemented() _check_webdav_enabled(user)
user_id: UUID = user.id
accounts: list[WebDAV] = await WebDAV.get(
session,
WebDAV.user_id == user_id,
fetch_mode="all",
)
return [_to_response(a) for a in accounts]
@webdav_router.post( @webdav_router.post(
path='/accounts', path='/accounts',
summary='建账号', summary='建账号',
description='Create a new WebDAV account.', status_code=201,
dependencies=[Depends(auth_required)],
) )
def router_webdav_create_account() -> ResponseBase: async def create_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: WebDAVCreateRequest,
) -> WebDAVAccountResponse:
""" """
Create a new WebDAV account. 创建 WebDAV 账户
Returns: 认证JWT Bearer Token
ResponseBase: A model containing the response data for the created account.
错误处理:
- 403: WebDAV 功能未启用
- 400: 根目录路径不存在或不是目录
- 409: 账户名已存在
""" """
http_exceptions.raise_not_implemented() _check_webdav_enabled(user)
user_id: UUID = user.id
@webdav_router.delete( # 验证账户名唯一
path='/accounts/{id}', existing = await WebDAV.get(
summary='删除账号', session,
description='Delete a WebDAV account by its ID.', (WebDAV.name == request.name) & (WebDAV.user_id == user_id),
dependencies=[Depends(auth_required)], )
) if existing:
def router_webdav_delete_account(id: str) -> ResponseBase: http_exceptions.raise_conflict("账户名已存在")
"""
Delete a WebDAV account by its ID.
Args: # 验证 root 路径存在且为目录
id (str): The ID of the account to be deleted. root_obj = await Object.get_by_path(session, user_id, request.root)
if not root_obj or not root_obj.is_folder:
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
Returns: # 创建账户
ResponseBase: A model containing the response data for the deletion operation. account = WebDAV(
""" name=request.name,
http_exceptions.raise_not_implemented() password=Password.hash(request.password),
root=request.root,
readonly=request.readonly,
use_proxy=request.use_proxy,
user_id=user_id,
)
account = await account.save(session)
@webdav_router.post( l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
path='/mount', return _to_response(account)
summary='新建目录挂载',
description='Create a new WebDAV mount point.',
dependencies=[Depends(auth_required)],
)
def router_webdav_create_mount() -> ResponseBase:
"""
Create a new WebDAV mount point.
Returns:
ResponseBase: A model containing the response data for the created mount point.
"""
http_exceptions.raise_not_implemented()
@webdav_router.delete(
path='/mount/{id}',
summary='删除目录挂载',
description='Delete a WebDAV mount point by its ID.',
dependencies=[Depends(auth_required)],
)
def router_webdav_delete_mount(id: str) -> ResponseBase:
"""
Delete a WebDAV mount point by its ID.
Args:
id (str): The ID of the mount point to be deleted.
Returns:
ResponseBase: A model containing the response data for the deletion operation.
"""
http_exceptions.raise_not_implemented()
@webdav_router.patch( @webdav_router.patch(
path='accounts/{id}', path='/accounts/{account_id}',
summary='更新账号信息', summary='更新账号',
description='Update WebDAV account information by ID.',
dependencies=[Depends(auth_required)],
) )
def router_webdav_update_account(id: str) -> ResponseBase: async def update_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
account_id: int,
request: WebDAVUpdateRequest,
) -> WebDAVAccountResponse:
""" """
Update WebDAV account information by ID. 更新 WebDAV 账户
Args: 认证JWT Bearer Token
id (str): The ID of the account to be updated.
Returns: 错误处理:
ResponseBase: A model containing the response data for the updated account. - 403: WebDAV 功能未启用
- 404: 账户不存在
- 400: 根目录路径不存在或不是目录
""" """
http_exceptions.raise_not_implemented() _check_webdav_enabled(user)
user_id: UUID = user.id
account = await WebDAV.get(
session,
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
)
if not account:
http_exceptions.raise_not_found("WebDAV 账户不存在")
# 验证 root 路径
if request.root is not None:
root_obj = await Object.get_by_path(session, user_id, request.root)
if not root_obj or not root_obj.is_folder:
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
# 密码哈希后原地替换update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
is_password_changed = request.password is not None
if is_password_changed:
request.password = Password.hash(request.password)
account = await account.update(session, request)
# 密码变更时清除认证缓存
if is_password_changed:
await WebDAVAuthCache.invalidate_account(user_id, account.name)
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
return _to_response(account)
@webdav_router.delete(
path='/accounts/{account_id}',
summary='删除账号',
status_code=204,
)
async def delete_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
account_id: int,
) -> None:
"""
删除 WebDAV 账户
认证JWT Bearer Token
错误处理:
- 403: WebDAV 功能未启用
- 404: 账户不存在
"""
_check_webdav_enabled(user)
user_id: UUID = user.id
account = await WebDAV.get(
session,
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
)
if not account:
http_exceptions.raise_not_found("WebDAV 账户不存在")
account_name = account.name
await WebDAV.delete(session, account)
# 清除认证缓存
await WebDAVAuthCache.invalidate_account(user_id, account_name)
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")

View File

@@ -1 +0,0 @@
# WebDAV 操作路由

35
routers/dav/__init__.py Normal file
View File

@@ -0,0 +1,35 @@
"""
WebDAV 协议入口
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
"""
from a2wsgi import WSGIMiddleware
from wsgidav.wsgidav_app import WsgiDAVApp
from .domain_controller import DiskNextDomainController
from .provider import DiskNextDAVProvider
_wsgidav_config: dict[str, object] = {
"provider_mapping": {
"/": DiskNextDAVProvider(),
},
"http_authenticator": {
"domain_controller": DiskNextDomainController,
"accept_basic": True,
"accept_digest": False,
"default_to_digest": False,
},
"verbose": 1,
# 使用 WsgiDAV 内置的内存锁管理器
"lock_storage": True,
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
"dir_browser": {
"enable": False,
},
}
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
"""ASGI 应用,挂载到 /dav 路径"""

View File

@@ -0,0 +1,148 @@
"""
WebDAV 认证控制器
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
通过 DiskNext 的 WebDAV 账户模型进行认证。
用户名格式: {email}/{webdav_account_name}
"""
import asyncio
from uuid import UUID
from loguru import logger as l
from wsgidav.dc.base_dc import BaseDomainController
from routers.dav.provider import EventLoopRef, _get_session
from service.redis.webdav_auth_cache import WebDAVAuthCache
from sqlmodels.user import User, UserStatus
from sqlmodels.webdav import WebDAV
from utils.password.pwd import Password, PasswordStatus
async def _authenticate(
email: str,
account_name: str,
password: str,
) -> tuple[UUID, int] | None:
"""
异步认证 WebDAV 用户。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 明文密码
:return: (user_id, webdav_id) 或 None
"""
# 1. 查缓存
cached = await WebDAVAuthCache.get(email, account_name, password)
if cached is not None:
return cached
# 2. 缓存未命中,查库验证
async with _get_session() as session:
user = await User.get(session, User.email == email, load=User.group)
if not user:
return None
if user.status != UserStatus.ACTIVE:
return None
if not user.group.web_dav_enabled:
return None
account = await WebDAV.get(
session,
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
)
if not account:
return None
status = Password.verify(account.password, password)
if status == PasswordStatus.INVALID:
return None
user_id: UUID = user.id
webdav_id: int = account.id
# 3. 写入缓存
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
return user_id, webdav_id
class DiskNextDomainController(BaseDomainController):
"""
DiskNext WebDAV 认证控制器
用户名格式: {email}/{webdav_account_name}
密码: WebDAV 账户密码(创建账户时设置)
"""
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
super().__init__(wsgidav_app, config)
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
"""返回 realm 名称"""
return "DiskNext WebDAV"
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
"""所有请求都需要认证"""
return True
def is_share_anonymous(self, path_info: str) -> bool:
"""不支持匿名访问"""
return False
def supports_http_digest_auth(self) -> bool:
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
return False
def basic_auth_user(
self,
realm: str,
user_name: str,
password: str,
environ: dict[str, object],
) -> bool:
"""
HTTP Basic Auth 认证。
用户名格式: {email}/{webdav_account_name}
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
"""
# 解析用户名
if "/" not in user_name:
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
return False
email, account_name = user_name.split("/", 1)
if not email or not account_name:
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
return False
# 在 WSGI 线程中调用异步认证
future = asyncio.run_coroutine_threadsafe(
_authenticate(email, account_name, password),
EventLoopRef.get(),
)
result = future.result()
if result is None:
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
return False
user_id, webdav_id = result
# 将认证信息存入 environ供 Provider 使用
environ["disknext.user_id"] = user_id
environ["disknext.webdav_id"] = webdav_id
environ["disknext.email"] = email
environ["disknext.account_name"] = account_name
return True
def digest_auth_user(
self,
realm: str,
user_name: str,
environ: dict[str, object],
) -> bool:
"""不支持 Digest 认证"""
return False

645
routers/dav/provider.py Normal file
View File

@@ -0,0 +1,645 @@
"""
DiskNext WebDAV 存储 Provider
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
"""
import asyncio
import io
import mimetypes
from pathlib import Path
from typing import ClassVar
from uuid import UUID
from loguru import logger as l
from wsgidav.dav_error import (
DAVError,
HTTP_FORBIDDEN,
HTTP_INSUFFICIENT_STORAGE,
HTTP_NOT_FOUND,
)
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
from service.storage import LocalStorageService, adjust_user_storage
from sqlmodels.database_connection import DatabaseManager
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy
from sqlmodels.user import User
from sqlmodels.webdav import WebDAV
class EventLoopRef:
"""持有主线程事件循环引用,供 WSGI 线程使用"""
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
@classmethod
async def capture(cls) -> None:
"""在 async 上下文中调用,捕获当前事件循环"""
cls._loop = asyncio.get_running_loop()
@classmethod
def get(cls) -> asyncio.AbstractEventLoop:
if cls._loop is None:
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
return cls._loop
def _run_async(coro): # type: ignore[no-untyped-def]
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
return future.result()
def _get_session(): # type: ignore[no-untyped-def]
"""获取数据库会话上下文管理器"""
return DatabaseManager._async_session_factory()
# ==================== 异步辅助函数 ====================
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
"""获取 WebDAV 账户"""
async with _get_session() as session:
return await WebDAV.get(session, WebDAV.id == webdav_id)
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
"""根据路径获取对象"""
async with _get_session() as session:
return await Object.get_by_path(session, user_id, path)
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
"""获取目录子对象"""
async with _get_session() as session:
return await Object.get_children(session, user_id, parent_id)
async def _get_object_by_id(object_id: UUID) -> Object | None:
"""根据ID获取对象"""
async with _get_session() as session:
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
async def _get_user(user_id: UUID) -> User | None:
"""获取用户(含 group 关系)"""
async with _get_session() as session:
return await User.get(session, User.id == user_id, load=User.group)
async def _get_policy(policy_id: UUID) -> Policy | None:
"""获取存储策略"""
async with _get_session() as session:
return await Policy.get(session, Policy.id == policy_id)
async def _create_folder(
name: str,
parent_id: UUID,
owner_id: UUID,
policy_id: UUID,
) -> Object:
"""创建目录对象"""
async with _get_session() as session:
obj = Object(
name=name,
type=ObjectType.FOLDER,
size=0,
parent_id=parent_id,
owner_id=owner_id,
policy_id=policy_id,
)
obj = await obj.save(session)
return obj
async def _create_file(
name: str,
parent_id: UUID,
owner_id: UUID,
policy_id: UUID,
) -> Object:
"""创建空文件对象"""
async with _get_session() as session:
obj = Object(
name=name,
type=ObjectType.FILE,
size=0,
parent_id=parent_id,
owner_id=owner_id,
policy_id=policy_id,
)
obj = await obj.save(session)
return obj
async def _soft_delete_object(object_id: UUID) -> None:
"""软删除对象(移入回收站)"""
from service.storage import soft_delete_objects
async with _get_session() as session:
obj = await Object.get(session, Object.id == object_id)
if obj:
await soft_delete_objects(session, [obj])
async def _finalize_upload(
object_id: UUID,
physical_path: str,
size: int,
owner_id: UUID,
policy_id: UUID,
) -> None:
"""上传完成后更新对象元数据和物理文件记录"""
async with _get_session() as session:
# 获取存储路径(相对路径)
policy = await Policy.get(session, Policy.id == policy_id)
if not policy or not policy.server:
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
base_path = Path(policy.server).resolve()
full_path = Path(physical_path).resolve()
storage_path = str(full_path.relative_to(base_path))
# 创建 PhysicalFile 记录
pf = PhysicalFile(
storage_path=storage_path,
size=size,
policy_id=policy_id,
reference_count=1,
)
pf = await pf.save(session)
# 更新 Object
obj = await Object.get(session, Object.id == object_id)
if obj:
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
obj = await obj.save(session)
# 更新用户存储用量
if size > 0:
await adjust_user_storage(session, owner_id, size)
async def _move_object(
object_id: UUID,
new_parent_id: UUID,
new_name: str,
) -> None:
"""移动/重命名对象"""
async with _get_session() as session:
obj = await Object.get(session, Object.id == object_id)
if obj:
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
obj = await obj.save(session)
async def _copy_object_recursive(
src_id: UUID,
dst_parent_id: UUID,
dst_name: str,
owner_id: UUID,
) -> None:
"""递归复制对象"""
from service.storage import copy_object_recursive
async with _get_session() as session:
src = await Object.get(session, Object.id == src_id)
if not src:
return
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
# ==================== 辅助工具 ====================
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
"""从 environ 中提取认证信息"""
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
return user_id, webdav_id
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
"""
将 DAV 相对路径映射到 DiskNext 绝对路径。
:param account_root: 账户挂载根路径,如 "/""/docs"
:param dav_path: DAV 请求路径,如 "/""/photos/cat.jpg"
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
"""
# 规范化根路径
root = account_root.rstrip("/")
if not root:
root = ""
# 规范化 DAV 路径
if not dav_path or dav_path == "/":
return root + "/" if root else "/"
if not dav_path.startswith("/"):
dav_path = "/" + dav_path
full = root + dav_path
return full if full else "/"
def _check_readonly(environ: dict[str, object]) -> None:
"""检查账户是否只读,只读则抛出 403"""
account = environ.get("disknext.webdav_account")
if account and getattr(account, 'readonly', False):
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
def _check_storage_quota(user: User, additional_bytes: int) -> None:
"""检查存储配额"""
max_storage = user.group.max_storage
if max_storage > 0 and user.storage + additional_bytes > max_storage:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
class QuotaLimitedWriter(io.RawIOBase):
"""带配额限制的写入流包装器"""
def __init__(self, stream: io.BufferedWriter, max_bytes: int) -> None:
self._stream = stream
self._max_bytes = max_bytes
self._bytes_written = 0
def writable(self) -> bool:
return True
def write(self, b: bytes | bytearray) -> int:
if self._bytes_written + len(b) > self._max_bytes:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
written = self._stream.write(b)
self._bytes_written += written
return written
def close(self) -> None:
self._stream.close()
super().close()
@property
def bytes_written(self) -> int:
return self._bytes_written
# ==================== Provider ====================
class DiskNextDAVProvider(DAVProvider):
"""DiskNext WebDAV 存储 Provider"""
def __init__(self) -> None:
super().__init__()
def get_resource_inst(
self,
path: str,
environ: dict[str, object],
) -> 'DiskNextCollection | DiskNextFile | None':
"""
将 WebDAV 路径映射到资源对象。
首次调用时加载 WebDAV 账户信息并缓存到 environ。
"""
user_id, webdav_id = _get_environ_info(environ)
# 首次请求时加载账户信息
if "disknext.webdav_account" not in environ:
account = _run_async(_get_webdav_account(webdav_id))
if not account:
return None
environ["disknext.webdav_account"] = account
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
disknext_path = _resolve_dav_path(account.root, path)
obj = _run_async(_get_object_by_path(user_id, disknext_path))
if not obj:
return None
if obj.is_folder:
return DiskNextCollection(path, environ, obj, user_id, account)
else:
return DiskNextFile(path, environ, obj, user_id, account)
def is_readonly(self) -> bool:
"""只读由账户级别控制,不在 provider 级别限制"""
return False
# ==================== Collection目录 ====================
class DiskNextCollection(DAVCollection):
"""DiskNext 目录资源"""
def __init__(
self,
path: str,
environ: dict[str, object],
obj: Object,
user_id: UUID,
account: WebDAV,
) -> None:
super().__init__(path, environ)
self._obj = obj
self._user_id = user_id
self._account = account
def get_display_info(self) -> dict[str, str]:
return {"type": "Directory"}
def get_member_names(self) -> list[str]:
"""获取子对象名称列表"""
children = _run_async(_get_children(self._user_id, self._obj.id))
return [c.name for c in children]
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
"""获取指定名称的子资源"""
member_path = self.path.rstrip("/") + "/" + name
account_root = self._account.root
disknext_path = _resolve_dav_path(account_root, member_path)
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
if not obj:
return None
if obj.is_folder:
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
else:
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
def get_creation_date(self) -> float | None:
if self._obj.created_at:
return self._obj.created_at.timestamp()
return None
def get_last_modified(self) -> float | None:
if self._obj.updated_at:
return self._obj.updated_at.timestamp()
return None
def create_empty_resource(self, name: str) -> 'DiskNextFile':
"""创建空文件PUT 操作的第一步)"""
_check_readonly(self.environ)
obj = _run_async(_create_file(
name=name,
parent_id=self._obj.id,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
member_path = self.path.rstrip("/") + "/" + name
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
def create_collection(self, name: str) -> 'DiskNextCollection':
"""创建子目录MKCOL"""
_check_readonly(self.environ)
obj = _run_async(_create_folder(
name=name,
parent_id=self._obj.id,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
member_path = self.path.rstrip("/") + "/" + name
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
def delete(self) -> None:
"""软删除目录"""
_check_readonly(self.environ)
_run_async(_soft_delete_object(self._obj.id))
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
"""复制或移动目录"""
_check_readonly(self.environ)
account_root = self._account.root
dest_disknext = _resolve_dav_path(account_root, dest_path)
# 解析目标父路径和新名称
if "/" in dest_disknext.rstrip("/"):
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
new_name = dest_disknext.rsplit("/", 1)[1]
else:
parent_path = "/"
new_name = dest_disknext.lstrip("/")
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
if not dest_parent:
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
if is_move:
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
else:
_run_async(_copy_object_recursive(
self._obj.id, dest_parent.id, new_name, self._user_id,
))
return True
def support_recursive_delete(self) -> bool:
return True
def support_recursive_move(self, dest_path: str) -> bool:
return True
# ==================== NonCollection文件 ====================
class DiskNextFile(DAVNonCollection):
"""DiskNext 文件资源"""
def __init__(
self,
path: str,
environ: dict[str, object],
obj: Object,
user_id: UUID,
account: WebDAV,
) -> None:
super().__init__(path, environ)
self._obj = obj
self._user_id = user_id
self._account = account
self._write_path: str | None = None
self._write_stream: io.BufferedWriter | QuotaLimitedWriter | None = None
def get_content_length(self) -> int | None:
return self._obj.size if self._obj.size else 0
def get_content_type(self) -> str | None:
# 尝试从文件名推断 MIME 类型
mime, _ = mimetypes.guess_type(self._obj.name)
return mime or "application/octet-stream"
def get_creation_date(self) -> float | None:
if self._obj.created_at:
return self._obj.created_at.timestamp()
return None
def get_last_modified(self) -> float | None:
if self._obj.updated_at:
return self._obj.updated_at.timestamp()
return None
def get_display_info(self) -> dict[str, str]:
return {"type": "File"}
def get_content(self) -> io.BufferedReader | None:
"""
返回文件内容的可读流。
WsgiDAV 在线程中运行,可安全使用同步 open()。
"""
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
if not obj_with_file or not obj_with_file.physical_file:
return None
pf = obj_with_file.physical_file
policy = _run_async(_get_policy(obj_with_file.policy_id))
if not policy or not policy.server:
return None
full_path = Path(policy.server).resolve() / pf.storage_path
if not full_path.is_file():
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
return None
return open(full_path, "rb") # noqa: SIM115
def begin_write(
self,
*,
content_type: str | None = None,
) -> io.BufferedWriter | QuotaLimitedWriter:
"""
开始写入文件PUT 操作)。
返回一个可写的文件流WsgiDAV 将向其中写入请求体数据。
当用户有配额限制时,返回 QuotaLimitedWriter 在写入过程中实时检查配额。
"""
_check_readonly(self.environ)
# 检查配额
remaining_quota: int = 0
user = _run_async(_get_user(self._user_id))
if user:
max_storage = user.group.max_storage
if max_storage > 0:
remaining_quota = max_storage - user.storage
if remaining_quota <= 0:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
# Content-Length 预检(如果有的话)
content_length = self.environ.get("CONTENT_LENGTH")
if content_length and int(content_length) > remaining_quota:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
# 获取策略以确定存储路径
policy = _run_async(_get_policy(self._obj.policy_id))
if not policy or not policy.server:
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = _run_async(
storage_service.generate_file_path(
user_id=self._user_id,
original_filename=self._obj.name,
)
)
self._write_path = full_path
raw_stream = open(full_path, "wb") # noqa: SIM115
# 有配额限制时使用包装流,实时检查写入量
if remaining_quota > 0:
self._write_stream = QuotaLimitedWriter(raw_stream, remaining_quota)
else:
self._write_stream = raw_stream
return self._write_stream
def end_write(self, *, with_errors: bool) -> None:
"""写入完成后的收尾工作"""
if self._write_stream:
self._write_stream.close()
self._write_stream = None
if with_errors:
if self._write_path:
file_path = Path(self._write_path)
if file_path.exists():
file_path.unlink()
return
if not self._write_path:
return
# 获取文件大小
file_path = Path(self._write_path)
if not file_path.exists():
return
size = file_path.stat().st_size
# 更新数据库记录
_run_async(_finalize_upload(
object_id=self._obj.id,
physical_path=self._write_path,
size=size,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
def delete(self) -> None:
"""软删除文件"""
_check_readonly(self.environ)
_run_async(_soft_delete_object(self._obj.id))
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
"""复制或移动文件"""
_check_readonly(self.environ)
account_root = self._account.root
dest_disknext = _resolve_dav_path(account_root, dest_path)
# 解析目标父路径和新名称
if "/" in dest_disknext.rstrip("/"):
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
new_name = dest_disknext.rsplit("/", 1)[1]
else:
parent_path = "/"
new_name = dest_disknext.lstrip("/")
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
if not dest_parent:
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
if is_move:
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
else:
_run_async(_copy_object_recursive(
self._obj.id, dest_parent.id, new_name, self._user_id,
))
return True
def support_content_length(self) -> bool:
return True
def get_etag(self) -> str | None:
"""返回 ETag基于ID和更新时间WsgiDAV 会自动加双引号"""
if self._obj.updated_at:
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
return None
def support_etag(self) -> bool:
return True
def support_ranges(self) -> bool:
return True

View File

@@ -0,0 +1,128 @@
"""
WebDAV 认证缓存
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
支持 Redis首选和内存缓存降级两种存储后端。
"""
import hashlib
from typing import ClassVar
from uuid import UUID
from cachetools import TTLCache
from loguru import logger as l
from . import RedisManager
_AUTH_TTL: int = 300
"""认证缓存 TTL5 分钟"""
class WebDAVAuthCache:
"""
WebDAV 认证结果缓存
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
缓存值格式: {user_id}:{webdav_id}
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
"""
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
"""内存缓存降级方案"""
@classmethod
def _build_key(cls, email: str, account_name: str, password: str) -> str:
"""构建缓存键"""
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
@classmethod
async def get(
cls,
email: str,
account_name: str,
password: str,
) -> tuple[UUID, int] | None:
"""
查询缓存中的认证结果。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 用户提供的明文密码
:return: (user_id, webdav_id) 或 None缓存未命中
"""
key = cls._build_key(email, account_name, password)
client = RedisManager.get_client()
if client is not None:
value = await client.get(key)
if value is not None:
raw = value.decode() if isinstance(value, bytes) else value
user_id_str, webdav_id_str = raw.split(":", 1)
return UUID(user_id_str), int(webdav_id_str)
else:
raw = cls._memory_cache.get(key)
if raw is not None:
user_id_str, webdav_id_str = raw.split(":", 1)
return UUID(user_id_str), int(webdav_id_str)
return None
@classmethod
async def set(
cls,
email: str,
account_name: str,
password: str,
user_id: UUID,
webdav_id: int,
) -> None:
"""
写入认证结果到缓存。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 用户提供的明文密码
:param user_id: 用户UUID
:param webdav_id: WebDAV 账户ID
"""
key = cls._build_key(email, account_name, password)
value = f"{user_id}:{webdav_id}"
client = RedisManager.get_client()
if client is not None:
await client.set(key, value, ex=_AUTH_TTL)
else:
cls._memory_cache[key] = value
@classmethod
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
"""
失效指定账户的所有缓存。
由于缓存键包含 password hash无法精确删除
Redis 端使用 pattern scan 删除,内存端清空全部。
:param user_id: 用户UUID
:param account_name: WebDAV 账户名
"""
client = RedisManager.get_client()
if client is not None:
pattern = f"webdav_auth:*/{account_name}:*"
cursor: int = 0
while True:
cursor, keys = await client.scan(cursor, match=pattern, count=100)
if keys:
await client.delete(*keys)
if cursor == 0:
break
else:
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
keys_to_delete = [
k for k in cls._memory_cache
if f"/{account_name}:" in k
]
for k in keys_to_delete:
cls._memory_cache.pop(k, None)
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")

View File

@@ -3,6 +3,7 @@
提供文件存储相关的服务,包括: 提供文件存储相关的服务,包括:
- 本地存储服务 - 本地存储服务
- S3 存储服务
- 命名规则解析器 - 命名规则解析器
- 存储异常定义 - 存储异常定义
""" """
@@ -11,6 +12,8 @@ from .exceptions import (
FileReadError, FileReadError,
FileWriteError, FileWriteError,
InvalidPathError, InvalidPathError,
S3APIError,
S3MultipartUploadError,
StorageException, StorageException,
StorageFileNotFoundError, StorageFileNotFoundError,
UploadSessionExpiredError, UploadSessionExpiredError,
@@ -26,3 +29,5 @@ from .object import (
restore_objects, restore_objects,
soft_delete_objects, soft_delete_objects,
) )
from .migrate import migrate_file_with_task, migrate_directory_files
from .s3_storage import S3StorageService

View File

@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
class InvalidPathError(StorageException): class InvalidPathError(StorageException):
"""无效的路径""" """无效的路径"""
pass pass
class S3APIError(StorageException):
"""S3 API 请求错误"""
pass
class S3MultipartUploadError(S3APIError):
"""S3 分片上传错误"""
pass

View File

@@ -263,15 +263,49 @@ class LocalStorageService:
""" """
删除文件(物理删除) 删除文件(物理删除)
删除文件后会尝试清理因此变空的父目录。
:param path: 完整文件路径 :param path: 完整文件路径
""" """
if await self.file_exists(path): if await self.file_exists(path):
try: try:
await aiofiles.os.remove(path) await aiofiles.os.remove(path)
l.debug(f"已删除文件: {path}") l.debug(f"已删除文件: {path}")
await self._cleanup_empty_parents(path)
except OSError as e: except OSError as e:
l.warning(f"删除文件失败 {path}: {e}") l.warning(f"删除文件失败 {path}: {e}")
async def _cleanup_empty_parents(self, file_path: str) -> None:
"""
从被删文件的父目录开始,向上逐级删除空目录
在以下情况停止:
- 到达存储根目录_base_path
- 遇到非空目录
- 遇到 .trash 目录
- 删除失败(权限、并发等)
:param file_path: 被删文件的完整路径
"""
current = Path(file_path).parent
while current != self._base_path and str(current).startswith(str(self._base_path)):
if current.name == '.trash':
break
try:
entries = await aiofiles.os.listdir(str(current))
if entries:
break
await aiofiles.os.rmdir(str(current))
l.debug(f"已清理空目录: {current}")
current = current.parent
except OSError as e:
l.debug(f"清理空目录失败(忽略): {current}: {e}")
break
async def move_to_trash( async def move_to_trash(
self, self,
source_path: str, source_path: str,
@@ -304,6 +338,7 @@ class LocalStorageService:
try: try:
await aiofiles.os.rename(source_path, str(trash_path)) await aiofiles.os.rename(source_path, str(trash_path))
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}") l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
await self._cleanup_empty_parents(source_path)
return str(trash_path) return str(trash_path)
except OSError as e: except OSError as e:
raise StorageException(f"移动文件到回收站失败: {e}") raise StorageException(f"移动文件到回收站失败: {e}")

291
service/storage/migrate.py Normal file
View File

@@ -0,0 +1,291 @@
"""
存储策略迁移服务
提供跨存储策略的文件迁移功能:
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
"""
from uuid import UUID
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.task import Task, TaskStatus
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
async def _get_storage_service(
policy: Policy,
) -> LocalStorageService | S3StorageService:
"""
根据策略类型创建对应的存储服务实例
:param policy: 存储策略
:return: 存储服务实例
"""
if policy.type == PolicyType.LOCAL:
return LocalStorageService(policy)
elif policy.type == PolicyType.S3:
return await S3StorageService.from_policy(policy)
else:
raise ValueError(f"不支持的存储策略类型: {policy.type}")
async def _read_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> bytes:
"""
从存储服务读取文件内容
:param service: 存储服务实例
:param storage_path: 文件存储路径
:return: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
return await service.read_file(storage_path)
else:
return await service.download_file(storage_path)
async def _write_file_to_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
data: bytes,
) -> None:
"""
将文件内容写入存储服务
:param service: 存储服务实例
:param storage_path: 文件存储路径
:param data: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
await service.write_file(storage_path, data)
else:
await service.upload_file(storage_path, data)
async def _delete_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> None:
"""
从存储服务删除文件
:param service: 存储服务实例
:param storage_path: 文件存储路径
"""
if isinstance(service, LocalStorageService):
await service.delete_file(storage_path)
else:
await service.delete_file(storage_path)
async def migrate_single_file(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
) -> None:
"""
将单个文件对象从当前存储策略迁移到目标策略
流程:
1. 获取源物理文件和存储服务
2. 读取源文件内容
3. 在目标存储中生成新路径并写入
4. 创建新的 PhysicalFile 记录
5. 更新 Object 的 policy_id 和 physical_file_id
6. 旧 PhysicalFile 引用计数 -1如为 0 则删除源物理文件
:param session: 数据库会话
:param obj: 待迁移的文件对象(必须为文件类型)
:param dest_policy: 目标存储策略
"""
if obj.type != ObjectType.FILE:
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
# 获取源策略和物理文件
src_policy: Policy = await obj.awaitable_attrs.policy
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
if not old_physical:
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
return
if src_policy.id == dest_policy.id:
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
return
# 1. 从源存储读取文件
src_service = await _get_storage_service(src_policy)
data = await _read_file_from_storage(src_service, old_physical.storage_path)
# 2. 在目标存储生成新路径并写入
dest_service = await _get_storage_service(dest_policy)
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
user_id=obj.owner_id,
original_filename=obj.name,
)
await _write_file_to_storage(dest_service, new_storage_path, data)
# 3. 创建新的 PhysicalFile
new_physical = PhysicalFile(
storage_path=new_storage_path,
size=old_physical.size,
checksum_md5=old_physical.checksum_md5,
policy_id=dest_policy.id,
reference_count=1,
)
new_physical = await new_physical.save(session)
# 4. 更新 Object
obj.policy_id = dest_policy.id
obj.physical_file_id = new_physical.id
obj = await obj.save(session)
# 5. 旧 PhysicalFile 引用计数 -1
old_physical.decrement_reference()
if old_physical.can_be_deleted:
# 删除源存储中的物理文件
try:
await _delete_file_from_storage(src_service, old_physical.storage_path)
except Exception as e:
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
await PhysicalFile.delete(session, old_physical)
else:
old_physical = await old_physical.save(session)
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name}{dest_policy.name}")
async def migrate_file_with_task(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移单个文件并更新任务状态
:param session: 数据库会话
:param obj: 待迁移的文件对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
await migrate_single_file(session, obj, dest_policy)
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
except Exception as e:
l.error(f"文件迁移任务失败: {obj.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
task = await task.save(session)
async def migrate_directory_files(
session: AsyncSession,
folder: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移目录下所有文件到目标存储策略
递归遍历目录树,将所有文件迁移到目标策略。
子目录的 policy_id 同步更新。
任务进度按文件数比例更新。
:param session: 数据库会话
:param folder: 目录对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
# 收集所有需要迁移的文件
files_to_migrate: list[Object] = []
folders_to_update: list[Object] = []
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
total = len(files_to_migrate)
migrated = 0
errors: list[str] = []
for file_obj in files_to_migrate:
try:
await migrate_single_file(session, file_obj, dest_policy)
migrated += 1
except Exception as e:
error_msg = f"{file_obj.name}: {e}"
l.error(f"迁移文件失败: {error_msg}")
errors.append(error_msg)
# 更新进度
if total > 0:
task.progress = min(99, int(migrated / total * 100))
task = await task.save(session)
# 更新所有子目录的 policy_id
for sub_folder in folders_to_update:
sub_folder.policy_id = dest_policy.id
sub_folder = await sub_folder.save(session)
# 完成任务
if errors:
task.status = TaskStatus.ERROR
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
else:
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
l.info(
f"目录迁移完成: {folder.name} ({folder.id}), "
f"成功 {migrated}/{total}, 错误 {len(errors)}"
)
except Exception as e:
l.error(f"目录迁移任务失败: {folder.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
task = await task.save(session)
async def _collect_objects_recursive(
session: AsyncSession,
folder: Object,
files: list[Object],
folders: list[Object],
) -> None:
"""
递归收集目录下所有文件和子目录
:param session: 数据库会话
:param folder: 当前目录
:param files: 文件列表(输出)
:param folders: 子目录列表(输出)
"""
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
for child in children:
if child.type == ObjectType.FILE:
files.append(child)
elif child.type == ObjectType.FOLDER:
folders.append(child)
await _collect_objects_recursive(session, child, files, folders)

View File

@@ -6,7 +6,8 @@ from sqlalchemy import update as sql_update
from sqlalchemy.sql.functions import func from sqlalchemy.sql.functions import func
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from service.storage import LocalStorageService from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
from sqlmodels import ( from sqlmodels import (
Object, Object,
PhysicalFile, PhysicalFile,
@@ -271,10 +272,14 @@ async def permanently_delete_objects(
if physical_file.can_be_deleted: if physical_file.can_be_deleted:
# 物理删除文件 # 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id) policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL: if policy:
try: try:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(physical_file.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}") l.debug(f"物理文件已删除: {obj_name}")
except Exception as e: except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}") l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
@@ -282,7 +287,7 @@ async def permanently_delete_objects(
await PhysicalFile.delete(session, physical_file, commit=False) await PhysicalFile.delete(session, physical_file, commit=False)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}") l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else: else:
await physical_file.save(session, commit=False) physical_file = await physical_file.save(session, commit=False)
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}") l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
# 更新用户存储配额 # 更新用户存储配额
@@ -374,10 +379,19 @@ async def delete_object_recursive(
if physical_file.can_be_deleted: if physical_file.can_be_deleted:
# 物理删除文件 # 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id) policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL: if policy:
try: try:
storage_service = LocalStorageService(policy) if policy.type == PolicyType.LOCAL:
await storage_service.delete_file(physical_file.storage_path) storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
options = await policy.awaitable_attrs.options
s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}") l.debug(f"物理文件已删除: {obj_name}")
except Exception as e: except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}") l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
@@ -385,7 +399,7 @@ async def delete_object_recursive(
await PhysicalFile.delete(session, physical_file, commit=False) await PhysicalFile.delete(session, physical_file, commit=False)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}") l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else: else:
await physical_file.save(session, commit=False) physical_file = await physical_file.save(session, commit=False)
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}") l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
# 阶段三:更新用户存储配额(与删除在同一事务中) # 阶段三:更新用户存储配额(与删除在同一事务中)
@@ -444,7 +458,7 @@ async def _copy_object_recursive(
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id) physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
if physical_file: if physical_file:
physical_file.increment_reference() physical_file.increment_reference()
await physical_file.save(session) physical_file = await physical_file.save(session)
total_copied_size += src_size total_copied_size += src_size
new_obj = await new_obj.save(session) new_obj = await new_obj.save(session)

View File

@@ -0,0 +1,709 @@
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
移植自 foxline-pro-backend-server 项目的 S3APIClient
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
"""
import hashlib
import hmac
import xml.etree.ElementTree as ET
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import ClassVar, Literal
from urllib.parse import quote, urlencode
from uuid import UUID
import aiohttp
from yarl import URL
from loguru import logger as l
from sqlmodels.policy import Policy
from .exceptions import S3APIError, S3MultipartUploadError
from .naming_rule import NamingContext, NamingRuleParser
def _sign(key: bytes, msg: str) -> bytes:
"""HMAC-SHA256 签名"""
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
class S3StorageService:
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息。
使用示例::
service = S3StorageService(policy, region='us-east-1')
await service.upload_file('path/to/file.txt', b'content')
data = await service.download_file('path/to/file.txt')
"""
_http_session: ClassVar[aiohttp.ClientSession | None] = None
def __init__(
self,
policy: Policy,
region: str = 'us-east-1',
is_path_style: bool = False,
):
"""
:param policy: 存储策略server=endpoint_url, bucket_name, access_key, secret_key
:param region: S3 区域
:param is_path_style: 是否使用路径风格 URL
"""
if not policy.server:
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
if not policy.bucket_name:
raise S3APIError("S3 策略必须指定 bucket_name")
if not policy.access_key:
raise S3APIError("S3 策略必须指定 access_key")
if not policy.secret_key:
raise S3APIError("S3 策略必须指定 secret_key")
self._policy = policy
self._endpoint_url = policy.server.rstrip("/")
self._bucket_name = policy.bucket_name
self._access_key = policy.access_key
self._secret_key = policy.secret_key
self._region = region
self._is_path_style = is_path_style
self._base_url = policy.base_url
# 从 endpoint_url 提取 host
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
# ==================== 工厂方法 ====================
@classmethod
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
"""
根据 Policy 异步创建 S3StorageService自动加载 options
:param policy: 存储策略
:return: S3StorageService 实例
"""
options = await policy.awaitable_attrs.options
region = options.s3_region if options else 'us-east-1'
is_path_style = options.s3_path_style if options else False
return cls(policy, region=region, is_path_style=is_path_style)
# ==================== HTTP Session 管理 ====================
@classmethod
async def initialize_session(cls) -> None:
"""初始化全局 aiohttp ClientSession"""
if cls._http_session is None or cls._http_session.closed:
cls._http_session = aiohttp.ClientSession()
l.info("S3StorageService HTTP session 已初始化")
@classmethod
async def close_session(cls) -> None:
"""关闭全局 aiohttp ClientSession"""
if cls._http_session and not cls._http_session.closed:
await cls._http_session.close()
cls._http_session = None
l.info("S3StorageService HTTP session 已关闭")
@classmethod
def _get_session(cls) -> aiohttp.ClientSession:
"""获取 HTTP session"""
if cls._http_session is None or cls._http_session.closed:
# 懒初始化,以防 initialize_session 未被调用
cls._http_session = aiohttp.ClientSession()
return cls._http_session
# ==================== AWS Signature V4 签名 ====================
def _get_signature_key(self, date_stamp: str) -> bytes:
"""生成 AWS Signature V4 签名密钥"""
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
k_region = _sign(k_date, self._region)
k_service = _sign(k_region, "s3")
return _sign(k_service, "aws4_request")
def _create_authorization_header(
self,
method: str,
uri: str,
query_string: str,
headers: dict[str, str],
payload_hash: str,
amz_date: str,
date_stamp: str,
) -> str:
"""创建 AWS Signature V4 授权头"""
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
canonical_headers = "".join(
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
)
canonical_request = (
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
f"{signed_headers}\n{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
string_to_sign = (
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
return (
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
f"SignedHeaders={signed_headers}, Signature={signature}"
)
def _build_headers(
self,
method: str,
uri: str,
query_string: str = "",
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
payload_hash: str | None = None,
host: str | None = None,
) -> dict[str, str]:
"""
构建包含 AWS V4 签名的完整请求头
:param method: HTTP 方法
:param uri: 请求 URI
:param query_string: 查询字符串
:param payload: 请求体字节(用于计算哈希)
:param content_type: Content-Type
:param extra_headers: 额外请求头
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
:param host: Host 头(默认使用 self._host
"""
now_utc = datetime.now(timezone.utc)
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
date_stamp = now_utc.strftime("%Y%m%d")
if payload_hash is None:
payload_hash = hashlib.sha256(payload).hexdigest()
effective_host = host or self._host
headers: dict[str, str] = {
"Host": effective_host,
"X-Amz-Date": amz_date,
"X-Amz-Content-Sha256": payload_hash,
}
if content_type:
headers["Content-Type"] = content_type
if extra_headers:
headers.update(extra_headers)
authorization = self._create_authorization_header(
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
)
headers["Authorization"] = authorization
return headers
# ==================== 内部请求方法 ====================
def _build_uri(self, key: str | None = None) -> str:
"""
构建请求 URI
按 AWS S3 Signature V4 规范对路径进行 URI 编码S3 仅需一次)。
斜杠作为路径分隔符保留不编码。
"""
if self._is_path_style:
if key:
return f"/{self._bucket_name}/{quote(key, safe='/')}"
return f"/{self._bucket_name}"
else:
if key:
return f"/{quote(key, safe='/')}"
return "/"
def _build_url(self, uri: str, query_string: str = "") -> str:
"""构建完整请求 URL"""
if self._is_path_style:
base = self._endpoint_url
else:
# 虚拟主机风格bucket.endpoint
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
base = f"{protocol}{self._bucket_name}.{self._host}"
url = f"{base}{uri}"
if query_string:
url = f"{url}?{query_string}"
return url
def _get_effective_host(self) -> str:
"""获取实际请求的 Host 头"""
if self._is_path_style:
return self._host
return f"{self._bucket_name}.{self._host}"
async def _request(
self,
method: str,
key: str | None = None,
query_params: dict[str, str] | None = None,
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> aiohttp.ClientResponse:
"""发送签名请求"""
uri = self._build_uri(key)
query_string = urlencode(sorted(query_params.items())) if query_params else ""
effective_host = self._get_effective_host()
headers = self._build_headers(
method, uri, query_string, payload, content_type,
extra_headers, host=effective_host,
)
url = self._build_url(uri, query_string)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=payload if payload else None,
)
return response
except Exception as e:
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
async def _request_streaming(
self,
method: str,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> aiohttp.ClientResponse:
"""
发送流式签名请求(大文件上传)
使用 UNSIGNED-PAYLOAD 作为 payload hash。
"""
uri = self._build_uri(key)
effective_host = self._get_effective_host()
headers = self._build_headers(
method,
uri,
query_string="",
content_type=content_type,
extra_headers={"Content-Length": str(content_length)},
payload_hash="UNSIGNED-PAYLOAD",
host=effective_host,
)
url = self._build_url(uri)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=data_stream,
)
return response
except Exception as e:
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
# ==================== 文件操作 ====================
async def upload_file(
self,
key: str,
data: bytes,
content_type: str = 'application/octet-stream',
) -> None:
"""
上传文件
:param key: S3 对象键
:param data: 文件内容
:param content_type: MIME 类型
"""
async with await self._request(
"PUT", key=key, payload=data, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
async def upload_file_streaming(
self,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> None:
"""
流式上传文件(大文件,避免全部加载到内存)
:param key: S3 对象键
:param data_stream: 异步字节流迭代器
:param content_length: 数据总长度(必须准确)
:param content_type: MIME 类型
"""
async with await self._request_streaming(
"PUT", key=key, data_stream=data_stream,
content_length=content_length, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"流式上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
async def download_file(self, key: str) -> bytes:
"""
下载文件
:param key: S3 对象键
:return: 文件内容
"""
async with await self._request("GET", key=key) as response:
if response.status != 200:
body = await response.text()
raise S3APIError(
f"下载失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
data = await response.read()
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
return data
async def delete_file(self, key: str) -> None:
"""
删除文件
:param key: S3 对象键
"""
async with await self._request("DELETE", key=key) as response:
if response.status in (200, 204):
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
else:
body = await response.text()
raise S3APIError(
f"删除失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
async def file_exists(self, key: str) -> bool:
"""
检查文件是否存在
:param key: S3 对象键
:return: 是否存在
"""
async with await self._request("HEAD", key=key) as response:
if response.status == 200:
return True
elif response.status == 404:
return False
else:
raise S3APIError(
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
async def get_file_size(self, key: str) -> int:
"""
获取文件大小
:param key: S3 对象键
:return: 文件大小(字节)
"""
async with await self._request("HEAD", key=key) as response:
if response.status != 200:
raise S3APIError(
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
return int(response.headers.get("Content-Length", 0))
# ==================== Multipart Upload ====================
async def create_multipart_upload(
self,
key: str,
content_type: str = 'application/octet-stream',
) -> str:
"""
创建分片上传任务
:param key: S3 对象键
:param content_type: MIME 类型
:return: Upload ID
"""
async with await self._request(
"POST",
key=key,
query_params={"uploads": ""},
content_type=content_type,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"创建分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
body = await response.text()
root = ET.fromstring(body)
# 查找 UploadId 元素(支持命名空间)
upload_id_elem = root.find("UploadId")
if upload_id_elem is None:
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
if upload_id_elem is None or not upload_id_elem.text:
raise S3MultipartUploadError(
f"创建分片上传响应中未找到 UploadId: {body}"
)
upload_id = upload_id_elem.text
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
return upload_id
async def upload_part(
self,
key: str,
upload_id: str,
part_number: int,
data: bytes,
) -> str:
"""
上传单个分片
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param part_number: 分片编号(从 1 开始)
:param data: 分片数据
:return: ETag
"""
async with await self._request(
"PUT",
key=key,
query_params={
"partNumber": str(part_number),
"uploadId": upload_id,
},
payload=data,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"上传分片失败: {self._bucket_name}/{key}, "
f"part={part_number}, 状态: {response.status}, {body}"
)
etag = response.headers.get("ETag", "").strip('"')
l.debug(
f"S3 分片上传成功: {self._bucket_name}/{key}, "
f"part={part_number}, etag={etag}"
)
return etag
async def complete_multipart_upload(
self,
key: str,
upload_id: str,
parts: list[tuple[int, str]],
) -> None:
"""
完成分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param parts: 分片列表 [(part_number, etag)]
"""
# 按 part_number 排序
parts_sorted = sorted(parts, key=lambda p: p[0])
# 构建 CompleteMultipartUpload XML
xml_parts = ''.join(
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
for pn, etag in parts_sorted
)
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
payload_bytes = payload.encode('utf-8')
async with await self._request(
"POST",
key=key,
query_params={"uploadId": upload_id},
payload=payload_bytes,
content_type="application/xml",
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"完成分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.info(
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
f"{len(parts)} 个分片"
)
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
"""
取消分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
"""
async with await self._request(
"DELETE",
key=key,
query_params={"uploadId": upload_id},
) as response:
if response.status in (200, 204):
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
else:
body = await response.text()
l.warning(
f"取消分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
# ==================== 预签名 URL ====================
def generate_presigned_url(
self,
key: str,
method: Literal['GET', 'PUT'] = 'GET',
expires_in: int = 3600,
filename: str | None = None,
) -> str:
"""
生成 S3 预签名 URLAWS Signature V4 Query String
:param key: S3 对象键
:param method: HTTP 方法GET 下载PUT 上传)
:param expires_in: URL 有效期(秒)
:param filename: 文件名GET 请求时设置 Content-Disposition
:return: 预签名 URL
"""
current_time = datetime.now(timezone.utc)
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
date_stamp = current_time.strftime("%Y%m%d")
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
credential = f"{self._access_key}/{credential_scope}"
uri = self._build_uri(key)
effective_host = self._get_effective_host()
query_params: dict[str, str] = {
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
'X-Amz-Credential': credential,
'X-Amz-Date': amz_date,
'X-Amz-Expires': str(expires_in),
'X-Amz-SignedHeaders': 'host',
}
# GET 请求时添加 Content-Disposition
if method == "GET" and filename:
encoded_filename = quote(filename, safe='')
query_params['response-content-disposition'] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
canonical_query_string = "&".join(
f"{quote(k, safe='')}={quote(v, safe='')}"
for k, v in sorted(query_params.items())
)
canonical_headers = f"host:{effective_host}\n"
signed_headers = "host"
payload_hash = "UNSIGNED-PAYLOAD"
canonical_request = (
f"{method}\n"
f"{uri}\n"
f"{canonical_query_string}\n"
f"{canonical_headers}\n"
f"{signed_headers}\n"
f"{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
string_to_sign = (
f"{algorithm}\n"
f"{amz_date}\n"
f"{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
base_url = self._build_url(uri)
return (
f"{base_url}?"
f"{canonical_query_string}&"
f"X-Amz-Signature={signature}"
)
# ==================== 路径生成 ====================
async def generate_file_path(
self,
user_id: UUID,
original_filename: str,
) -> tuple[str, str, str]:
"""
根据命名规则生成 S3 文件存储路径
与 LocalStorageService.generate_file_path 接口一致。
:param user_id: 用户UUID
:param original_filename: 原始文件名
:return: (相对目录路径, 存储文件名, 完整存储路径)
"""
context = NamingContext(
user_id=user_id,
original_filename=original_filename,
)
# 解析目录规则
dir_path = ""
if self._policy.dir_name_rule:
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
# 解析文件名规则
if self._policy.auto_rename and self._policy.file_name_rule:
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
# 确保有扩展名
if '.' in original_filename and '.' not in storage_name:
ext = original_filename.rsplit('.', 1)[1]
storage_name = f"{storage_name}.{ext}"
else:
storage_name = original_filename
# S3 不需要创建目录,直接拼接路径
if dir_path:
storage_path = f"{dir_path}/{storage_name}"
else:
storage_path = storage_name
return dir_path, storage_name, storage_path

View File

@@ -3,12 +3,14 @@
支持多种认证方式邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信预留 支持多种认证方式邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信预留
""" """
import hashlib
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from loguru import logger as l from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from service.redis.token_store import TokenStore
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.group import GroupClaims, GroupOptions from sqlmodels.group import GroupClaims, GroupOptions
from sqlmodels.object import Object, ObjectType from sqlmodels.object import Object, ObjectType
@@ -190,7 +192,7 @@ async def _login_oauth(
# 已绑定 → 更新 OAuth 信息并返回关联用户 # 已绑定 → 更新 OAuth 信息并返回关联用户
identity.display_name = nickname identity.display_name = nickname
identity.avatar_url = avatar_url identity.avatar_url = avatar_url
await identity.save(session) identity = await identity.save(session)
user: User = await User.get(session, User.id == identity.user_id, load=User.group) user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user: if not user:
@@ -252,7 +254,7 @@ async def _auto_register_oauth_user(
is_verified=True, is_verified=True,
user_id=new_user_id, user_id=new_user_id,
) )
await identity.save(session) identity = await identity.save(session)
# 创建用户根目录 # 创建用户根目录
default_policy = await Policy.get(session, Policy.name == "本地存储") default_policy = await Policy.get(session, Policy.name == "本地存储")
@@ -333,7 +335,7 @@ async def _login_passkey(
# 更新签名计数 # 更新签名计数
authn.sign_count = verification.new_sign_count authn.sign_count = verification.new_sign_count
await authn.save(session) authn = await authn.save(session)
# 加载用户 # 加载用户
user: User = await User.get(session, User.id == authn.user_id, load=User.group) user: User = await User.get(session, User.id == authn.user_id, load=User.group)
@@ -363,6 +365,12 @@ async def _login_magic_link(
except BadSignature: except BadSignature:
http_exceptions.raise_unauthorized("Magic Link 无效") http_exceptions.raise_unauthorized("Magic Link 无效")
# 防重放:使用 token 哈希作为标识符
token_hash = hashlib.sha256(request.identifier.encode()).hexdigest()
is_first_use = await TokenStore.mark_used(f"magic_link:{token_hash}", ttl=600)
if not is_first_use:
http_exceptions.raise_unauthorized("Magic Link 已被使用")
# 查找绑定了该邮箱的 AuthIdentityemail_password 或 magic_link # 查找绑定了该邮箱的 AuthIdentityemail_password 或 magic_link
identity: AuthIdentity | None = await AuthIdentity.get( identity: AuthIdentity | None = await AuthIdentity.get(
session, session,
@@ -384,7 +392,7 @@ async def _login_magic_link(
# 标记邮箱已验证 # 标记邮箱已验证
if not identity.is_verified: if not identity.is_verified:
identity.is_verified = True identity.is_verified = True
await identity.save(session) identity = await identity.save(session)
return user return user

185
service/wopi/__init__.py Normal file
View File

@@ -0,0 +1,185 @@
"""
WOPI Discovery 服务模块
解析 WOPI 服务端Collabora / OnlyOffice 等)的 Discovery XML
提取支持的文件扩展名及对应的编辑器 URL 模板。
参考Cloudreve pkg/wopi/discovery.go 和 pkg/wopi/wopi.go
"""
import xml.etree.ElementTree as ET
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from loguru import logger as l
# WOPI URL 模板中已知的查询参数占位符及其替换值
# 值为 None 表示删除该参数,非 None 表示替换为该值
# 参考 Cloudreve pkg/wopi/wopi.go queryPlaceholders
_WOPI_QUERY_PLACEHOLDERS: dict[str, str | None] = {
'BUSINESS_USER': None,
'DC_LLCC': 'lng',
'DISABLE_ASYNC': None,
'DISABLE_CHAT': None,
'EMBEDDED': 'true',
'FULLSCREEN': 'true',
'HOST_SESSION_ID': None,
'SESSION_CONTEXT': None,
'RECORDING': None,
'THEME_ID': 'darkmode',
'UI_LLCC': 'lng',
'VALIDATOR_TEST_CATEGORY': None,
}
_WOPI_SRC_PLACEHOLDER = 'WOPI_SOURCE'
def process_wopi_action_url(raw_urlsrc: str) -> str:
"""
将 WOPI Discovery 中的原始 urlsrc 转换为 DiskNext 可用的 URL 模板。
处理流程(参考 Cloudreve generateActionUrl
1. 去除 ``<>`` 占位符标记
2. 解析查询参数,替换/删除已知占位符
3. ``WOPI_SOURCE`` → ``{wopi_src}``
注意access_token 和 access_token_ttl 不放在 URL 中,
根据 WOPI 规范它们通过 POST 表单字段传递给编辑器。
:param raw_urlsrc: WOPI Discovery XML 中的 urlsrc 原始值
:return: 处理后的 URL 模板字符串,包含 {wopi_src} 占位符
"""
# 去除 <> 标记
cleaned = raw_urlsrc.replace('<', '').replace('>', '')
parsed = urlparse(cleaned)
raw_params = parse_qs(parsed.query, keep_blank_values=True)
new_params: list[tuple[str, str]] = []
is_src_replaced = False
for key, values in raw_params.items():
value = values[0] if values else ''
# WOPI_SOURCE 占位符 → {wopi_src}
if value == _WOPI_SRC_PLACEHOLDER:
new_params.append((key, '{wopi_src}'))
is_src_replaced = True
continue
# 已知占位符
if value in _WOPI_QUERY_PLACEHOLDERS:
replacement = _WOPI_QUERY_PLACEHOLDERS[value]
if replacement is not None:
new_params.append((key, replacement))
# replacement 为 None 时删除该参数
continue
# 其他参数保留原值
new_params.append((key, value))
# 如果没有找到 WOPI_SOURCE 占位符,手动添加 WOPISrc
if not is_src_replaced:
new_params.append(('WOPISrc', '{wopi_src}'))
# LibreOffice/Collabora 需要 lang 参数(避免重复添加)
existing_keys = {k for k, _ in new_params}
if 'lang' not in existing_keys:
new_params.append(('lang', 'lng'))
# 注意access_token 和 access_token_ttl 不放在 URL 中
# 根据 WOPI 规范,它们通过 POST 表单字段传递给编辑器
# 重建 URL
new_query = urlencode(new_params, safe='{}')
result = urlunparse((
parsed.scheme,
parsed.netloc,
parsed.path,
parsed.params,
new_query,
'',
))
return result
def parse_wopi_discovery_xml(xml_content: str) -> tuple[dict[str, str], list[str]]:
"""
解析 WOPI Discovery XML提取扩展名到 URL 模板的映射。
XML 结构::
<wopi-discovery>
<net-zone name="external-https">
<app name="Writer" favIconUrl="...">
<action name="edit" ext="docx" urlsrc="https://..."/>
<action name="view" ext="docx" urlsrc="https://..."/>
</app>
</net-zone>
</wopi-discovery>
动作优先级edit > embedview > view参考 Cloudreve discovery.go
:param xml_content: WOPI Discovery 端点返回的 XML 字符串
:return: (action_urls, app_names) 元组
action_urls: {extension: processed_url_template}
app_names: 发现的应用名称列表
:raises ValueError: XML 解析失败或格式无效
"""
try:
root = ET.fromstring(xml_content)
except ET.ParseError as e:
raise ValueError(f"WOPI Discovery XML 解析失败: {e}")
# 查找 net-zone可能有多个取第一个非空的
net_zones = root.findall('net-zone')
if not net_zones:
raise ValueError("WOPI Discovery XML 缺少 net-zone 节点")
# ext_actions: {extension: {action_name: urlsrc}}
ext_actions: dict[str, dict[str, str]] = {}
app_names: list[str] = []
for net_zone in net_zones:
for app_elem in net_zone.findall('app'):
app_name = app_elem.get('name', '')
if app_name:
app_names.append(app_name)
for action_elem in app_elem.findall('action'):
action_name = action_elem.get('name', '')
ext = action_elem.get('ext', '')
urlsrc = action_elem.get('urlsrc', '')
if not ext or not urlsrc:
continue
# 只关注 edit / embedview / view 三种动作
if action_name not in ('edit', 'embedview', 'view'):
continue
if ext not in ext_actions:
ext_actions[ext] = {}
ext_actions[ext][action_name] = urlsrc
# 为每个扩展名选择最佳 URL: edit > embedview > view
action_urls: dict[str, str] = {}
for ext, actions_map in ext_actions.items():
selected_urlsrc: str | None = None
for preferred in ('edit', 'embedview', 'view'):
if preferred in actions_map:
selected_urlsrc = actions_map[preferred]
break
if selected_urlsrc:
action_urls[ext] = process_wopi_action_url(selected_urlsrc)
# 去重 app_names
seen: set[str] = set()
unique_names: list[str] = []
for name in app_names:
if name not in seen:
seen.add(name)
unique_names.append(name)
l.info(f"WOPI Discovery 解析完成: {len(action_urls)} 个扩展名, 应用: {unique_names}")
return action_urls, unique_names

View File

@@ -84,6 +84,7 @@ if __name__ == "__main__":
setup( setup(
name="disknext-ee", name="disknext-ee",
packages=[],
ext_modules=cythonize( ext_modules=cythonize(
extensions, extensions,
compiler_directives={'language_level': "3"}, compiler_directives={'language_level': "3"},

View File

@@ -954,18 +954,11 @@ class PolicyType(StrEnum):
S3 = "s3" # S3 兼容存储 S3 = "s3" # S3 兼容存储
``` ```
### StorageType ### PolicyType
```python ```python
class StorageType(StrEnum): class PolicyType(StrEnum):
LOCAL = "local" # 本地存储 LOCAL = "local" # 本地存储
QINIU = "qiniu" # 七牛云 S3 = "s3" # S3 兼容存储
TENCENT = "tencent" # 腾讯云
ALIYUN = "aliyun" # 阿里云
ONEDRIVE = "onedrive" # OneDrive
GOOGLE_DRIVE = "google_drive" # Google Drive
DROPBOX = "dropbox" # Dropbox
WEBDAV = "webdav" # WebDAV
REMOTE = "remote" # 远程存储
``` ```
### UserStatus ### UserStatus

View File

@@ -69,18 +69,20 @@ from .object import (
CreateUploadSessionRequest, CreateUploadSessionRequest,
DirectoryCreateRequest, DirectoryCreateRequest,
DirectoryResponse, DirectoryResponse,
FileMetadata,
FileMetadataBase,
Object, Object,
ObjectBase, ObjectBase,
ObjectCopyRequest, ObjectCopyRequest,
ObjectDeleteRequest, ObjectDeleteRequest,
ObjectFileFinalize,
ObjectMoveRequest, ObjectMoveRequest,
ObjectMoveUpdate,
ObjectPropertyDetailResponse, ObjectPropertyDetailResponse,
ObjectPropertyResponse, ObjectPropertyResponse,
ObjectRenameRequest, ObjectRenameRequest,
ObjectResponse, ObjectResponse,
ObjectSwitchPolicyRequest,
ObjectType, ObjectType,
FileCategory,
PolicyResponse, PolicyResponse,
UploadChunkResponse, UploadChunkResponse,
UploadSession, UploadSession,
@@ -95,11 +97,42 @@ from .object import (
TrashRestoreRequest, TrashRestoreRequest,
TrashDeleteRequest, TrashDeleteRequest,
) )
from .object_metadata import (
ObjectMetadata,
ObjectMetadataBase,
MetadataNamespace,
MetadataResponse,
MetadataPatchItem,
MetadataPatchRequest,
INTERNAL_NAMESPACES,
USER_WRITABLE_NAMESPACES,
)
from .custom_property import (
CustomPropertyDefinition,
CustomPropertyDefinitionBase,
CustomPropertyType,
CustomPropertyCreateRequest,
CustomPropertyUpdateRequest,
CustomPropertyResponse,
)
from .physical_file import PhysicalFile, PhysicalFileBase from .physical_file import PhysicalFile, PhysicalFileBase
from .uri import DiskNextURI, FileSystemNamespace from .uri import DiskNextURI, FileSystemNamespace
from .order import Order, OrderStatus, OrderType from .order import (
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary Order, OrderStatus, OrderType,
from .redeem import Redeem, RedeemType CreateOrderRequest, OrderResponse,
)
from .policy import (
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
PolicyType, PolicySummary, PolicyUpdateRequest,
)
from .product import (
Product, ProductBase, ProductType, PaymentMethod,
ProductCreateRequest, ProductUpdateRequest, ProductResponse,
)
from .redeem import (
Redeem, RedeemType,
RedeemCreateRequest, RedeemUseRequest, RedeemInfoResponse, RedeemAdminResponse,
)
from .report import Report, ReportReason from .report import Report, ReportReason
from .setting import ( from .setting import (
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig, Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
@@ -112,16 +145,20 @@ from .share import (
AdminShareListItem, AdminShareListItem,
) )
from .source_link import SourceLink from .source_link import SourceLink
from .storage_pack import StoragePack from .storage_pack import StoragePack, StoragePackResponse
from .tag import Tag, TagType from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
from .webdav import WebDAV from .webdav import (
WebDAV, WebDAVBase,
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
)
from .file_app import ( from .file_app import (
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault, FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
# DTO # DTO
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse, FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse, FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse, ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
WopiDiscoveredExtension, WopiDiscoveryResponse,
) )
from .wopi import WopiFileInfo, WopiAccessTokenPayload from .wopi import WopiFileInfo, WopiAccessTokenPayload

View File

@@ -10,7 +10,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100, Str128, Str255, Text1024
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -87,7 +87,7 @@ class ChangePasswordRequest(SQLModelBase):
old_password: str = Field(min_length=1) old_password: str = Field(min_length=1)
"""当前密码""" """当前密码"""
new_password: str = Field(min_length=8, max_length=128) new_password: Str128 = Field(min_length=8)
"""新密码(至少 8 位)""" """新密码(至少 8 位)"""
@@ -103,13 +103,13 @@ class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
provider: AuthProviderType = Field(index=True) provider: AuthProviderType = Field(index=True)
"""提供者类型""" """提供者类型"""
identifier: str = Field(max_length=255, index=True) identifier: Str255 = Field(index=True)
"""标识符(邮箱/手机号/OAuth openid""" """标识符(邮箱/手机号/OAuth openid"""
credential: str | None = Field(default=None, max_length=1024) credential: Text1024 | None = None
"""凭证Argon2 哈希密码 / null""" """凭证Argon2 哈希密码 / null"""
display_name: str | None = Field(default=None, max_length=100) display_name: Str100 | None = None
"""OAuth 昵称""" """OAuth 昵称"""
avatar_url: str | None = Field(default=None, max_length=512) avatar_url: str | None = Field(default=None, max_length=512)

View File

@@ -0,0 +1,135 @@
"""
用户自定义属性定义模型
允许用户定义类型化的自定义属性模板(如标签、评分、分类等),
实际值通过 ObjectMetadata KV 表存储键名格式custom:{property_definition_id}
支持的属性类型text, number, boolean, select, multi_select, rating, link
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import JSON
from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
if TYPE_CHECKING:
from .user import User
# ==================== 枚举 ====================
class CustomPropertyType(StrEnum):
"""自定义属性值类型枚举"""
TEXT = "text"
"""文本"""
NUMBER = "number"
"""数字"""
BOOLEAN = "boolean"
"""布尔值"""
SELECT = "select"
"""单选"""
MULTI_SELECT = "multi_select"
"""多选"""
RATING = "rating"
"""评分1-5"""
LINK = "link"
"""链接"""
# ==================== Base 模型 ====================
class CustomPropertyDefinitionBase(SQLModelBase):
"""自定义属性定义基础模型"""
name: Str100
"""属性显示名称"""
type: CustomPropertyType
"""属性值类型"""
icon: Str100 | None = None
"""图标标识iconify 名称)"""
options: list[str] | None = Field(default=None, sa_type=JSON)
"""可选值列表(仅 select/multi_select 类型)"""
default_value: str | None = Field(default=None, max_length=500)
"""默认值"""
# ==================== 数据库模型 ====================
class CustomPropertyDefinition(CustomPropertyDefinitionBase, UUIDTableBaseMixin):
"""
用户自定义属性定义
每个用户独立管理自己的属性模板。
实际属性值存储在 ObjectMetadata 表中键名格式custom:{id}
"""
owner_id: UUID = Field(
foreign_key="user.id",
ondelete="CASCADE",
index=True,
)
"""所有者用户UUID"""
sort_order: int = 0
"""排序顺序"""
# 关系
owner: "User" = Relationship()
"""所有者"""
# ==================== DTO 模型 ====================
class CustomPropertyCreateRequest(SQLModelBase):
"""创建自定义属性请求 DTO"""
name: Str100
"""属性显示名称"""
type: CustomPropertyType
"""属性值类型"""
icon: str | None = None
"""图标标识"""
options: list[str] | None = None
"""可选值列表(仅 select/multi_select 类型)"""
default_value: str | None = None
"""默认值"""
class CustomPropertyUpdateRequest(SQLModelBase):
"""更新自定义属性请求 DTO"""
name: str | None = None
"""属性显示名称"""
icon: str | None = None
"""图标标识"""
options: list[str] | None = None
"""可选值列表"""
default_value: str | None = None
"""默认值"""
sort_order: int | None = None
"""排序顺序"""
class CustomPropertyResponse(CustomPropertyDefinitionBase):
"""自定义属性响应 DTO"""
id: UUID
"""属性定义UUID"""
sort_order: int
"""排序顺序"""

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, Index from sqlmodel import Field, Relationship, UniqueConstraint, Index
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -141,7 +141,7 @@ class Download(DownloadBase, UUIDTableBaseMixin):
speed: int = Field(default=0) speed: int = Field(default=0)
"""下载速度bytes/s""" """下载速度bytes/s"""
parent: str | None = Field(default=None, max_length=255) parent: Str255 | None = None
"""父任务标识""" """父任务标识"""
error: str | None = Field(default=None) error: str | None = Field(default=None)

View File

@@ -20,7 +20,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str100, Str255, Text1024
if TYPE_CHECKING: if TYPE_CHECKING:
from .group import Group from .group import Group
@@ -119,7 +119,7 @@ class UserFileAppDefaultResponse(SQLModelBase):
class FileAppCreateRequest(SQLModelBase): class FileAppCreateRequest(SQLModelBase):
"""管理员创建应用请求 DTO""" """管理员创建应用请求 DTO"""
name: str = Field(max_length=100) name: Str100
"""应用名称""" """应用名称"""
app_key: str = Field(max_length=50) app_key: str = Field(max_length=50)
@@ -128,7 +128,7 @@ class FileAppCreateRequest(SQLModelBase):
type: FileAppType type: FileAppType
"""应用类型""" """应用类型"""
icon: str | None = Field(default=None, max_length=255) icon: Str255 | None = None
"""图标名称/URL""" """图标名称/URL"""
description: str | None = Field(default=None, max_length=500) description: str | None = Field(default=None, max_length=500)
@@ -140,13 +140,13 @@ class FileAppCreateRequest(SQLModelBase):
is_restricted: bool = False is_restricted: bool = False
"""是否限制用户组访问""" """是否限制用户组访问"""
iframe_url_template: str | None = Field(default=None, max_length=1024) iframe_url_template: Text1024 | None = None
"""iframe URL 模板""" """iframe URL 模板"""
wopi_discovery_url: str | None = Field(default=None, max_length=512) wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 发现端点 URL""" """WOPI 发现端点 URL"""
wopi_editor_url_template: str | None = Field(default=None, max_length=1024) wopi_editor_url_template: Text1024 | None = None
"""WOPI 编辑器 URL 模板""" """WOPI 编辑器 URL 模板"""
extensions: list[str] = [] extensions: list[str] = []
@@ -159,7 +159,7 @@ class FileAppCreateRequest(SQLModelBase):
class FileAppUpdateRequest(SQLModelBase): class FileAppUpdateRequest(SQLModelBase):
"""管理员更新应用请求 DTO所有字段可选""" """管理员更新应用请求 DTO所有字段可选"""
name: str | None = Field(default=None, max_length=100) name: Str100 | None = None
"""应用名称""" """应用名称"""
app_key: str | None = Field(default=None, max_length=50) app_key: str | None = Field(default=None, max_length=50)
@@ -168,7 +168,7 @@ class FileAppUpdateRequest(SQLModelBase):
type: FileAppType | None = None type: FileAppType | None = None
"""应用类型""" """应用类型"""
icon: str | None = Field(default=None, max_length=255) icon: Str255 | None = None
"""图标名称/URL""" """图标名称/URL"""
description: str | None = Field(default=None, max_length=500) description: str | None = Field(default=None, max_length=500)
@@ -180,13 +180,13 @@ class FileAppUpdateRequest(SQLModelBase):
is_restricted: bool | None = None is_restricted: bool | None = None
"""是否限制用户组访问""" """是否限制用户组访问"""
iframe_url_template: str | None = Field(default=None, max_length=1024) iframe_url_template: Text1024 | None = None
"""iframe URL 模板""" """iframe URL 模板"""
wopi_discovery_url: str | None = Field(default=None, max_length=512) wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 发现端点 URL""" """WOPI 发现端点 URL"""
wopi_editor_url_template: str | None = Field(default=None, max_length=1024) wopi_editor_url_template: Text1024 | None = None
"""WOPI 编辑器 URL 模板""" """WOPI 编辑器 URL 模板"""
@@ -297,12 +297,35 @@ class WopiSessionResponse(SQLModelBase):
"""完整的编辑器 URL""" """完整的编辑器 URL"""
class WopiDiscoveredExtension(SQLModelBase):
"""单个 WOPI Discovery 发现的扩展名"""
extension: str
"""文件扩展名"""
action_url: str
"""处理后的动作 URL 模板"""
class WopiDiscoveryResponse(SQLModelBase):
"""WOPI Discovery 结果响应 DTO"""
discovered_extensions: list[WopiDiscoveredExtension] = []
"""发现的扩展名及其 URL 模板"""
app_names: list[str] = []
"""WOPI 服务端报告的应用名称(如 Writer、Calc、Impress"""
applied_count: int = 0
"""已应用到 FileAppExtension 的数量"""
# ==================== 数据库模型 ==================== # ==================== 数据库模型 ====================
class FileApp(SQLModelBase, UUIDTableBaseMixin): class FileApp(SQLModelBase, UUIDTableBaseMixin):
"""文件查看器应用注册表""" """文件查看器应用注册表"""
name: str = Field(max_length=100) name: Str100
"""应用名称""" """应用名称"""
app_key: str = Field(max_length=50, unique=True, index=True) app_key: str = Field(max_length=50, unique=True, index=True)
@@ -311,7 +334,7 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
type: FileAppType type: FileAppType
"""应用类型""" """应用类型"""
icon: str | None = Field(default=None, max_length=255) icon: Str255 | None = None
"""图标名称/URL""" """图标名称/URL"""
description: str | None = Field(default=None, max_length=500) description: str | None = Field(default=None, max_length=500)
@@ -323,13 +346,13 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
is_restricted: bool = False is_restricted: bool = False
"""是否限制用户组访问""" """是否限制用户组访问"""
iframe_url_template: str | None = Field(default=None, max_length=1024) iframe_url_template: Text1024 | None = None
"""iframe URL 模板,支持 {file_url} 占位符""" """iframe URL 模板,支持 {file_url} 占位符"""
wopi_discovery_url: str | None = Field(default=None, max_length=512) wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 客户端发现端点 URL""" """WOPI 客户端发现端点 URL"""
wopi_editor_url_template: str | None = Field(default=None, max_length=1024) wopi_editor_url_template: Text1024 | None = None
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}""" """WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
# 关系 # 关系
@@ -377,6 +400,9 @@ class FileAppExtension(SQLModelBase, TableBaseMixin):
priority: int = Field(default=0, ge=0) priority: int = Field(default=0, ge=0)
"""排序优先级(越小越优先)""" """排序优先级(越小越优先)"""
wopi_action_url: str | None = Field(default=None, max_length=2048)
"""WOPI 动作 URL 模板Discovery 自动填充),支持 {wopi_src} {access_token} {access_token_ttl}"""
# 关系 # 关系
app: FileApp = Relationship(back_populates="extensions") app: FileApp = Relationship(back_populates="extensions")

View File

@@ -2,9 +2,10 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, text from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -66,7 +67,7 @@ class GroupAllOptionsBase(GroupOptionsBase):
class GroupCreateRequest(GroupAllOptionsBase): class GroupCreateRequest(GroupAllOptionsBase):
"""创建用户组请求 DTO""" """创建用户组请求 DTO"""
name: str = Field(max_length=255) name: Str255
"""用户组名称""" """用户组名称"""
max_storage: int = Field(default=0, ge=0) max_storage: int = Field(default=0, ge=0)
@@ -91,7 +92,7 @@ class GroupCreateRequest(GroupAllOptionsBase):
class GroupUpdateRequest(SQLModelBase): class GroupUpdateRequest(SQLModelBase):
"""更新用户组请求 DTO所有字段可选""" """更新用户组请求 DTO所有字段可选"""
name: str | None = Field(default=None, max_length=255) name: Str255 | None = None
"""用户组名称""" """用户组名称"""
max_storage: int | None = Field(default=None, ge=0) max_storage: int | None = Field(default=None, ge=0)
@@ -257,10 +258,10 @@ class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
class Group(GroupBase, UUIDTableBaseMixin): class Group(GroupBase, UUIDTableBaseMixin):
"""用户组模型""" """用户组模型"""
name: str = Field(max_length=255, unique=True) name: Str255 = Field(unique=True)
"""用户组名""" """用户组名"""
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""最大存储空间(字节)""" """最大存储空间(字节)"""
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})

View File

@@ -130,6 +130,11 @@ default_settings: list[Setting] = [
Setting(name="sms_provider", value="", type=SettingsType.MOBILE), Setting(name="sms_provider", value="", type=SettingsType.MOBILE),
Setting(name="sms_access_key", value="", type=SettingsType.MOBILE), Setting(name="sms_access_key", value="", type=SettingsType.MOBILE),
Setting(name="sms_secret_key", value="", type=SettingsType.MOBILE), Setting(name="sms_secret_key", value="", type=SettingsType.MOBILE),
# ==================== 文件分类扩展名配置 ====================
Setting(name="image", value="jpg,jpeg,png,gif,bmp,webp,svg,ico,tiff,tif,avif,heic,heif,psd,raw", type=SettingsType.FILE_CATEGORY),
Setting(name="video", value="mp4,mkv,avi,mov,wmv,flv,webm,m4v,ts,3gp,mpg,mpeg", type=SettingsType.FILE_CATEGORY),
Setting(name="audio", value="mp3,wav,flac,aac,ogg,wma,m4a,opus,ape,aiff,mid,midi", type=SettingsType.FILE_CATEGORY),
Setting(name="document", value="pdf,doc,docx,odt,rtf,txt,tex,epub,pages,ppt,pptx,odp,key,xls,xlsx,csv,ods,numbers,tsv,md,markdown,mdx", type=SettingsType.FILE_CATEGORY),
] ]
async def init_default_settings() -> None: async def init_default_settings() -> None:
@@ -173,7 +178,7 @@ async def init_default_group() -> None:
admin=True, admin=True,
) )
admin_group_id = admin_group.id # 在 save 前保存 UUID admin_group_id = admin_group.id # 在 save 前保存 UUID
await admin_group.save(session) admin_group = await admin_group.save(session)
await GroupOptions( await GroupOptions(
group_id=admin_group_id, group_id=admin_group_id,
@@ -203,7 +208,7 @@ async def init_default_group() -> None:
web_dav_enabled=True, web_dav_enabled=True,
) )
member_group_id = member_group.id # 在 save 前保存 UUID member_group_id = member_group.id # 在 save 前保存 UUID
await member_group.save(session) member_group = await member_group.save(session)
await GroupOptions( await GroupOptions(
group_id=member_group_id, group_id=member_group_id,
@@ -222,7 +227,7 @@ async def init_default_group() -> None:
default_group_setting = await Setting.get(session, Setting.name == "default_group") default_group_setting = await Setting.get(session, Setting.name == "default_group")
if default_group_setting: if default_group_setting:
default_group_setting.value = str(member_group_id) default_group_setting.value = str(member_group_id)
await default_group_setting.save(session) default_group_setting = await default_group_setting.save(session)
# 未找到初始游客组时,则创建 # 未找到初始游客组时,则创建
if not await Group.get(session, Group.name == "游客"): if not await Group.get(session, Group.name == "游客"):
@@ -232,7 +237,7 @@ async def init_default_group() -> None:
web_dav_enabled=False, web_dav_enabled=False,
) )
guest_group_id = guest_group.id # 在 save 前保存 UUID guest_group_id = guest_group.id # 在 save 前保存 UUID
await guest_group.save(session) guest_group = await guest_group.save(session)
await GroupOptions( await GroupOptions(
group_id=guest_group_id, group_id=guest_group_id,
@@ -284,7 +289,7 @@ async def init_default_user() -> None:
group_id=admin_group.id, group_id=admin_group.id,
) )
admin_user_id = admin_user.id # 在 save 前保存 UUID admin_user_id = admin_user.id # 在 save 前保存 UUID
await admin_user.save(session) admin_user = await admin_user.save(session)
# 创建 AuthIdentity邮箱密码身份 # 创建 AuthIdentity邮箱密码身份
await AuthIdentity( await AuthIdentity(
@@ -373,7 +378,7 @@ async def init_default_theme_presets() -> None:
error=ChromaticColor.RED, error=ChromaticColor.RED,
neutral=NeutralColor.ZINC, neutral=NeutralColor.ZINC,
) )
await default_preset.save(session) default_preset = await default_preset.save(session)
log.info('已创建默认主题预设') log.info('已创建默认主题预设')
@@ -446,36 +451,43 @@ _DEFAULT_FILE_APPS: list[dict] = [
"is_enabled": True, "is_enabled": True,
"extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"], "extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"],
}, },
# iframe 应用(默认禁用) {
"name": "EPUB 阅读器",
"app_key": "epub_reader",
"type": "builtin",
"icon": "book-open",
"description": "阅读 EPUB 电子书",
"is_enabled": True,
"extensions": ["epub"],
},
{
"name": "3D 模型预览",
"app_key": "model_viewer",
"type": "builtin",
"icon": "cube",
"description": "预览 3D 模型",
"is_enabled": True,
"extensions": ["gltf", "glb", "stl", "obj", "fbx", "ply", "3mf"],
},
{
"name": "Font Viewer",
"app_key": "font_viewer",
"type": "builtin",
"icon": "type",
"description": "预览字体文件并显示元数据和文本样本",
"is_enabled": True,
"extensions": ["ttf", "otf", "woff", "woff2"],
},
{ {
"name": "Office 在线预览", "name": "Office 在线预览",
"app_key": "office_viewer", "app_key": "office_viewer",
"type": "iframe", "type": "iframe",
"icon": "file-word", "icon": "file-word",
"description": "使用 Microsoft Office Online 预览文档", "description": "使用 Microsoft Office Online 预览文档",
"is_enabled": False, "is_enabled": True,
"iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}", "iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"], "extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
}, },
# WOPI 应用(默认禁用)
{
"name": "Collabora Online",
"app_key": "collabora",
"type": "wopi",
"icon": "file-text",
"description": "Collabora Online 文档编辑器(需自行部署)",
"is_enabled": False,
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx", "odt", "ods", "odp"],
},
{
"name": "OnlyOffice",
"app_key": "onlyoffice",
"type": "wopi",
"icon": "file-text",
"description": "OnlyOffice 文档编辑器(需自行部署)",
"is_enabled": False,
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
},
] ]
@@ -493,7 +505,7 @@ async def init_default_file_apps() -> None:
return return
for app_data in _DEFAULT_FILE_APPS: for app_data in _DEFAULT_FILE_APPS:
extensions = app_data.pop("extensions") extensions = app_data["extensions"]
app = FileApp( app = FileApp(
name=app_data["name"], name=app_data["name"],
@@ -515,6 +527,6 @@ async def init_default_file_apps() -> None:
extension=ext.lower(), extension=ext.lower(),
priority=i, priority=i,
) )
await ext_record.save(session) ext_record = await ext_record.save(session)
log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用') log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用')

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Index from sqlmodel import Field, Relationship, text, Index
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .download import Download from .download import Download
@@ -28,13 +28,13 @@ class NodeType(StrEnum):
class Aria2ConfigurationBase(SQLModelBase): class Aria2ConfigurationBase(SQLModelBase):
"""Aria2配置基础模型""" """Aria2配置基础模型"""
rpc_url: str | None = Field(default=None, max_length=255) rpc_url: Str255 | None = None
"""RPC地址""" """RPC地址"""
rpc_secret: str | None = None rpc_secret: str | None = None
"""RPC密钥""" """RPC密钥"""
temp_path: str | None = Field(default=None, max_length=255) temp_path: Str255 | None = None
"""临时下载路径""" """临时下载路径"""
max_concurrent: int = Field(default=5, ge=1, le=50) max_concurrent: int = Field(default=5, ge=1, le=50)
@@ -70,19 +70,19 @@ class Node(SQLModelBase, TableBaseMixin):
status: NodeStatus = Field(default=NodeStatus.ONLINE) status: NodeStatus = Field(default=NodeStatus.ONLINE)
"""节点状态""" """节点状态"""
name: str = Field(max_length=255, unique=True) name: Str255 = Field(unique=True)
"""节点名称""" """节点名称"""
type: NodeType type: NodeType
"""节点类型""" """节点类型"""
server: str = Field(max_length=255) server: Str255
"""节点地址IP或域名""" """节点地址IP或域名"""
slave_key: str | None = Field(default=None, max_length=255) slave_key: Str255 | None = None
"""从机通讯密钥""" """从机通讯密钥"""
master_key: str | None = Field(default=None, max_length=255) master_key: Str255 | None = None
"""主机通讯密钥""" """主机通讯密钥"""
aria2_enabled: bool = False aria2_enabled: bool = False

View File

@@ -7,7 +7,9 @@ from enum import StrEnum
from sqlalchemy import BigInteger from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, CheckConstraint, Index, text from sqlmodel import Field, Relationship, CheckConstraint, Index, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255, Str256
from .policy import PolicyType
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -16,6 +18,7 @@ if TYPE_CHECKING:
from .share import Share from .share import Share
from .physical_file import PhysicalFile from .physical_file import PhysicalFile
from .uri import DiskNextURI from .uri import DiskNextURI
from .object_metadata import ObjectMetadata
class ObjectType(StrEnum): class ObjectType(StrEnum):
@@ -23,42 +26,13 @@ class ObjectType(StrEnum):
FILE = "file" FILE = "file"
FOLDER = "folder" FOLDER = "folder"
class StorageType(StrEnum):
"""存储类型枚举"""
LOCAL = "local"
QINIU = "qiniu"
TENCENT = "tencent"
ALIYUN = "aliyun"
ONEDRIVE = "onedrive"
GOOGLE_DRIVE = "google_drive"
DROPBOX = "dropbox"
WEBDAV = "webdav"
REMOTE = "remote"
class FileCategory(StrEnum):
class FileMetadataBase(SQLModelBase): """文件类型分类枚举,用于按类别筛选文件"""
"""文件元数据基础模型""" IMAGE = "image"
VIDEO = "video"
width: int | None = Field(default=None) AUDIO = "audio"
"""图片宽度(像素)""" DOCUMENT = "document"
height: int | None = Field(default=None)
"""图片高度(像素)"""
duration: float | None = Field(default=None)
"""音视频时长(秒)"""
bitrate: int | None = Field(default=None)
"""比特率kbps"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型"""
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和"""
checksum_sha256: str | None = Field(default=None, max_length=64)
"""SHA256校验和"""
# ==================== Base 模型 ==================== # ==================== Base 模型 ====================
@@ -75,9 +49,32 @@ class ObjectBase(SQLModelBase):
size: int | None = None size: int | None = None
"""文件大小(字节),目录为 None""" """文件大小(字节),目录为 None"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型仅文件有效"""
# ==================== DTO 模型 ==================== # ==================== DTO 模型 ====================
class ObjectFileFinalize(SQLModelBase):
"""文件上传完成后更新 Object 的 DTO"""
size: int
"""文件大小(字节)"""
physical_file_id: UUID
"""关联的物理文件UUID"""
class ObjectMoveUpdate(SQLModelBase):
"""移动/重命名 Object 的 DTO"""
parent_id: UUID
"""新的父目录UUID"""
name: str
"""新名称"""
class DirectoryCreateRequest(SQLModelBase): class DirectoryCreateRequest(SQLModelBase):
"""创建目录请求 DTO""" """创建目录请求 DTO"""
@@ -136,7 +133,7 @@ class PolicyResponse(SQLModelBase):
name: str name: str
"""策略名称""" """策略名称"""
type: StorageType type: PolicyType
"""存储类型""" """存储类型"""
max_size: int = Field(ge=0, default=0, sa_type=BigInteger) max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
@@ -164,22 +161,6 @@ class DirectoryResponse(SQLModelBase):
# ==================== 数据库模型 ==================== # ==================== 数据库模型 ====================
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
"""文件元数据模型与Object一对一关联"""
object_id: UUID = Field(
foreign_key="object.id",
unique=True,
index=True,
ondelete="CASCADE"
)
"""关联的对象UUID"""
# 反向关系
object: "Object" = Relationship(back_populates="file_metadata")
"""关联的对象"""
class Object(ObjectBase, UUIDTableBaseMixin): class Object(ObjectBase, UUIDTableBaseMixin):
""" """
统一对象模型 统一对象模型
@@ -217,13 +198,13 @@ class Object(ObjectBase, UUIDTableBaseMixin):
# ==================== 基础字段 ==================== # ==================== 基础字段 ====================
name: str = Field(max_length=255) name: Str255
"""对象名称(文件名或目录名)""" """对象名称(文件名或目录名)"""
type: ObjectType type: ObjectType
"""对象类型file 或 folder""" """对象类型file 或 folder"""
password: str | None = Field(default=None, max_length=255) password: Str255 | None = None
"""对象独立密码(仅当用户为对象单独设置密码时有效)""" """对象独立密码(仅当用户为对象单独设置密码时有效)"""
# ==================== 文件专属字段 ==================== # ==================== 文件专属字段 ====================
@@ -231,7 +212,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}) size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""文件大小(字节),目录为 0""" """文件大小(字节),目录为 0"""
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True) upload_session_id: Str255 | None = Field(default=None, unique=True, index=True)
"""分块上传会话ID仅文件有效""" """分块上传会话ID仅文件有效"""
physical_file_id: UUID | None = Field( physical_file_id: UUID | None = Field(
@@ -334,11 +315,11 @@ class Object(ObjectBase, UUIDTableBaseMixin):
"""子对象(文件和子目录)""" """子对象(文件和子目录)"""
# 仅文件有效的关系 # 仅文件有效的关系
file_metadata: FileMetadata | None = Relationship( metadata_entries: list["ObjectMetadata"] = Relationship(
back_populates="object", back_populates="object",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"}, sa_relationship_kwargs={"cascade": "all, delete-orphan"},
) )
"""文件元数据(仅文件有效)""" """元数据键值对列表"""
source_links: list["SourceLink"] = Relationship( source_links: list["SourceLink"] = Relationship(
back_populates="object", back_populates="object",
@@ -496,6 +477,37 @@ class Object(ObjectBase, UUIDTableBaseMixin):
fetch_mode="all" fetch_mode="all"
) )
@classmethod
async def get_by_category(
cls,
session: 'AsyncSession',
user_id: UUID,
extensions: list[str],
table_view: 'TableViewRequest | None' = None,
) -> 'ListResponse[Object]':
"""
按扩展名列表查询用户的所有文件(跨目录)
只查询未删除、未封禁的文件对象,使用 ILIKE 匹配文件名后缀。
:param session: 数据库会话
:param user_id: 用户UUID
:param extensions: 扩展名列表(不含点号)
:param table_view: 分页排序参数
:return: 分页文件列表
"""
from sqlalchemy import or_
ext_conditions = [cls.name.ilike(f"%.{ext}") for ext in extensions]
condition = (
(cls.owner_id == user_id) &
(cls.type == ObjectType.FILE) &
(cls.deleted_at == None) &
(cls.is_banned == False) &
or_(*ext_conditions)
)
return await cls.get_with_count(session, condition, table_view=table_view)
@classmethod @classmethod
async def resolve_uri( async def resolve_uri(
cls, cls,
@@ -573,7 +585,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
class UploadSessionBase(SQLModelBase): class UploadSessionBase(SQLModelBase):
"""上传会话基础字段""" """上传会话基础字段"""
file_name: str = Field(max_length=255) file_name: Str255
"""原始文件名""" """原始文件名"""
file_size: int = Field(ge=0, sa_type=BigInteger) file_size: int = Field(ge=0, sa_type=BigInteger)
@@ -604,6 +616,12 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
storage_path: str | None = Field(default=None, max_length=512) storage_path: str | None = Field(default=None, max_length=512)
"""文件存储路径""" """文件存储路径"""
s3_upload_id: Str256 | None = None
"""S3 Multipart Upload ID仅 S3 策略使用)"""
s3_part_etags: str | None = None
"""S3 已上传分片的 ETag 列表JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
expires_at: datetime expires_at: datetime
"""会话过期时间""" """会话过期时间"""
@@ -645,7 +663,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
class CreateUploadSessionRequest(SQLModelBase): class CreateUploadSessionRequest(SQLModelBase):
"""创建上传会话请求 DTO""" """创建上传会话请求 DTO"""
file_name: str = Field(max_length=255) file_name: Str255
"""文件名""" """文件名"""
file_size: int = Field(ge=0) file_size: int = Field(ge=0)
@@ -702,7 +720,7 @@ class UploadChunkResponse(SQLModelBase):
class CreateFileRequest(SQLModelBase): class CreateFileRequest(SQLModelBase):
"""创建空白文件请求 DTO""" """创建空白文件请求 DTO"""
name: str = Field(max_length=255) name: Str255
"""文件名""" """文件名"""
parent_id: UUID parent_id: UUID
@@ -712,6 +730,16 @@ class CreateFileRequest(SQLModelBase):
"""存储策略UUID不指定则使用父目录的策略""" """存储策略UUID不指定则使用父目录的策略"""
class ObjectSwitchPolicyRequest(SQLModelBase):
"""切换对象存储策略请求"""
policy_id: UUID
"""目标存储策略UUID"""
is_migrate_existing: bool = False
"""(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
# ==================== 对象操作相关 DTO ==================== # ==================== 对象操作相关 DTO ====================
class ObjectCopyRequest(SQLModelBase): class ObjectCopyRequest(SQLModelBase):
@@ -730,7 +758,7 @@ class ObjectRenameRequest(SQLModelBase):
id: UUID id: UUID
"""对象UUID""" """对象UUID"""
new_name: str = Field(max_length=255) new_name: Str255
"""新名称""" """新名称"""
@@ -749,6 +777,9 @@ class ObjectPropertyResponse(SQLModelBase):
size: int size: int
"""文件大小(字节)""" """文件大小(字节)"""
mime_type: str | None = None
"""MIME类型"""
created_at: datetime created_at: datetime
"""创建时间""" """创建时间"""
@@ -762,22 +793,13 @@ class ObjectPropertyResponse(SQLModelBase):
class ObjectPropertyDetailResponse(ObjectPropertyResponse): class ObjectPropertyDetailResponse(ObjectPropertyResponse):
"""对象详细属性响应 DTO继承基本属性""" """对象详细属性响应 DTO继承基本属性"""
# 元数据信息 # 校验和(从 PhysicalFile 读取)
mime_type: str | None = None
"""MIME类型"""
width: int | None = None
"""图片宽度(像素)"""
height: int | None = None
"""图片高度(像素)"""
duration: float | None = None
"""音视频时长(秒)"""
checksum_md5: str | None = None checksum_md5: str | None = None
"""MD5校验和""" """MD5校验和"""
checksum_sha256: str | None = None
"""SHA256校验和"""
# 分享统计 # 分享统计
share_count: int = 0 share_count: int = 0
"""分享次数""" """分享次数"""
@@ -795,6 +817,10 @@ class ObjectPropertyDetailResponse(ObjectPropertyResponse):
reference_count: int = 1 reference_count: int = 1
"""物理文件引用计数(仅文件有效)""" """物理文件引用计数(仅文件有效)"""
# 元数据KV 格式)
metadatas: dict[str, str] = {}
"""所有元数据条目(键名 → 值)"""
# ==================== 管理员文件管理 DTO ==================== # ==================== 管理员文件管理 DTO ====================

View File

@@ -0,0 +1,127 @@
"""
对象元数据 KV 模型
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
如 exif:width, stream:duration, music:artist 等。
架构:
ObjectMetadata (KV 表,与 Object 一对多关系)
└── 每个 Object 可以有多条元数据记录
└── (object_id, name) 组合唯一索引
命名空间:
- exif: 图片 EXIF 信息(尺寸、相机参数、拍摄时间等)
- stream: 音视频流信息(时长、比特率、视频尺寸、编解码等)
- music: 音乐标签(标题、艺术家、专辑等)
- geo: 地理位置(经纬度、地址)
- apk: Android 安装包信息
- custom: 用户自定义属性
- sys: 系统内部元数据
- thumb: 缩略图信息
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, UniqueConstraint, Index, Relationship
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
if TYPE_CHECKING:
from .object import Object
# ==================== 枚举 ====================
class MetadataNamespace(StrEnum):
"""元数据命名空间枚举"""
EXIF = "exif"
"""图片 EXIF 信息(含尺寸、相机参数、拍摄时间等)"""
MUSIC = "music"
"""音乐标签title/artist/album/genre 等)"""
STREAM = "stream"
"""音视频流信息codec/duration/bitrate/resolution 等)"""
GEO = "geo"
"""地理位置latitude/longitude/address"""
APK = "apk"
"""Android 安装包信息package_name/version 等)"""
THUMB = "thumb"
"""缩略图信息(内部使用)"""
SYS = "sys"
"""系统元数据(内部使用)"""
CUSTOM = "custom"
"""用户自定义属性"""
# 对外不可见的命名空间API 不返回给普通用户)
INTERNAL_NAMESPACES: set[str] = {MetadataNamespace.SYS, MetadataNamespace.THUMB}
# 用户可写的命名空间
USER_WRITABLE_NAMESPACES: set[str] = {MetadataNamespace.CUSTOM}
# ==================== Base 模型 ====================
class ObjectMetadataBase(SQLModelBase):
"""对象元数据 KV 基础模型"""
name: Str255
"""元数据键名格式namespace:key如 exif:width, stream:duration"""
value: str
"""元数据值(统一为字符串存储)"""
# ==================== 数据库模型 ====================
class ObjectMetadata(ObjectMetadataBase, UUIDTableBaseMixin):
"""
对象元数据 KV 模型
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
每个对象的每个键名唯一(通过唯一索引保证)。
"""
__table_args__ = (
UniqueConstraint("object_id", "name", name="uq_object_metadata_object_name"),
Index("ix_object_metadata_object_id", "object_id"),
)
object_id: UUID = Field(
foreign_key="object.id",
ondelete="CASCADE",
)
"""关联的对象UUID"""
is_public: bool = False
"""是否对分享页面公开"""
# 关系
object: "Object" = Relationship(back_populates="metadata_entries")
"""关联的对象"""
# ==================== DTO 模型 ====================
class MetadataResponse(SQLModelBase):
"""元数据查询响应 DTO"""
metadatas: dict[str, str]
"""元数据字典(键名 → 值)"""
class MetadataPatchItem(SQLModelBase):
"""单条元数据补丁 DTO"""
key: Str255
"""元数据键名"""
value: str | None = None
"""None 表示删除此条目"""
class MetadataPatchRequest(SQLModelBase):
"""元数据批量更新请求 DTO"""
patches: list[MetadataPatchItem]
"""补丁列表"""

View File

@@ -1,54 +1,118 @@
from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlalchemy import Numeric
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .product import Product
from .user import User from .user import User
class OrderStatus(StrEnum): class OrderStatus(StrEnum):
"""订单状态枚举""" """订单状态枚举"""
PENDING = "pending" PENDING = "pending"
"""待支付""" """待支付"""
COMPLETED = "completed" COMPLETED = "completed"
"""已完成""" """已完成"""
CANCELLED = "cancelled" CANCELLED = "cancelled"
"""已取消""" """已取消"""
class OrderType(StrEnum): class OrderType(StrEnum):
"""订单类型枚举""" """订单类型枚举"""
# [TODO] 补充具体订单类型
pass
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
# ==================== DTO 模型 ====================
class CreateOrderRequest(SQLModelBase):
"""创建订单请求 DTO"""
product_id: UUID
"""商品UUID"""
num: int = Field(default=1, ge=1)
"""购买数量"""
method: str
"""支付方式"""
class OrderResponse(SQLModelBase):
"""订单响应 DTO"""
id: int
"""订单ID"""
order_no: str
"""订单号"""
type: OrderType
"""订单类型"""
method: str | None = None
"""支付方式"""
product_id: UUID | None = None
"""商品UUID"""
num: int
"""购买数量"""
name: str
"""商品名称"""
price: float
"""订单价格(元)"""
status: OrderStatus
"""订单状态"""
user_id: UUID
"""所属用户UUID"""
# ==================== 数据库模型 ====================
class Order(SQLModelBase, TableBaseMixin): class Order(SQLModelBase, TableBaseMixin):
"""订单模型""" """订单模型"""
order_no: str = Field(max_length=255, unique=True, index=True) order_no: Str255 = Field(unique=True, index=True)
"""订单号,唯一""" """订单号,唯一"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) type: OrderType
"""订单类型 [TODO] 待定义枚举""" """订单类型"""
method: str | None = Field(default=None, max_length=255) method: Str255 | None = None
"""支付方式""" """支付方式"""
product_id: int | None = Field(default=None) product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
"""商品ID""" """关联商品UUID"""
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}) num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""购买数量""" """购买数量"""
name: str = Field(max_length=255) name: Str255
"""商品名称""" """商品名称"""
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
"""订单价格(""" """订单价格("""
status: OrderStatus = Field(default=OrderStatus.PENDING) status: OrderStatus = Field(default=OrderStatus.PENDING)
"""订单状态""" """订单状态"""
@@ -63,3 +127,19 @@ class Order(SQLModelBase, TableBaseMixin):
# 关系 # 关系
user: "User" = Relationship(back_populates="orders") user: "User" = Relationship(back_populates="orders")
product: "Product" = Relationship(back_populates="orders")
def to_response(self) -> OrderResponse:
"""转换为响应 DTO"""
return OrderResponse(
id=self.id,
order_no=self.order_no,
type=self.type,
method=self.method,
product_id=self.product_id,
num=self.num,
name=self.name,
price=float(self.price),
status=self.status,
user_id=self.user_id,
)

View File

@@ -15,7 +15,7 @@ from uuid import UUID
from sqlalchemy import BigInteger from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, Index from sqlmodel import Field, Relationship, Index
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str32, Str64
if TYPE_CHECKING: if TYPE_CHECKING:
from .object import Object from .object import Object
@@ -31,9 +31,12 @@ class PhysicalFileBase(SQLModelBase):
size: int = Field(default=0, sa_type=BigInteger) size: int = Field(default=0, sa_type=BigInteger)
"""文件大小(字节)""" """文件大小(字节)"""
checksum_md5: str | None = Field(default=None, max_length=32) checksum_md5: Str32 | None = None
"""MD5校验和用于文件去重和完整性校验""" """MD5校验和用于文件去重和完整性校验"""
checksum_sha256: Str64 | None = None
"""SHA256校验和"""
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin): class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
""" """

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from enum import StrEnum from enum import StrEnum
from sqlmodel import Field, Relationship, text from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .object import Object from .object import Object
@@ -37,22 +37,22 @@ class PolicyType(StrEnum):
class PolicyBase(SQLModelBase): class PolicyBase(SQLModelBase):
"""存储策略基础字段,供 DTO 和数据库模型共享""" """存储策略基础字段,供 DTO 和数据库模型共享"""
name: str = Field(max_length=255) name: Str255
"""策略名称""" """策略名称"""
type: PolicyType type: PolicyType
"""存储策略类型""" """存储策略类型"""
server: str | None = Field(default=None, max_length=255) server: Str255 | None = None
"""服务器地址(本地策略为绝对路径)""" """服务器地址(本地策略为绝对路径)"""
bucket_name: str | None = Field(default=None, max_length=255) bucket_name: Str255 | None = None
"""存储桶名称""" """存储桶名称"""
is_private: bool = True is_private: bool = True
"""是否为私有空间""" """是否为私有空间"""
base_url: str | None = Field(default=None, max_length=255) base_url: Str255 | None = None
"""访问文件的基础URL""" """访问文件的基础URL"""
access_key: str | None = None access_key: str | None = None
@@ -67,10 +67,10 @@ class PolicyBase(SQLModelBase):
auto_rename: bool = False auto_rename: bool = False
"""是否自动重命名""" """是否自动重命名"""
dir_name_rule: str | None = Field(default=None, max_length=255) dir_name_rule: Str255 | None = None
"""目录命名规则""" """目录命名规则"""
file_name_rule: str | None = Field(default=None, max_length=255) file_name_rule: Str255 | None = None
"""文件命名规则""" """文件命名规则"""
is_origin_link_enable: bool = False is_origin_link_enable: bool = False
@@ -102,6 +102,94 @@ class PolicySummary(SQLModelBase):
"""是否私有""" """是否私有"""
class PolicyCreateRequest(PolicyBase):
"""创建存储策略请求 DTO包含 PolicyOptions 扁平字段"""
# PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: Str255 | None = None
"""OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, ge=1)
"""分片上传大小字节默认50MB"""
s3_path_style: bool = False
"""是否使用S3路径风格"""
s3_region: Str64 = 'us-east-1'
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyUpdateRequest(SQLModelBase):
"""更新存储策略请求 DTO所有字段可选"""
name: Str255 | None = None
"""策略名称"""
server: Str255 | None = None
"""服务器地址"""
bucket_name: Str255 | None = None
"""存储桶名称"""
is_private: bool | None = None
"""是否为私有空间"""
base_url: Str255 | None = None
"""访问文件的基础URL"""
access_key: str | None = None
"""Access Key"""
secret_key: str | None = None
"""Secret Key"""
max_size: int | None = Field(default=None, ge=0)
"""允许上传的最大文件尺寸(字节)"""
auto_rename: bool | None = None
"""是否自动重命名"""
dir_name_rule: Str255 | None = None
"""目录命名规则"""
file_name_rule: Str255 | None = None
"""文件命名规则"""
is_origin_link_enable: bool | None = None
"""是否开启源链接访问"""
# PolicyOptions 字段
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: Str255 | None = None
"""OneDrive重定向地址"""
chunk_size: int | None = Field(default=None, ge=1)
"""分片上传大小(字节)"""
s3_path_style: bool | None = None
"""是否使用S3路径风格"""
s3_region: Str64 | None = None
"""S3 区域"""
# ==================== 数据库模型 ==================== # ==================== 数据库模型 ====================
@@ -117,7 +205,7 @@ class PolicyOptionsBase(SQLModelBase):
mimetype: str | None = Field(default=None, max_length=127) mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型""" """MIME类型"""
od_redirect: str | None = Field(default=None, max_length=255) od_redirect: Str255 | None = None
"""OneDrive重定向地址""" """OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"}) chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
@@ -126,6 +214,9 @@ class PolicyOptionsBase(SQLModelBase):
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否使用S3路径风格""" """是否使用S3路径风格"""
s3_region: Str64 = Field(default='us-east-1', sa_column_kwargs={"server_default": "'us-east-1'"})
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin): class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
"""存储策略选项模型与Policy一对一关联""" """存储策略选项模型与Policy一对一关联"""
@@ -146,7 +237,7 @@ class Policy(PolicyBase, UUIDTableBaseMixin):
"""存储策略模型""" """存储策略模型"""
# 覆盖基类字段以添加数据库专有配置 # 覆盖基类字段以添加数据库专有配置
name: str = Field(max_length=255, unique=True) name: Str255 = Field(unique=True)
"""策略名称""" """策略名称"""
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}) is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})

206
sqlmodels/product.py Normal file
View File

@@ -0,0 +1,206 @@
from decimal import Decimal
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import Numeric, BigInteger
from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
if TYPE_CHECKING:
from .order import Order
from .redeem import Redeem
class ProductType(StrEnum):
"""商品类型枚举"""
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
class PaymentMethod(StrEnum):
"""支付方式枚举"""
ALIPAY = "alipay"
"""支付宝"""
WECHAT = "wechat"
"""微信支付"""
STRIPE = "stripe"
"""Stripe"""
EASYPAY = "easypay"
"""易支付"""
CUSTOM = "custom"
"""自定义支付"""
# ==================== DTO 模型 ====================
class ProductBase(SQLModelBase):
"""商品基础字段"""
name: str
"""商品名称"""
type: ProductType
"""商品类型"""
description: str | None = None
"""商品描述"""
class ProductCreateRequest(ProductBase):
"""创建商品请求 DTO"""
name: Str255
"""商品名称"""
price: Decimal = Field(ge=0, decimal_places=2)
"""商品价格(元)"""
is_active: bool = True
"""是否上架"""
sort_order: int = Field(default=0, ge=0)
"""排序权重(越大越靠前)"""
# storage_pack 专用
size: int | None = Field(default=None, ge=0)
"""容量大小字节type=storage_pack 时必填"""
duration_days: int | None = Field(default=None, ge=1)
"""有效天数type=storage_pack/group_time 时必填"""
# group_time 专用
group_id: UUID | None = None
"""目标用户组UUIDtype=group_time 时必填"""
# score 专用
score_amount: int | None = Field(default=None, ge=1)
"""积分数量type=score 时必填"""
class ProductUpdateRequest(SQLModelBase):
"""更新商品请求 DTO所有字段可选"""
name: Str255 | None = None
"""商品名称"""
description: str | None = None
"""商品描述"""
price: Decimal | None = Field(default=None, ge=0, decimal_places=2)
"""商品价格(元)"""
is_active: bool | None = None
"""是否上架"""
sort_order: int | None = Field(default=None, ge=0)
"""排序权重"""
size: int | None = Field(default=None, ge=0)
"""容量大小(字节)"""
duration_days: int | None = Field(default=None, ge=1)
"""有效天数"""
group_id: UUID | None = None
"""目标用户组UUID"""
score_amount: int | None = Field(default=None, ge=1)
"""积分数量"""
class ProductResponse(ProductBase):
"""商品响应 DTO"""
id: UUID
"""商品UUID"""
price: float
"""商品价格(元)"""
is_active: bool
"""是否上架"""
sort_order: int
"""排序权重"""
size: int | None = None
"""容量大小(字节)"""
duration_days: int | None = None
"""有效天数"""
group_id: UUID | None = None
"""目标用户组UUID"""
score_amount: int | None = None
"""积分数量"""
# ==================== 数据库模型 ====================
class Product(ProductBase, UUIDTableBaseMixin):
"""商品模型"""
name: Str255
"""商品名称"""
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
"""商品价格(元)"""
is_active: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否上架"""
sort_order: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""排序权重(越大越靠前)"""
# storage_pack 专用
size: int | None = Field(default=None, sa_type=BigInteger)
"""容量大小字节type=storage_pack 时必填"""
duration_days: int | None = None
"""有效天数type=storage_pack/group_time 时必填"""
# group_time 专用
group_id: UUID | None = Field(default=None, foreign_key="group.id", ondelete="SET NULL")
"""目标用户组UUIDtype=group_time 时必填"""
# score 专用
score_amount: int | None = None
"""积分数量type=score 时必填"""
# 关系
orders: list["Order"] = Relationship(back_populates="product")
"""关联的订单列表"""
redeems: list["Redeem"] = Relationship(back_populates="product")
"""关联的兑换码列表"""
def to_response(self) -> ProductResponse:
"""转换为响应 DTO"""
return ProductResponse(
id=self.id,
name=self.name,
type=self.type,
description=self.description,
price=float(self.price),
is_active=self.is_active,
sort_order=self.sort_order,
size=self.size,
duration_days=self.duration_days,
group_id=self.group_id,
score_amount=self.score_amount,
)

View File

@@ -1,22 +1,141 @@
from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, text from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .product import Product
from .user import User
class RedeemType(StrEnum): class RedeemType(StrEnum):
"""兑换码类型枚举""" """兑换码类型枚举"""
# [TODO] 补充具体兑换码类型
pass
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
# ==================== DTO 模型 ====================
class RedeemCreateRequest(SQLModelBase):
"""批量生成兑换码请求 DTO"""
product_id: UUID
"""关联商品UUID"""
count: int = Field(default=1, ge=1, le=100)
"""生成数量"""
class RedeemUseRequest(SQLModelBase):
"""使用兑换码请求 DTO"""
code: str
"""兑换码"""
class RedeemInfoResponse(SQLModelBase):
"""兑换码信息响应 DTO用户侧"""
type: RedeemType
"""兑换码类型"""
product_name: str | None = None
"""关联商品名称"""
num: int
"""可兑换数量"""
is_used: bool
"""是否已使用"""
class RedeemAdminResponse(SQLModelBase):
"""兑换码管理响应 DTO管理侧"""
id: int
"""兑换码ID"""
type: RedeemType
"""兑换码类型"""
product_id: UUID | None = None
"""关联商品UUID"""
num: int
"""可兑换数量"""
code: str
"""兑换码"""
is_used: bool
"""是否已使用"""
used_at: datetime | None = None
"""使用时间"""
used_by: UUID | None = None
"""使用者UUID"""
# ==================== 数据库模型 ====================
class Redeem(SQLModelBase, TableBaseMixin): class Redeem(SQLModelBase, TableBaseMixin):
"""兑换码模型""" """兑换码模型"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) type: RedeemType
"""兑换码类型 [TODO] 待定义枚举""" """兑换码类型"""
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等") product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
code: str = Field(unique=True, index=True, description="兑换码,唯一") """关联商品UUID"""
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""可兑换数量/时长等"""
code: str = Field(unique=True, index=True)
"""兑换码,唯一"""
is_used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否已使用"""
used_at: datetime | None = None
"""使用时间"""
used_by: UUID | None = Field(default=None, foreign_key="user.id", ondelete="SET NULL")
"""使用者UUID"""
# 关系
product: "Product" = Relationship(back_populates="redeems")
user: "User" = Relationship(back_populates="redeems")
def to_admin_response(self) -> RedeemAdminResponse:
"""转换为管理侧响应 DTO"""
return RedeemAdminResponse(
id=self.id,
type=self.type,
product_id=self.product_id,
num=self.num,
code=self.code,
is_used=self.is_used,
used_at=self.used_at,
used_by=self.used_by,
)
def to_info_response(self, product_name: str | None = None) -> RedeemInfoResponse:
"""转换为用户侧响应 DTO"""
return RedeemInfoResponse(
type=self.type,
product_name=product_name,
num=self.num,
is_used=self.is_used,
)

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .share import Share from .share import Share
@@ -21,7 +21,7 @@ class Report(SQLModelBase, TableBaseMixin):
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""举报原因 [TODO] 待定义枚举""" """举报原因 [TODO] 待定义枚举"""
description: str | None = Field(default=None, max_length=255, description="补充描述") description: Str255 | None = Field(default=None, description="补充描述")
# 外键 # 外键
share_id: UUID = Field( share_id: UUID = Field(

View File

@@ -76,6 +76,9 @@ class SiteConfigResponse(SQLModelBase):
email_binding_required: bool = True email_binding_required: bool = True
"""是否强制绑定邮箱""" """是否强制绑定邮箱"""
avatar_max_size: int = 2097152
"""头像文件最大字节数(默认 2MB"""
footer_code: str | None = None footer_code: str | None = None
"""自定义页脚代码""" """自定义页脚代码"""
@@ -160,6 +163,7 @@ class SettingsType(StrEnum):
VERSION = "version" VERSION = "version"
VIEW = "view" VIEW = "view"
WOPI = "wopi" WOPI = "wopi"
FILE_CATEGORY = "file_category"
# 数据库模型 # 数据库模型
class Setting(SettingItem, TableBaseMixin): class Setting(SettingItem, TableBaseMixin):

View File

@@ -5,7 +5,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
from .model_base import ResponseBase from .model_base import ResponseBase
from .object import ObjectType from .object import ObjectType
@@ -52,10 +52,10 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
Index("ix_share_object", "object_id"), Index("ix_share_object", "object_id"),
) )
code: str = Field(max_length=64, nullable=False, index=True) code: Str64 = Field(nullable=False, index=True)
"""分享码""" """分享码"""
password: str | None = Field(default=None, max_length=255) password: Str255 | None = None
"""分享密码(加密后)""" """分享密码(加密后)"""
object_id: UUID = Field( object_id: UUID = Field(
@@ -80,7 +80,7 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}) preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否允许预览""" """是否允许预览"""
source_name: str | None = Field(default=None, max_length=255) source_name: Str255 | None = None
"""源名称(冗余字段,便于展示)""" """源名称(冗余字段,便于展示)"""
score: int = Field(default=0, ge=0) score: int = Field(default=0, ge=0)

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, Index from sqlmodel import Field, Relationship, Index
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .object import Object from .object import Object
@@ -17,7 +17,7 @@ class SourceLink(SQLModelBase, TableBaseMixin):
Index("ix_sourcelink_object_name", "object_id", "name"), Index("ix_sourcelink_object_name", "object_id", "name"),
) )
name: str = Field(max_length=255) name: Str255
"""链接名称""" """链接名称"""
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})

View File

@@ -1,22 +1,59 @@
from typing import TYPE_CHECKING
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlmodel import Field, Relationship, Column, func, DateTime from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
# ==================== DTO 模型 ====================
class StoragePackResponse(SQLModelBase):
"""容量包响应 DTO"""
id: int
"""容量包ID"""
name: str
"""容量包名称"""
size: int
"""容量大小(字节)"""
active_time: datetime | None = None
"""激活时间"""
expired_time: datetime | None = None
"""过期时间"""
product_id: UUID | None = None
"""来源商品UUID"""
# ==================== 数据库模型 ====================
class StoragePack(SQLModelBase, TableBaseMixin): class StoragePack(SQLModelBase, TableBaseMixin):
"""容量包模型""" """容量包模型"""
name: str = Field(max_length=255, description="容量包名称") name: Str255
active_time: datetime | None = Field(default=None, description="激活时间") """容量包名称"""
expired_time: datetime | None = Field(default=None, index=True, description="过期时间")
size: int = Field(description="容量包大小(字节)") active_time: datetime | None = None
"""激活时间"""
expired_time: datetime | None = Field(default=None, index=True)
"""过期时间"""
size: int = Field(sa_type=BigInteger)
"""容量包大小(字节)"""
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
"""来源商品UUID"""
# 外键 # 外键
user_id: UUID = Field( user_id: UUID = Field(
@@ -28,3 +65,14 @@ class StoragePack(SQLModelBase, TableBaseMixin):
# 关系 # 关系
user: "User" = Relationship(back_populates="storage_packs") user: "User" = Relationship(back_populates="storage_packs")
def to_response(self) -> StoragePackResponse:
"""转换为响应 DTO"""
return StoragePackResponse(
id=self.id,
name=self.name,
size=self.size,
active_time=self.active_time,
expired_time=self.expired_time,
product_id=self.product_id,
)

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -24,13 +24,13 @@ class Tag(SQLModelBase, TableBaseMixin):
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),) __table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
name: str = Field(max_length=255) name: Str255
"""标签名称""" """标签名称"""
icon: str | None = Field(default=None, max_length=255) icon: Str255 | None = None
"""标签图标""" """标签图标"""
color: str | None = Field(default=None, max_length=255) color: Str255 | None = None
"""标签颜色""" """标签颜色"""
type: TagType = Field(default=TagType.MANUAL) type: TagType = Field(default=TagType.MANUAL)

View File

@@ -26,8 +26,8 @@ class TaskStatus(StrEnum):
class TaskType(StrEnum): class TaskType(StrEnum):
"""任务类型枚举""" """任务类型枚举"""
# [TODO] 补充具体任务类型 POLICY_MIGRATE = "policy_migrate"
pass """存储策略迁移"""
# ==================== DTO 模型 ==================== # ==================== DTO 模型 ====================
@@ -39,7 +39,7 @@ class TaskSummaryBase(SQLModelBase):
id: int id: int
"""任务ID""" """任务ID"""
type: int type: TaskType
"""任务类型""" """任务类型"""
status: TaskStatus status: TaskStatus
@@ -91,7 +91,14 @@ class TaskPropsBase(SQLModelBase):
file_ids: str | None = None file_ids: str | None = None
"""文件ID列表逗号分隔""" """文件ID列表逗号分隔"""
# [TODO] 根据业务需求补充更多字段 source_policy_id: UUID | None = None
"""源存储策略UUID"""
dest_policy_id: UUID | None = None
"""目标存储策略UUID"""
object_id: UUID | None = None
"""关联的对象UUID"""
class TaskProps(TaskPropsBase, TableBaseMixin): class TaskProps(TaskPropsBase, TableBaseMixin):
@@ -99,7 +106,7 @@ class TaskProps(TaskPropsBase, TableBaseMixin):
task_id: int = Field( task_id: int = Field(
foreign_key="task.id", foreign_key="task.id",
primary_key=True, unique=True,
ondelete="CASCADE" ondelete="CASCADE"
) )
"""关联的任务ID""" """关联的任务ID"""
@@ -121,8 +128,8 @@ class Task(SQLModelBase, TableBaseMixin):
status: TaskStatus = Field(default=TaskStatus.QUEUED) status: TaskStatus = Field(default=TaskStatus.QUEUED)
"""任务状态""" """任务状态"""
type: int = Field(default=0) type: TaskType
"""任务类型 [TODO] 待定义枚举""" """任务类型"""
progress: int = Field(default=0, ge=0, le=100) progress: int = Field(default=0, ge=0, le=100)
"""任务进度0-100""" """任务进度0-100"""

View File

@@ -3,7 +3,7 @@ from uuid import UUID
from sqlmodel import Field from sqlmodel import Field
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
from .color import ChromaticColor, NeutralColor, ThemeColorsBase from .color import ChromaticColor, NeutralColor, ThemeColorsBase
@@ -11,7 +11,7 @@ from .color import ChromaticColor, NeutralColor, ThemeColorsBase
class ThemePresetBase(SQLModelBase): class ThemePresetBase(SQLModelBase):
"""主题预设基础字段""" """主题预设基础字段"""
name: str = Field(max_length=100) name: Str100
"""预设名称""" """预设名称"""
is_default: bool = False is_default: bool = False
@@ -42,7 +42,7 @@ class ThemePresetBase(SQLModelBase):
class ThemePreset(ThemePresetBase, UUIDTableBaseMixin): class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
"""主题预设表""" """主题预设表"""
name: str = Field(max_length=100, unique=True) name: Str100 = Field(unique=True)
"""预设名称(唯一约束)""" """预设名称(唯一约束)"""
@@ -51,7 +51,7 @@ class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
class ThemePresetCreateRequest(SQLModelBase): class ThemePresetCreateRequest(SQLModelBase):
"""创建主题预设请求 DTO""" """创建主题预设请求 DTO"""
name: str = Field(max_length=100) name: Str100
"""预设名称""" """预设名称"""
colors: ThemeColorsBase colors: ThemeColorsBase
@@ -61,7 +61,7 @@ class ThemePresetCreateRequest(SQLModelBase):
class ThemePresetUpdateRequest(SQLModelBase): class ThemePresetUpdateRequest(SQLModelBase):
"""更新主题预设请求 DTO""" """更新主题预设请求 DTO"""
name: str | None = Field(default=None, max_length=100) name: Str100 | None = None
"""预设名称(可选)""" """预设名称(可选)"""
colors: ThemeColorsBase | None = None colors: ThemeColorsBase | None = None

View File

@@ -4,12 +4,12 @@ from typing import Literal, TYPE_CHECKING, TypeVar
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import BinaryExpression, ClauseElement, and_ from sqlalchemy import BigInteger, BinaryExpression, ClauseElement, and_
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo from sqlmodel.main import RelationshipInfo
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse, Str255
from .auth_identity import AuthProviderType from .auth_identity import AuthProviderType
from .color import ChromaticColor, NeutralColor, ThemeColorsBase from .color import ChromaticColor, NeutralColor, ThemeColorsBase
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
from .download import Download from .download import Download
from .object import Object from .object import Object
from .order import Order from .order import Order
from .redeem import Redeem
from .share import Share from .share import Share
from .storage_pack import StoragePack from .storage_pack import StoragePack
from .tag import Tag from .tag import Tag
@@ -473,10 +474,10 @@ class User(UserBase, UUIDTableBaseMixin):
status: UserStatus = UserStatus.ACTIVE status: UserStatus = UserStatus.ACTIVE
"""用户状态""" """用户状态"""
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0) storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}, ge=0)
"""已用存储空间(字节)""" """已用存储空间(字节)"""
avatar: str = Field(default="default", max_length=255) avatar: Str255 = Field(default="default")
"""头像地址""" """头像地址"""
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0) score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
@@ -570,6 +571,14 @@ class User(UserBase, UUIDTableBaseMixin):
back_populates="user", back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"} sa_relationship_kwargs={"cascade": "all, delete-orphan"}
) )
redeems: list["Redeem"] = Relationship(
back_populates="user",
sa_relationship_kwargs={
"cascade": "all, delete-orphan",
"foreign_keys": "[Redeem.used_by]"
}
)
"""用户使用过的兑换码列表"""
shares: list["Share"] = Relationship( shares: list["Share"] = Relationship(
back_populates="user", back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"} sa_relationship_kwargs={"cascade": "all, delete-orphan"}

View File

@@ -5,7 +5,7 @@ from uuid import UUID
from sqlalchemy import Column, Text from sqlalchemy import Column, Text
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str32, Str100, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -51,7 +51,7 @@ class AuthnDetailResponse(SQLModelBase):
class AuthnRenameRequest(SQLModelBase): class AuthnRenameRequest(SQLModelBase):
"""WebAuthn 凭证重命名请求 DTO""" """WebAuthn 凭证重命名请求 DTO"""
name: str = Field(max_length=100) name: Str100
"""新的凭证名称""" """新的凭证名称"""
@@ -60,7 +60,7 @@ class AuthnRenameRequest(SQLModelBase):
class UserAuthn(SQLModelBase, TableBaseMixin): class UserAuthn(SQLModelBase, TableBaseMixin):
"""用户 WebAuthn 凭证模型,与 User 为多对一关系""" """用户 WebAuthn 凭证模型,与 User 为多对一关系"""
credential_id: str = Field(max_length=255, unique=True, index=True) credential_id: Str255 = Field(unique=True, index=True)
"""凭证 IDBase64URL 编码""" """凭证 IDBase64URL 编码"""
credential_public_key: str = Field(sa_column=Column(Text)) credential_public_key: str = Field(sa_column=Column(Text))
@@ -69,16 +69,16 @@ class UserAuthn(SQLModelBase, TableBaseMixin):
sign_count: int = Field(default=0, ge=0) sign_count: int = Field(default=0, ge=0)
"""签名计数器,用于防重放攻击""" """签名计数器,用于防重放攻击"""
credential_device_type: str = Field(max_length=32) credential_device_type: Str32
"""凭证设备类型:'single_device''multi_device'""" """凭证设备类型:'single_device''multi_device'"""
credential_backed_up: bool = Field(default=False) credential_backed_up: bool = Field(default=False)
"""凭证是否已备份""" """凭证是否已备份"""
transports: str | None = Field(default=None, max_length=255) transports: Str255 | None = None
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'""" """支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
name: str | None = Field(default=None, max_length=100) name: Str100 | None = None
"""用户自定义的凭证名称,便于识别""" """用户自定义的凭证名称,便于识别"""
# 外键 # 外键

View File

@@ -1,32 +1,117 @@
"""
WebDAV 账户模型
管理用户的 WebDAV 连接账户,每个账户对应一个挂载根路径。
通过 HTTP Basic Auth 认证访问 DAV 协议端点。
"""
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, TableBaseMixin from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
class WebDAV(SQLModelBase, TableBaseMixin):
"""WebDAV账户模型""" # ==================== Base 模型 ====================
class WebDAVBase(SQLModelBase):
"""WebDAV 账户基础字段"""
name: Str255
"""账户名称(同一用户下唯一)"""
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"})
"""挂载根目录路径"""
readonly: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
"""是否只读"""
use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
"""是否使用代理下载"""
# ==================== 数据库模型 ====================
class WebDAV(WebDAVBase, TableBaseMixin):
"""WebDAV 账户模型"""
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),) __table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
name: str = Field(max_length=255, description="WebDAV账户名") password: Str255
password: str = Field(max_length=255, description="WebDAV密码") """密码Argon2 哈希)"""
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
readonly: bool = Field(default=False, description="是否只读")
use_proxy: bool = Field(default=False, description="是否使用代理下载")
# 外键 # 外键
user_id: UUID = Field( user_id: UUID = Field(
foreign_key="user.id", foreign_key="user.id",
index=True, index=True,
ondelete="CASCADE" ondelete="CASCADE",
) )
"""所属用户UUID""" """所属用户UUID"""
# 关系 # 关系
user: "User" = Relationship(back_populates="webdavs") user: "User" = Relationship(back_populates="webdavs")
# ==================== DTO 模型 ====================
class WebDAVCreateRequest(SQLModelBase):
"""创建 WebDAV 账户请求"""
name: Str255
"""账户名称"""
password: Str255 = Field(min_length=1)
"""账户密码(明文,服务端哈希后存储)"""
root: str = "/"
"""挂载根目录路径"""
readonly: bool = False
"""是否只读"""
use_proxy: bool = False
"""是否使用代理下载"""
class WebDAVUpdateRequest(SQLModelBase):
"""更新 WebDAV 账户请求"""
password: Str255 | None = Field(default=None, min_length=1)
"""新密码(为 None 时不修改)"""
root: str | None = None
"""新挂载根目录路径(为 None 时不修改)"""
readonly: bool | None = None
"""是否只读(为 None 时不修改)"""
use_proxy: bool | None = None
"""是否使用代理下载(为 None 时不修改)"""
class WebDAVAccountResponse(SQLModelBase):
"""WebDAV 账户响应"""
id: int
"""账户ID"""
name: str
"""账户名称"""
root: str
"""挂载根目录路径"""
readonly: bool
"""是否只读"""
use_proxy: bool
"""是否使用代理下载"""
created_at: str
"""创建时间"""
updated_at: str
"""更新时间"""

View File

@@ -92,9 +92,9 @@ class ObjectFactory:
owner_id=owner_id, owner_id=owner_id,
policy_id=policy_id, policy_id=policy_id,
size=size, size=size,
mime_type=kwargs.get("mime_type"),
source_name=kwargs.get("source_name", name), source_name=kwargs.get("source_name", name),
upload_session_id=kwargs.get("upload_session_id"), upload_session_id=kwargs.get("upload_session_id"),
file_metadata=kwargs.get("file_metadata"),
password=kwargs.get("password"), password=kwargs.get("password"),
) )

View File

@@ -71,7 +71,7 @@ class UserFactory:
is_verified=True, is_verified=True,
user_id=user.id, user_id=user.id,
) )
await identity.save(session) identity = await identity.save(session)
return user return user
@@ -123,7 +123,7 @@ class UserFactory:
is_verified=True, is_verified=True,
user_id=admin.id, user_id=admin.id,
) )
await identity.save(session) identity = await identity.save(session)
return admin return admin
@@ -170,7 +170,7 @@ class UserFactory:
is_verified=True, is_verified=True,
user_id=banned_user.id, user_id=banned_user.id,
) )
await identity.save(session) identity = await identity.save(session)
return banned_user return banned_user
@@ -219,6 +219,6 @@ class UserFactory:
is_verified=True, is_verified=True,
user_id=user.id, user_id=user.id,
) )
await identity.save(session) identity = await identity.save(session)
return user return user

View File

@@ -0,0 +1,219 @@
"""
自定义属性定义端点集成测试
"""
import pytest
from httpx import AsyncClient
from uuid import UUID, uuid4
# ==================== 获取属性定义列表测试 ====================
@pytest.mark.asyncio
async def test_list_custom_properties_requires_auth(async_client: AsyncClient):
"""测试获取属性定义需要认证"""
response = await async_client.get("/api/v1/object/custom_property")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_list_custom_properties_empty(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试获取空的属性定义列表"""
response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data == []
# ==================== 创建属性定义测试 ====================
@pytest.mark.asyncio
async def test_create_custom_property(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建自定义属性"""
response = await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={
"name": "评分",
"type": "rating",
"icon": "mdi:star",
},
)
assert response.status_code == 204
# 验证已创建
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
data = list_response.json()
assert len(data) == 1
assert data[0]["name"] == "评分"
assert data[0]["type"] == "rating"
assert data[0]["icon"] == "mdi:star"
@pytest.mark.asyncio
async def test_create_custom_property_with_options(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建带选项的自定义属性"""
response = await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={
"name": "分类",
"type": "select",
"options": ["工作", "个人", "归档"],
"default_value": "个人",
},
)
assert response.status_code == 204
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
data = list_response.json()
prop = next(p for p in data if p["name"] == "分类")
assert prop["type"] == "select"
assert prop["options"] == ["工作", "个人", "归档"]
assert prop["default_value"] == "个人"
@pytest.mark.asyncio
async def test_create_custom_property_duplicate_name(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建同名属性返回 409"""
# 先创建
await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "标签", "type": "text"},
)
# 再创建同名
response = await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "标签", "type": "text"},
)
assert response.status_code == 409
# ==================== 更新属性定义测试 ====================
@pytest.mark.asyncio
async def test_update_custom_property(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新自定义属性"""
# 先创建
await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "备注", "type": "text"},
)
# 获取 ID
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "备注")
# 更新
response = await async_client.patch(
f"/api/v1/object/custom_property/{prop_id}",
headers=auth_headers,
json={"name": "详细备注", "icon": "mdi:note"},
)
assert response.status_code == 204
# 验证已更新
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop = next(p for p in list_response.json() if p["id"] == prop_id)
assert prop["name"] == "详细备注"
assert prop["icon"] == "mdi:note"
@pytest.mark.asyncio
async def test_update_custom_property_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新不存在的属性返回 404"""
fake_id = str(uuid4())
response = await async_client.patch(
f"/api/v1/object/custom_property/{fake_id}",
headers=auth_headers,
json={"name": "不存在"},
)
assert response.status_code == 404
# ==================== 删除属性定义测试 ====================
@pytest.mark.asyncio
async def test_delete_custom_property(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试删除自定义属性"""
# 先创建
await async_client.post(
"/api/v1/object/custom_property",
headers=auth_headers,
json={"name": "待删除", "type": "text"},
)
# 获取 ID
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "待删除")
# 删除
response = await async_client.delete(
f"/api/v1/object/custom_property/{prop_id}",
headers=auth_headers,
)
assert response.status_code == 204
# 验证已删除
list_response = await async_client.get(
"/api/v1/object/custom_property",
headers=auth_headers,
)
prop_names = [p["name"] for p in list_response.json()]
assert "待删除" not in prop_names
@pytest.mark.asyncio
async def test_delete_custom_property_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试删除不存在的属性返回 404"""
fake_id = str(uuid4())
response = await async_client.delete(
f"/api/v1/object/custom_property/{fake_id}",
headers=auth_headers,
)
assert response.status_code == 404

View File

@@ -0,0 +1,239 @@
"""
对象元数据端点集成测试
"""
import pytest
from httpx import AsyncClient
from uuid import UUID, uuid4
from sqlmodels import ObjectMetadata
# ==================== 获取元数据测试 ====================
@pytest.mark.asyncio
async def test_get_metadata_requires_auth(async_client: AsyncClient):
"""测试获取元数据需要认证"""
fake_id = str(uuid4())
response = await async_client.get(f"/api/v1/object/{fake_id}/metadata")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_metadata_empty(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试获取无元数据的对象"""
file_id = test_directory_structure["file_id"]
response = await async_client.get(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["metadatas"] == {}
@pytest.mark.asyncio
async def test_get_metadata_with_entries(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
initialized_db,
):
"""测试获取有元数据的对象"""
file_id = test_directory_structure["file_id"]
# 直接写入元数据
entries = [
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
ObjectMetadata(object_id=file_id, name="exif:height", value="1080", is_public=True),
ObjectMetadata(object_id=file_id, name="sys:extract_status", value="done", is_public=False),
]
for entry in entries:
initialized_db.add(entry)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
# sys: 命名空间应被过滤
assert "exif:width" in data["metadatas"]
assert "exif:height" in data["metadatas"]
assert "sys:extract_status" not in data["metadatas"]
assert data["metadatas"]["exif:width"] == "1920"
@pytest.mark.asyncio
async def test_get_metadata_ns_filter(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
initialized_db,
):
"""测试按命名空间过滤元数据"""
file_id = test_directory_structure["file_id"]
entries = [
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
ObjectMetadata(object_id=file_id, name="music:title", value="Test Song", is_public=True),
]
for entry in entries:
initialized_db.add(entry)
await initialized_db.commit()
# 只获取 exif 命名空间
response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=exif",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "exif:width" in data["metadatas"]
assert "music:title" not in data["metadatas"]
@pytest.mark.asyncio
async def test_get_metadata_nonexistent_object(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试获取不存在对象的元数据"""
fake_id = str(uuid4())
response = await async_client.get(
f"/api/v1/object/{fake_id}/metadata",
headers=auth_headers,
)
assert response.status_code == 404
# ==================== 更新元数据测试 ====================
@pytest.mark.asyncio
async def test_patch_metadata_requires_auth(async_client: AsyncClient):
"""测试更新元数据需要认证"""
fake_id = str(uuid4())
response = await async_client.patch(
f"/api/v1/object/{fake_id}/metadata",
json={"patches": [{"key": "custom:tag", "value": "test"}]},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_patch_metadata_set_custom(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试设置自定义元数据"""
file_id = test_directory_structure["file_id"]
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={
"patches": [
{"key": "custom:tag1", "value": "旅游"},
{"key": "custom:tag2", "value": "风景"},
]
},
)
assert response.status_code == 204
# 验证已写入
get_response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=custom",
headers=auth_headers,
)
assert get_response.status_code == 200
data = get_response.json()
assert data["metadatas"]["custom:tag1"] == "旅游"
assert data["metadatas"]["custom:tag2"] == "风景"
@pytest.mark.asyncio
async def test_patch_metadata_update_existing(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试更新已有的元数据"""
file_id = test_directory_structure["file_id"]
# 先创建
await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:note", "value": "旧值"}]},
)
# 再更新
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:note", "value": "新值"}]},
)
assert response.status_code == 204
# 验证已更新
get_response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=custom",
headers=auth_headers,
)
data = get_response.json()
assert data["metadatas"]["custom:note"] == "新值"
@pytest.mark.asyncio
async def test_patch_metadata_delete(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试删除元数据条目"""
file_id = test_directory_structure["file_id"]
# 先创建
await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:to_delete", "value": "temp"}]},
)
# 删除value 为 null
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "custom:to_delete", "value": None}]},
)
assert response.status_code == 204
# 验证已删除
get_response = await async_client.get(
f"/api/v1/object/{file_id}/metadata?ns=custom",
headers=auth_headers,
)
data = get_response.json()
assert "custom:to_delete" not in data["metadatas"]
@pytest.mark.asyncio
async def test_patch_metadata_reject_non_custom_namespace(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试拒绝修改非 custom: 命名空间"""
file_id = test_directory_structure["file_id"]
response = await async_client.patch(
f"/api/v1/object/{file_id}/metadata",
headers=auth_headers,
json={"patches": [{"key": "exif:width", "value": "1920"}]},
)
assert response.status_code == 400

View File

@@ -0,0 +1,591 @@
"""
WebDAV 账户管理端点集成测试
"""
from uuid import UUID, uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, User
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import UserStatus
from utils import Password
from utils.JWT import create_access_token
API_PREFIX = "/api/v1/webdav"
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def no_webdav_headers(initialized_db: AsyncSession) -> dict[str, str]:
"""创建一个 WebDAV 被禁用的用户,返回其认证头"""
group = Group(
id=uuid4(),
name="无WebDAV用户组",
max_storage=1024 * 1024 * 1024,
share_enabled=True,
web_dav_enabled=False,
admin=False,
speed_limit=0,
)
initialized_db.add(group)
await initialized_db.commit()
await initialized_db.refresh(group)
group_options = GroupOptions(
group_id=group.id,
share_download=True,
share_free=False,
relocate=False,
source_batch=0,
select_node=False,
advance_delete=False,
)
initialized_db.add(group_options)
await initialized_db.commit()
await initialized_db.refresh(group_options)
user = User(
id=uuid4(),
email="nowebdav@test.local",
nickname="无WebDAV用户",
status=UserStatus.ACTIVE,
storage=0,
score=0,
group_id=group.id,
avatar="default",
)
initialized_db.add(user)
await initialized_db.commit()
await initialized_db.refresh(user)
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="nowebdav@test.local",
credential=Password.hash("nowebdav123"),
is_primary=True,
is_verified=True,
user_id=user.id,
)
initialized_db.add(identity)
from sqlmodels import Policy
policy = await Policy.get(initialized_db, Policy.name == "本地存储")
root = Object(
id=uuid4(),
name="/",
type=ObjectType.FOLDER,
owner_id=user.id,
parent_id=None,
policy_id=policy.id,
size=0,
)
initialized_db.add(root)
await initialized_db.commit()
group.options = group_options
group_claims = GroupClaims.from_group(group)
result = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
)
return {"Authorization": f"Bearer {result.access_token}"}
# ==================== 认证测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_requires_auth(async_client: AsyncClient):
"""测试获取账户列表需要认证"""
response = await async_client.get(f"{API_PREFIX}/accounts")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_create_account_requires_auth(async_client: AsyncClient):
"""测试创建账户需要认证"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
json={"name": "test", "password": "testpass"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_update_account_requires_auth(async_client: AsyncClient):
"""测试更新账户需要认证"""
response = await async_client.patch(
f"{API_PREFIX}/accounts/1",
json={"readonly": True},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_delete_account_requires_auth(async_client: AsyncClient):
"""测试删除账户需要认证"""
response = await async_client.delete(f"{API_PREFIX}/accounts/1")
assert response.status_code == 401
# ==================== WebDAV 禁用测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_webdav_disabled(
async_client: AsyncClient,
no_webdav_headers: dict[str, str],
):
"""测试 WebDAV 被禁用时返回 403"""
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=no_webdav_headers,
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_create_account_webdav_disabled(
async_client: AsyncClient,
no_webdav_headers: dict[str, str],
):
"""测试 WebDAV 被禁用时创建账户返回 403"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=no_webdav_headers,
json={"name": "test", "password": "testpass"},
)
assert response.status_code == 403
# ==================== 获取账户列表测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_empty(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试初始状态账户列表为空"""
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
assert response.json() == []
# ==================== 创建账户测试 ====================
@pytest.mark.asyncio
async def test_create_account_success(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试成功创建 WebDAV 账户"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "my-nas", "password": "secretpass"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "my-nas"
assert data["root"] == "/"
assert data["readonly"] is False
assert data["use_proxy"] is False
assert "id" in data
assert "created_at" in data
assert "updated_at" in data
@pytest.mark.asyncio
async def test_create_account_with_options(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建带选项的 WebDAV 账户"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={
"name": "readonly-nas",
"password": "secretpass",
"readonly": True,
"use_proxy": True,
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "readonly-nas"
assert data["readonly"] is True
assert data["use_proxy"] is True
@pytest.mark.asyncio
async def test_create_account_duplicate_name(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试重名账户返回 409"""
# 先创建一个
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "dup-test", "password": "pass1"},
)
assert response.status_code == 201
# 再创建同名的
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "dup-test", "password": "pass2"},
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_create_account_invalid_root(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试无效根目录路径返回 400"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={
"name": "bad-root",
"password": "secretpass",
"root": "/nonexistent/path",
},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_create_account_with_valid_subdir(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试使用有效的子目录作为根路径"""
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={
"name": "docs-only",
"password": "secretpass",
"root": "/docs",
},
)
assert response.status_code == 201
assert response.json()["root"] == "/docs"
# ==================== 列表包含已创建账户测试 ====================
@pytest.mark.asyncio
async def test_list_accounts_after_create(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试创建后列表中包含该账户"""
# 创建
await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "list-test", "password": "pass"},
)
# 列表
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
accounts = response.json()
assert len(accounts) == 1
assert accounts[0]["name"] == "list-test"
# ==================== 更新账户测试 ====================
@pytest.mark.asyncio
async def test_update_account_success(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试成功更新 WebDAV 账户"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "update-test", "password": "oldpass"},
)
account_id = create_resp.json()["id"]
# 更新
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"readonly": True},
)
assert response.status_code == 200
data = response.json()
assert data["readonly"] is True
assert data["name"] == "update-test"
@pytest.mark.asyncio
async def test_update_account_password(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新密码"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "pwd-test", "password": "oldpass"},
)
account_id = create_resp.json()["id"]
# 更新密码
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"password": "newpass123"},
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_update_account_root(
async_client: AsyncClient,
auth_headers: dict[str, str],
test_directory_structure: dict[str, UUID],
):
"""测试更新根目录路径"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "root-update", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 更新 root 到有效子目录
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"root": "/docs"},
)
assert response.status_code == 200
assert response.json()["root"] == "/docs"
@pytest.mark.asyncio
async def test_update_account_invalid_root(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新为无效根目录返回 400"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "bad-root-update", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 更新到无效路径
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"root": "/nonexistent"},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_update_account_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试更新不存在的账户返回 404"""
response = await async_client.patch(
f"{API_PREFIX}/accounts/99999",
headers=auth_headers,
json={"readonly": True},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_other_user_account(
async_client: AsyncClient,
auth_headers: dict[str, str],
admin_headers: dict[str, str],
):
"""测试更新其他用户的账户返回 404"""
# 管理员创建账户
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=admin_headers,
json={"name": "admin-account", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 普通用户尝试更新
response = await async_client.patch(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
json={"readonly": True},
)
assert response.status_code == 404
# ==================== 删除账户测试 ====================
@pytest.mark.asyncio
async def test_delete_account_success(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试成功删除 WebDAV 账户"""
# 创建
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "delete-test", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 删除
response = await async_client.delete(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
)
assert response.status_code == 204
# 确认列表中已不存在
list_resp = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert list_resp.status_code == 200
names = [a["name"] for a in list_resp.json()]
assert "delete-test" not in names
@pytest.mark.asyncio
async def test_delete_account_not_found(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试删除不存在的账户返回 404"""
response = await async_client.delete(
f"{API_PREFIX}/accounts/99999",
headers=auth_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_delete_other_user_account(
async_client: AsyncClient,
auth_headers: dict[str, str],
admin_headers: dict[str, str],
):
"""测试删除其他用户的账户返回 404"""
# 管理员创建账户
create_resp = await async_client.post(
f"{API_PREFIX}/accounts",
headers=admin_headers,
json={"name": "admin-del-test", "password": "pass"},
)
account_id = create_resp.json()["id"]
# 普通用户尝试删除
response = await async_client.delete(
f"{API_PREFIX}/accounts/{account_id}",
headers=auth_headers,
)
assert response.status_code == 404
# ==================== 多账户测试 ====================
@pytest.mark.asyncio
async def test_multiple_accounts(
async_client: AsyncClient,
auth_headers: dict[str, str],
):
"""测试同一用户可以创建多个账户"""
for name in ["account-1", "account-2", "account-3"]:
response = await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": name, "password": "pass"},
)
assert response.status_code == 201
# 列表应有3个
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
assert len(response.json()) == 3
# ==================== 用户隔离测试 ====================
@pytest.mark.asyncio
async def test_accounts_user_isolation(
async_client: AsyncClient,
auth_headers: dict[str, str],
admin_headers: dict[str, str],
):
"""测试不同用户的账户相互隔离"""
# 普通用户创建
await async_client.post(
f"{API_PREFIX}/accounts",
headers=auth_headers,
json={"name": "user-account", "password": "pass"},
)
# 管理员创建
await async_client.post(
f"{API_PREFIX}/accounts",
headers=admin_headers,
json={"name": "admin-account", "password": "pass"},
)
# 普通用户只看到自己的
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=auth_headers,
)
assert response.status_code == 200
accounts = response.json()
assert len(accounts) == 1
assert accounts[0]["name"] == "user-account"
# 管理员只看到自己的
response = await async_client.get(
f"{API_PREFIX}/accounts",
headers=admin_headers,
)
assert response.status_code == 200
accounts = response.json()
assert len(accounts) == 1
assert accounts[0]["name"] == "admin-account"

View File

@@ -23,6 +23,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
from main import app from main import app
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels.policy import GroupPolicyLink
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import UserStatus from sqlmodels.user import UserStatus
from utils import Password from utils import Password
@@ -108,6 +109,12 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"), Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"), Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"), Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
Setting(type=SettingsType.AVATAR, name="gravatar_server", value="https://www.gravatar.com/"),
Setting(type=SettingsType.AVATAR, name="avatar_size", value="2097152"),
Setting(type=SettingsType.AVATAR, name="avatar_size_l", value="200"),
Setting(type=SettingsType.AVATAR, name="avatar_size_m", value="130"),
Setting(type=SettingsType.AVATAR, name="avatar_size_s", value="50"),
Setting(type=SettingsType.PATH, name="avatar_path", value="avatar"),
] ]
for setting in settings: for setting in settings:
test_session.add(setting) test_session.add(setting)
@@ -156,7 +163,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
await test_session.refresh(admin_group) await test_session.refresh(admin_group)
await test_session.refresh(default_policy) await test_session.refresh(default_policy)
# 4. 创建用户组选项 # 4. 关联用户组与存储策略
test_session.add(GroupPolicyLink(group_id=default_group.id, policy_id=default_policy.id))
test_session.add(GroupPolicyLink(group_id=admin_group.id, policy_id=default_policy.id))
# 5. 创建用户组选项
default_group_options = GroupOptions( default_group_options = GroupOptions(
group_id=default_group.id, group_id=default_group.id,
share_download=True, share_download=True,

View File

@@ -37,6 +37,12 @@ async def load_secret_key() -> None:
if setting: if setting:
SECRET_KEY = setting.value SECRET_KEY = setting.value
if not SECRET_KEY:
raise RuntimeError(
"JWT SECRET_KEY 未配置,拒绝启动。"
"请在 Setting 表中添加 type='auth', name='secret_key' 的记录。"
)
def build_token_payload( def build_token_payload(
data: dict, data: dict,

View File

@@ -62,6 +62,10 @@ def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
"""Raises an HTTP 501 Not Implemented exception.""" """Raises an HTTP 501 Not Implemented exception."""
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail) raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail)
def raise_bad_gateway(detail: str | None = None) -> NoReturn:
"""Raises an HTTP 502 Bad Gateway exception."""
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=detail)
def raise_service_unavailable(detail: str | None = None) -> NoReturn: def raise_service_unavailable(detail: str | None = None) -> NoReturn:
"""Raises an HTTP 503 Service Unavailable exception.""" """Raises an HTTP 503 Service Unavailable exception."""
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail) raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail)

997
uv.lock generated

File diff suppressed because it is too large Load Diff