Compare commits
16 Commits
1ecc0fdc1c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 15b2efe52a | |||
| 6c96c43bea | |||
| 9185f26b83 | |||
| f4052d229a | |||
| bc2182720d | |||
| eddf38d316 | |||
| 03e768d232 | |||
| bcb0a9b322 | |||
| 743a2c9d65 | |||
| 3639a31163 | |||
| 7200df6d87 | |||
| 40b6a31c98 | |||
| 19837b4817 | |||
| b5d09009e3 | |||
| 0b521ae8ab | |||
| eac0766e79 |
@@ -5,7 +5,8 @@
|
||||
"Bash(findstr:*)",
|
||||
"Bash(find:*)",
|
||||
"Bash(yarn tsc:*)",
|
||||
"Bash(dir:*)"
|
||||
"Bash(dir:*)",
|
||||
"mcp__server-notify__notify"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,8 +1,6 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.so
|
||||
*.egg
|
||||
*.egg-info/
|
||||
@@ -79,3 +77,6 @@ statics/
|
||||
# 许可证密钥(保密)
|
||||
license_private.pem
|
||||
license.key
|
||||
|
||||
avatar/
|
||||
.dev/
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "ee"]
|
||||
path = ee
|
||||
url = https://git.yxqi.cn/Yuerchu/disknext-ee.git
|
||||
1
ee
Submodule
1
ee
Submodule
Submodule ee added at cc32d8db91
@@ -1,42 +0,0 @@
|
||||
"""
|
||||
DiskNext Enterprise Edition (EE) 模块
|
||||
|
||||
通过 ``try: from ee import init_ee`` 检测是否存在。
|
||||
CE 版本中此目录不存在,ImportError 被 main.py 捕获。
|
||||
"""
|
||||
from loguru import logger as l
|
||||
|
||||
from utils.conf import appmeta
|
||||
|
||||
_ee_initialized: bool = False
|
||||
|
||||
|
||||
def is_pro() -> bool:
|
||||
"""当前实例是否以 Pro 版本运行。"""
|
||||
return _ee_initialized
|
||||
|
||||
|
||||
async def init_ee() -> None:
|
||||
"""
|
||||
初始化企业版功能。
|
||||
|
||||
1. 加载并验证许可证
|
||||
2. 设置 appmeta.IsPro = True
|
||||
3. 标记 EE 已初始化
|
||||
|
||||
许可证无效或缺失时抛出异常,阻止应用启动。
|
||||
"""
|
||||
global _ee_initialized
|
||||
|
||||
from ee.service.license_service import load_and_validate_license
|
||||
|
||||
payload = await load_and_validate_license()
|
||||
|
||||
appmeta.IsPro = True
|
||||
_ee_initialized = True
|
||||
|
||||
l.info(
|
||||
f"Pro 版本已激活 — 域名: {payload.domain}, "
|
||||
f"过期: {payload.expires_at.isoformat()}, "
|
||||
f"功能: {payload.features}"
|
||||
)
|
||||
@@ -1,86 +0,0 @@
|
||||
"""
|
||||
RSA-PSS 许可证验签核心(编译为 .so 后公钥藏入二进制)
|
||||
|
||||
此文件只包含纯函数和常量,不包含 SQLModel 类。
|
||||
"""
|
||||
import base64
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import orjson
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
|
||||
_PUBLIC_KEY_PEM: bytes = b"""-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
|
||||
oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
|
||||
eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
|
||||
VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
|
||||
vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
|
||||
TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
|
||||
l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
|
||||
7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
|
||||
EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
|
||||
CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
|
||||
enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
|
||||
R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
|
||||
-----END PUBLIC KEY-----"""
|
||||
|
||||
|
||||
class LicenseError(Exception):
|
||||
"""许可证验证基础异常"""
|
||||
|
||||
|
||||
class LicenseExpiredError(LicenseError):
|
||||
"""许可证已过期"""
|
||||
|
||||
|
||||
def verify_license(raw: str) -> dict:
|
||||
"""
|
||||
验证许可证字符串并返回载荷字典。
|
||||
|
||||
:param raw: 格式为 ``base64(json_payload).base64(signature)``
|
||||
:returns: 解析后的载荷字典
|
||||
:raises LicenseError: 格式无效或签名验证失败
|
||||
:raises LicenseExpiredError: 许可证已过期
|
||||
"""
|
||||
parts = raw.strip().split(".")
|
||||
if len(parts) != 2:
|
||||
raise LicenseError("许可证格式无效:需要 payload.signature")
|
||||
|
||||
payload_b64, signature_b64 = parts
|
||||
|
||||
try:
|
||||
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
||||
signature = base64.urlsafe_b64decode(signature_b64)
|
||||
except Exception as exc:
|
||||
raise LicenseError(f"许可证 base64 解码失败: {exc}") from exc
|
||||
|
||||
public_key = serialization.load_pem_public_key(_PUBLIC_KEY_PEM)
|
||||
|
||||
try:
|
||||
public_key.verify( # type: ignore[union-attr]
|
||||
signature,
|
||||
payload_bytes,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise LicenseError(f"许可证签名验证失败: {exc}") from exc
|
||||
|
||||
data: dict = orjson.loads(payload_bytes)
|
||||
|
||||
expires_at_str: str | None = data.get('expires_at')
|
||||
if not expires_at_str:
|
||||
raise LicenseError("许可证缺少 expires_at 字段")
|
||||
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at < datetime.now(timezone.utc):
|
||||
raise LicenseExpiredError(
|
||||
f"许可证已过期: {expires_at.isoformat()}"
|
||||
)
|
||||
|
||||
return data
|
||||
@@ -1,6 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .pro import router as pro_router
|
||||
|
||||
ee_router = APIRouter()
|
||||
ee_router.include_router(pro_router)
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
Pro 版本状态端点
|
||||
|
||||
提供许可证状态查询,需管理员权限。
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from ee.service import LicensePayload
|
||||
from ee.service.license_service import get_cached_license
|
||||
from middleware.auth import admin_required
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
router = APIRouter(prefix="/pro")
|
||||
|
||||
|
||||
class ProStatusResponse(SQLModelBase):
|
||||
"""Pro 版本状态响应"""
|
||||
|
||||
is_active: bool
|
||||
"""许可证是否有效"""
|
||||
|
||||
domain: str
|
||||
"""授权域名"""
|
||||
|
||||
expires_at: datetime
|
||||
"""过期时间"""
|
||||
|
||||
max_users: int
|
||||
"""最大用户数(0 = 无限制)"""
|
||||
|
||||
features: list[str]
|
||||
"""已授权的功能列表"""
|
||||
|
||||
|
||||
@router.get(
|
||||
'/status',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def get_pro_status() -> ProStatusResponse:
|
||||
"""
|
||||
查询 Pro 版本许可证状态
|
||||
|
||||
认证:
|
||||
- JWT token in Authorization header
|
||||
- 需要管理员权限
|
||||
|
||||
响应:
|
||||
- ProStatusResponse: 当前许可证信息
|
||||
|
||||
错误处理:
|
||||
- HTTPException 401: 未授权
|
||||
- HTTPException 403: 非管理员
|
||||
- HTTPException 500: 许可证缓存异常
|
||||
"""
|
||||
payload: LicensePayload | None = get_cached_license()
|
||||
if not payload:
|
||||
# init_ee 成功后 payload 一定存在,此处做防御性编程
|
||||
from utils.http.http_exceptions import raise_internal_error
|
||||
raise_internal_error()
|
||||
|
||||
return ProStatusResponse(
|
||||
is_active=True,
|
||||
domain=payload.domain,
|
||||
expires_at=payload.expires_at,
|
||||
max_users=payload.max_users,
|
||||
features=payload.features,
|
||||
)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""
|
||||
EE 许可证服务模块
|
||||
|
||||
提供许可证加载、验证和缓存功能。
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class LicensePayload(SQLModelBase):
|
||||
"""许可证载荷(RSA 验签后的明文数据)"""
|
||||
|
||||
domain: str
|
||||
"""授权域名"""
|
||||
|
||||
expires_at: datetime
|
||||
"""过期时间(UTC)"""
|
||||
|
||||
max_users: int
|
||||
"""最大用户数(0 = 无限制)"""
|
||||
|
||||
features: list[str]
|
||||
"""已授权的功能列表"""
|
||||
|
||||
issued_at: datetime
|
||||
"""签发时间(UTC)"""
|
||||
|
||||
|
||||
from .license_service import get_cached_license, load_and_validate_license
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
许可证加载与缓存服务(编译为 .so)
|
||||
|
||||
从环境变量 LICENSE_KEY 或 license.key 文件加载许可证,
|
||||
调用 verify_license() 验证后缓存结果。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from ee.license import LicenseError, verify_license
|
||||
from ee.service import LicensePayload
|
||||
|
||||
_cached_payload: LicensePayload | None = None
|
||||
|
||||
|
||||
async def load_and_validate_license() -> LicensePayload:
|
||||
"""
|
||||
加载并验证许可证,成功后缓存。
|
||||
|
||||
加载优先级:
|
||||
1. 环境变量 ``LICENSE_KEY``
|
||||
2. 项目根目录 ``license.key`` 文件
|
||||
|
||||
:returns: 验证通过的 LicensePayload
|
||||
:raises LicenseError: 未找到许可证 / 验证失败 / 已过期
|
||||
"""
|
||||
global _cached_payload
|
||||
|
||||
raw: str | None = os.getenv("LICENSE_KEY")
|
||||
|
||||
if not raw:
|
||||
key_path = Path("license.key")
|
||||
if key_path.is_file():
|
||||
async with aiofiles.open(key_path, 'r') as f:
|
||||
raw = (await f.read()).strip()
|
||||
|
||||
if not raw:
|
||||
raise LicenseError("未找到许可证:请设置 LICENSE_KEY 环境变量或提供 license.key 文件")
|
||||
|
||||
data = verify_license(raw)
|
||||
_cached_payload = LicensePayload.model_validate(data)
|
||||
return _cached_payload
|
||||
|
||||
|
||||
def get_cached_license() -> LicensePayload | None:
|
||||
"""获取已缓存的许可证载荷(未加载时返回 None)。"""
|
||||
return _cached_payload
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
EE 版本数据库模型
|
||||
|
||||
后续 Pro 功能的 SQLModel 定义位置。
|
||||
"""
|
||||
31
main.py
31
main.py
@@ -5,7 +5,10 @@ from fastapi import FastAPI, Request
|
||||
from loguru import logger as l
|
||||
|
||||
from routers import router
|
||||
from routers.dav import dav_app
|
||||
from routers.dav.provider import EventLoopRef
|
||||
from service.redis import RedisManager
|
||||
from service.storage import S3StorageService
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.migration import migration
|
||||
from utils import JWT
|
||||
@@ -14,24 +17,26 @@ from utils.http.http_exceptions import raise_internal_error
|
||||
from utils.lifespan import lifespan
|
||||
|
||||
# 尝试加载企业版功能
|
||||
_has_ee: bool = False
|
||||
try:
|
||||
from ee import init_ee
|
||||
from ee.license import LicenseError
|
||||
from ee.routers import ee_router
|
||||
|
||||
async def _init_ee_and_routes() -> None:
|
||||
_has_ee = True
|
||||
|
||||
async def _init_ee() -> None:
|
||||
"""启动时验证许可证,路由由 license_valid_required 依赖保护"""
|
||||
try:
|
||||
await init_ee()
|
||||
except LicenseError as exc:
|
||||
l.critical(f"许可证验证失败: {exc}")
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
from ee.routers import ee_router
|
||||
from routers.api.v1 import router as v1_router
|
||||
v1_router.include_router(ee_router)
|
||||
|
||||
lifespan.add_startup(_init_ee_and_routes)
|
||||
except ImportError:
|
||||
l.info("以 Community 版本运行")
|
||||
lifespan.add_startup(_init_ee)
|
||||
except ImportError as exc:
|
||||
ee_router = None
|
||||
l.info(f"以 Community 版本运行 (原因: {exc})")
|
||||
|
||||
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
|
||||
"""前端静态文件目录(由 Docker 构建时复制)"""
|
||||
@@ -40,13 +45,18 @@ async def _init_db() -> None:
|
||||
"""初始化数据库连接引擎"""
|
||||
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
||||
|
||||
# 捕获事件循环引用(供 WSGI 线程桥接使用)
|
||||
lifespan.add_startup(EventLoopRef.capture)
|
||||
|
||||
# 添加初始化数据库启动项
|
||||
lifespan.add_startup(_init_db)
|
||||
lifespan.add_startup(migration)
|
||||
lifespan.add_startup(JWT.load_secret_key)
|
||||
lifespan.add_startup(RedisManager.connect)
|
||||
lifespan.add_startup(S3StorageService.initialize_session)
|
||||
|
||||
# 添加关闭项
|
||||
lifespan.add_shutdown(S3StorageService.close_session)
|
||||
lifespan.add_shutdown(DatabaseManager.close)
|
||||
lifespan.add_shutdown(RedisManager.disconnect)
|
||||
|
||||
@@ -87,6 +97,11 @@ async def handle_unexpected_exceptions(
|
||||
|
||||
# 挂载路由
|
||||
app.include_router(router)
|
||||
if _has_ee:
|
||||
app.include_router(ee_router, prefix="/api/v1")
|
||||
|
||||
# 挂载 WebDAV 协议端点(优先于 SPA catch-all)
|
||||
app.mount("/dav", dav_app)
|
||||
|
||||
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
|
||||
if STATICS_DIR.is_dir():
|
||||
|
||||
@@ -33,6 +33,8 @@ dependencies = [
|
||||
"uvicorn>=0.38.0",
|
||||
"webauthn>=2.7.0",
|
||||
"whatthepatch>=1.0.6",
|
||||
"wsgidav>=4.3.0",
|
||||
"a2wsgi>=1.10.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -5,6 +5,7 @@ from utils.conf import appmeta
|
||||
from .admin import admin_router
|
||||
|
||||
from .callback import callback_router
|
||||
from .category import category_router
|
||||
from .directory import directory_router
|
||||
from .download import download_router
|
||||
from .file import router as file_router
|
||||
@@ -14,7 +15,6 @@ from .trash import trash_router
|
||||
from .site import site_router
|
||||
from .slave import slave_router
|
||||
from .user import user_router
|
||||
from .vas import vas_router
|
||||
from .webdav import webdav_router
|
||||
|
||||
router = APIRouter(prefix="/v1")
|
||||
@@ -24,6 +24,7 @@ router = APIRouter(prefix="/v1")
|
||||
if appmeta.mode == "master":
|
||||
router.include_router(admin_router)
|
||||
router.include_router(callback_router)
|
||||
router.include_router(category_router)
|
||||
router.include_router(directory_router)
|
||||
router.include_router(download_router)
|
||||
router.include_router(file_router)
|
||||
@@ -32,7 +33,6 @@ if appmeta.mode == "master":
|
||||
router.include_router(site_router)
|
||||
router.include_router(trash_router)
|
||||
router.include_router(user_router)
|
||||
router.include_router(vas_router)
|
||||
router.include_router(webdav_router)
|
||||
elif appmeta.mode == "slave":
|
||||
router.include_router(slave_router)
|
||||
|
||||
@@ -16,6 +16,12 @@ from sqlmodels.setting import (
|
||||
from sqlmodels.setting import SettingsType
|
||||
from utils import http_exceptions
|
||||
from utils.conf import appmeta
|
||||
|
||||
try:
|
||||
from ee.service import get_cached_license
|
||||
except ImportError:
|
||||
get_cached_license = None
|
||||
|
||||
from .file import admin_file_router
|
||||
from .file_app import admin_file_app_router
|
||||
from .group import admin_group_router
|
||||
@@ -24,7 +30,6 @@ from .share import admin_share_router
|
||||
from .task import admin_task_router
|
||||
from .user import admin_user_router
|
||||
from .theme import admin_theme_router
|
||||
from .vas import admin_vas_router
|
||||
|
||||
|
||||
class Aria2TestRequest(SQLModelBase):
|
||||
@@ -50,7 +55,6 @@ admin_router.include_router(admin_policy_router)
|
||||
admin_router.include_router(admin_share_router)
|
||||
admin_router.include_router(admin_task_router)
|
||||
admin_router.include_router(admin_theme_router)
|
||||
admin_router.include_router(admin_vas_router)
|
||||
|
||||
# 离线下载 /api/admin/aria2
|
||||
admin_aria2_router = APIRouter(
|
||||
@@ -159,14 +163,24 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
||||
if site_url_setting and site_url_setting.value:
|
||||
site_urls.append(site_url_setting.value)
|
||||
|
||||
# 许可证信息(从设置读取或使用默认值)
|
||||
license_info = LicenseInfo(
|
||||
expired_at=now + timedelta(days=365),
|
||||
signed_at=now,
|
||||
root_domains=[],
|
||||
domains=[],
|
||||
vol_domains=[],
|
||||
)
|
||||
# 许可证信息(Pro 版本从缓存读取,CE 版本永不过期)
|
||||
if appmeta.IsPro and get_cached_license:
|
||||
payload = get_cached_license()
|
||||
license_info = LicenseInfo(
|
||||
expired_at=payload.expires_at,
|
||||
signed_at=payload.issued_at,
|
||||
root_domains=[],
|
||||
domains=[payload.domain],
|
||||
vol_domains=[],
|
||||
)
|
||||
else:
|
||||
license_info = LicenseInfo(
|
||||
expired_at=datetime.max,
|
||||
signed_at=now,
|
||||
root_domains=[],
|
||||
domains=[],
|
||||
vol_domains=[],
|
||||
)
|
||||
|
||||
# 版本信息
|
||||
version_info = VersionInfo(
|
||||
@@ -225,11 +239,11 @@ async def router_admin_update_settings(
|
||||
|
||||
if existing:
|
||||
existing.value = item.value
|
||||
await existing.save(session)
|
||||
existing = await existing.save(session)
|
||||
updated_count += 1
|
||||
else:
|
||||
new_setting = Setting(type=item.type, name=item.name, value=item.value)
|
||||
await new_setting.save(session)
|
||||
new_setting = await new_setting.save(session)
|
||||
created_count += 1
|
||||
|
||||
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")
|
||||
|
||||
@@ -54,7 +54,7 @@ async def _set_ban_recursive(
|
||||
obj.banned_by = None
|
||||
obj.ban_reason = None
|
||||
|
||||
await obj.save(session)
|
||||
obj = await obj.save(session)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -131,9 +131,7 @@ async def router_admin_preview_file(
|
||||
:param file_id: 文件UUID
|
||||
:return: 文件内容
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
file_obj = await Object.get_exist_one(session, file_id)
|
||||
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
@@ -182,9 +180,7 @@ async def router_admin_ban_file(
|
||||
:param claims: 当前管理员 JWT claims
|
||||
:return: 封禁结果
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
file_obj = await Object.get_exist_one(session, file_id)
|
||||
|
||||
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
||||
|
||||
@@ -212,9 +208,7 @@ async def router_admin_delete_file(
|
||||
:param delete_physical: 是否同时删除物理文件
|
||||
:return: 删除结果
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
file_obj = await Object.get_exist_one(session, file_id)
|
||||
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""
|
||||
管理员文件应用管理端点
|
||||
|
||||
提供文件查看器应用的 CRUD、扩展名管理和用户组权限管理。
|
||||
提供文件查看器应用的 CRUD、扩展名管理、用户组权限管理和 WOPI Discovery。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import select
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from service.wopi import parse_wopi_discovery_xml
|
||||
from sqlmodels import (
|
||||
FileApp,
|
||||
FileAppCreateRequest,
|
||||
@@ -21,7 +23,10 @@ from sqlmodels import (
|
||||
FileAppUpdateRequest,
|
||||
ExtensionUpdateRequest,
|
||||
GroupAccessUpdateRequest,
|
||||
WopiDiscoveredExtension,
|
||||
WopiDiscoveryResponse,
|
||||
)
|
||||
from sqlmodels.file_app import FileAppType
|
||||
from utils import http_exceptions
|
||||
|
||||
admin_file_app_router = APIRouter(
|
||||
@@ -123,6 +128,7 @@ async def create_file_app(
|
||||
group_links.append(link)
|
||||
if group_links:
|
||||
await session.commit()
|
||||
await session.refresh(app)
|
||||
|
||||
l.info(f"创建文件应用: {app.name} ({app.app_key})")
|
||||
|
||||
@@ -145,9 +151,7 @@ async def get_file_app(
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
@@ -180,9 +184,7 @@ async def update_file_app(
|
||||
- 404: 应用不存在
|
||||
- 409: 新 app_key 已被其他应用使用
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
# 检查 app_key 唯一性
|
||||
if request.app_key is not None and request.app_key != app.app_key:
|
||||
@@ -229,9 +231,7 @@ async def delete_file_app(
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
app_name = app.app_key
|
||||
await FileApp.delete(session, app)
|
||||
@@ -257,20 +257,24 @@ async def update_extensions(
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
# 删除旧的扩展名
|
||||
# 保留旧扩展名的 wopi_action_url(Discovery 填充的值)
|
||||
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
old_url_map: dict[str, str] = {
|
||||
ext.extension: ext.wopi_action_url
|
||||
for ext in old_extensions
|
||||
if ext.wopi_action_url
|
||||
}
|
||||
for old_ext in old_extensions:
|
||||
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||
await session.flush()
|
||||
|
||||
# 创建新的扩展名
|
||||
# 创建新的扩展名(保留已有的 wopi_action_url)
|
||||
new_extensions: list[FileAppExtension] = []
|
||||
for i, ext in enumerate(request.extensions):
|
||||
normalized = ext.lower().strip().lstrip('.')
|
||||
@@ -278,12 +282,14 @@ async def update_extensions(
|
||||
app_id=app_id,
|
||||
extension=normalized,
|
||||
priority=i,
|
||||
wopi_action_url=old_url_map.get(normalized),
|
||||
)
|
||||
session.add(ext_record)
|
||||
new_extensions.append(ext_record)
|
||||
|
||||
await session.commit()
|
||||
# refresh 新创建的记录
|
||||
# refresh commit 后过期的对象
|
||||
await session.refresh(app)
|
||||
for ext_record in new_extensions:
|
||||
await session.refresh(ext_record)
|
||||
|
||||
@@ -316,9 +322,7 @@ async def update_group_access(
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
# 删除旧的用户组关联
|
||||
old_links_result = await session.exec(
|
||||
@@ -336,6 +340,7 @@ async def update_group_access(
|
||||
new_links.append(link)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(app)
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
@@ -346,3 +351,100 @@ async def update_group_access(
|
||||
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, new_links)
|
||||
|
||||
|
||||
@admin_file_app_router.post(
|
||||
path='/{app_id}/discover',
|
||||
summary='执行 WOPI Discovery',
|
||||
)
|
||||
async def discover_wopi(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
) -> WopiDiscoveryResponse:
|
||||
"""
|
||||
从 WOPI 服务端获取 Discovery XML 并自动配置扩展名和 URL 模板。
|
||||
|
||||
流程:
|
||||
1. 验证 FileApp 存在且为 WOPI 类型
|
||||
2. 使用 FileApp.wopi_discovery_url 获取 Discovery XML
|
||||
3. 解析 XML,提取扩展名和动作 URL
|
||||
4. 全量替换 FileAppExtension 记录(带 wopi_action_url)
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
- 400: 非 WOPI 类型 / discovery URL 未配置 / XML 解析失败
|
||||
- 502: WOPI 服务端不可达或返回无效响应
|
||||
"""
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
if app.type != FileAppType.WOPI:
|
||||
http_exceptions.raise_bad_request("仅 WOPI 类型应用支持自动发现")
|
||||
|
||||
if not app.wopi_discovery_url:
|
||||
http_exceptions.raise_bad_request("未配置 WOPI Discovery URL")
|
||||
|
||||
# commit 后对象会过期,先保存需要的值
|
||||
discovery_url = app.wopi_discovery_url
|
||||
app_key = app.app_key
|
||||
|
||||
# 获取 Discovery XML
|
||||
try:
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.get(
|
||||
discovery_url,
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
http_exceptions.raise_bad_gateway(
|
||||
f"WOPI 服务端返回 HTTP {resp.status}"
|
||||
)
|
||||
xml_content = await resp.text()
|
||||
except aiohttp.ClientError as e:
|
||||
http_exceptions.raise_bad_gateway(f"无法连接 WOPI 服务端: {e}")
|
||||
|
||||
# 解析 XML
|
||||
try:
|
||||
action_urls, app_names = parse_wopi_discovery_xml(xml_content)
|
||||
except ValueError as e:
|
||||
http_exceptions.raise_bad_request(str(e))
|
||||
|
||||
if not action_urls:
|
||||
return WopiDiscoveryResponse(app_names=app_names)
|
||||
|
||||
# 全量替换扩展名
|
||||
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
for old_ext in old_extensions:
|
||||
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||
await session.flush()
|
||||
|
||||
new_extensions: list[FileAppExtension] = []
|
||||
discovered: list[WopiDiscoveredExtension] = []
|
||||
for i, (ext, action_url) in enumerate(sorted(action_urls.items())):
|
||||
ext_record = FileAppExtension(
|
||||
app_id=app_id,
|
||||
extension=ext,
|
||||
priority=i,
|
||||
wopi_action_url=action_url,
|
||||
)
|
||||
session.add(ext_record)
|
||||
new_extensions.append(ext_record)
|
||||
discovered.append(WopiDiscoveredExtension(extension=ext, action_url=action_url))
|
||||
|
||||
await session.commit()
|
||||
|
||||
l.info(
|
||||
f"WOPI Discovery 完成: 应用 {app_key}, "
|
||||
f"发现 {len(discovered)} 个扩展名"
|
||||
)
|
||||
|
||||
return WopiDiscoveryResponse(
|
||||
discovered_extensions=discovered,
|
||||
app_names=app_names,
|
||||
applied_count=len(discovered),
|
||||
)
|
||||
|
||||
@@ -63,10 +63,7 @@ async def router_admin_get_group(
|
||||
:param group_id: 用户组UUID
|
||||
:return: 用户组详情
|
||||
"""
|
||||
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
group = await Group.get_exist_one(session, group_id, load=[Group.options, Group.policies])
|
||||
|
||||
# 直接访问已加载的关系,无需额外查询
|
||||
policies = group.policies
|
||||
@@ -94,9 +91,7 @@ async def router_admin_get_group_members(
|
||||
:return: 分页成员列表
|
||||
"""
|
||||
# 验证组存在
|
||||
group = await Group.get(session, Group.id == group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
await Group.get_exist_one(session, group_id)
|
||||
|
||||
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
|
||||
|
||||
@@ -138,10 +133,11 @@ async def router_admin_create_group(
|
||||
speed_limit=request.speed_limit,
|
||||
)
|
||||
group = await group.save(session)
|
||||
group_id_val: UUID = group.id
|
||||
|
||||
# 创建选项
|
||||
options = GroupOptions(
|
||||
group_id=group.id,
|
||||
group_id=group_id_val,
|
||||
share_download=request.share_download,
|
||||
share_free=request.share_free,
|
||||
relocate=request.relocate,
|
||||
@@ -154,11 +150,11 @@ async def router_admin_create_group(
|
||||
aria2=request.aria2,
|
||||
redirected_source=request.redirected_source,
|
||||
)
|
||||
await options.save(session)
|
||||
options = await options.save(session)
|
||||
|
||||
# 关联存储策略
|
||||
for policy_id in request.policy_ids:
|
||||
link = GroupPolicyLink(group_id=group.id, policy_id=policy_id)
|
||||
link = GroupPolicyLink(group_id=group_id_val, policy_id=policy_id)
|
||||
session.add(link)
|
||||
await session.commit()
|
||||
|
||||
@@ -185,9 +181,7 @@ async def router_admin_update_group(
|
||||
:param request: 更新请求
|
||||
:return: 更新结果
|
||||
"""
|
||||
group = await Group.get(session, Group.id == group_id, load=Group.options)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
group = await Group.get_exist_one(session, group_id, load=Group.options)
|
||||
|
||||
# 检查名称唯一性(如果要更新名称)
|
||||
if request.name and request.name != group.name:
|
||||
@@ -217,7 +211,7 @@ async def router_admin_update_group(
|
||||
if options_data:
|
||||
for key, value in options_data.items():
|
||||
setattr(group.options, key, value)
|
||||
await group.options.save(session)
|
||||
group.options = await group.options.save(session)
|
||||
|
||||
# 更新策略关联
|
||||
if request.policy_ids is not None:
|
||||
@@ -255,9 +249,7 @@ async def router_admin_delete_group(
|
||||
:param group_id: 用户组UUID
|
||||
:return: 删除结果
|
||||
"""
|
||||
group = await Group.get(session, Group.id == group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
group = await Group.get_exist_one(session, group_id)
|
||||
|
||||
# 检查是否有用户属于该组
|
||||
user_count = await User.count(session, User.group_id == group_id)
|
||||
|
||||
@@ -8,11 +8,11 @@ from sqlmodel import Field
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||
ListResponse, Object,
|
||||
Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
|
||||
PolicyUpdateRequest, ResponseBase, ListResponse, Object,
|
||||
)
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService
|
||||
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
|
||||
|
||||
admin_policy_router = APIRouter(
|
||||
prefix='/policy',
|
||||
@@ -67,6 +67,12 @@ class PolicyDetailResponse(SQLModelBase):
|
||||
base_url: str | None
|
||||
"""基础URL"""
|
||||
|
||||
access_key: str | None
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str | None
|
||||
"""Secret Key"""
|
||||
|
||||
max_size: int
|
||||
"""最大文件尺寸"""
|
||||
|
||||
@@ -107,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
|
||||
secret: str
|
||||
"""从机通信密钥"""
|
||||
|
||||
class PolicyCreateRequest(PolicyBase):
|
||||
"""创建存储策略请求 DTO,继承 PolicyBase 中的所有字段"""
|
||||
pass
|
||||
class PolicyTestS3Request(SQLModelBase):
|
||||
"""测试 S3 连接请求 DTO"""
|
||||
|
||||
server: str = Field(max_length=255)
|
||||
"""S3 端点地址"""
|
||||
|
||||
bucket_name: str = Field(max_length=255)
|
||||
"""存储桶名称"""
|
||||
|
||||
access_key: str
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str
|
||||
"""Secret Key"""
|
||||
|
||||
s3_region: str = Field(default='us-east-1', max_length=64)
|
||||
"""S3 区域"""
|
||||
|
||||
s3_path_style: bool = False
|
||||
"""是否使用路径风格"""
|
||||
|
||||
|
||||
class PolicyTestS3Response(SQLModelBase):
|
||||
"""S3 连接测试响应"""
|
||||
|
||||
is_connected: bool
|
||||
"""连接是否成功"""
|
||||
|
||||
message: str
|
||||
"""测试结果消息"""
|
||||
|
||||
|
||||
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
|
||||
|
||||
_OPTIONS_FIELDS: set[str] = {
|
||||
'token', 'file_type', 'mimetype', 'od_redirect',
|
||||
'chunk_size', 's3_path_style', 's3_region',
|
||||
}
|
||||
|
||||
|
||||
@admin_policy_router.get(
|
||||
path='/list',
|
||||
@@ -277,7 +319,20 @@ async def router_policy_add_policy(
|
||||
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
|
||||
|
||||
# 保存到数据库
|
||||
await policy.save(session)
|
||||
policy = await policy.save(session)
|
||||
|
||||
# 创建策略选项
|
||||
options = PolicyOptions(
|
||||
policy_id=policy.id,
|
||||
token=request.token,
|
||||
file_type=request.file_type,
|
||||
mimetype=request.mimetype,
|
||||
od_redirect=request.od_redirect,
|
||||
chunk_size=request.chunk_size,
|
||||
s3_path_style=request.s3_path_style,
|
||||
s3_region=request.s3_region,
|
||||
)
|
||||
options = await options.save(session)
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/cors',
|
||||
@@ -328,9 +383,7 @@ async def router_policy_onddrive_oauth(
|
||||
:param policy_id: 存储策略UUID
|
||||
:return: OAuth URL
|
||||
"""
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
# TODO: 实现OneDrive OAuth
|
||||
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
|
||||
@@ -353,9 +406,7 @@ async def router_policy_get_policy(
|
||||
:param policy_id: 存储策略UUID
|
||||
:return: 策略详情
|
||||
"""
|
||||
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
||||
|
||||
# 获取使用此策略的用户组
|
||||
groups = await policy.awaitable_attrs.groups
|
||||
@@ -371,6 +422,8 @@ async def router_policy_get_policy(
|
||||
bucket_name=policy.bucket_name,
|
||||
is_private=policy.is_private,
|
||||
base_url=policy.base_url,
|
||||
access_key=policy.access_key,
|
||||
secret_key=policy.secret_key,
|
||||
max_size=policy.max_size,
|
||||
auto_rename=policy.auto_rename,
|
||||
dir_name_rule=policy.dir_name_rule,
|
||||
@@ -402,9 +455,7 @@ async def router_policy_delete_policy(
|
||||
:param policy_id: 存储策略UUID
|
||||
:return: 删除结果
|
||||
"""
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
# 检查是否有文件使用此策略
|
||||
file_count = await Object.count(session, Object.policy_id == policy_id)
|
||||
@@ -417,4 +468,106 @@ async def router_policy_delete_policy(
|
||||
policy_name = policy.name
|
||||
await Policy.delete(session, policy)
|
||||
|
||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||
|
||||
|
||||
@admin_policy_router.patch(
|
||||
path='/{policy_id}',
|
||||
summary='更新存储策略',
|
||||
description='更新存储策略配置。策略类型创建后不可更改。',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_policy_update_policy(
|
||||
session: SessionDep,
|
||||
policy_id: UUID,
|
||||
request: PolicyUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新存储策略端点
|
||||
|
||||
功能:
|
||||
- 更新策略基础字段和扩展选项
|
||||
- 策略类型(type)不可更改
|
||||
|
||||
认证:
|
||||
- 需要管理员权限
|
||||
|
||||
:param session: 数据库会话
|
||||
:param policy_id: 存储策略UUID
|
||||
:param request: 更新请求
|
||||
"""
|
||||
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
||||
|
||||
# 检查名称唯一性(如果要更新名称)
|
||||
if request.name and request.name != policy.name:
|
||||
existing = await Policy.get(session, Policy.name == request.name)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="策略名称已存在")
|
||||
|
||||
# 分离 Policy 字段和 Options 字段
|
||||
all_data = request.model_dump(exclude_unset=True)
|
||||
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
|
||||
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
|
||||
|
||||
# 更新 Policy 基础字段
|
||||
if policy_data:
|
||||
for key, value in policy_data.items():
|
||||
setattr(policy, key, value)
|
||||
policy = await policy.save(session)
|
||||
|
||||
# 更新或创建 PolicyOptions
|
||||
if options_data:
|
||||
if policy.options:
|
||||
for key, value in options_data.items():
|
||||
setattr(policy.options, key, value)
|
||||
policy.options = await policy.options.save(session)
|
||||
else:
|
||||
options = PolicyOptions(policy_id=policy.id, **options_data)
|
||||
options = await options.save(session)
|
||||
|
||||
l.info(f"管理员更新了存储策略: {policy_id}")
|
||||
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/test/s3',
|
||||
summary='测试 S3 连接',
|
||||
description='测试 S3 存储端点的连通性和凭据有效性。',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_policy_test_s3(
|
||||
request: PolicyTestS3Request,
|
||||
) -> PolicyTestS3Response:
|
||||
"""
|
||||
测试 S3 连接端点
|
||||
|
||||
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
|
||||
|
||||
:param request: 测试请求
|
||||
:return: 测试结果
|
||||
"""
|
||||
from service.storage import S3APIError
|
||||
|
||||
# 构造临时 Policy 对象用于创建 S3StorageService
|
||||
temp_policy = Policy(
|
||||
name="__test__",
|
||||
type=PolicyType.S3,
|
||||
server=request.server,
|
||||
bucket_name=request.bucket_name,
|
||||
access_key=request.access_key,
|
||||
secret_key=request.secret_key,
|
||||
)
|
||||
s3_service = S3StorageService(
|
||||
temp_policy,
|
||||
region=request.s3_region,
|
||||
is_path_style=request.s3_path_style,
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 file_exists 发送 HEAD 请求来验证连通性
|
||||
await s3_service.file_exists("__connection_test__")
|
||||
return PolicyTestS3Response(is_connected=True, message="连接成功")
|
||||
except S3APIError as e:
|
||||
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
|
||||
except Exception as e:
|
||||
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")
|
||||
@@ -155,9 +155,7 @@ async def router_admin_delete_share(
|
||||
:param share_id: 分享ID
|
||||
:return: 删除结果
|
||||
"""
|
||||
share = await Share.get(session, Share.id == share_id)
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="分享不存在")
|
||||
share = await Share.get_exist_one(session, share_id)
|
||||
|
||||
await Share.delete(session, share)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
ListResponse,
|
||||
Task, TaskSummary,
|
||||
Task, TaskSummary, TaskStatus, TaskType,
|
||||
)
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
@@ -19,10 +19,10 @@ class TaskDetailResponse(SQLModelBase):
|
||||
id: int
|
||||
"""任务ID"""
|
||||
|
||||
status: int
|
||||
status: TaskStatus
|
||||
"""任务状态"""
|
||||
|
||||
type: int
|
||||
type: TaskType
|
||||
"""任务类型"""
|
||||
|
||||
progress: int
|
||||
@@ -150,9 +150,7 @@ async def router_admin_delete_task(
|
||||
:param task_id: 任务ID
|
||||
:return: 删除结果
|
||||
"""
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
task = await Task.get_exist_one(session, task_id)
|
||||
|
||||
await Task.delete(session, task)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ async def router_admin_theme_create(
|
||||
name=request.name,
|
||||
**request.colors.model_dump(),
|
||||
)
|
||||
await preset.save(session)
|
||||
preset = await preset.save(session)
|
||||
l.info(f"管理员创建了主题预设: {request.name}")
|
||||
|
||||
|
||||
@@ -101,11 +101,7 @@ async def router_admin_theme_update(
|
||||
- 404: 预设不存在
|
||||
- 409: 名称已被其他预设使用
|
||||
"""
|
||||
preset: ThemePreset | None = await ThemePreset.get(
|
||||
session, ThemePreset.id == preset_id
|
||||
)
|
||||
if not preset:
|
||||
http_exceptions.raise_not_found("主题预设不存在")
|
||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||
|
||||
# 检查名称唯一性(排除自身)
|
||||
if request.name is not None and request.name != preset.name:
|
||||
@@ -120,7 +116,7 @@ async def router_admin_theme_update(
|
||||
for key, value in color_data.items():
|
||||
setattr(preset, key, value)
|
||||
|
||||
await preset.save(session)
|
||||
preset = await preset.save(session)
|
||||
l.info(f"管理员更新了主题预设: {preset.name}")
|
||||
|
||||
|
||||
@@ -147,11 +143,7 @@ async def router_admin_theme_delete(
|
||||
副作用:
|
||||
- 关联用户的 theme_preset_id 会被数据库 SET NULL
|
||||
"""
|
||||
preset: ThemePreset | None = await ThemePreset.get(
|
||||
session, ThemePreset.id == preset_id
|
||||
)
|
||||
if not preset:
|
||||
http_exceptions.raise_not_found("主题预设不存在")
|
||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||
|
||||
await preset.delete(session)
|
||||
l.info(f"管理员删除了主题预设: {preset.name}")
|
||||
@@ -180,11 +172,7 @@ async def router_admin_theme_set_default(
|
||||
逻辑:
|
||||
- 事务中先清除所有旧默认,再设新默认
|
||||
"""
|
||||
preset: ThemePreset | None = await ThemePreset.get(
|
||||
session, ThemePreset.id == preset_id
|
||||
)
|
||||
if not preset:
|
||||
http_exceptions.raise_not_found("主题预设不存在")
|
||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||
|
||||
# 清除所有旧默认
|
||||
await session.execute(
|
||||
@@ -195,5 +183,5 @@ async def router_admin_theme_set_default(
|
||||
|
||||
# 设新默认
|
||||
preset.is_default = True
|
||||
await preset.save(session)
|
||||
preset = await preset.save(session)
|
||||
l.info(f"管理员将主题预设 '{preset.name}' 设为默认")
|
||||
|
||||
@@ -128,8 +128,9 @@ async def router_admin_create_user(
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
user = await User.get(session, User.id == user.id, load=User.group)
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@@ -153,9 +154,7 @@ async def router_admin_update_user(
|
||||
:param request: 更新请求
|
||||
:return: 更新结果
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
|
||||
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
||||
default_admin_setting = await Setting.get(
|
||||
@@ -252,9 +251,7 @@ async def router_admin_calibrate_storage(
|
||||
:param user_id: 用户UUID
|
||||
:return: 校准结果
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
|
||||
previous_storage = user.storage
|
||||
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
ResponseBase,
|
||||
)
|
||||
|
||||
admin_vas_router = APIRouter(
|
||||
prefix='/vas',
|
||||
tags=['admin', 'admin_vas']
|
||||
)
|
||||
|
||||
@admin_vas_router.get(
|
||||
path='/list',
|
||||
summary='获取增值服务列表',
|
||||
description='Get VAS list (orders and storage packs)',
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_vas_list(
|
||||
session: SessionDep,
|
||||
user_id: UUID | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
获取增值服务列表(订单和存储包)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 按用户筛选
|
||||
:param page: 页码
|
||||
:param page_size: 每页数量
|
||||
:return: 增值服务列表
|
||||
"""
|
||||
# TODO: 实现增值服务列表
|
||||
# 需要查询 Order 和 StoragePack 模型
|
||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||
|
||||
|
||||
@admin_vas_router.get(
|
||||
path='/{vas_id}',
|
||||
summary='获取增值服务详情',
|
||||
description='Get VAS detail by ID',
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_vas(
|
||||
session: SessionDep,
|
||||
vas_id: UUID,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
获取增值服务详情。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param vas_id: 增值服务UUID
|
||||
:return: 增值服务详情
|
||||
"""
|
||||
# TODO: 实现增值服务详情
|
||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||
|
||||
|
||||
@admin_vas_router.delete(
|
||||
path='/{vas_id}',
|
||||
summary='删除增值服务',
|
||||
description='Delete VAS by ID',
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_delete_vas(
|
||||
session: SessionDep,
|
||||
vas_id: UUID,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
删除增值服务。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param vas_id: 增值服务UUID
|
||||
:return: 删除结果
|
||||
"""
|
||||
# TODO: 实现增值服务删除
|
||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from loguru import logger as l
|
||||
|
||||
from sqlmodels import ResponseBase
|
||||
import service.oauth
|
||||
@@ -15,18 +16,12 @@ oauth_router = APIRouter(
|
||||
tags=["callback", "oauth"],
|
||||
)
|
||||
|
||||
pay_router = APIRouter(
|
||||
prefix='/callback/pay',
|
||||
tags=["callback", "pay"],
|
||||
)
|
||||
|
||||
upload_router = APIRouter(
|
||||
prefix='/callback/upload',
|
||||
tags=["callback", "upload"],
|
||||
)
|
||||
|
||||
callback_router.include_router(oauth_router)
|
||||
callback_router.include_router(pay_router)
|
||||
callback_router.include_router(upload_router)
|
||||
|
||||
@oauth_router.post(
|
||||
@@ -37,7 +32,7 @@ callback_router.include_router(upload_router)
|
||||
def router_callback_qq() -> ResponseBase:
|
||||
"""
|
||||
Handle QQ OAuth callback and return user information.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the QQ OAuth callback.
|
||||
"""
|
||||
@@ -54,101 +49,27 @@ async def router_callback_github(
|
||||
GitHub OAuth 回调处理
|
||||
- 错误响应示例:
|
||||
- {
|
||||
'error': 'bad_verification_code',
|
||||
'error_description': 'The code passed is incorrect or expired.',
|
||||
'error': 'bad_verification_code',
|
||||
'error_description': 'The code passed is incorrect or expired.',
|
||||
'error_uri': 'https://docs.github.com/apps/managing-oauth-apps/troubleshooting-oauth-app-access-token-request-errors/#bad-verification-code'
|
||||
}
|
||||
|
||||
|
||||
Returns:
|
||||
PlainTextResponse: A response containing the user information from GitHub.
|
||||
"""
|
||||
try:
|
||||
access_token = await service.oauth.github.get_access_token(code)
|
||||
# [TODO] 把access_token写数据库里
|
||||
if not access_token:
|
||||
return PlainTextResponse("Failed to retrieve access token from GitHub.", status_code=400)
|
||||
|
||||
return PlainTextResponse("GitHub 认证失败", status_code=400)
|
||||
|
||||
user_data = await service.oauth.github.get_user_info(access_token.access_token)
|
||||
# [TODO] 把user_data写数据库里
|
||||
|
||||
return PlainTextResponse(f"User information processed successfully, code: {code}, user_data: {user_data.json_dump()}", status_code=200)
|
||||
# [TODO] 把 access_token 和 user_data 写数据库,生成 JWT,重定向到前端
|
||||
l.info(f"GitHub OAuth 回调成功: user={user_data.user_data.login}")
|
||||
|
||||
return PlainTextResponse("认证成功,功能开发中", status_code=200)
|
||||
except Exception as e:
|
||||
return PlainTextResponse(f"An error occurred: {str(e)}", status_code=500)
|
||||
|
||||
@pay_router.post(
|
||||
path='/alipay',
|
||||
summary='支付宝支付回调',
|
||||
description='Handle Alipay payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_alipay() -> ResponseBase:
|
||||
"""
|
||||
Handle Alipay payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Alipay payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.post(
|
||||
path='/wechat',
|
||||
summary='微信支付回调',
|
||||
description='Handle WeChat Pay payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_wechat() -> ResponseBase:
|
||||
"""
|
||||
Handle WeChat Pay payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.post(
|
||||
path='/stripe',
|
||||
summary='Stripe支付回调',
|
||||
description='Handle Stripe payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_stripe() -> ResponseBase:
|
||||
"""
|
||||
Handle Stripe payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Stripe payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.get(
|
||||
path='/easypay',
|
||||
summary='易支付回调',
|
||||
description='Handle EasyPay payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_easypay() -> PlainTextResponse:
|
||||
"""
|
||||
Handle EasyPay payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
# return PlainTextResponse("success", status_code=200)
|
||||
|
||||
@pay_router.get(
|
||||
path='/custom/{order_no}/{id}',
|
||||
summary='自定义支付回调',
|
||||
description='Handle custom payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_custom(order_no: str, id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle custom payment callback and return payment status.
|
||||
|
||||
Args:
|
||||
order_no (str): The order number for the payment.
|
||||
id (str): The ID associated with the payment.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the custom payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
l.error(f"GitHub OAuth 回调异常: {e}")
|
||||
return PlainTextResponse("认证过程中发生错误,请重试", status_code=500)
|
||||
|
||||
@upload_router.post(
|
||||
path='/remote/{session_id}/{key}',
|
||||
@@ -158,11 +79,11 @@ def router_callback_custom(order_no: str, id: str) -> ResponseBase:
|
||||
def router_callback_remote(session_id: str, key: str) -> ResponseBase:
|
||||
"""
|
||||
Handle remote upload callback and return upload status.
|
||||
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID for the upload.
|
||||
key (str): The key for the uploaded file.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the remote upload callback.
|
||||
"""
|
||||
@@ -176,15 +97,15 @@ def router_callback_remote(session_id: str, key: str) -> ResponseBase:
|
||||
def router_callback_qiniu(session_id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle Qiniu Cloud upload callback and return upload status.
|
||||
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID for the upload.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Qiniu Cloud upload callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@upload_router.post(
|
||||
path='/tencent/{session_id}',
|
||||
summary='腾讯云上传回调',
|
||||
@@ -193,16 +114,16 @@ def router_callback_qiniu(session_id: str) -> ResponseBase:
|
||||
def router_callback_tencent(session_id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle Tencent Cloud upload callback and return upload status.
|
||||
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID for the upload.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Tencent Cloud upload callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
@upload_router.post(
|
||||
path='/aliyun/{session_id}',
|
||||
summary='阿里云上传回调',
|
||||
description='Handle Aliyun upload callback and return upload status.',
|
||||
@@ -210,16 +131,16 @@ def router_callback_tencent(session_id: str) -> ResponseBase:
|
||||
def router_callback_aliyun(session_id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle Aliyun upload callback and return upload status.
|
||||
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID for the upload.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Aliyun upload callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@upload_router.post(
|
||||
@upload_router.post(
|
||||
path='/upyun/{session_id}',
|
||||
summary='又拍云上传回调',
|
||||
description='Handle Upyun upload callback and return upload status.',
|
||||
@@ -227,10 +148,10 @@ def router_callback_aliyun(session_id: str) -> ResponseBase:
|
||||
def router_callback_upyun(session_id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle Upyun upload callback and return upload status.
|
||||
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID for the upload.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Upyun upload callback.
|
||||
"""
|
||||
@@ -244,10 +165,10 @@ def router_callback_upyun(session_id: str) -> ResponseBase:
|
||||
def router_callback_aws(session_id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle AWS S3 upload callback and return upload status.
|
||||
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID for the upload.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the AWS S3 upload callback.
|
||||
"""
|
||||
@@ -261,10 +182,10 @@ def router_callback_aws(session_id: str) -> ResponseBase:
|
||||
def router_callback_onedrive_finish(session_id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle OneDrive upload completion callback and return upload status.
|
||||
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID for the upload.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the OneDrive upload completion callback.
|
||||
"""
|
||||
@@ -278,7 +199,7 @@ def router_callback_onedrive_finish(session_id: str) -> ResponseBase:
|
||||
def router_callback_onedrive_auth() -> ResponseBase:
|
||||
"""
|
||||
Handle OneDrive authorization callback and return authorization status.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the OneDrive authorization callback.
|
||||
"""
|
||||
@@ -292,8 +213,8 @@ def router_callback_onedrive_auth() -> ResponseBase:
|
||||
def router_callback_google_auth() -> ResponseBase:
|
||||
"""
|
||||
Handle Google OAuth completion callback and return authorization status.
|
||||
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Google OAuth completion callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
100
routers/api/v1/category/__init__.py
Normal file
100
routers/api/v1/category/__init__.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
文件分类筛选端点
|
||||
|
||||
按文件类型分类(图片/视频/音频/文档)查询用户的所有文件,
|
||||
跨目录搜索,支持分页。扩展名映射从数据库 Setting 表读取。
|
||||
"""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
FileCategory,
|
||||
ListResponse,
|
||||
Object,
|
||||
ObjectResponse,
|
||||
ObjectType,
|
||||
Setting,
|
||||
SettingsType,
|
||||
User,
|
||||
)
|
||||
|
||||
category_router = APIRouter(
|
||||
prefix="/category",
|
||||
tags=["category"],
|
||||
)
|
||||
|
||||
|
||||
@category_router.get(
|
||||
path="/{category}",
|
||||
summary="按分类获取文件列表",
|
||||
)
|
||||
async def router_category_list(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
category: FileCategory,
|
||||
table_view: TableViewRequestDep,
|
||||
) -> ListResponse[ObjectResponse]:
|
||||
"""
|
||||
按文件类型分类查询用户的所有文件
|
||||
|
||||
跨所有目录搜索,返回分页结果。
|
||||
扩展名配置从数据库 Setting 表读取(type=file_category)。
|
||||
|
||||
认证:
|
||||
- JWT token in Authorization header
|
||||
|
||||
路径参数:
|
||||
- category: 文件分类(image / video / audio / document)
|
||||
|
||||
查询参数:
|
||||
- offset: 分页偏移量(默认0)
|
||||
- limit: 每页数量(默认20,最大100)
|
||||
- desc: 是否降序(默认true)
|
||||
- order: 排序字段(created_at / updated_at)
|
||||
|
||||
响应:
|
||||
- ListResponse[ObjectResponse]: 分页文件列表
|
||||
|
||||
错误处理:
|
||||
- HTTPException 422: category 参数无效
|
||||
- HTTPException 404: 该分类未配置扩展名
|
||||
"""
|
||||
# 从数据库读取该分类的扩展名配置
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.FILE_CATEGORY) & (Setting.name == category.value),
|
||||
)
|
||||
if not setting or not setting.value:
|
||||
raise HTTPException(status_code=404, detail=f"分类 {category.value} 未配置扩展名")
|
||||
|
||||
extensions = [ext.strip() for ext in setting.value.split(",") if ext.strip()]
|
||||
if not extensions:
|
||||
raise HTTPException(status_code=404, detail=f"分类 {category.value} 扩展名列表为空")
|
||||
|
||||
result = await Object.get_by_category(
|
||||
session,
|
||||
user.id,
|
||||
extensions,
|
||||
table_view=table_view,
|
||||
)
|
||||
|
||||
items = [
|
||||
ObjectResponse(
|
||||
id=obj.id,
|
||||
name=obj.name,
|
||||
type=ObjectType.FILE,
|
||||
size=obj.size,
|
||||
mime_type=obj.mime_type,
|
||||
thumb=False,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
source_enabled=False,
|
||||
)
|
||||
for obj in result.items
|
||||
]
|
||||
|
||||
return ListResponse(count=result.count, items=items)
|
||||
@@ -57,7 +57,7 @@ async def _get_directory_response(
|
||||
policy_response = PolicyResponse(
|
||||
id=policy.id,
|
||||
name=policy.name,
|
||||
type=policy.type.value,
|
||||
type=policy.type,
|
||||
max_size=policy.max_size,
|
||||
)
|
||||
|
||||
@@ -189,6 +189,14 @@ async def router_directory_create(
|
||||
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
||||
|
||||
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
||||
|
||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
||||
if request.policy_id:
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
if request.policy_id not in {p.id for p in group.policies}:
|
||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
||||
|
||||
parent_id = parent.id # 在 save 前保存
|
||||
|
||||
new_folder = Object(
|
||||
@@ -198,4 +206,4 @@ async def router_directory_create(
|
||||
parent_id=parent_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
await new_folder.save(session)
|
||||
new_folder = await new_folder.save(session)
|
||||
|
||||
@@ -13,9 +13,11 @@ from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
import orjson
|
||||
import whatthepatch
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from starlette.responses import Response
|
||||
from loguru import logger as l
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from whatthepatch.exceptions import HunkApplyException
|
||||
@@ -44,7 +46,9 @@ from sqlmodels import (
|
||||
User,
|
||||
WopiSessionResponse,
|
||||
)
|
||||
from service.storage import LocalStorageService, adjust_user_storage
|
||||
import orjson
|
||||
|
||||
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
|
||||
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
|
||||
from utils.JWT.wopi_token import create_wopi_token
|
||||
from utils import http_exceptions
|
||||
@@ -180,9 +184,14 @@ async def create_upload_session(
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
||||
if request.policy_id:
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
if request.policy_id not in {p.id for p in group.policies}:
|
||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
||||
|
||||
# 验证文件大小限制
|
||||
_check_policy_size_limit(policy, request.file_size)
|
||||
@@ -210,6 +219,7 @@ async def create_upload_session(
|
||||
|
||||
# 生成存储路径
|
||||
storage_path: str | None = None
|
||||
s3_upload_id: str | None = None
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||
@@ -217,8 +227,25 @@ async def create_upload_session(
|
||||
original_filename=request.file_name,
|
||||
)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = S3StorageService(
|
||||
policy,
|
||||
region=options.s3_region if options else 'us-east-1',
|
||||
is_path_style=options.s3_path_style if options else False,
|
||||
)
|
||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
||||
user_id=user.id,
|
||||
original_filename=request.file_name,
|
||||
)
|
||||
# 多分片时创建 multipart upload
|
||||
if total_chunks > 1:
|
||||
s3_upload_id = await s3_service.create_multipart_upload(
|
||||
storage_path, content_type='application/octet-stream',
|
||||
)
|
||||
|
||||
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
|
||||
if request.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, request.file_size, commit=False)
|
||||
|
||||
# 创建上传会话
|
||||
upload_session = UploadSession(
|
||||
@@ -227,6 +254,7 @@ async def create_upload_session(
|
||||
chunk_size=chunk_size,
|
||||
total_chunks=total_chunks,
|
||||
storage_path=storage_path,
|
||||
s3_upload_id=s3_upload_id,
|
||||
expires_at=datetime.now() + timedelta(hours=24),
|
||||
owner_id=user.id,
|
||||
parent_id=request.parent_id,
|
||||
@@ -302,8 +330,38 @@ async def upload_chunk(
|
||||
content,
|
||||
offset,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
if not upload_session.storage_path:
|
||||
raise HTTPException(status_code=500, detail="存储路径丢失")
|
||||
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
|
||||
if upload_session.total_chunks == 1:
|
||||
# 单分片:直接 PUT 上传
|
||||
await s3_service.upload_file(upload_session.storage_path, content)
|
||||
else:
|
||||
# 多分片:UploadPart
|
||||
if not upload_session.s3_upload_id:
|
||||
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
|
||||
|
||||
etag = await s3_service.upload_part(
|
||||
upload_session.storage_path,
|
||||
upload_session.s3_upload_id,
|
||||
chunk_index + 1, # S3 part number 从 1 开始
|
||||
content,
|
||||
)
|
||||
# 追加 ETag 到 s3_part_etags
|
||||
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
|
||||
etags.append([chunk_index + 1, etag])
|
||||
upload_session.s3_part_etags = orjson.dumps(etags).decode()
|
||||
|
||||
# 在 save(commit)前缓存后续需要的属性(commit 后 ORM 对象会过期)
|
||||
policy_type = policy.type
|
||||
s3_upload_id = upload_session.s3_upload_id
|
||||
s3_part_etags = upload_session.s3_part_etags
|
||||
s3_service_for_complete: S3StorageService | None = None
|
||||
if policy_type == PolicyType.S3:
|
||||
s3_service_for_complete = await S3StorageService.from_policy(policy)
|
||||
|
||||
# 更新会话进度
|
||||
upload_session.uploaded_chunks += 1
|
||||
@@ -319,12 +377,26 @@ async def upload_chunk(
|
||||
if is_complete:
|
||||
# 保存 upload_session 属性(commit 后会过期)
|
||||
file_name = upload_session.file_name
|
||||
file_size = upload_session.file_size
|
||||
uploaded_size = upload_session.uploaded_size
|
||||
storage_path = upload_session.storage_path
|
||||
upload_session_id = upload_session.id
|
||||
parent_id = upload_session.parent_id
|
||||
policy_id = upload_session.policy_id
|
||||
|
||||
# S3 多分片上传完成:合并分片
|
||||
if (
|
||||
policy_type == PolicyType.S3
|
||||
and s3_upload_id
|
||||
and s3_part_etags
|
||||
and s3_service_for_complete
|
||||
):
|
||||
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
|
||||
parts = [(int(pn), str(et)) for pn, et in parts_data]
|
||||
await s3_service_for_complete.complete_multipart_upload(
|
||||
storage_path, s3_upload_id, parts,
|
||||
)
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
@@ -355,9 +427,10 @@ async def upload_chunk(
|
||||
commit=False
|
||||
)
|
||||
|
||||
# 更新用户存储配额
|
||||
if uploaded_size > 0:
|
||||
await adjust_user_storage(session, user_id, uploaded_size, commit=False)
|
||||
# 调整存储配额差值(创建会话时已预扣 file_size,这里只补差)
|
||||
size_diff = uploaded_size - file_size
|
||||
if size_diff != 0:
|
||||
await adjust_user_storage(session, user_id, size_diff, commit=False)
|
||||
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
@@ -390,9 +463,25 @@ async def delete_upload_session(
|
||||
|
||||
# 删除临时文件
|
||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
if policy and upload_session.storage_path:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 如果有分片上传,先取消
|
||||
if upload_session.s3_upload_id:
|
||||
await s3_service.abort_multipart_upload(
|
||||
upload_session.storage_path, upload_session.s3_upload_id,
|
||||
)
|
||||
else:
|
||||
# 单分片上传已完成的话,删除已上传的文件
|
||||
if upload_session.uploaded_chunks > 0:
|
||||
await s3_service.delete_file(upload_session.storage_path)
|
||||
|
||||
# 释放预扣的存储空间
|
||||
if upload_session.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
||||
|
||||
# 删除会话记录
|
||||
await UploadSession.delete(session, upload_session)
|
||||
@@ -422,9 +511,22 @@ async def clear_upload_sessions(
|
||||
for upload_session in sessions:
|
||||
# 删除临时文件
|
||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
if policy and upload_session.storage_path:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(upload_session.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
if upload_session.s3_upload_id:
|
||||
await s3_service.abort_multipart_upload(
|
||||
upload_session.storage_path, upload_session.s3_upload_id,
|
||||
)
|
||||
elif upload_session.uploaded_chunks > 0:
|
||||
await s3_service.delete_file(upload_session.storage_path)
|
||||
|
||||
# 释放预扣的存储空间
|
||||
if upload_session.file_size > 0:
|
||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
||||
|
||||
await UploadSession.delete(session, upload_session)
|
||||
deleted_count += 1
|
||||
@@ -486,11 +588,12 @@ async def create_download_token_endpoint(
|
||||
path='/{token}',
|
||||
summary='下载文件',
|
||||
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
|
||||
response_model=None,
|
||||
)
|
||||
async def download_file(
|
||||
session: SessionDep,
|
||||
token: str,
|
||||
) -> FileResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
下载文件端点
|
||||
|
||||
@@ -540,8 +643,15 @@ async def download_file(
|
||||
filename=file_obj.name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 302 重定向到预签名 URL
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
raise HTTPException(status_code=500, detail="不支持的存储类型")
|
||||
|
||||
|
||||
# ==================== 包含子路由 ====================
|
||||
@@ -599,9 +709,7 @@ async def create_empty_file(
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
# 生成存储路径并创建空文件
|
||||
storage_path: str | None = None
|
||||
@@ -613,8 +721,13 @@ async def create_empty_file(
|
||||
)
|
||||
await storage_service.create_empty_file(full_path)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
||||
user_id=user_id,
|
||||
original_filename=request.name,
|
||||
)
|
||||
await s3_service.upload_file(storage_path, b'')
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
@@ -695,6 +808,7 @@ async def create_wopi_session(
|
||||
)
|
||||
|
||||
wopi_app: FileApp | None = None
|
||||
matched_ext_record: FileAppExtension | None = None
|
||||
for ext_record in ext_records:
|
||||
app = ext_record.app
|
||||
if app.type == FileAppType.WOPI and app.is_enabled:
|
||||
@@ -710,13 +824,20 @@ async def create_wopi_session(
|
||||
if not result.first():
|
||||
continue
|
||||
wopi_app = app
|
||||
matched_ext_record = ext_record
|
||||
break
|
||||
|
||||
if not wopi_app:
|
||||
http_exceptions.raise_not_found("无可用的 WOPI 查看器")
|
||||
|
||||
if not wopi_app.wopi_editor_url_template:
|
||||
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板")
|
||||
# 优先使用 per-extension URL(Discovery 自动填充),回退到全局模板
|
||||
editor_url_template: str | None = None
|
||||
if matched_ext_record and matched_ext_record.wopi_action_url:
|
||||
editor_url_template = matched_ext_record.wopi_action_url
|
||||
if not editor_url_template:
|
||||
editor_url_template = wopi_app.wopi_editor_url_template
|
||||
if not editor_url_template:
|
||||
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板,请先执行 Discovery 或手动配置")
|
||||
|
||||
# 获取站点 URL
|
||||
site_url_setting: Setting | None = await Setting.get(
|
||||
@@ -732,12 +853,8 @@ async def create_wopi_session(
|
||||
# 构建 wopi_src
|
||||
wopi_src = f"{site_url}/wopi/files/{file_id}"
|
||||
|
||||
# 构建 editor URL
|
||||
editor_url = wopi_app.wopi_editor_url_template.format(
|
||||
wopi_src=wopi_src,
|
||||
access_token=token,
|
||||
access_token_ttl=access_token_ttl,
|
||||
)
|
||||
# 构建 editor URL(只替换 wopi_src,token 通过 POST 表单传递)
|
||||
editor_url = editor_url_template.format(wopi_src=wopi_src)
|
||||
|
||||
return WopiSessionResponse(
|
||||
wopi_src=wopi_src,
|
||||
@@ -798,12 +915,13 @@ async def _validate_source_link(
|
||||
path='/get/{file_id}/{name}',
|
||||
summary='文件外链(直接输出文件数据)',
|
||||
description='通过外链直接获取文件内容,公开访问无需认证。',
|
||||
response_model=None,
|
||||
)
|
||||
async def file_get(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
name: str,
|
||||
) -> FileResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
文件外链端点(直接输出)
|
||||
|
||||
@@ -815,25 +933,32 @@ async def file_get(
|
||||
"""
|
||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 缓存物理路径(save 后对象属性会过期)
|
||||
file_path = physical_file.storage_path
|
||||
|
||||
# 递增下载次数
|
||||
link.downloads += 1
|
||||
await link.save(session)
|
||||
link = await link.save(session)
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(file_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
# S3 外链直接输出:302 重定向到预签名 URL
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
file_path, method='GET', expires_in=3600, filename=name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -846,7 +971,7 @@ async def file_source_redirect(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
name: str,
|
||||
) -> FileResponse | RedirectResponse:
|
||||
) -> Response:
|
||||
"""
|
||||
文件外链端点(重定向/直接输出)
|
||||
|
||||
@@ -860,13 +985,6 @@ async def file_source_redirect(
|
||||
"""
|
||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 缓存所有需要的值(save 后对象属性会过期)
|
||||
file_path = physical_file.storage_path
|
||||
is_private = policy.is_private
|
||||
@@ -874,20 +992,38 @@ async def file_source_redirect(
|
||||
|
||||
# 递增下载次数
|
||||
link.downloads += 1
|
||||
await link.save(session)
|
||||
link = await link.save(session)
|
||||
|
||||
# 公有存储:302 重定向到 base_url
|
||||
if not is_private and base_url:
|
||||
relative_path = storage_service.get_relative_path(file_path)
|
||||
redirect_url = f"{base_url}/{relative_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(file_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
# 私有存储或 base_url 为空:通过应用代理文件
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
# 公有存储:302 重定向到 base_url
|
||||
if not is_private and base_url:
|
||||
relative_path = storage_service.get_relative_path(file_path)
|
||||
redirect_url = f"{base_url}/{relative_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
# 私有存储或 base_url 为空:通过应用代理文件
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
# 公有存储且有 base_url:直接重定向到公开 URL
|
||||
if not is_private and base_url:
|
||||
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
# 私有存储:生成预签名 URL 重定向
|
||||
presigned_url = s3_service.generate_presigned_url(
|
||||
file_path, method='GET', expires_in=3600, filename=name,
|
||||
)
|
||||
return RedirectResponse(url=presigned_url, status_code=302)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -941,11 +1077,15 @@ async def file_content(
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||
# 读取文件内容
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
raw_bytes = await s3_service.download_file(physical_file.storage_path)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
try:
|
||||
content = raw_bytes.decode('utf-8')
|
||||
@@ -1011,11 +1151,15 @@ async def patch_file_content(
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(storage_path)
|
||||
# 读取文件内容
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
raw_bytes = await s3_service.download_file(storage_path)
|
||||
else:
|
||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
||||
|
||||
# 解码 + 规范化
|
||||
original_text = raw_bytes.decode('utf-8')
|
||||
@@ -1049,7 +1193,10 @@ async def patch_file_content(
|
||||
_check_policy_size_limit(policy, len(new_bytes))
|
||||
|
||||
# 写入文件
|
||||
await storage_service.write_file(storage_path, new_bytes)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
await storage_service.write_file(storage_path, new_bytes)
|
||||
elif policy.type == PolicyType.S3:
|
||||
await s3_service.upload_file(storage_path, new_bytes)
|
||||
|
||||
# 更新数据库
|
||||
owner_id = file_obj.owner_id
|
||||
|
||||
@@ -8,13 +8,14 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
Group,
|
||||
Object,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
@@ -22,24 +23,42 @@ from sqlmodels import (
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
ObjectSwitchPolicyRequest,
|
||||
ObjectType,
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
Task,
|
||||
TaskProps,
|
||||
TaskStatus,
|
||||
TaskSummaryBase,
|
||||
TaskType,
|
||||
User,
|
||||
# 元数据相关
|
||||
ObjectMetadata,
|
||||
MetadataResponse,
|
||||
MetadataPatchRequest,
|
||||
INTERNAL_NAMESPACES,
|
||||
USER_WRITABLE_NAMESPACES,
|
||||
)
|
||||
from service.storage import (
|
||||
LocalStorageService,
|
||||
adjust_user_storage,
|
||||
copy_object_recursive,
|
||||
migrate_file_with_task,
|
||||
migrate_directory_files,
|
||||
)
|
||||
from service.storage.object import soft_delete_objects
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from utils import http_exceptions
|
||||
|
||||
from .custom_property import router as custom_property_router
|
||||
|
||||
object_router = APIRouter(
|
||||
prefix="/object",
|
||||
tags=["object"]
|
||||
)
|
||||
object_router.include_router(custom_property_router)
|
||||
|
||||
@object_router.post(
|
||||
path='/',
|
||||
@@ -93,9 +112,7 @@ async def router_object_create(
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
parent_id = parent.id
|
||||
|
||||
@@ -130,7 +147,7 @@ async def router_object_create(
|
||||
owner_id=user_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
await file_object.save(session)
|
||||
file_object = await file_object.save(session)
|
||||
|
||||
l.info(f"创建空白文件: {request.name}")
|
||||
|
||||
@@ -455,7 +472,7 @@ async def router_object_rename(
|
||||
|
||||
# 更新名称
|
||||
obj.name = new_name
|
||||
await obj.save(session)
|
||||
obj = await obj.save(session)
|
||||
|
||||
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
|
||||
|
||||
@@ -493,6 +510,7 @@ async def router_object_property(
|
||||
name=obj.name,
|
||||
type=obj.type,
|
||||
size=obj.size,
|
||||
mime_type=obj.mime_type,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
parent_id=obj.parent_id,
|
||||
@@ -520,7 +538,7 @@ async def router_object_property_detail(
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == id) & (Object.deleted_at == None),
|
||||
load=Object.file_metadata,
|
||||
load=Object.metadata_entries,
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
@@ -543,35 +561,301 @@ async def router_object_property_detail(
|
||||
total_views = sum(s.views for s in shares)
|
||||
total_downloads = sum(s.downloads for s in shares)
|
||||
|
||||
# 获取物理文件引用计数
|
||||
# 获取物理文件信息(引用计数、校验和)
|
||||
reference_count = 1
|
||||
checksum_md5: str | None = None
|
||||
checksum_sha256: str | None = None
|
||||
if obj.physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||
if physical_file:
|
||||
reference_count = physical_file.reference_count
|
||||
checksum_md5 = physical_file.checksum_md5
|
||||
checksum_sha256 = physical_file.checksum_sha256
|
||||
|
||||
# 构建响应
|
||||
response = ObjectPropertyDetailResponse(
|
||||
# 构建元数据字典(排除内部命名空间)
|
||||
metadata: dict[str, str] = {}
|
||||
for entry in obj.metadata_entries:
|
||||
ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
||||
if ns not in INTERNAL_NAMESPACES:
|
||||
metadata[entry.name] = entry.value
|
||||
|
||||
return ObjectPropertyDetailResponse(
|
||||
id=obj.id,
|
||||
name=obj.name,
|
||||
type=obj.type,
|
||||
size=obj.size,
|
||||
mime_type=obj.mime_type,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
parent_id=obj.parent_id,
|
||||
checksum_md5=checksum_md5,
|
||||
checksum_sha256=checksum_sha256,
|
||||
policy_name=policy_name,
|
||||
share_count=share_count,
|
||||
total_views=total_views,
|
||||
total_downloads=total_downloads,
|
||||
reference_count=reference_count,
|
||||
metadatas=metadata,
|
||||
)
|
||||
|
||||
# 添加文件元数据
|
||||
if obj.file_metadata:
|
||||
response.mime_type = obj.file_metadata.mime_type
|
||||
response.width = obj.file_metadata.width
|
||||
response.height = obj.file_metadata.height
|
||||
response.duration = obj.file_metadata.duration
|
||||
response.checksum_md5 = obj.file_metadata.checksum_md5
|
||||
|
||||
return response
|
||||
@object_router.patch(
|
||||
path='/{object_id}/policy',
|
||||
summary='切换对象存储策略',
|
||||
)
|
||||
async def router_object_switch_policy(
|
||||
session: SessionDep,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
object_id: UUID,
|
||||
request: ObjectSwitchPolicyRequest,
|
||||
) -> TaskSummaryBase:
|
||||
"""
|
||||
切换对象的存储策略
|
||||
|
||||
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
|
||||
目录:更新目录 policy_id(新文件使用新策略);
|
||||
若 is_migrate_existing=True,额外创建后台任务迁移所有已有文件。
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 404: 对象不存在
|
||||
- 403: 无权操作此对象 / 用户组无权使用目标策略
|
||||
- 400: 目标策略与当前相同 / 不能对根目录操作
|
||||
"""
|
||||
user_id = user.id
|
||||
|
||||
# 查找对象
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == object_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not obj:
|
||||
http_exceptions.raise_not_found("对象不存在")
|
||||
if obj.owner_id != user_id:
|
||||
http_exceptions.raise_forbidden("无权操作此对象")
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 根目录不能直接切换策略(应通过子对象或子目录操作)
|
||||
if obj.parent_id is None:
|
||||
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
|
||||
|
||||
# 校验目标策略存在
|
||||
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
|
||||
if not dest_policy:
|
||||
http_exceptions.raise_not_found("目标存储策略不存在")
|
||||
|
||||
# 校验用户组权限
|
||||
group: Group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
allowed_ids = {p.id for p in group.policies}
|
||||
if request.policy_id not in allowed_ids:
|
||||
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
|
||||
|
||||
# 不能切换到相同策略
|
||||
if obj.policy_id == request.policy_id:
|
||||
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
|
||||
|
||||
# 保存必要的属性,避免 save 后对象过期
|
||||
src_policy_id = obj.policy_id
|
||||
obj_id = obj.id
|
||||
obj_is_file = obj.type == ObjectType.FILE
|
||||
dest_policy_id = request.policy_id
|
||||
dest_policy_name = dest_policy.name
|
||||
|
||||
# 创建任务记录
|
||||
task = Task(
|
||||
type=TaskType.POLICY_MIGRATE,
|
||||
status=TaskStatus.QUEUED,
|
||||
user_id=user_id,
|
||||
)
|
||||
task = await task.save(session)
|
||||
task_id = task.id
|
||||
|
||||
task_props = TaskProps(
|
||||
task_id=task_id,
|
||||
source_policy_id=src_policy_id,
|
||||
dest_policy_id=dest_policy_id,
|
||||
object_id=obj_id,
|
||||
)
|
||||
task_props = await task_props.save(session)
|
||||
|
||||
if obj_is_file:
|
||||
# 文件:后台迁移
|
||||
async def _run_file_migration() -> None:
|
||||
async with DatabaseManager.session() as bg_session:
|
||||
bg_obj = await Object.get(bg_session, Object.id == obj_id)
|
||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
|
||||
|
||||
background_tasks.add_task(_run_file_migration)
|
||||
else:
|
||||
# 目录:先更新目录自身的 policy_id
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
obj.policy_id = dest_policy_id
|
||||
obj = await obj.save(session)
|
||||
|
||||
if request.is_migrate_existing:
|
||||
# 后台迁移所有已有文件
|
||||
async def _run_dir_migration() -> None:
|
||||
async with DatabaseManager.session() as bg_session:
|
||||
bg_folder = await Object.get(bg_session, Object.id == obj_id)
|
||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
|
||||
|
||||
background_tasks.add_task(_run_dir_migration)
|
||||
else:
|
||||
# 不迁移已有文件,直接完成任务
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.progress = 100
|
||||
task = await task.save(session)
|
||||
|
||||
# 重新获取 task 以读取最新状态
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
|
||||
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
|
||||
|
||||
return TaskSummaryBase(
|
||||
id=task.id,
|
||||
type=task.type,
|
||||
status=task.status,
|
||||
progress=task.progress,
|
||||
error=task.error,
|
||||
user_id=task.user_id,
|
||||
created_at=task.created_at,
|
||||
updated_at=task.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 元数据端点 ====================
|
||||
|
||||
@object_router.get(
|
||||
path='/{object_id}/metadata',
|
||||
summary='获取对象元数据',
|
||||
description='获取对象的元数据键值对,可按命名空间过滤。',
|
||||
)
|
||||
async def router_get_object_metadata(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
object_id: UUID,
|
||||
ns: str | None = None,
|
||||
) -> MetadataResponse:
|
||||
"""
|
||||
获取对象元数据端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
查询参数:
|
||||
- ns: 逗号分隔的命名空间列表(如 exif,stream),不传返回所有非内部命名空间
|
||||
|
||||
错误处理:
|
||||
- 404: 对象不存在
|
||||
- 403: 无权查看此对象
|
||||
"""
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == object_id) & (Object.deleted_at == None),
|
||||
load=Object.metadata_entries,
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
|
||||
if obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权查看此对象")
|
||||
|
||||
# 解析命名空间过滤
|
||||
ns_filter: set[str] | None = None
|
||||
if ns:
|
||||
ns_filter = {n.strip() for n in ns.split(",") if n.strip()}
|
||||
# 不允许查看内部命名空间
|
||||
ns_filter -= INTERNAL_NAMESPACES
|
||||
|
||||
# 构建元数据字典
|
||||
metadata: dict[str, str] = {}
|
||||
for entry in obj.metadata_entries:
|
||||
entry_ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
||||
if entry_ns in INTERNAL_NAMESPACES:
|
||||
continue
|
||||
if ns_filter is not None and entry_ns not in ns_filter:
|
||||
continue
|
||||
metadata[entry.name] = entry.value
|
||||
|
||||
return MetadataResponse(metadatas=metadata)
|
||||
|
||||
|
||||
@object_router.patch(
|
||||
path='/{object_id}/metadata',
|
||||
summary='批量更新对象元数据',
|
||||
description='批量设置或删除对象的元数据条目。仅允许修改 custom: 命名空间。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_patch_object_metadata(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
object_id: UUID,
|
||||
request: MetadataPatchRequest,
|
||||
) -> None:
|
||||
"""
|
||||
批量更新对象元数据端点
|
||||
|
||||
请求体中值为 None 的键将被删除,其余键将被设置/更新。
|
||||
用户只能修改 custom: 命名空间的条目。
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 400: 尝试修改非 custom: 命名空间的条目
|
||||
- 404: 对象不存在
|
||||
- 403: 无权操作此对象
|
||||
"""
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == object_id) & (Object.deleted_at == None),
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
|
||||
if obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此对象")
|
||||
|
||||
for patch in request.patches:
|
||||
# 验证命名空间
|
||||
patch_ns = patch.key.split(":")[0] if ":" in patch.key else ""
|
||||
if patch_ns not in USER_WRITABLE_NAMESPACES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不允许修改命名空间 '{patch_ns}' 的元数据,仅允许 custom: 命名空间",
|
||||
)
|
||||
|
||||
if patch.value is None:
|
||||
# 删除元数据条目
|
||||
existing = await ObjectMetadata.get(
|
||||
session,
|
||||
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
||||
)
|
||||
if existing:
|
||||
await ObjectMetadata.delete(session, instances=existing)
|
||||
else:
|
||||
# 设置/更新元数据条目
|
||||
existing = await ObjectMetadata.get(
|
||||
session,
|
||||
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
||||
)
|
||||
if existing:
|
||||
existing.value = patch.value
|
||||
existing = await existing.save(session)
|
||||
else:
|
||||
entry = ObjectMetadata(
|
||||
object_id=object_id,
|
||||
name=patch.key,
|
||||
value=patch.value,
|
||||
is_public=True,
|
||||
)
|
||||
entry = await entry.save(session)
|
||||
|
||||
l.info(f"用户 {user.id} 更新了对象 {object_id} 的 {len(request.patches)} 条元数据")
|
||||
|
||||
168
routers/api/v1/object/custom_property/__init__.py
Normal file
168
routers/api/v1/object/custom_property/__init__.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
用户自定义属性定义路由
|
||||
|
||||
提供自定义属性模板的增删改查功能。
|
||||
用户可以定义类型化的属性模板(如标签、评分、分类等),
|
||||
然后通过元数据 PATCH 端点为对象设置属性值。
|
||||
|
||||
路由前缀:/custom_property
|
||||
"""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
CustomPropertyDefinition,
|
||||
CustomPropertyCreateRequest,
|
||||
CustomPropertyUpdateRequest,
|
||||
CustomPropertyResponse,
|
||||
User,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/custom_property",
|
||||
tags=["custom_property"],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
path='',
|
||||
summary='获取自定义属性定义列表',
|
||||
description='获取当前用户的所有自定义属性定义,按 sort_order 排序。',
|
||||
)
|
||||
async def router_list_custom_properties(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[CustomPropertyResponse]:
|
||||
"""
|
||||
获取自定义属性定义列表端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
返回当前用户定义的所有自定义属性模板。
|
||||
"""
|
||||
definitions = await CustomPropertyDefinition.get(
|
||||
session,
|
||||
CustomPropertyDefinition.owner_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
return [
|
||||
CustomPropertyResponse(
|
||||
id=d.id,
|
||||
name=d.name,
|
||||
type=d.type,
|
||||
icon=d.icon,
|
||||
options=d.options,
|
||||
default_value=d.default_value,
|
||||
sort_order=d.sort_order,
|
||||
)
|
||||
for d in sorted(definitions, key=lambda x: x.sort_order)
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
path='',
|
||||
summary='创建自定义属性定义',
|
||||
description='创建一个新的自定义属性模板。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_create_custom_property(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CustomPropertyCreateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
创建自定义属性定义端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 400: 请求数据无效
|
||||
- 409: 同名属性已存在
|
||||
"""
|
||||
# 检查同名属性
|
||||
existing = await CustomPropertyDefinition.get(
|
||||
session,
|
||||
(CustomPropertyDefinition.owner_id == user.id) &
|
||||
(CustomPropertyDefinition.name == request.name),
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名自定义属性已存在")
|
||||
|
||||
definition = CustomPropertyDefinition(
|
||||
owner_id=user.id,
|
||||
name=request.name,
|
||||
type=request.type,
|
||||
icon=request.icon,
|
||||
options=request.options,
|
||||
default_value=request.default_value,
|
||||
)
|
||||
definition = await definition.save(session)
|
||||
|
||||
l.info(f"用户 {user.id} 创建了自定义属性: {request.name}")
|
||||
|
||||
|
||||
@router.patch(
|
||||
path='/{id}',
|
||||
summary='更新自定义属性定义',
|
||||
description='更新自定义属性模板的名称、图标、选项等。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_update_custom_property(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
request: CustomPropertyUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新自定义属性定义端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 属性定义不存在
|
||||
- 403: 无权操作此属性
|
||||
"""
|
||||
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
||||
|
||||
if definition.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此属性")
|
||||
|
||||
definition = await definition.update(session, request)
|
||||
|
||||
l.info(f"用户 {user.id} 更新了自定义属性: {id}")
|
||||
|
||||
|
||||
@router.delete(
|
||||
path='/{id}',
|
||||
summary='删除自定义属性定义',
|
||||
description='删除自定义属性模板。注意:不会自动清理已使用该属性的元数据条目。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_delete_custom_property(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除自定义属性定义端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 属性定义不存在
|
||||
- 403: 无权操作此属性
|
||||
"""
|
||||
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
||||
|
||||
if definition.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此属性")
|
||||
|
||||
await CustomPropertyDefinition.delete(session, instances=definition)
|
||||
|
||||
l.info(f"用户 {user.id} 删除了自定义属性: {id}")
|
||||
@@ -45,12 +45,7 @@ async def router_share_get(
|
||||
4. 返回分享详情(含文件树和分享者信息)
|
||||
"""
|
||||
# 1. 查询分享(预加载 user 和 object)
|
||||
share = await Share.get(
|
||||
session, Share.id == id,
|
||||
load=[Share.user, Share.object],
|
||||
)
|
||||
if not share:
|
||||
http_exceptions.raise_not_found(detail="分享不存在或已被取消")
|
||||
share = await Share.get_exist_one(session, id, load=[Share.user, Share.object])
|
||||
|
||||
# 2. 检查过期
|
||||
now = datetime.now()
|
||||
@@ -474,16 +469,29 @@ def router_share_update(id: str) -> ResponseBase:
|
||||
path='/{id}',
|
||||
summary='删除分享',
|
||||
description='Delete a share by ID.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
status_code=204,
|
||||
)
|
||||
def router_share_delete(id: str) -> ResponseBase:
|
||||
async def router_share_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a share by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the share to be deleted.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deleted share.
|
||||
删除分享
|
||||
|
||||
认证:需要 JWT token
|
||||
|
||||
流程:
|
||||
1. 通过分享ID查找分享
|
||||
2. 验证分享属于当前用户
|
||||
3. 删除分享记录
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
share = await Share.get_exist_one(session, id)
|
||||
if share.user_id != user.id:
|
||||
http_exceptions.raise_forbidden(detail="无权删除此分享")
|
||||
|
||||
user_id = user.id
|
||||
share_code = share.code
|
||||
await Share.delete(session, share)
|
||||
|
||||
l.info(f"用户 {user_id} 删除了分享: {share_code}")
|
||||
@@ -82,7 +82,8 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
(Setting.type == SettingsType.REGISTER) |
|
||||
(Setting.type == SettingsType.CAPTCHA) |
|
||||
(Setting.type == SettingsType.AUTH) |
|
||||
(Setting.type == SettingsType.OAUTH),
|
||||
(Setting.type == SettingsType.OAUTH) |
|
||||
(Setting.type == SettingsType.AVATAR),
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
@@ -122,6 +123,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
password_required=s.get("auth_password_required") == "1",
|
||||
phone_binding_required=s.get("auth_phone_binding_required") == "1",
|
||||
email_binding_required=s.get("auth_email_binding_required") == "1",
|
||||
avatar_max_size=int(s["avatar_size"]),
|
||||
footer_code=s.get("footer_code"),
|
||||
tos_url=s.get("tos_url"),
|
||||
privacy_url=s.get("privacy_url"),
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from itsdangerous import URLSafeTimedSerializer
|
||||
from loguru import logger
|
||||
from webauthn import (
|
||||
@@ -233,7 +234,7 @@ async def router_user_register(
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id
|
||||
await new_user.save(session)
|
||||
new_user = await new_user.save(session)
|
||||
|
||||
# 7. 创建 AuthIdentity
|
||||
hashed_password = Password.hash(request.credential) if request.credential else None
|
||||
@@ -245,13 +246,14 @@ async def router_user_register(
|
||||
is_verified=False,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
# 8. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
# 8. 创建用户根目录(使用用户组关联的第一个存储策略)
|
||||
await session.refresh(default_group, ['policies'])
|
||||
if not default_group.policies:
|
||||
logger.error("默认用户组未关联任何存储策略")
|
||||
http_exceptions.raise_internal_error()
|
||||
default_policy = default_group.policies[0]
|
||||
|
||||
await sqlmodels.Object(
|
||||
name="/",
|
||||
@@ -318,7 +320,7 @@ async def router_user_magic_link(
|
||||
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
||||
|
||||
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token})
|
||||
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)")
|
||||
logger.info(f"Magic Link token 已为 {request.email} 生成 (邮件发送待实现)")
|
||||
|
||||
|
||||
@user_router.post(
|
||||
@@ -357,20 +359,78 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
@user_router.get(
|
||||
path='/avatar/{id}/{size}',
|
||||
summary='获取用户头像',
|
||||
description='Get user avatar by ID and size.',
|
||||
response_model=None,
|
||||
)
|
||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||
async def router_user_avatar(
|
||||
session: SessionDep,
|
||||
id: UUID,
|
||||
size: int = 128,
|
||||
) -> FileResponse | RedirectResponse:
|
||||
"""
|
||||
Get user avatar by ID and size.
|
||||
获取指定用户指定尺寸的头像(公开端点,无需认证)
|
||||
|
||||
Args:
|
||||
id (str): The user ID.
|
||||
size (int): The size of the avatar image.
|
||||
路径参数:
|
||||
- id: 用户 UUID
|
||||
- size: 请求的头像尺寸(px),默认 128
|
||||
|
||||
Returns:
|
||||
str: A Base64 encoded string of the user avatar image.
|
||||
行为:
|
||||
- default: 302 重定向到 Gravatar identicon
|
||||
- gravatar: 302 重定向到 Gravatar(使用用户邮箱 MD5)
|
||||
- file: 返回本地 WebP 文件
|
||||
|
||||
响应:
|
||||
- 200: image/webp(file 模式)
|
||||
- 302: 重定向到外部 URL(default/gravatar 模式)
|
||||
- 404: 用户不存在
|
||||
|
||||
缓存:Cache-Control: public, max-age=3600
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
import aiofiles.os
|
||||
|
||||
from service.avatar import (
|
||||
get_avatar_file_path,
|
||||
get_avatar_settings,
|
||||
gravatar_url,
|
||||
resolve_avatar_size,
|
||||
)
|
||||
|
||||
user = await sqlmodels.User.get(session, sqlmodels.User.id == id)
|
||||
if not user:
|
||||
http_exceptions.raise_not_found("用户不存在")
|
||||
|
||||
avatar_path, _, size_l, size_m, size_s = await get_avatar_settings(session)
|
||||
|
||||
if user.avatar == "file":
|
||||
size_label = resolve_avatar_size(size, size_l, size_m, size_s)
|
||||
file_path = get_avatar_file_path(avatar_path, user.id, size_label)
|
||||
|
||||
if not await aiofiles.os.path.exists(file_path):
|
||||
# 文件丢失,降级为 identicon
|
||||
fallback_url = gravatar_url(str(user.id), size, "https://www.gravatar.com/")
|
||||
return RedirectResponse(url=fallback_url, status_code=302)
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
media_type="image/webp",
|
||||
headers={"Cache-Control": "public, max-age=3600"},
|
||||
)
|
||||
|
||||
elif user.avatar == "gravatar":
|
||||
gravatar_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AVATAR)
|
||||
& (sqlmodels.Setting.name == "gravatar_server"),
|
||||
)
|
||||
server = gravatar_setting.value if gravatar_setting else "https://www.gravatar.com/"
|
||||
email = user.email or str(user.id)
|
||||
url = gravatar_url(email, size, server)
|
||||
return RedirectResponse(url=url, status_code=302)
|
||||
|
||||
else:
|
||||
# default: identicon
|
||||
email_or_id = user.email or str(user.id)
|
||||
url = gravatar_url(email_or_id, size, "https://www.gravatar.com/")
|
||||
return RedirectResponse(url=url, status_code=302)
|
||||
|
||||
#####################
|
||||
# 需要登录的接口
|
||||
@@ -434,9 +494,24 @@ async def router_user_storage(
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
# [TODO] 总空间加上用户购买的额外空间
|
||||
# 查询用户所有未过期容量包的 size 总和
|
||||
from datetime import datetime
|
||||
from sqlalchemy import func, select, and_, or_
|
||||
|
||||
total: int = group.max_storage
|
||||
now = datetime.now()
|
||||
stmt = select(func.coalesce(func.sum(sqlmodels.StoragePack.size), 0)).where(
|
||||
and_(
|
||||
sqlmodels.StoragePack.user_id == user.id,
|
||||
or_(
|
||||
sqlmodels.StoragePack.expired_time.is_(None),
|
||||
sqlmodels.StoragePack.expired_time > now,
|
||||
),
|
||||
)
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
active_packs_total: int = result.scalar_one()
|
||||
|
||||
total: int = group.max_storage + active_packs_total
|
||||
used: int = user.storage
|
||||
free: int = max(0, total - used)
|
||||
|
||||
@@ -578,7 +653,7 @@ async def router_user_authn_finish(
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
return authn.to_detail_response()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
|
||||
import sqlmodels
|
||||
@@ -13,6 +13,7 @@ from sqlmodels import (
|
||||
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
||||
ChangePasswordRequest,
|
||||
AuthnDetailResponse, AuthnRenameRequest,
|
||||
PolicySummary,
|
||||
)
|
||||
from sqlmodels.color import ThemeColorsBase
|
||||
from sqlmodels.user_authn import UserAuthn
|
||||
@@ -31,16 +32,25 @@ user_settings_router.include_router(file_viewers_router)
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> sqlmodels.ResponseBase:
|
||||
async def router_user_settings_policies(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> list[PolicySummary]:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
获取当前用户所在组可选的存储策略列表
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
返回用户组关联的所有存储策略的摘要信息。
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
return [
|
||||
PolicySummary(
|
||||
id=p.id, name=p.name, type=p.type,
|
||||
server=p.server, max_size=p.max_size, is_private=p.is_private,
|
||||
)
|
||||
for p in group.policies
|
||||
]
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
@@ -155,34 +165,121 @@ async def router_user_settings(
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
status_code=204,
|
||||
)
|
||||
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
|
||||
async def router_user_settings_avatar(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
file: UploadFile = File(...),
|
||||
) -> None:
|
||||
"""
|
||||
Upload user avatar from file.
|
||||
上传头像文件
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
认证:JWT token
|
||||
请求体:multipart/form-data,file 字段
|
||||
|
||||
流程:
|
||||
1. 验证文件 MIME 类型(JPEG/PNG/GIF/WebP)
|
||||
2. 验证文件大小 <= avatar_size 设置(默认 2MB)
|
||||
3. 调用 Pillow 验证图片有效性并处理(居中裁剪、缩放 L/M/S)
|
||||
4. 保存三种尺寸的 WebP 文件
|
||||
5. 更新 User.avatar = "file"
|
||||
|
||||
错误处理:
|
||||
- 400: 文件类型不支持 / 图片无法解析
|
||||
- 413: 文件过大
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
from service.avatar import (
|
||||
ALLOWED_CONTENT_TYPES,
|
||||
get_avatar_settings,
|
||||
process_and_save_avatar,
|
||||
)
|
||||
|
||||
# 验证 MIME 类型
|
||||
if file.content_type not in ALLOWED_CONTENT_TYPES:
|
||||
http_exceptions.raise_bad_request(
|
||||
f"不支持的图片格式,允许: {', '.join(ALLOWED_CONTENT_TYPES)}"
|
||||
)
|
||||
|
||||
# 读取并验证大小
|
||||
_, max_upload_size, _, _, _ = await get_avatar_settings(session)
|
||||
raw_bytes = await file.read()
|
||||
if len(raw_bytes) > max_upload_size:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"文件过大,最大允许 {max_upload_size} 字节",
|
||||
)
|
||||
|
||||
# 处理并保存(内部会验证图片有效性,无效抛出 ValueError)
|
||||
try:
|
||||
await process_and_save_avatar(session, user.id, raw_bytes)
|
||||
except ValueError as e:
|
||||
http_exceptions.raise_bad_request(str(e))
|
||||
|
||||
# 更新用户头像字段
|
||||
user.avatar = "file"
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
summary='设定为 Gravatar 头像',
|
||||
status_code=204,
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> None:
|
||||
async def router_user_settings_avatar_gravatar(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> None:
|
||||
"""
|
||||
Set user avatar to Gravatar.
|
||||
将头像切换为 Gravatar
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
认证:JWT token
|
||||
|
||||
流程:
|
||||
1. 验证用户有邮箱(Gravatar 基于邮箱 MD5)
|
||||
2. 如果当前是 FILE 头像,删除本地文件
|
||||
3. 更新 User.avatar = "gravatar"
|
||||
|
||||
错误处理:
|
||||
- 400: 用户没有邮箱
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
from service.avatar import delete_avatar_files
|
||||
|
||||
if not user.email:
|
||||
http_exceptions.raise_bad_request("Gravatar 需要邮箱,请先绑定邮箱")
|
||||
|
||||
if user.avatar == "file":
|
||||
await delete_avatar_files(session, user.id)
|
||||
|
||||
user.avatar = "gravatar"
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.delete(
|
||||
path='/avatar',
|
||||
summary='重置头像为默认',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_settings_avatar_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> None:
|
||||
"""
|
||||
重置头像为默认
|
||||
|
||||
认证:JWT token
|
||||
|
||||
流程:
|
||||
1. 如果当前是 FILE 头像,删除本地文件
|
||||
2. 更新 User.avatar = "default"
|
||||
"""
|
||||
from service.avatar import delete_avatar_files
|
||||
|
||||
if user.avatar == "file":
|
||||
await delete_avatar_files(session, user.id)
|
||||
|
||||
user.avatar = "default"
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
@@ -224,7 +321,7 @@ async def router_user_settings_theme(
|
||||
user.color_error = request.theme_colors.error
|
||||
user.color_neutral = request.theme_colors.neutral
|
||||
|
||||
await user.save(session)
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
@@ -261,7 +358,7 @@ async def router_user_settings_change_password(
|
||||
http_exceptions.raise_forbidden("当前密码错误")
|
||||
|
||||
email_identity.credential = Password.hash(request.new_password)
|
||||
await email_identity.save(session)
|
||||
email_identity = await email_identity.save(session)
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
@@ -295,7 +392,7 @@ async def router_user_settings_patch(
|
||||
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
|
||||
|
||||
setattr(user, option.value, value)
|
||||
await user.save(session)
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
@@ -357,7 +454,7 @@ async def router_user_settings_2fa_enable(
|
||||
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
|
||||
extra["two_factor"] = secret
|
||||
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
|
||||
await email_identity.save(session)
|
||||
email_identity = await email_identity.save(session)
|
||||
|
||||
|
||||
# ==================== 认证身份管理 ====================
|
||||
|
||||
@@ -79,9 +79,7 @@ async def set_default_viewer(
|
||||
|
||||
if existing:
|
||||
existing.app_id = request.app_id
|
||||
existing = await existing.save(session)
|
||||
# 重新加载 app 关系
|
||||
await session.refresh(existing, attribute_names=["app"])
|
||||
existing = await existing.save(session, load=UserFileAppDefault.app)
|
||||
return existing.to_response()
|
||||
else:
|
||||
new_default = UserFileAppDefault(
|
||||
@@ -89,9 +87,7 @@ async def set_default_viewer(
|
||||
extension=normalized_ext,
|
||||
app_id=request.app_id,
|
||||
)
|
||||
new_default = await new_default.save(session)
|
||||
# 重新加载 app 关系
|
||||
await session.refresh(new_default, attribute_names=["app"])
|
||||
new_default = await new_default.save(session, load=UserFileAppDefault.app)
|
||||
return new_default.to_response()
|
||||
|
||||
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
vas_router = APIRouter(
|
||||
prefix="/vas",
|
||||
tags=["vas"]
|
||||
)
|
||||
|
||||
@vas_router.get(
|
||||
path='/pack',
|
||||
summary='获取容量包及配额信息',
|
||||
description='Get information about storage packs and quotas.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_pack() -> ResponseBase:
|
||||
"""
|
||||
Get information about storage packs and quotas.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for storage packs and quotas.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/product',
|
||||
summary='获取商品信息,同时返回支付信息',
|
||||
description='Get product information along with payment details.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_product() -> ResponseBase:
|
||||
"""
|
||||
Get product information along with payment details.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for products and payment information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.post(
|
||||
path='/order',
|
||||
summary='新建支付订单',
|
||||
description='Create an order for a product.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_order() -> ResponseBase:
|
||||
"""
|
||||
Create an order for a product.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created order.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/order/{id}',
|
||||
summary='查询订单状态',
|
||||
description='Get information about a specific payment order by ID.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_order_get(id: str) -> ResponseBase:
|
||||
"""
|
||||
Get information about a specific payment order by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the order to retrieve information for.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the specified order.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/redeem',
|
||||
summary='获取兑换码信息',
|
||||
description='Get information about a specific redemption code.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_redeem(code: str) -> ResponseBase:
|
||||
"""
|
||||
Get information about a specific redemption code.
|
||||
|
||||
Args:
|
||||
code (str): The redemption code to retrieve information for.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the specified redemption code.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.post(
|
||||
path='/redeem',
|
||||
summary='执行兑换',
|
||||
description='Redeem a redemption code for a product or service.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_redeem_post() -> ResponseBase:
|
||||
"""
|
||||
Redeem a redemption code for a product or service.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the redeemed code.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -1,110 +1,207 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from sqlmodels import ResponseBase
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
Object,
|
||||
User,
|
||||
WebDAV,
|
||||
WebDAVAccountResponse,
|
||||
WebDAVCreateRequest,
|
||||
WebDAVUpdateRequest,
|
||||
)
|
||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password
|
||||
|
||||
# WebDAV 管理路由
|
||||
webdav_router = APIRouter(
|
||||
prefix='/webdav',
|
||||
tags=["webdav"],
|
||||
)
|
||||
|
||||
|
||||
def _check_webdav_enabled(user: User) -> None:
|
||||
"""检查用户组是否启用了 WebDAV 功能"""
|
||||
if not user.group.web_dav_enabled:
|
||||
http_exceptions.raise_forbidden("WebDAV 功能未启用")
|
||||
|
||||
|
||||
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
|
||||
"""将 WebDAV 数据库模型转换为响应 DTO"""
|
||||
return WebDAVAccountResponse(
|
||||
id=account.id,
|
||||
name=account.name,
|
||||
root=account.root,
|
||||
readonly=account.readonly,
|
||||
use_proxy=account.use_proxy,
|
||||
created_at=str(account.created_at),
|
||||
updated_at=str(account.updated_at),
|
||||
)
|
||||
|
||||
|
||||
@webdav_router.get(
|
||||
path='/accounts',
|
||||
summary='获取账号信息',
|
||||
description='Get account information for WebDAV.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
summary='获取账号列表',
|
||||
)
|
||||
def router_webdav_accounts() -> ResponseBase:
|
||||
async def list_accounts(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[WebDAVAccountResponse]:
|
||||
"""
|
||||
Get account information for WebDAV.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the account information.
|
||||
列出当前用户所有 WebDAV 账户
|
||||
|
||||
认证:JWT Bearer Token
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
accounts: list[WebDAV] = await WebDAV.get(
|
||||
session,
|
||||
WebDAV.user_id == user_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
return [_to_response(a) for a in accounts]
|
||||
|
||||
|
||||
@webdav_router.post(
|
||||
path='/accounts',
|
||||
summary='新建账号',
|
||||
description='Create a new WebDAV account.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
summary='创建账号',
|
||||
status_code=201,
|
||||
)
|
||||
def router_webdav_create_account() -> ResponseBase:
|
||||
async def create_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: WebDAVCreateRequest,
|
||||
) -> WebDAVAccountResponse:
|
||||
"""
|
||||
Create a new WebDAV account.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created account.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
创建 WebDAV 账户
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/accounts/{id}',
|
||||
summary='删除账号',
|
||||
description='Delete a WebDAV account by its ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_account(id: str) -> ResponseBase:
|
||||
"""
|
||||
Delete a WebDAV account by its ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the account to be deleted.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
认证:JWT Bearer Token
|
||||
|
||||
@webdav_router.post(
|
||||
path='/mount',
|
||||
summary='新建目录挂载',
|
||||
description='Create a new WebDAV mount point.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_create_mount() -> ResponseBase:
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 400: 根目录路径不存在或不是目录
|
||||
- 409: 账户名已存在
|
||||
"""
|
||||
Create a new WebDAV mount point.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created mount point.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
# 验证账户名唯一
|
||||
existing = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.name == request.name) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict("账户名已存在")
|
||||
|
||||
# 验证 root 路径存在且为目录
|
||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||
if not root_obj or not root_obj.is_folder:
|
||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||
|
||||
# 创建账户
|
||||
account = WebDAV(
|
||||
name=request.name,
|
||||
password=Password.hash(request.password),
|
||||
root=request.root,
|
||||
readonly=request.readonly,
|
||||
use_proxy=request.use_proxy,
|
||||
user_id=user_id,
|
||||
)
|
||||
account = await account.save(session)
|
||||
|
||||
l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
|
||||
return _to_response(account)
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/mount/{id}',
|
||||
summary='删除目录挂载',
|
||||
description='Delete a WebDAV mount point by its ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_mount(id: str) -> ResponseBase:
|
||||
"""
|
||||
Delete a WebDAV mount point by its ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the mount point to be deleted.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.patch(
|
||||
path='accounts/{id}',
|
||||
summary='更新账号信息',
|
||||
description='Update WebDAV account information by ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
path='/accounts/{account_id}',
|
||||
summary='更新账号',
|
||||
)
|
||||
def router_webdav_update_account(id: str) -> ResponseBase:
|
||||
async def update_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
account_id: int,
|
||||
request: WebDAVUpdateRequest,
|
||||
) -> WebDAVAccountResponse:
|
||||
"""
|
||||
Update WebDAV account information by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the account to be updated.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the updated account.
|
||||
更新 WebDAV 账户
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 404: 账户不存在
|
||||
- 400: 根目录路径不存在或不是目录
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if not account:
|
||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||
|
||||
# 验证 root 路径
|
||||
if request.root is not None:
|
||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||
if not root_obj or not root_obj.is_folder:
|
||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||
|
||||
# 密码哈希后原地替换,update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
|
||||
is_password_changed = request.password is not None
|
||||
if is_password_changed:
|
||||
request.password = Password.hash(request.password)
|
||||
|
||||
account = await account.update(session, request)
|
||||
|
||||
# 密码变更时清除认证缓存
|
||||
if is_password_changed:
|
||||
await WebDAVAuthCache.invalidate_account(user_id, account.name)
|
||||
|
||||
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
|
||||
return _to_response(account)
|
||||
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/accounts/{account_id}',
|
||||
summary='删除账号',
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
account_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
删除 WebDAV 账户
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 404: 账户不存在
|
||||
"""
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if not account:
|
||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||
|
||||
account_name = account.name
|
||||
await WebDAV.delete(session, account)
|
||||
|
||||
# 清除认证缓存
|
||||
await WebDAVAuthCache.invalidate_account(user_id, account_name)
|
||||
|
||||
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# WebDAV 操作路由
|
||||
35
routers/dav/__init__.py
Normal file
35
routers/dav/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
WebDAV 协议入口
|
||||
|
||||
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
|
||||
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
|
||||
"""
|
||||
from a2wsgi import WSGIMiddleware
|
||||
from wsgidav.wsgidav_app import WsgiDAVApp
|
||||
|
||||
from .domain_controller import DiskNextDomainController
|
||||
from .provider import DiskNextDAVProvider
|
||||
|
||||
_wsgidav_config: dict[str, object] = {
|
||||
"provider_mapping": {
|
||||
"/": DiskNextDAVProvider(),
|
||||
},
|
||||
"http_authenticator": {
|
||||
"domain_controller": DiskNextDomainController,
|
||||
"accept_basic": True,
|
||||
"accept_digest": False,
|
||||
"default_to_digest": False,
|
||||
},
|
||||
"verbose": 1,
|
||||
# 使用 WsgiDAV 内置的内存锁管理器
|
||||
"lock_storage": True,
|
||||
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
|
||||
"dir_browser": {
|
||||
"enable": False,
|
||||
},
|
||||
}
|
||||
|
||||
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
|
||||
|
||||
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
|
||||
"""ASGI 应用,挂载到 /dav 路径"""
|
||||
148
routers/dav/domain_controller.py
Normal file
148
routers/dav/domain_controller.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
WebDAV 认证控制器
|
||||
|
||||
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
|
||||
通过 DiskNext 的 WebDAV 账户模型进行认证。
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
"""
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from wsgidav.dc.base_dc import BaseDomainController
|
||||
|
||||
from routers.dav.provider import EventLoopRef, _get_session
|
||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from sqlmodels.webdav import WebDAV
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
async def _authenticate(
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
) -> tuple[UUID, int] | None:
|
||||
"""
|
||||
异步认证 WebDAV 用户。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 明文密码
|
||||
:return: (user_id, webdav_id) 或 None
|
||||
"""
|
||||
# 1. 查缓存
|
||||
cached = await WebDAVAuthCache.get(email, account_name, password)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. 缓存未命中,查库验证
|
||||
async with _get_session() as session:
|
||||
user = await User.get(session, User.email == email, load=User.group)
|
||||
if not user:
|
||||
return None
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
return None
|
||||
if not user.group.web_dav_enabled:
|
||||
return None
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
|
||||
)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
status = Password.verify(account.password, password)
|
||||
if status == PasswordStatus.INVALID:
|
||||
return None
|
||||
|
||||
user_id: UUID = user.id
|
||||
webdav_id: int = account.id
|
||||
|
||||
# 3. 写入缓存
|
||||
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
|
||||
|
||||
return user_id, webdav_id
|
||||
|
||||
|
||||
class DiskNextDomainController(BaseDomainController):
|
||||
"""
|
||||
DiskNext WebDAV 认证控制器
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
密码: WebDAV 账户密码(创建账户时设置)
|
||||
"""
|
||||
|
||||
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
|
||||
super().__init__(wsgidav_app, config)
|
||||
|
||||
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
|
||||
"""返回 realm 名称"""
|
||||
return "DiskNext WebDAV"
|
||||
|
||||
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
|
||||
"""所有请求都需要认证"""
|
||||
return True
|
||||
|
||||
def is_share_anonymous(self, path_info: str) -> bool:
|
||||
"""不支持匿名访问"""
|
||||
return False
|
||||
|
||||
def supports_http_digest_auth(self) -> bool:
|
||||
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
|
||||
return False
|
||||
|
||||
def basic_auth_user(
|
||||
self,
|
||||
realm: str,
|
||||
user_name: str,
|
||||
password: str,
|
||||
environ: dict[str, object],
|
||||
) -> bool:
|
||||
"""
|
||||
HTTP Basic Auth 认证。
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
|
||||
"""
|
||||
# 解析用户名
|
||||
if "/" not in user_name:
|
||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||
return False
|
||||
|
||||
email, account_name = user_name.split("/", 1)
|
||||
if not email or not account_name:
|
||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||
return False
|
||||
|
||||
# 在 WSGI 线程中调用异步认证
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_authenticate(email, account_name, password),
|
||||
EventLoopRef.get(),
|
||||
)
|
||||
result = future.result()
|
||||
|
||||
if result is None:
|
||||
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
|
||||
return False
|
||||
|
||||
user_id, webdav_id = result
|
||||
|
||||
# 将认证信息存入 environ,供 Provider 使用
|
||||
environ["disknext.user_id"] = user_id
|
||||
environ["disknext.webdav_id"] = webdav_id
|
||||
environ["disknext.email"] = email
|
||||
environ["disknext.account_name"] = account_name
|
||||
|
||||
return True
|
||||
|
||||
def digest_auth_user(
|
||||
self,
|
||||
realm: str,
|
||||
user_name: str,
|
||||
environ: dict[str, object],
|
||||
) -> bool:
|
||||
"""不支持 Digest 认证"""
|
||||
return False
|
||||
645
routers/dav/provider.py
Normal file
645
routers/dav/provider.py
Normal file
@@ -0,0 +1,645 @@
|
||||
"""
|
||||
DiskNext WebDAV 存储 Provider
|
||||
|
||||
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
|
||||
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from wsgidav.dav_error import (
|
||||
DAVError,
|
||||
HTTP_FORBIDDEN,
|
||||
HTTP_INSUFFICIENT_STORAGE,
|
||||
HTTP_NOT_FOUND,
|
||||
)
|
||||
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
|
||||
|
||||
from service.storage import LocalStorageService, adjust_user_storage
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.physical_file import PhysicalFile
|
||||
from sqlmodels.policy import Policy
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.webdav import WebDAV
|
||||
|
||||
|
||||
class EventLoopRef:
|
||||
"""持有主线程事件循环引用,供 WSGI 线程使用"""
|
||||
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
|
||||
|
||||
@classmethod
|
||||
async def capture(cls) -> None:
|
||||
"""在 async 上下文中调用,捕获当前事件循环"""
|
||||
cls._loop = asyncio.get_running_loop()
|
||||
|
||||
@classmethod
|
||||
def get(cls) -> asyncio.AbstractEventLoop:
|
||||
if cls._loop is None:
|
||||
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
|
||||
return cls._loop
|
||||
|
||||
|
||||
def _run_async(coro): # type: ignore[no-untyped-def]
|
||||
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
|
||||
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
|
||||
return future.result()
|
||||
|
||||
|
||||
def _get_session(): # type: ignore[no-untyped-def]
|
||||
"""获取数据库会话上下文管理器"""
|
||||
return DatabaseManager._async_session_factory()
|
||||
|
||||
|
||||
# ==================== 异步辅助函数 ====================
|
||||
|
||||
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
|
||||
"""获取 WebDAV 账户"""
|
||||
async with _get_session() as session:
|
||||
return await WebDAV.get(session, WebDAV.id == webdav_id)
|
||||
|
||||
|
||||
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
|
||||
"""根据路径获取对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get_by_path(session, user_id, path)
|
||||
|
||||
|
||||
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
|
||||
"""获取目录子对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get_children(session, user_id, parent_id)
|
||||
|
||||
|
||||
async def _get_object_by_id(object_id: UUID) -> Object | None:
|
||||
"""根据ID获取对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
|
||||
|
||||
|
||||
async def _get_user(user_id: UUID) -> User | None:
|
||||
"""获取用户(含 group 关系)"""
|
||||
async with _get_session() as session:
|
||||
return await User.get(session, User.id == user_id, load=User.group)
|
||||
|
||||
|
||||
async def _get_policy(policy_id: UUID) -> Policy | None:
|
||||
"""获取存储策略"""
|
||||
async with _get_session() as session:
|
||||
return await Policy.get(session, Policy.id == policy_id)
|
||||
|
||||
|
||||
async def _create_folder(
|
||||
name: str,
|
||||
parent_id: UUID,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> Object:
|
||||
"""创建目录对象"""
|
||||
async with _get_session() as session:
|
||||
obj = Object(
|
||||
name=name,
|
||||
type=ObjectType.FOLDER,
|
||||
size=0,
|
||||
parent_id=parent_id,
|
||||
owner_id=owner_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
obj = await obj.save(session)
|
||||
return obj
|
||||
|
||||
|
||||
async def _create_file(
|
||||
name: str,
|
||||
parent_id: UUID,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> Object:
|
||||
"""创建空文件对象"""
|
||||
async with _get_session() as session:
|
||||
obj = Object(
|
||||
name=name,
|
||||
type=ObjectType.FILE,
|
||||
size=0,
|
||||
parent_id=parent_id,
|
||||
owner_id=owner_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
obj = await obj.save(session)
|
||||
return obj
|
||||
|
||||
|
||||
async def _soft_delete_object(object_id: UUID) -> None:
|
||||
"""软删除对象(移入回收站)"""
|
||||
from service.storage import soft_delete_objects
|
||||
|
||||
async with _get_session() as session:
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
await soft_delete_objects(session, [obj])
|
||||
|
||||
|
||||
async def _finalize_upload(
|
||||
object_id: UUID,
|
||||
physical_path: str,
|
||||
size: int,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> None:
|
||||
"""上传完成后更新对象元数据和物理文件记录"""
|
||||
async with _get_session() as session:
|
||||
# 获取存储路径(相对路径)
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy or not policy.server:
|
||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||
|
||||
base_path = Path(policy.server).resolve()
|
||||
full_path = Path(physical_path).resolve()
|
||||
storage_path = str(full_path.relative_to(base_path))
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
pf = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
size=size,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
pf = await pf.save(session)
|
||||
|
||||
# 更新 Object
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
|
||||
obj = await obj.save(session)
|
||||
|
||||
# 更新用户存储用量
|
||||
if size > 0:
|
||||
await adjust_user_storage(session, owner_id, size)
|
||||
|
||||
|
||||
async def _move_object(
|
||||
object_id: UUID,
|
||||
new_parent_id: UUID,
|
||||
new_name: str,
|
||||
) -> None:
|
||||
"""移动/重命名对象"""
|
||||
async with _get_session() as session:
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
|
||||
obj = await obj.save(session)
|
||||
|
||||
|
||||
async def _copy_object_recursive(
|
||||
src_id: UUID,
|
||||
dst_parent_id: UUID,
|
||||
dst_name: str,
|
||||
owner_id: UUID,
|
||||
) -> None:
|
||||
"""递归复制对象"""
|
||||
from service.storage import copy_object_recursive
|
||||
|
||||
async with _get_session() as session:
|
||||
src = await Object.get(session, Object.id == src_id)
|
||||
if not src:
|
||||
return
|
||||
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
|
||||
|
||||
|
||||
# ==================== 辅助工具 ====================
|
||||
|
||||
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
|
||||
"""从 environ 中提取认证信息"""
|
||||
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
|
||||
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
|
||||
return user_id, webdav_id
|
||||
|
||||
|
||||
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
|
||||
"""
|
||||
将 DAV 相对路径映射到 DiskNext 绝对路径。
|
||||
|
||||
:param account_root: 账户挂载根路径,如 "/" 或 "/docs"
|
||||
:param dav_path: DAV 请求路径,如 "/" 或 "/photos/cat.jpg"
|
||||
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
|
||||
"""
|
||||
# 规范化根路径
|
||||
root = account_root.rstrip("/")
|
||||
if not root:
|
||||
root = ""
|
||||
|
||||
# 规范化 DAV 路径
|
||||
if not dav_path or dav_path == "/":
|
||||
return root + "/" if root else "/"
|
||||
|
||||
if not dav_path.startswith("/"):
|
||||
dav_path = "/" + dav_path
|
||||
|
||||
full = root + dav_path
|
||||
return full if full else "/"
|
||||
|
||||
|
||||
def _check_readonly(environ: dict[str, object]) -> None:
|
||||
"""检查账户是否只读,只读则抛出 403"""
|
||||
account = environ.get("disknext.webdav_account")
|
||||
if account and getattr(account, 'readonly', False):
|
||||
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
|
||||
|
||||
|
||||
def _check_storage_quota(user: User, additional_bytes: int) -> None:
|
||||
"""检查存储配额"""
|
||||
max_storage = user.group.max_storage
|
||||
if max_storage > 0 and user.storage + additional_bytes > max_storage:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
|
||||
|
||||
class QuotaLimitedWriter(io.RawIOBase):
|
||||
"""带配额限制的写入流包装器"""
|
||||
|
||||
def __init__(self, stream: io.BufferedWriter, max_bytes: int) -> None:
|
||||
self._stream = stream
|
||||
self._max_bytes = max_bytes
|
||||
self._bytes_written = 0
|
||||
|
||||
def writable(self) -> bool:
|
||||
return True
|
||||
|
||||
def write(self, b: bytes | bytearray) -> int:
|
||||
if self._bytes_written + len(b) > self._max_bytes:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
written = self._stream.write(b)
|
||||
self._bytes_written += written
|
||||
return written
|
||||
|
||||
def close(self) -> None:
|
||||
self._stream.close()
|
||||
super().close()
|
||||
|
||||
@property
|
||||
def bytes_written(self) -> int:
|
||||
return self._bytes_written
|
||||
|
||||
|
||||
# ==================== Provider ====================
|
||||
|
||||
class DiskNextDAVProvider(DAVProvider):
|
||||
"""DiskNext WebDAV 存储 Provider"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_resource_inst(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
) -> 'DiskNextCollection | DiskNextFile | None':
|
||||
"""
|
||||
将 WebDAV 路径映射到资源对象。
|
||||
|
||||
首次调用时加载 WebDAV 账户信息并缓存到 environ。
|
||||
"""
|
||||
user_id, webdav_id = _get_environ_info(environ)
|
||||
|
||||
# 首次请求时加载账户信息
|
||||
if "disknext.webdav_account" not in environ:
|
||||
account = _run_async(_get_webdav_account(webdav_id))
|
||||
if not account:
|
||||
return None
|
||||
environ["disknext.webdav_account"] = account
|
||||
|
||||
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
|
||||
disknext_path = _resolve_dav_path(account.root, path)
|
||||
|
||||
obj = _run_async(_get_object_by_path(user_id, disknext_path))
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if obj.is_folder:
|
||||
return DiskNextCollection(path, environ, obj, user_id, account)
|
||||
else:
|
||||
return DiskNextFile(path, environ, obj, user_id, account)
|
||||
|
||||
def is_readonly(self) -> bool:
|
||||
"""只读由账户级别控制,不在 provider 级别限制"""
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Collection(目录) ====================
|
||||
|
||||
class DiskNextCollection(DAVCollection):
|
||||
"""DiskNext 目录资源"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
account: WebDAV,
|
||||
) -> None:
|
||||
super().__init__(path, environ)
|
||||
self._obj = obj
|
||||
self._user_id = user_id
|
||||
self._account = account
|
||||
|
||||
def get_display_info(self) -> dict[str, str]:
|
||||
return {"type": "Directory"}
|
||||
|
||||
def get_member_names(self) -> list[str]:
|
||||
"""获取子对象名称列表"""
|
||||
children = _run_async(_get_children(self._user_id, self._obj.id))
|
||||
return [c.name for c in children]
|
||||
|
||||
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
|
||||
"""获取指定名称的子资源"""
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
account_root = self._account.root
|
||||
disknext_path = _resolve_dav_path(account_root, member_path)
|
||||
|
||||
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if obj.is_folder:
|
||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||
else:
|
||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def get_creation_date(self) -> float | None:
|
||||
if self._obj.created_at:
|
||||
return self._obj.created_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_last_modified(self) -> float | None:
|
||||
if self._obj.updated_at:
|
||||
return self._obj.updated_at.timestamp()
|
||||
return None
|
||||
|
||||
def create_empty_resource(self, name: str) -> 'DiskNextFile':
|
||||
"""创建空文件(PUT 操作的第一步)"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
obj = _run_async(_create_file(
|
||||
name=name,
|
||||
parent_id=self._obj.id,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def create_collection(self, name: str) -> 'DiskNextCollection':
|
||||
"""创建子目录(MKCOL)"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
obj = _run_async(_create_folder(
|
||||
name=name,
|
||||
parent_id=self._obj.id,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""软删除目录"""
|
||||
_check_readonly(self.environ)
|
||||
_run_async(_soft_delete_object(self._obj.id))
|
||||
|
||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||
"""复制或移动目录"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
account_root = self._account.root
|
||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||
|
||||
# 解析目标父路径和新名称
|
||||
if "/" in dest_disknext.rstrip("/"):
|
||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||
else:
|
||||
parent_path = "/"
|
||||
new_name = dest_disknext.lstrip("/")
|
||||
|
||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||
if not dest_parent:
|
||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||
|
||||
if is_move:
|
||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||
else:
|
||||
_run_async(_copy_object_recursive(
|
||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||
))
|
||||
|
||||
return True
|
||||
|
||||
def support_recursive_delete(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_recursive_move(self, dest_path: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# ==================== NonCollection(文件) ====================
|
||||
|
||||
class DiskNextFile(DAVNonCollection):
|
||||
"""DiskNext 文件资源"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
account: WebDAV,
|
||||
) -> None:
|
||||
super().__init__(path, environ)
|
||||
self._obj = obj
|
||||
self._user_id = user_id
|
||||
self._account = account
|
||||
self._write_path: str | None = None
|
||||
self._write_stream: io.BufferedWriter | QuotaLimitedWriter | None = None
|
||||
|
||||
def get_content_length(self) -> int | None:
|
||||
return self._obj.size if self._obj.size else 0
|
||||
|
||||
def get_content_type(self) -> str | None:
|
||||
# 尝试从文件名推断 MIME 类型
|
||||
mime, _ = mimetypes.guess_type(self._obj.name)
|
||||
return mime or "application/octet-stream"
|
||||
|
||||
def get_creation_date(self) -> float | None:
|
||||
if self._obj.created_at:
|
||||
return self._obj.created_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_last_modified(self) -> float | None:
|
||||
if self._obj.updated_at:
|
||||
return self._obj.updated_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_display_info(self) -> dict[str, str]:
|
||||
return {"type": "File"}
|
||||
|
||||
def get_content(self) -> io.BufferedReader | None:
|
||||
"""
|
||||
返回文件内容的可读流。
|
||||
|
||||
WsgiDAV 在线程中运行,可安全使用同步 open()。
|
||||
"""
|
||||
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
|
||||
if not obj_with_file or not obj_with_file.physical_file:
|
||||
return None
|
||||
|
||||
pf = obj_with_file.physical_file
|
||||
policy = _run_async(_get_policy(obj_with_file.policy_id))
|
||||
if not policy or not policy.server:
|
||||
return None
|
||||
|
||||
full_path = Path(policy.server).resolve() / pf.storage_path
|
||||
if not full_path.is_file():
|
||||
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
|
||||
return None
|
||||
|
||||
return open(full_path, "rb") # noqa: SIM115
|
||||
|
||||
def begin_write(
|
||||
self,
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
) -> io.BufferedWriter | QuotaLimitedWriter:
|
||||
"""
|
||||
开始写入文件(PUT 操作)。
|
||||
|
||||
返回一个可写的文件流,WsgiDAV 将向其中写入请求体数据。
|
||||
当用户有配额限制时,返回 QuotaLimitedWriter 在写入过程中实时检查配额。
|
||||
"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
# 检查配额
|
||||
remaining_quota: int = 0
|
||||
user = _run_async(_get_user(self._user_id))
|
||||
if user:
|
||||
max_storage = user.group.max_storage
|
||||
if max_storage > 0:
|
||||
remaining_quota = max_storage - user.storage
|
||||
if remaining_quota <= 0:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
# Content-Length 预检(如果有的话)
|
||||
content_length = self.environ.get("CONTENT_LENGTH")
|
||||
if content_length and int(content_length) > remaining_quota:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
|
||||
# 获取策略以确定存储路径
|
||||
policy = _run_async(_get_policy(self._obj.policy_id))
|
||||
if not policy or not policy.server:
|
||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = _run_async(
|
||||
storage_service.generate_file_path(
|
||||
user_id=self._user_id,
|
||||
original_filename=self._obj.name,
|
||||
)
|
||||
)
|
||||
|
||||
self._write_path = full_path
|
||||
raw_stream = open(full_path, "wb") # noqa: SIM115
|
||||
|
||||
# 有配额限制时使用包装流,实时检查写入量
|
||||
if remaining_quota > 0:
|
||||
self._write_stream = QuotaLimitedWriter(raw_stream, remaining_quota)
|
||||
else:
|
||||
self._write_stream = raw_stream
|
||||
|
||||
return self._write_stream
|
||||
|
||||
def end_write(self, *, with_errors: bool) -> None:
|
||||
"""写入完成后的收尾工作"""
|
||||
if self._write_stream:
|
||||
self._write_stream.close()
|
||||
self._write_stream = None
|
||||
|
||||
if with_errors:
|
||||
if self._write_path:
|
||||
file_path = Path(self._write_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
return
|
||||
|
||||
if not self._write_path:
|
||||
return
|
||||
|
||||
# 获取文件大小
|
||||
file_path = Path(self._write_path)
|
||||
if not file_path.exists():
|
||||
return
|
||||
|
||||
size = file_path.stat().st_size
|
||||
|
||||
# 更新数据库记录
|
||||
_run_async(_finalize_upload(
|
||||
object_id=self._obj.id,
|
||||
physical_path=self._write_path,
|
||||
size=size,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
|
||||
|
||||
def delete(self) -> None:
|
||||
"""软删除文件"""
|
||||
_check_readonly(self.environ)
|
||||
_run_async(_soft_delete_object(self._obj.id))
|
||||
|
||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||
"""复制或移动文件"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
account_root = self._account.root
|
||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||
|
||||
# 解析目标父路径和新名称
|
||||
if "/" in dest_disknext.rstrip("/"):
|
||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||
else:
|
||||
parent_path = "/"
|
||||
new_name = dest_disknext.lstrip("/")
|
||||
|
||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||
if not dest_parent:
|
||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||
|
||||
if is_move:
|
||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||
else:
|
||||
_run_async(_copy_object_recursive(
|
||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||
))
|
||||
|
||||
return True
|
||||
|
||||
def support_content_length(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_etag(self) -> str | None:
|
||||
"""返回 ETag(基于ID和更新时间),WsgiDAV 会自动加双引号"""
|
||||
if self._obj.updated_at:
|
||||
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
|
||||
return None
|
||||
|
||||
def support_etag(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_ranges(self) -> bool:
|
||||
return True
|
||||
128
service/redis/webdav_auth_cache.py
Normal file
128
service/redis/webdav_auth_cache.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
WebDAV 认证缓存
|
||||
|
||||
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
|
||||
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
||||
"""
|
||||
import hashlib
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from cachetools import TTLCache
|
||||
from loguru import logger as l
|
||||
|
||||
from . import RedisManager
|
||||
|
||||
_AUTH_TTL: int = 300
|
||||
"""认证缓存 TTL(秒),5 分钟"""
|
||||
|
||||
|
||||
class WebDAVAuthCache:
|
||||
"""
|
||||
WebDAV 认证结果缓存
|
||||
|
||||
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
|
||||
缓存值格式: {user_id}:{webdav_id}
|
||||
|
||||
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
|
||||
"""
|
||||
|
||||
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
|
||||
"""内存缓存降级方案"""
|
||||
|
||||
@classmethod
|
||||
def _build_key(cls, email: str, account_name: str, password: str) -> str:
|
||||
"""构建缓存键"""
|
||||
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
|
||||
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls,
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
) -> tuple[UUID, int] | None:
|
||||
"""
|
||||
查询缓存中的认证结果。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 用户提供的明文密码
|
||||
:return: (user_id, webdav_id) 或 None(缓存未命中)
|
||||
"""
|
||||
key = cls._build_key(email, account_name, password)
|
||||
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
value = await client.get(key)
|
||||
if value is not None:
|
||||
raw = value.decode() if isinstance(value, bytes) else value
|
||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||
return UUID(user_id_str), int(webdav_id_str)
|
||||
else:
|
||||
raw = cls._memory_cache.get(key)
|
||||
if raw is not None:
|
||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||
return UUID(user_id_str), int(webdav_id_str)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def set(
|
||||
cls,
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
user_id: UUID,
|
||||
webdav_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
写入认证结果到缓存。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 用户提供的明文密码
|
||||
:param user_id: 用户UUID
|
||||
:param webdav_id: WebDAV 账户ID
|
||||
"""
|
||||
key = cls._build_key(email, account_name, password)
|
||||
value = f"{user_id}:{webdav_id}"
|
||||
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
await client.set(key, value, ex=_AUTH_TTL)
|
||||
else:
|
||||
cls._memory_cache[key] = value
|
||||
|
||||
@classmethod
|
||||
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
|
||||
"""
|
||||
失效指定账户的所有缓存。
|
||||
|
||||
由于缓存键包含 password hash,无法精确删除,
|
||||
Redis 端使用 pattern scan 删除,内存端清空全部。
|
||||
|
||||
:param user_id: 用户UUID
|
||||
:param account_name: WebDAV 账户名
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
pattern = f"webdav_auth:*/{account_name}:*"
|
||||
cursor: int = 0
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await client.delete(*keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
else:
|
||||
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
|
||||
keys_to_delete = [
|
||||
k for k in cls._memory_cache
|
||||
if f"/{account_name}:" in k
|
||||
]
|
||||
for k in keys_to_delete:
|
||||
cls._memory_cache.pop(k, None)
|
||||
|
||||
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
提供文件存储相关的服务,包括:
|
||||
- 本地存储服务
|
||||
- S3 存储服务
|
||||
- 命名规则解析器
|
||||
- 存储异常定义
|
||||
"""
|
||||
@@ -11,6 +12,8 @@ from .exceptions import (
|
||||
FileReadError,
|
||||
FileWriteError,
|
||||
InvalidPathError,
|
||||
S3APIError,
|
||||
S3MultipartUploadError,
|
||||
StorageException,
|
||||
StorageFileNotFoundError,
|
||||
UploadSessionExpiredError,
|
||||
@@ -25,4 +28,6 @@ from .object import (
|
||||
permanently_delete_objects,
|
||||
restore_objects,
|
||||
soft_delete_objects,
|
||||
)
|
||||
)
|
||||
from .migrate import migrate_file_with_task, migrate_directory_files
|
||||
from .s3_storage import S3StorageService
|
||||
@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
|
||||
class InvalidPathError(StorageException):
|
||||
"""无效的路径"""
|
||||
pass
|
||||
|
||||
|
||||
class S3APIError(StorageException):
|
||||
"""S3 API 请求错误"""
|
||||
pass
|
||||
|
||||
|
||||
class S3MultipartUploadError(S3APIError):
|
||||
"""S3 分片上传错误"""
|
||||
pass
|
||||
|
||||
@@ -263,15 +263,49 @@ class LocalStorageService:
|
||||
"""
|
||||
删除文件(物理删除)
|
||||
|
||||
删除文件后会尝试清理因此变空的父目录。
|
||||
|
||||
:param path: 完整文件路径
|
||||
"""
|
||||
if await self.file_exists(path):
|
||||
try:
|
||||
await aiofiles.os.remove(path)
|
||||
l.debug(f"已删除文件: {path}")
|
||||
await self._cleanup_empty_parents(path)
|
||||
except OSError as e:
|
||||
l.warning(f"删除文件失败 {path}: {e}")
|
||||
|
||||
async def _cleanup_empty_parents(self, file_path: str) -> None:
|
||||
"""
|
||||
从被删文件的父目录开始,向上逐级删除空目录
|
||||
|
||||
在以下情况停止:
|
||||
|
||||
- 到达存储根目录(_base_path)
|
||||
- 遇到非空目录
|
||||
- 遇到 .trash 目录
|
||||
- 删除失败(权限、并发等)
|
||||
|
||||
:param file_path: 被删文件的完整路径
|
||||
"""
|
||||
current = Path(file_path).parent
|
||||
|
||||
while current != self._base_path and str(current).startswith(str(self._base_path)):
|
||||
if current.name == '.trash':
|
||||
break
|
||||
|
||||
try:
|
||||
entries = await aiofiles.os.listdir(str(current))
|
||||
if entries:
|
||||
break
|
||||
|
||||
await aiofiles.os.rmdir(str(current))
|
||||
l.debug(f"已清理空目录: {current}")
|
||||
current = current.parent
|
||||
except OSError as e:
|
||||
l.debug(f"清理空目录失败(忽略): {current}: {e}")
|
||||
break
|
||||
|
||||
async def move_to_trash(
|
||||
self,
|
||||
source_path: str,
|
||||
@@ -304,6 +338,7 @@ class LocalStorageService:
|
||||
try:
|
||||
await aiofiles.os.rename(source_path, str(trash_path))
|
||||
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
|
||||
await self._cleanup_empty_parents(source_path)
|
||||
return str(trash_path)
|
||||
except OSError as e:
|
||||
raise StorageException(f"移动文件到回收站失败: {e}")
|
||||
|
||||
291
service/storage/migrate.py
Normal file
291
service/storage/migrate.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
存储策略迁移服务
|
||||
|
||||
提供跨存储策略的文件迁移功能:
|
||||
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
|
||||
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.physical_file import PhysicalFile
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.task import Task, TaskStatus
|
||||
|
||||
from .local_storage import LocalStorageService
|
||||
from .s3_storage import S3StorageService
|
||||
|
||||
|
||||
async def _get_storage_service(
|
||||
policy: Policy,
|
||||
) -> LocalStorageService | S3StorageService:
|
||||
"""
|
||||
根据策略类型创建对应的存储服务实例
|
||||
|
||||
:param policy: 存储策略
|
||||
:return: 存储服务实例
|
||||
"""
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
return LocalStorageService(policy)
|
||||
elif policy.type == PolicyType.S3:
|
||||
return await S3StorageService.from_policy(policy)
|
||||
else:
|
||||
raise ValueError(f"不支持的存储策略类型: {policy.type}")
|
||||
|
||||
|
||||
async def _read_file_from_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
) -> bytes:
|
||||
"""
|
||||
从存储服务读取文件内容
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
:return: 文件二进制内容
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
return await service.read_file(storage_path)
|
||||
else:
|
||||
return await service.download_file(storage_path)
|
||||
|
||||
|
||||
async def _write_file_to_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
data: bytes,
|
||||
) -> None:
|
||||
"""
|
||||
将文件内容写入存储服务
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
:param data: 文件二进制内容
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
await service.write_file(storage_path, data)
|
||||
else:
|
||||
await service.upload_file(storage_path, data)
|
||||
|
||||
|
||||
async def _delete_file_from_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
从存储服务删除文件
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
await service.delete_file(storage_path)
|
||||
else:
|
||||
await service.delete_file(storage_path)
|
||||
|
||||
|
||||
async def migrate_single_file(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
dest_policy: Policy,
|
||||
) -> None:
|
||||
"""
|
||||
将单个文件对象从当前存储策略迁移到目标策略
|
||||
|
||||
流程:
|
||||
1. 获取源物理文件和存储服务
|
||||
2. 读取源文件内容
|
||||
3. 在目标存储中生成新路径并写入
|
||||
4. 创建新的 PhysicalFile 记录
|
||||
5. 更新 Object 的 policy_id 和 physical_file_id
|
||||
6. 旧 PhysicalFile 引用计数 -1,如为 0 则删除源物理文件
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 待迁移的文件对象(必须为文件类型)
|
||||
:param dest_policy: 目标存储策略
|
||||
"""
|
||||
if obj.type != ObjectType.FILE:
|
||||
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
|
||||
|
||||
# 获取源策略和物理文件
|
||||
src_policy: Policy = await obj.awaitable_attrs.policy
|
||||
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
|
||||
|
||||
if not old_physical:
|
||||
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
|
||||
return
|
||||
|
||||
if src_policy.id == dest_policy.id:
|
||||
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
|
||||
return
|
||||
|
||||
# 1. 从源存储读取文件
|
||||
src_service = await _get_storage_service(src_policy)
|
||||
data = await _read_file_from_storage(src_service, old_physical.storage_path)
|
||||
|
||||
# 2. 在目标存储生成新路径并写入
|
||||
dest_service = await _get_storage_service(dest_policy)
|
||||
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
|
||||
user_id=obj.owner_id,
|
||||
original_filename=obj.name,
|
||||
)
|
||||
await _write_file_to_storage(dest_service, new_storage_path, data)
|
||||
|
||||
# 3. 创建新的 PhysicalFile
|
||||
new_physical = PhysicalFile(
|
||||
storage_path=new_storage_path,
|
||||
size=old_physical.size,
|
||||
checksum_md5=old_physical.checksum_md5,
|
||||
policy_id=dest_policy.id,
|
||||
reference_count=1,
|
||||
)
|
||||
new_physical = await new_physical.save(session)
|
||||
|
||||
# 4. 更新 Object
|
||||
obj.policy_id = dest_policy.id
|
||||
obj.physical_file_id = new_physical.id
|
||||
obj = await obj.save(session)
|
||||
|
||||
# 5. 旧 PhysicalFile 引用计数 -1
|
||||
old_physical.decrement_reference()
|
||||
if old_physical.can_be_deleted:
|
||||
# 删除源存储中的物理文件
|
||||
try:
|
||||
await _delete_file_from_storage(src_service, old_physical.storage_path)
|
||||
except Exception as e:
|
||||
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
|
||||
await PhysicalFile.delete(session, old_physical)
|
||||
else:
|
||||
old_physical = await old_physical.save(session)
|
||||
|
||||
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name} → {dest_policy.name}")
|
||||
|
||||
|
||||
async def migrate_file_with_task(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
dest_policy: Policy,
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""
|
||||
迁移单个文件并更新任务状态
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 待迁移的文件对象
|
||||
:param dest_policy: 目标存储策略
|
||||
:param task: 关联的任务记录
|
||||
"""
|
||||
try:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = 0
|
||||
task = await task.save(session)
|
||||
|
||||
await migrate_single_file(session, obj, dest_policy)
|
||||
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.progress = 100
|
||||
task = await task.save(session)
|
||||
except Exception as e:
|
||||
l.error(f"文件迁移任务失败: {obj.id}: {e}")
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = str(e)[:500]
|
||||
task = await task.save(session)
|
||||
|
||||
|
||||
async def migrate_directory_files(
|
||||
session: AsyncSession,
|
||||
folder: Object,
|
||||
dest_policy: Policy,
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""
|
||||
迁移目录下所有文件到目标存储策略
|
||||
|
||||
递归遍历目录树,将所有文件迁移到目标策略。
|
||||
子目录的 policy_id 同步更新。
|
||||
任务进度按文件数比例更新。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param folder: 目录对象
|
||||
:param dest_policy: 目标存储策略
|
||||
:param task: 关联的任务记录
|
||||
"""
|
||||
try:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = 0
|
||||
task = await task.save(session)
|
||||
|
||||
# 收集所有需要迁移的文件
|
||||
files_to_migrate: list[Object] = []
|
||||
folders_to_update: list[Object] = []
|
||||
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
|
||||
|
||||
total = len(files_to_migrate)
|
||||
migrated = 0
|
||||
errors: list[str] = []
|
||||
|
||||
for file_obj in files_to_migrate:
|
||||
try:
|
||||
await migrate_single_file(session, file_obj, dest_policy)
|
||||
migrated += 1
|
||||
except Exception as e:
|
||||
error_msg = f"{file_obj.name}: {e}"
|
||||
l.error(f"迁移文件失败: {error_msg}")
|
||||
errors.append(error_msg)
|
||||
|
||||
# 更新进度
|
||||
if total > 0:
|
||||
task.progress = min(99, int(migrated / total * 100))
|
||||
task = await task.save(session)
|
||||
|
||||
# 更新所有子目录的 policy_id
|
||||
for sub_folder in folders_to_update:
|
||||
sub_folder.policy_id = dest_policy.id
|
||||
sub_folder = await sub_folder.save(session)
|
||||
|
||||
# 完成任务
|
||||
if errors:
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
|
||||
else:
|
||||
task.status = TaskStatus.COMPLETED
|
||||
|
||||
task.progress = 100
|
||||
task = await task.save(session)
|
||||
|
||||
l.info(
|
||||
f"目录迁移完成: {folder.name} ({folder.id}), "
|
||||
f"成功 {migrated}/{total}, 错误 {len(errors)}"
|
||||
)
|
||||
except Exception as e:
|
||||
l.error(f"目录迁移任务失败: {folder.id}: {e}")
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = str(e)[:500]
|
||||
task = await task.save(session)
|
||||
|
||||
|
||||
async def _collect_objects_recursive(
|
||||
session: AsyncSession,
|
||||
folder: Object,
|
||||
files: list[Object],
|
||||
folders: list[Object],
|
||||
) -> None:
|
||||
"""
|
||||
递归收集目录下所有文件和子目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param folder: 当前目录
|
||||
:param files: 文件列表(输出)
|
||||
:param folders: 子目录列表(输出)
|
||||
"""
|
||||
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
|
||||
|
||||
for child in children:
|
||||
if child.type == ObjectType.FILE:
|
||||
files.append(child)
|
||||
elif child.type == ObjectType.FOLDER:
|
||||
folders.append(child)
|
||||
await _collect_objects_recursive(session, child, files, folders)
|
||||
@@ -6,7 +6,8 @@ from sqlalchemy import update as sql_update
|
||||
from sqlalchemy.sql.functions import func
|
||||
from middleware.dependencies import SessionDep
|
||||
|
||||
from service.storage import LocalStorageService
|
||||
from .local_storage import LocalStorageService
|
||||
from .s3_storage import S3StorageService
|
||||
from sqlmodels import (
|
||||
Object,
|
||||
PhysicalFile,
|
||||
@@ -271,10 +272,14 @@ async def permanently_delete_objects(
|
||||
if physical_file.can_be_deleted:
|
||||
# 物理删除文件
|
||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL:
|
||||
if policy:
|
||||
try:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
await s3_service.delete_file(physical_file.storage_path)
|
||||
l.debug(f"物理文件已删除: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||
@@ -282,7 +287,7 @@ async def permanently_delete_objects(
|
||||
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||
else:
|
||||
await physical_file.save(session, commit=False)
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||
|
||||
# 更新用户存储配额
|
||||
@@ -374,10 +379,19 @@ async def delete_object_recursive(
|
||||
if physical_file.can_be_deleted:
|
||||
# 物理删除文件
|
||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL:
|
||||
if policy:
|
||||
try:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
options = await policy.awaitable_attrs.options
|
||||
s3_service = S3StorageService(
|
||||
policy,
|
||||
region=options.s3_region if options else 'us-east-1',
|
||||
is_path_style=options.s3_path_style if options else False,
|
||||
)
|
||||
await s3_service.delete_file(physical_file.storage_path)
|
||||
l.debug(f"物理文件已删除: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||
@@ -385,7 +399,7 @@ async def delete_object_recursive(
|
||||
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||
else:
|
||||
await physical_file.save(session, commit=False)
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||
|
||||
# 阶段三:更新用户存储配额(与删除在同一事务中)
|
||||
@@ -444,7 +458,7 @@ async def _copy_object_recursive(
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
|
||||
if physical_file:
|
||||
physical_file.increment_reference()
|
||||
await physical_file.save(session)
|
||||
physical_file = await physical_file.save(session)
|
||||
total_copied_size += src_size
|
||||
|
||||
new_obj = await new_obj.save(session)
|
||||
|
||||
709
service/storage/s3_storage.py
Normal file
709
service/storage/s3_storage.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""
|
||||
S3 存储服务
|
||||
|
||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
|
||||
|
||||
移植自 foxline-pro-backend-server 项目的 S3APIClient,
|
||||
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import xml.etree.ElementTree as ET
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import ClassVar, Literal
|
||||
from urllib.parse import quote, urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
from loguru import logger as l
|
||||
|
||||
from sqlmodels.policy import Policy
|
||||
from .exceptions import S3APIError, S3MultipartUploadError
|
||||
from .naming_rule import NamingContext, NamingRuleParser
|
||||
|
||||
|
||||
def _sign(key: bytes, msg: str) -> bytes:
|
||||
"""HMAC-SHA256 签名"""
|
||||
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
|
||||
|
||||
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
||||
|
||||
|
||||
class S3StorageService:
|
||||
"""
|
||||
S3 存储服务
|
||||
|
||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||
从 Policy 配置中读取 S3 连接信息。
|
||||
|
||||
使用示例::
|
||||
|
||||
service = S3StorageService(policy, region='us-east-1')
|
||||
await service.upload_file('path/to/file.txt', b'content')
|
||||
data = await service.download_file('path/to/file.txt')
|
||||
"""
|
||||
|
||||
_http_session: ClassVar[aiohttp.ClientSession | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Policy,
|
||||
region: str = 'us-east-1',
|
||||
is_path_style: bool = False,
|
||||
):
|
||||
"""
|
||||
:param policy: 存储策略(server=endpoint_url, bucket_name, access_key, secret_key)
|
||||
:param region: S3 区域
|
||||
:param is_path_style: 是否使用路径风格 URL
|
||||
"""
|
||||
if not policy.server:
|
||||
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
|
||||
if not policy.bucket_name:
|
||||
raise S3APIError("S3 策略必须指定 bucket_name")
|
||||
if not policy.access_key:
|
||||
raise S3APIError("S3 策略必须指定 access_key")
|
||||
if not policy.secret_key:
|
||||
raise S3APIError("S3 策略必须指定 secret_key")
|
||||
|
||||
self._policy = policy
|
||||
self._endpoint_url = policy.server.rstrip("/")
|
||||
self._bucket_name = policy.bucket_name
|
||||
self._access_key = policy.access_key
|
||||
self._secret_key = policy.secret_key
|
||||
self._region = region
|
||||
self._is_path_style = is_path_style
|
||||
self._base_url = policy.base_url
|
||||
|
||||
# 从 endpoint_url 提取 host
|
||||
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
|
||||
# ==================== 工厂方法 ====================
|
||||
|
||||
@classmethod
|
||||
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
|
||||
"""
|
||||
根据 Policy 异步创建 S3StorageService(自动加载 options)
|
||||
|
||||
:param policy: 存储策略
|
||||
:return: S3StorageService 实例
|
||||
"""
|
||||
options = await policy.awaitable_attrs.options
|
||||
region = options.s3_region if options else 'us-east-1'
|
||||
is_path_style = options.s3_path_style if options else False
|
||||
return cls(policy, region=region, is_path_style=is_path_style)
|
||||
|
||||
# ==================== HTTP Session 管理 ====================
|
||||
|
||||
@classmethod
|
||||
async def initialize_session(cls) -> None:
|
||||
"""初始化全局 aiohttp ClientSession"""
|
||||
if cls._http_session is None or cls._http_session.closed:
|
||||
cls._http_session = aiohttp.ClientSession()
|
||||
l.info("S3StorageService HTTP session 已初始化")
|
||||
|
||||
@classmethod
|
||||
async def close_session(cls) -> None:
|
||||
"""关闭全局 aiohttp ClientSession"""
|
||||
if cls._http_session and not cls._http_session.closed:
|
||||
await cls._http_session.close()
|
||||
cls._http_session = None
|
||||
l.info("S3StorageService HTTP session 已关闭")
|
||||
|
||||
@classmethod
|
||||
def _get_session(cls) -> aiohttp.ClientSession:
|
||||
"""获取 HTTP session"""
|
||||
if cls._http_session is None or cls._http_session.closed:
|
||||
# 懒初始化,以防 initialize_session 未被调用
|
||||
cls._http_session = aiohttp.ClientSession()
|
||||
return cls._http_session
|
||||
|
||||
# ==================== AWS Signature V4 签名 ====================
|
||||
|
||||
def _get_signature_key(self, date_stamp: str) -> bytes:
|
||||
"""生成 AWS Signature V4 签名密钥"""
|
||||
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
|
||||
k_region = _sign(k_date, self._region)
|
||||
k_service = _sign(k_region, "s3")
|
||||
return _sign(k_service, "aws4_request")
|
||||
|
||||
def _create_authorization_header(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
query_string: str,
|
||||
headers: dict[str, str],
|
||||
payload_hash: str,
|
||||
amz_date: str,
|
||||
date_stamp: str,
|
||||
) -> str:
|
||||
"""创建 AWS Signature V4 授权头"""
|
||||
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
|
||||
canonical_headers = "".join(
|
||||
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
|
||||
)
|
||||
canonical_request = (
|
||||
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
|
||||
f"{signed_headers}\n{payload_hash}"
|
||||
)
|
||||
|
||||
algorithm = "AWS4-HMAC-SHA256"
|
||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||
string_to_sign = (
|
||||
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
|
||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||
)
|
||||
|
||||
signing_key = self._get_signature_key(date_stamp)
|
||||
signature = hmac.new(
|
||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return (
|
||||
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
|
||||
f"SignedHeaders={signed_headers}, Signature={signature}"
|
||||
)
|
||||
|
||||
def _build_headers(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
query_string: str = "",
|
||||
payload: bytes = b"",
|
||||
content_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
payload_hash: str | None = None,
|
||||
host: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
构建包含 AWS V4 签名的完整请求头
|
||||
|
||||
:param method: HTTP 方法
|
||||
:param uri: 请求 URI
|
||||
:param query_string: 查询字符串
|
||||
:param payload: 请求体字节(用于计算哈希)
|
||||
:param content_type: Content-Type
|
||||
:param extra_headers: 额外请求头
|
||||
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
|
||||
:param host: Host 头(默认使用 self._host)
|
||||
"""
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
|
||||
date_stamp = now_utc.strftime("%Y%m%d")
|
||||
|
||||
if payload_hash is None:
|
||||
payload_hash = hashlib.sha256(payload).hexdigest()
|
||||
|
||||
effective_host = host or self._host
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Host": effective_host,
|
||||
"X-Amz-Date": amz_date,
|
||||
"X-Amz-Content-Sha256": payload_hash,
|
||||
}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
authorization = self._create_authorization_header(
|
||||
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
|
||||
)
|
||||
headers["Authorization"] = authorization
|
||||
return headers
|
||||
|
||||
# ==================== 内部请求方法 ====================
|
||||
|
||||
def _build_uri(self, key: str | None = None) -> str:
|
||||
"""
|
||||
构建请求 URI
|
||||
|
||||
按 AWS S3 Signature V4 规范对路径进行 URI 编码(S3 仅需一次)。
|
||||
斜杠作为路径分隔符保留不编码。
|
||||
"""
|
||||
if self._is_path_style:
|
||||
if key:
|
||||
return f"/{self._bucket_name}/{quote(key, safe='/')}"
|
||||
return f"/{self._bucket_name}"
|
||||
else:
|
||||
if key:
|
||||
return f"/{quote(key, safe='/')}"
|
||||
return "/"
|
||||
|
||||
def _build_url(self, uri: str, query_string: str = "") -> str:
|
||||
"""构建完整请求 URL"""
|
||||
if self._is_path_style:
|
||||
base = self._endpoint_url
|
||||
else:
|
||||
# 虚拟主机风格:bucket.endpoint
|
||||
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
|
||||
base = f"{protocol}{self._bucket_name}.{self._host}"
|
||||
|
||||
url = f"{base}{uri}"
|
||||
if query_string:
|
||||
url = f"{url}?{query_string}"
|
||||
return url
|
||||
|
||||
def _get_effective_host(self) -> str:
|
||||
"""获取实际请求的 Host 头"""
|
||||
if self._is_path_style:
|
||||
return self._host
|
||||
return f"{self._bucket_name}.{self._host}"
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
key: str | None = None,
|
||||
query_params: dict[str, str] | None = None,
|
||||
payload: bytes = b"",
|
||||
content_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""发送签名请求"""
|
||||
uri = self._build_uri(key)
|
||||
query_string = urlencode(sorted(query_params.items())) if query_params else ""
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
headers = self._build_headers(
|
||||
method, uri, query_string, payload, content_type,
|
||||
extra_headers, host=effective_host,
|
||||
)
|
||||
|
||||
url = self._build_url(uri, query_string)
|
||||
|
||||
try:
|
||||
response = await self._get_session().request(
|
||||
method, URL(url, encoded=True),
|
||||
headers=headers, data=payload if payload else None,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
|
||||
|
||||
async def _request_streaming(
|
||||
self,
|
||||
method: str,
|
||||
key: str,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
content_length: int,
|
||||
content_type: str | None = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""
|
||||
发送流式签名请求(大文件上传)
|
||||
|
||||
使用 UNSIGNED-PAYLOAD 作为 payload hash。
|
||||
"""
|
||||
uri = self._build_uri(key)
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
headers = self._build_headers(
|
||||
method,
|
||||
uri,
|
||||
query_string="",
|
||||
content_type=content_type,
|
||||
extra_headers={"Content-Length": str(content_length)},
|
||||
payload_hash="UNSIGNED-PAYLOAD",
|
||||
host=effective_host,
|
||||
)
|
||||
|
||||
url = self._build_url(uri)
|
||||
|
||||
try:
|
||||
response = await self._get_session().request(
|
||||
method, URL(url, encoded=True),
|
||||
headers=headers, data=data_stream,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
|
||||
|
||||
# ==================== 文件操作 ====================
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
key: str,
|
||||
data: bytes,
|
||||
content_type: str = 'application/octet-stream',
|
||||
) -> None:
|
||||
"""
|
||||
上传文件
|
||||
|
||||
:param key: S3 对象键
|
||||
:param data: 文件内容
|
||||
:param content_type: MIME 类型
|
||||
"""
|
||||
async with await self._request(
|
||||
"PUT", key=key, payload=data, content_type=content_type,
|
||||
) as response:
|
||||
if response.status not in (200, 201):
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
|
||||
|
||||
async def upload_file_streaming(
|
||||
self,
|
||||
key: str,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
content_length: int,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
流式上传文件(大文件,避免全部加载到内存)
|
||||
|
||||
:param key: S3 对象键
|
||||
:param data_stream: 异步字节流迭代器
|
||||
:param content_length: 数据总长度(必须准确)
|
||||
:param content_type: MIME 类型
|
||||
"""
|
||||
async with await self._request_streaming(
|
||||
"PUT", key=key, data_stream=data_stream,
|
||||
content_length=content_length, content_type=content_type,
|
||||
) as response:
|
||||
if response.status not in (200, 201):
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"流式上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
|
||||
|
||||
async def download_file(self, key: str) -> bytes:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 文件内容
|
||||
"""
|
||||
async with await self._request("GET", key=key) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"下载失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
data = await response.read()
|
||||
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
|
||||
return data
|
||||
|
||||
async def delete_file(self, key: str) -> None:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
:param key: S3 对象键
|
||||
"""
|
||||
async with await self._request("DELETE", key=key) as response:
|
||||
if response.status in (200, 204):
|
||||
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
|
||||
else:
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"删除失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
async def file_exists(self, key: str) -> bool:
|
||||
"""
|
||||
检查文件是否存在
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 是否存在
|
||||
"""
|
||||
async with await self._request("HEAD", key=key) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
elif response.status == 404:
|
||||
return False
|
||||
else:
|
||||
raise S3APIError(
|
||||
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||
)
|
||||
|
||||
async def get_file_size(self, key: str) -> int:
|
||||
"""
|
||||
获取文件大小
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 文件大小(字节)
|
||||
"""
|
||||
async with await self._request("HEAD", key=key) as response:
|
||||
if response.status != 200:
|
||||
raise S3APIError(
|
||||
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||
)
|
||||
return int(response.headers.get("Content-Length", 0))
|
||||
|
||||
# ==================== Multipart Upload ====================
|
||||
|
||||
async def create_multipart_upload(
|
||||
self,
|
||||
key: str,
|
||||
content_type: str = 'application/octet-stream',
|
||||
) -> str:
|
||||
"""
|
||||
创建分片上传任务
|
||||
|
||||
:param key: S3 对象键
|
||||
:param content_type: MIME 类型
|
||||
:return: Upload ID
|
||||
"""
|
||||
async with await self._request(
|
||||
"POST",
|
||||
key=key,
|
||||
query_params={"uploads": ""},
|
||||
content_type=content_type,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"创建分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
body = await response.text()
|
||||
root = ET.fromstring(body)
|
||||
|
||||
# 查找 UploadId 元素(支持命名空间)
|
||||
upload_id_elem = root.find("UploadId")
|
||||
if upload_id_elem is None:
|
||||
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
|
||||
if upload_id_elem is None or not upload_id_elem.text:
|
||||
raise S3MultipartUploadError(
|
||||
f"创建分片上传响应中未找到 UploadId: {body}"
|
||||
)
|
||||
|
||||
upload_id = upload_id_elem.text
|
||||
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
|
||||
return upload_id
|
||||
|
||||
async def upload_part(
|
||||
self,
|
||||
key: str,
|
||||
upload_id: str,
|
||||
part_number: int,
|
||||
data: bytes,
|
||||
) -> str:
|
||||
"""
|
||||
上传单个分片
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
:param part_number: 分片编号(从 1 开始)
|
||||
:param data: 分片数据
|
||||
:return: ETag
|
||||
"""
|
||||
async with await self._request(
|
||||
"PUT",
|
||||
key=key,
|
||||
query_params={
|
||||
"partNumber": str(part_number),
|
||||
"uploadId": upload_id,
|
||||
},
|
||||
payload=data,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"上传分片失败: {self._bucket_name}/{key}, "
|
||||
f"part={part_number}, 状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
etag = response.headers.get("ETag", "").strip('"')
|
||||
l.debug(
|
||||
f"S3 分片上传成功: {self._bucket_name}/{key}, "
|
||||
f"part={part_number}, etag={etag}"
|
||||
)
|
||||
return etag
|
||||
|
||||
async def complete_multipart_upload(
|
||||
self,
|
||||
key: str,
|
||||
upload_id: str,
|
||||
parts: list[tuple[int, str]],
|
||||
) -> None:
|
||||
"""
|
||||
完成分片上传
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
:param parts: 分片列表 [(part_number, etag)]
|
||||
"""
|
||||
# 按 part_number 排序
|
||||
parts_sorted = sorted(parts, key=lambda p: p[0])
|
||||
|
||||
# 构建 CompleteMultipartUpload XML
|
||||
xml_parts = ''.join(
|
||||
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
|
||||
for pn, etag in parts_sorted
|
||||
)
|
||||
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
|
||||
payload_bytes = payload.encode('utf-8')
|
||||
|
||||
async with await self._request(
|
||||
"POST",
|
||||
key=key,
|
||||
query_params={"uploadId": upload_id},
|
||||
payload=payload_bytes,
|
||||
content_type="application/xml",
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"完成分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.info(
|
||||
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
|
||||
f"共 {len(parts)} 个分片"
|
||||
)
|
||||
|
||||
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
|
||||
"""
|
||||
取消分片上传
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
"""
|
||||
async with await self._request(
|
||||
"DELETE",
|
||||
key=key,
|
||||
query_params={"uploadId": upload_id},
|
||||
) as response:
|
||||
if response.status in (200, 204):
|
||||
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
|
||||
else:
|
||||
body = await response.text()
|
||||
l.warning(
|
||||
f"取消分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
# ==================== 预签名 URL ====================
|
||||
|
||||
def generate_presigned_url(
|
||||
self,
|
||||
key: str,
|
||||
method: Literal['GET', 'PUT'] = 'GET',
|
||||
expires_in: int = 3600,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
生成 S3 预签名 URL(AWS Signature V4 Query String)
|
||||
|
||||
:param key: S3 对象键
|
||||
:param method: HTTP 方法(GET 下载,PUT 上传)
|
||||
:param expires_in: URL 有效期(秒)
|
||||
:param filename: 文件名(GET 请求时设置 Content-Disposition)
|
||||
:return: 预签名 URL
|
||||
"""
|
||||
current_time = datetime.now(timezone.utc)
|
||||
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
|
||||
date_stamp = current_time.strftime("%Y%m%d")
|
||||
|
||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||
credential = f"{self._access_key}/{credential_scope}"
|
||||
|
||||
uri = self._build_uri(key)
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
query_params: dict[str, str] = {
|
||||
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
|
||||
'X-Amz-Credential': credential,
|
||||
'X-Amz-Date': amz_date,
|
||||
'X-Amz-Expires': str(expires_in),
|
||||
'X-Amz-SignedHeaders': 'host',
|
||||
}
|
||||
|
||||
# GET 请求时添加 Content-Disposition
|
||||
if method == "GET" and filename:
|
||||
encoded_filename = quote(filename, safe='')
|
||||
query_params['response-content-disposition'] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
canonical_query_string = "&".join(
|
||||
f"{quote(k, safe='')}={quote(v, safe='')}"
|
||||
for k, v in sorted(query_params.items())
|
||||
)
|
||||
|
||||
canonical_headers = f"host:{effective_host}\n"
|
||||
signed_headers = "host"
|
||||
payload_hash = "UNSIGNED-PAYLOAD"
|
||||
|
||||
canonical_request = (
|
||||
f"{method}\n"
|
||||
f"{uri}\n"
|
||||
f"{canonical_query_string}\n"
|
||||
f"{canonical_headers}\n"
|
||||
f"{signed_headers}\n"
|
||||
f"{payload_hash}"
|
||||
)
|
||||
|
||||
algorithm = "AWS4-HMAC-SHA256"
|
||||
string_to_sign = (
|
||||
f"{algorithm}\n"
|
||||
f"{amz_date}\n"
|
||||
f"{credential_scope}\n"
|
||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||
)
|
||||
|
||||
signing_key = self._get_signature_key(date_stamp)
|
||||
signature = hmac.new(
|
||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
base_url = self._build_url(uri)
|
||||
return (
|
||||
f"{base_url}?"
|
||||
f"{canonical_query_string}&"
|
||||
f"X-Amz-Signature={signature}"
|
||||
)
|
||||
|
||||
# ==================== 路径生成 ====================
|
||||
|
||||
async def generate_file_path(
|
||||
self,
|
||||
user_id: UUID,
|
||||
original_filename: str,
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
根据命名规则生成 S3 文件存储路径
|
||||
|
||||
与 LocalStorageService.generate_file_path 接口一致。
|
||||
|
||||
:param user_id: 用户UUID
|
||||
:param original_filename: 原始文件名
|
||||
:return: (相对目录路径, 存储文件名, 完整存储路径)
|
||||
"""
|
||||
context = NamingContext(
|
||||
user_id=user_id,
|
||||
original_filename=original_filename,
|
||||
)
|
||||
|
||||
# 解析目录规则
|
||||
dir_path = ""
|
||||
if self._policy.dir_name_rule:
|
||||
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
|
||||
|
||||
# 解析文件名规则
|
||||
if self._policy.auto_rename and self._policy.file_name_rule:
|
||||
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
|
||||
# 确保有扩展名
|
||||
if '.' in original_filename and '.' not in storage_name:
|
||||
ext = original_filename.rsplit('.', 1)[1]
|
||||
storage_name = f"{storage_name}.{ext}"
|
||||
else:
|
||||
storage_name = original_filename
|
||||
|
||||
# S3 不需要创建目录,直接拼接路径
|
||||
if dir_path:
|
||||
storage_path = f"{dir_path}/{storage_name}"
|
||||
else:
|
||||
storage_path = storage_name
|
||||
|
||||
return dir_path, storage_name, storage_path
|
||||
@@ -3,12 +3,14 @@
|
||||
|
||||
支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。
|
||||
"""
|
||||
import hashlib
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from service.redis.token_store import TokenStore
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.group import GroupClaims, GroupOptions
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
@@ -190,7 +192,7 @@ async def _login_oauth(
|
||||
# 已绑定 → 更新 OAuth 信息并返回关联用户
|
||||
identity.display_name = nickname
|
||||
identity.avatar_url = avatar_url
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
@@ -252,7 +254,7 @@ async def _auto_register_oauth_user(
|
||||
is_verified=True,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
# 创建用户根目录
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
@@ -333,7 +335,7 @@ async def _login_passkey(
|
||||
|
||||
# 更新签名计数
|
||||
authn.sign_count = verification.new_sign_count
|
||||
await authn.save(session)
|
||||
authn = await authn.save(session)
|
||||
|
||||
# 加载用户
|
||||
user: User = await User.get(session, User.id == authn.user_id, load=User.group)
|
||||
@@ -363,6 +365,12 @@ async def _login_magic_link(
|
||||
except BadSignature:
|
||||
http_exceptions.raise_unauthorized("Magic Link 无效")
|
||||
|
||||
# 防重放:使用 token 哈希作为标识符
|
||||
token_hash = hashlib.sha256(request.identifier.encode()).hexdigest()
|
||||
is_first_use = await TokenStore.mark_used(f"magic_link:{token_hash}", ttl=600)
|
||||
if not is_first_use:
|
||||
http_exceptions.raise_unauthorized("Magic Link 已被使用")
|
||||
|
||||
# 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link)
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
@@ -384,7 +392,7 @@ async def _login_magic_link(
|
||||
# 标记邮箱已验证
|
||||
if not identity.is_verified:
|
||||
identity.is_verified = True
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
185
service/wopi/__init__.py
Normal file
185
service/wopi/__init__.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
WOPI Discovery 服务模块
|
||||
|
||||
解析 WOPI 服务端(Collabora / OnlyOffice 等)的 Discovery XML,
|
||||
提取支持的文件扩展名及对应的编辑器 URL 模板。
|
||||
|
||||
参考:Cloudreve pkg/wopi/discovery.go 和 pkg/wopi/wopi.go
|
||||
"""
|
||||
import xml.etree.ElementTree as ET
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from loguru import logger as l
|
||||
|
||||
# WOPI URL 模板中已知的查询参数占位符及其替换值
|
||||
# 值为 None 表示删除该参数,非 None 表示替换为该值
|
||||
# 参考 Cloudreve pkg/wopi/wopi.go queryPlaceholders
|
||||
_WOPI_QUERY_PLACEHOLDERS: dict[str, str | None] = {
|
||||
'BUSINESS_USER': None,
|
||||
'DC_LLCC': 'lng',
|
||||
'DISABLE_ASYNC': None,
|
||||
'DISABLE_CHAT': None,
|
||||
'EMBEDDED': 'true',
|
||||
'FULLSCREEN': 'true',
|
||||
'HOST_SESSION_ID': None,
|
||||
'SESSION_CONTEXT': None,
|
||||
'RECORDING': None,
|
||||
'THEME_ID': 'darkmode',
|
||||
'UI_LLCC': 'lng',
|
||||
'VALIDATOR_TEST_CATEGORY': None,
|
||||
}
|
||||
|
||||
_WOPI_SRC_PLACEHOLDER = 'WOPI_SOURCE'
|
||||
|
||||
|
||||
def process_wopi_action_url(raw_urlsrc: str) -> str:
|
||||
"""
|
||||
将 WOPI Discovery 中的原始 urlsrc 转换为 DiskNext 可用的 URL 模板。
|
||||
|
||||
处理流程(参考 Cloudreve generateActionUrl):
|
||||
1. 去除 ``<>`` 占位符标记
|
||||
2. 解析查询参数,替换/删除已知占位符
|
||||
3. ``WOPI_SOURCE`` → ``{wopi_src}``
|
||||
|
||||
注意:access_token 和 access_token_ttl 不放在 URL 中,
|
||||
根据 WOPI 规范它们通过 POST 表单字段传递给编辑器。
|
||||
|
||||
:param raw_urlsrc: WOPI Discovery XML 中的 urlsrc 原始值
|
||||
:return: 处理后的 URL 模板字符串,包含 {wopi_src} 占位符
|
||||
"""
|
||||
# 去除 <> 标记
|
||||
cleaned = raw_urlsrc.replace('<', '').replace('>', '')
|
||||
parsed = urlparse(cleaned)
|
||||
raw_params = parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
new_params: list[tuple[str, str]] = []
|
||||
is_src_replaced = False
|
||||
|
||||
for key, values in raw_params.items():
|
||||
value = values[0] if values else ''
|
||||
|
||||
# WOPI_SOURCE 占位符 → {wopi_src}
|
||||
if value == _WOPI_SRC_PLACEHOLDER:
|
||||
new_params.append((key, '{wopi_src}'))
|
||||
is_src_replaced = True
|
||||
continue
|
||||
|
||||
# 已知占位符
|
||||
if value in _WOPI_QUERY_PLACEHOLDERS:
|
||||
replacement = _WOPI_QUERY_PLACEHOLDERS[value]
|
||||
if replacement is not None:
|
||||
new_params.append((key, replacement))
|
||||
# replacement 为 None 时删除该参数
|
||||
continue
|
||||
|
||||
# 其他参数保留原值
|
||||
new_params.append((key, value))
|
||||
|
||||
# 如果没有找到 WOPI_SOURCE 占位符,手动添加 WOPISrc
|
||||
if not is_src_replaced:
|
||||
new_params.append(('WOPISrc', '{wopi_src}'))
|
||||
|
||||
# LibreOffice/Collabora 需要 lang 参数(避免重复添加)
|
||||
existing_keys = {k for k, _ in new_params}
|
||||
if 'lang' not in existing_keys:
|
||||
new_params.append(('lang', 'lng'))
|
||||
|
||||
# 注意:access_token 和 access_token_ttl 不放在 URL 中
|
||||
# 根据 WOPI 规范,它们通过 POST 表单字段传递给编辑器
|
||||
|
||||
# 重建 URL
|
||||
new_query = urlencode(new_params, safe='{}')
|
||||
result = urlunparse((
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
new_query,
|
||||
'',
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_wopi_discovery_xml(xml_content: str) -> tuple[dict[str, str], list[str]]:
|
||||
"""
|
||||
解析 WOPI Discovery XML,提取扩展名到 URL 模板的映射。
|
||||
|
||||
XML 结构::
|
||||
|
||||
<wopi-discovery>
|
||||
<net-zone name="external-https">
|
||||
<app name="Writer" favIconUrl="...">
|
||||
<action name="edit" ext="docx" urlsrc="https://..."/>
|
||||
<action name="view" ext="docx" urlsrc="https://..."/>
|
||||
</app>
|
||||
</net-zone>
|
||||
</wopi-discovery>
|
||||
|
||||
动作优先级:edit > embedview > view(参考 Cloudreve discovery.go)
|
||||
|
||||
:param xml_content: WOPI Discovery 端点返回的 XML 字符串
|
||||
:return: (action_urls, app_names) 元组
|
||||
action_urls: {extension: processed_url_template}
|
||||
app_names: 发现的应用名称列表
|
||||
:raises ValueError: XML 解析失败或格式无效
|
||||
"""
|
||||
try:
|
||||
root = ET.fromstring(xml_content)
|
||||
except ET.ParseError as e:
|
||||
raise ValueError(f"WOPI Discovery XML 解析失败: {e}")
|
||||
|
||||
# 查找 net-zone(可能有多个,取第一个非空的)
|
||||
net_zones = root.findall('net-zone')
|
||||
if not net_zones:
|
||||
raise ValueError("WOPI Discovery XML 缺少 net-zone 节点")
|
||||
|
||||
# ext_actions: {extension: {action_name: urlsrc}}
|
||||
ext_actions: dict[str, dict[str, str]] = {}
|
||||
app_names: list[str] = []
|
||||
|
||||
for net_zone in net_zones:
|
||||
for app_elem in net_zone.findall('app'):
|
||||
app_name = app_elem.get('name', '')
|
||||
if app_name:
|
||||
app_names.append(app_name)
|
||||
|
||||
for action_elem in app_elem.findall('action'):
|
||||
action_name = action_elem.get('name', '')
|
||||
ext = action_elem.get('ext', '')
|
||||
urlsrc = action_elem.get('urlsrc', '')
|
||||
|
||||
if not ext or not urlsrc:
|
||||
continue
|
||||
|
||||
# 只关注 edit / embedview / view 三种动作
|
||||
if action_name not in ('edit', 'embedview', 'view'):
|
||||
continue
|
||||
|
||||
if ext not in ext_actions:
|
||||
ext_actions[ext] = {}
|
||||
ext_actions[ext][action_name] = urlsrc
|
||||
|
||||
# 为每个扩展名选择最佳 URL: edit > embedview > view
|
||||
action_urls: dict[str, str] = {}
|
||||
for ext, actions_map in ext_actions.items():
|
||||
selected_urlsrc: str | None = None
|
||||
for preferred in ('edit', 'embedview', 'view'):
|
||||
if preferred in actions_map:
|
||||
selected_urlsrc = actions_map[preferred]
|
||||
break
|
||||
|
||||
if selected_urlsrc:
|
||||
action_urls[ext] = process_wopi_action_url(selected_urlsrc)
|
||||
|
||||
# 去重 app_names
|
||||
seen: set[str] = set()
|
||||
unique_names: list[str] = []
|
||||
for name in app_names:
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
unique_names.append(name)
|
||||
|
||||
l.info(f"WOPI Discovery 解析完成: {len(action_urls)} 个扩展名, 应用: {unique_names}")
|
||||
|
||||
return action_urls, unique_names
|
||||
@@ -84,6 +84,7 @@ if __name__ == "__main__":
|
||||
|
||||
setup(
|
||||
name="disknext-ee",
|
||||
packages=[],
|
||||
ext_modules=cythonize(
|
||||
extensions,
|
||||
compiler_directives={'language_level': "3"},
|
||||
|
||||
@@ -954,18 +954,11 @@ class PolicyType(StrEnum):
|
||||
S3 = "s3" # S3 兼容存储
|
||||
```
|
||||
|
||||
### StorageType
|
||||
### PolicyType
|
||||
```python
|
||||
class StorageType(StrEnum):
|
||||
class PolicyType(StrEnum):
|
||||
LOCAL = "local" # 本地存储
|
||||
QINIU = "qiniu" # 七牛云
|
||||
TENCENT = "tencent" # 腾讯云
|
||||
ALIYUN = "aliyun" # 阿里云
|
||||
ONEDRIVE = "onedrive" # OneDrive
|
||||
GOOGLE_DRIVE = "google_drive" # Google Drive
|
||||
DROPBOX = "dropbox" # Dropbox
|
||||
WEBDAV = "webdav" # WebDAV
|
||||
REMOTE = "remote" # 远程存储
|
||||
S3 = "s3" # S3 兼容存储
|
||||
```
|
||||
|
||||
### UserStatus
|
||||
|
||||
@@ -69,18 +69,20 @@ from .object import (
|
||||
CreateUploadSessionRequest,
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
FileMetadata,
|
||||
FileMetadataBase,
|
||||
Object,
|
||||
ObjectBase,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
ObjectFileFinalize,
|
||||
ObjectMoveRequest,
|
||||
ObjectMoveUpdate,
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
ObjectResponse,
|
||||
ObjectSwitchPolicyRequest,
|
||||
ObjectType,
|
||||
FileCategory,
|
||||
PolicyResponse,
|
||||
UploadChunkResponse,
|
||||
UploadSession,
|
||||
@@ -95,11 +97,42 @@ from .object import (
|
||||
TrashRestoreRequest,
|
||||
TrashDeleteRequest,
|
||||
)
|
||||
from .object_metadata import (
|
||||
ObjectMetadata,
|
||||
ObjectMetadataBase,
|
||||
MetadataNamespace,
|
||||
MetadataResponse,
|
||||
MetadataPatchItem,
|
||||
MetadataPatchRequest,
|
||||
INTERNAL_NAMESPACES,
|
||||
USER_WRITABLE_NAMESPACES,
|
||||
)
|
||||
from .custom_property import (
|
||||
CustomPropertyDefinition,
|
||||
CustomPropertyDefinitionBase,
|
||||
CustomPropertyType,
|
||||
CustomPropertyCreateRequest,
|
||||
CustomPropertyUpdateRequest,
|
||||
CustomPropertyResponse,
|
||||
)
|
||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||
from .uri import DiskNextURI, FileSystemNamespace
|
||||
from .order import Order, OrderStatus, OrderType
|
||||
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
|
||||
from .redeem import Redeem, RedeemType
|
||||
from .order import (
|
||||
Order, OrderStatus, OrderType,
|
||||
CreateOrderRequest, OrderResponse,
|
||||
)
|
||||
from .policy import (
|
||||
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
|
||||
PolicyType, PolicySummary, PolicyUpdateRequest,
|
||||
)
|
||||
from .product import (
|
||||
Product, ProductBase, ProductType, PaymentMethod,
|
||||
ProductCreateRequest, ProductUpdateRequest, ProductResponse,
|
||||
)
|
||||
from .redeem import (
|
||||
Redeem, RedeemType,
|
||||
RedeemCreateRequest, RedeemUseRequest, RedeemInfoResponse, RedeemAdminResponse,
|
||||
)
|
||||
from .report import Report, ReportReason
|
||||
from .setting import (
|
||||
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
|
||||
@@ -112,16 +145,20 @@ from .share import (
|
||||
AdminShareListItem,
|
||||
)
|
||||
from .source_link import SourceLink
|
||||
from .storage_pack import StoragePack
|
||||
from .storage_pack import StoragePack, StoragePackResponse
|
||||
from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||
from .webdav import WebDAV
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
|
||||
from .webdav import (
|
||||
WebDAV, WebDAVBase,
|
||||
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
|
||||
)
|
||||
from .file_app import (
|
||||
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
|
||||
# DTO
|
||||
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
|
||||
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
|
||||
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
|
||||
WopiDiscoveredExtension, WopiDiscoveryResponse,
|
||||
)
|
||||
from .wopi import WopiFileInfo, WopiAccessTokenPayload
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100, Str128, Str255, Text1024
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -87,7 +87,7 @@ class ChangePasswordRequest(SQLModelBase):
|
||||
old_password: str = Field(min_length=1)
|
||||
"""当前密码"""
|
||||
|
||||
new_password: str = Field(min_length=8, max_length=128)
|
||||
new_password: Str128 = Field(min_length=8)
|
||||
"""新密码(至少 8 位)"""
|
||||
|
||||
|
||||
@@ -103,13 +103,13 @@ class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
|
||||
provider: AuthProviderType = Field(index=True)
|
||||
"""提供者类型"""
|
||||
|
||||
identifier: str = Field(max_length=255, index=True)
|
||||
identifier: Str255 = Field(index=True)
|
||||
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||
|
||||
credential: str | None = Field(default=None, max_length=1024)
|
||||
credential: Text1024 | None = None
|
||||
"""凭证(Argon2 哈希密码 / null)"""
|
||||
|
||||
display_name: str | None = Field(default=None, max_length=100)
|
||||
display_name: Str100 | None = None
|
||||
"""OAuth 昵称"""
|
||||
|
||||
avatar_url: str | None = Field(default=None, max_length=512)
|
||||
|
||||
135
sqlmodels/custom_property.py
Normal file
135
sqlmodels/custom_property.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
用户自定义属性定义模型
|
||||
|
||||
允许用户定义类型化的自定义属性模板(如标签、评分、分类等),
|
||||
实际值通过 ObjectMetadata KV 表存储,键名格式:custom:{property_definition_id}。
|
||||
|
||||
支持的属性类型:text, number, boolean, select, multi_select, rating, link
|
||||
"""
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import JSON
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
# ==================== 枚举 ====================
|
||||
|
||||
class CustomPropertyType(StrEnum):
|
||||
"""自定义属性值类型枚举"""
|
||||
TEXT = "text"
|
||||
"""文本"""
|
||||
NUMBER = "number"
|
||||
"""数字"""
|
||||
BOOLEAN = "boolean"
|
||||
"""布尔值"""
|
||||
SELECT = "select"
|
||||
"""单选"""
|
||||
MULTI_SELECT = "multi_select"
|
||||
"""多选"""
|
||||
RATING = "rating"
|
||||
"""评分(1-5)"""
|
||||
LINK = "link"
|
||||
"""链接"""
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class CustomPropertyDefinitionBase(SQLModelBase):
|
||||
"""自定义属性定义基础模型"""
|
||||
|
||||
name: Str100
|
||||
"""属性显示名称"""
|
||||
|
||||
type: CustomPropertyType
|
||||
"""属性值类型"""
|
||||
|
||||
icon: Str100 | None = None
|
||||
"""图标标识(iconify 名称)"""
|
||||
|
||||
options: list[str] | None = Field(default=None, sa_type=JSON)
|
||||
"""可选值列表(仅 select/multi_select 类型)"""
|
||||
|
||||
default_value: str | None = Field(default=None, max_length=500)
|
||||
"""默认值"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class CustomPropertyDefinition(CustomPropertyDefinitionBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
用户自定义属性定义
|
||||
|
||||
每个用户独立管理自己的属性模板。
|
||||
实际属性值存储在 ObjectMetadata 表中,键名格式:custom:{id}。
|
||||
"""
|
||||
|
||||
owner_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
ondelete="CASCADE",
|
||||
index=True,
|
||||
)
|
||||
"""所有者用户UUID"""
|
||||
|
||||
sort_order: int = 0
|
||||
"""排序顺序"""
|
||||
|
||||
# 关系
|
||||
owner: "User" = Relationship()
|
||||
"""所有者"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class CustomPropertyCreateRequest(SQLModelBase):
|
||||
"""创建自定义属性请求 DTO"""
|
||||
|
||||
name: Str100
|
||||
"""属性显示名称"""
|
||||
|
||||
type: CustomPropertyType
|
||||
"""属性值类型"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标标识"""
|
||||
|
||||
options: list[str] | None = None
|
||||
"""可选值列表(仅 select/multi_select 类型)"""
|
||||
|
||||
default_value: str | None = None
|
||||
"""默认值"""
|
||||
|
||||
|
||||
class CustomPropertyUpdateRequest(SQLModelBase):
|
||||
"""更新自定义属性请求 DTO"""
|
||||
|
||||
name: str | None = None
|
||||
"""属性显示名称"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标标识"""
|
||||
|
||||
options: list[str] | None = None
|
||||
"""可选值列表"""
|
||||
|
||||
default_value: str | None = None
|
||||
"""默认值"""
|
||||
|
||||
sort_order: int | None = None
|
||||
"""排序顺序"""
|
||||
|
||||
|
||||
class CustomPropertyResponse(CustomPropertyDefinitionBase):
|
||||
"""自定义属性响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""属性定义UUID"""
|
||||
|
||||
sort_order: int
|
||||
"""排序顺序"""
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -141,7 +141,7 @@ class Download(DownloadBase, UUIDTableBaseMixin):
|
||||
speed: int = Field(default=0)
|
||||
"""下载速度(bytes/s)"""
|
||||
|
||||
parent: str | None = Field(default=None, max_length=255)
|
||||
parent: Str255 | None = None
|
||||
"""父任务标识"""
|
||||
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
@@ -20,7 +20,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str100, Str255, Text1024
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .group import Group
|
||||
@@ -119,7 +119,7 @@ class UserFileAppDefaultResponse(SQLModelBase):
|
||||
class FileAppCreateRequest(SQLModelBase):
|
||||
"""管理员创建应用请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=100)
|
||||
name: Str100
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str = Field(max_length=50)
|
||||
@@ -128,7 +128,7 @@ class FileAppCreateRequest(SQLModelBase):
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
icon: Str255 | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
@@ -140,13 +140,13 @@ class FileAppCreateRequest(SQLModelBase):
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||
iframe_url_template: Text1024 | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||
wopi_editor_url_template: Text1024 | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
extensions: list[str] = []
|
||||
@@ -159,7 +159,7 @@ class FileAppCreateRequest(SQLModelBase):
|
||||
class FileAppUpdateRequest(SQLModelBase):
|
||||
"""管理员更新应用请求 DTO(所有字段可选)"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
name: Str100 | None = None
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str | None = Field(default=None, max_length=50)
|
||||
@@ -168,7 +168,7 @@ class FileAppUpdateRequest(SQLModelBase):
|
||||
type: FileAppType | None = None
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
icon: Str255 | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
@@ -180,13 +180,13 @@ class FileAppUpdateRequest(SQLModelBase):
|
||||
is_restricted: bool | None = None
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||
iframe_url_template: Text1024 | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||
wopi_editor_url_template: Text1024 | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
|
||||
@@ -297,12 +297,35 @@ class WopiSessionResponse(SQLModelBase):
|
||||
"""完整的编辑器 URL"""
|
||||
|
||||
|
||||
class WopiDiscoveredExtension(SQLModelBase):
|
||||
"""单个 WOPI Discovery 发现的扩展名"""
|
||||
|
||||
extension: str
|
||||
"""文件扩展名"""
|
||||
|
||||
action_url: str
|
||||
"""处理后的动作 URL 模板"""
|
||||
|
||||
|
||||
class WopiDiscoveryResponse(SQLModelBase):
|
||||
"""WOPI Discovery 结果响应 DTO"""
|
||||
|
||||
discovered_extensions: list[WopiDiscoveredExtension] = []
|
||||
"""发现的扩展名及其 URL 模板"""
|
||||
|
||||
app_names: list[str] = []
|
||||
"""WOPI 服务端报告的应用名称(如 Writer、Calc、Impress)"""
|
||||
|
||||
applied_count: int = 0
|
||||
"""已应用到 FileAppExtension 的数量"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
||||
"""文件查看器应用注册表"""
|
||||
|
||||
name: str = Field(max_length=100)
|
||||
name: Str100
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str = Field(max_length=50, unique=True, index=True)
|
||||
@@ -311,7 +334,7 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
icon: Str255 | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
@@ -323,13 +346,13 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||
iframe_url_template: Text1024 | None = None
|
||||
"""iframe URL 模板,支持 {file_url} 占位符"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 客户端发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||
wopi_editor_url_template: Text1024 | None = None
|
||||
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||
|
||||
# 关系
|
||||
@@ -377,6 +400,9 @@ class FileAppExtension(SQLModelBase, TableBaseMixin):
|
||||
priority: int = Field(default=0, ge=0)
|
||||
"""排序优先级(越小越优先)"""
|
||||
|
||||
wopi_action_url: str | None = Field(default=None, max_length=2048)
|
||||
"""WOPI 动作 URL 模板(Discovery 自动填充),支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||
|
||||
# 关系
|
||||
app: FileApp = Relationship(back_populates="extensions")
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -66,7 +67,7 @@ class GroupAllOptionsBase(GroupOptionsBase):
|
||||
class GroupCreateRequest(GroupAllOptionsBase):
|
||||
"""创建用户组请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""用户组名称"""
|
||||
|
||||
max_storage: int = Field(default=0, ge=0)
|
||||
@@ -91,7 +92,7 @@ class GroupCreateRequest(GroupAllOptionsBase):
|
||||
class GroupUpdateRequest(SQLModelBase):
|
||||
"""更新用户组请求 DTO(所有字段可选)"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=255)
|
||||
name: Str255 | None = None
|
||||
"""用户组名称"""
|
||||
|
||||
max_storage: int | None = Field(default=None, ge=0)
|
||||
@@ -257,10 +258,10 @@ class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
|
||||
class Group(GroupBase, UUIDTableBaseMixin):
|
||||
"""用户组模型"""
|
||||
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
name: Str255 = Field(unique=True)
|
||||
"""用户组名"""
|
||||
|
||||
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||
"""最大存储空间(字节)"""
|
||||
|
||||
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
|
||||
@@ -130,6 +130,11 @@ default_settings: list[Setting] = [
|
||||
Setting(name="sms_provider", value="", type=SettingsType.MOBILE),
|
||||
Setting(name="sms_access_key", value="", type=SettingsType.MOBILE),
|
||||
Setting(name="sms_secret_key", value="", type=SettingsType.MOBILE),
|
||||
# ==================== 文件分类扩展名配置 ====================
|
||||
Setting(name="image", value="jpg,jpeg,png,gif,bmp,webp,svg,ico,tiff,tif,avif,heic,heif,psd,raw", type=SettingsType.FILE_CATEGORY),
|
||||
Setting(name="video", value="mp4,mkv,avi,mov,wmv,flv,webm,m4v,ts,3gp,mpg,mpeg", type=SettingsType.FILE_CATEGORY),
|
||||
Setting(name="audio", value="mp3,wav,flac,aac,ogg,wma,m4a,opus,ape,aiff,mid,midi", type=SettingsType.FILE_CATEGORY),
|
||||
Setting(name="document", value="pdf,doc,docx,odt,rtf,txt,tex,epub,pages,ppt,pptx,odp,key,xls,xlsx,csv,ods,numbers,tsv,md,markdown,mdx", type=SettingsType.FILE_CATEGORY),
|
||||
]
|
||||
|
||||
async def init_default_settings() -> None:
|
||||
@@ -173,7 +178,7 @@ async def init_default_group() -> None:
|
||||
admin=True,
|
||||
)
|
||||
admin_group_id = admin_group.id # 在 save 前保存 UUID
|
||||
await admin_group.save(session)
|
||||
admin_group = await admin_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=admin_group_id,
|
||||
@@ -203,7 +208,7 @@ async def init_default_group() -> None:
|
||||
web_dav_enabled=True,
|
||||
)
|
||||
member_group_id = member_group.id # 在 save 前保存 UUID
|
||||
await member_group.save(session)
|
||||
member_group = await member_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=member_group_id,
|
||||
@@ -222,7 +227,7 @@ async def init_default_group() -> None:
|
||||
default_group_setting = await Setting.get(session, Setting.name == "default_group")
|
||||
if default_group_setting:
|
||||
default_group_setting.value = str(member_group_id)
|
||||
await default_group_setting.save(session)
|
||||
default_group_setting = await default_group_setting.save(session)
|
||||
|
||||
# 未找到初始游客组时,则创建
|
||||
if not await Group.get(session, Group.name == "游客"):
|
||||
@@ -232,7 +237,7 @@ async def init_default_group() -> None:
|
||||
web_dav_enabled=False,
|
||||
)
|
||||
guest_group_id = guest_group.id # 在 save 前保存 UUID
|
||||
await guest_group.save(session)
|
||||
guest_group = await guest_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=guest_group_id,
|
||||
@@ -284,7 +289,7 @@ async def init_default_user() -> None:
|
||||
group_id=admin_group.id,
|
||||
)
|
||||
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
||||
await admin_user.save(session)
|
||||
admin_user = await admin_user.save(session)
|
||||
|
||||
# 创建 AuthIdentity(邮箱密码身份)
|
||||
await AuthIdentity(
|
||||
@@ -373,7 +378,7 @@ async def init_default_theme_presets() -> None:
|
||||
error=ChromaticColor.RED,
|
||||
neutral=NeutralColor.ZINC,
|
||||
)
|
||||
await default_preset.save(session)
|
||||
default_preset = await default_preset.save(session)
|
||||
log.info('已创建默认主题预设')
|
||||
|
||||
|
||||
@@ -446,36 +451,43 @@ _DEFAULT_FILE_APPS: list[dict] = [
|
||||
"is_enabled": True,
|
||||
"extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"],
|
||||
},
|
||||
# iframe 应用(默认禁用)
|
||||
{
|
||||
"name": "EPUB 阅读器",
|
||||
"app_key": "epub_reader",
|
||||
"type": "builtin",
|
||||
"icon": "book-open",
|
||||
"description": "阅读 EPUB 电子书",
|
||||
"is_enabled": True,
|
||||
"extensions": ["epub"],
|
||||
},
|
||||
{
|
||||
"name": "3D 模型预览",
|
||||
"app_key": "model_viewer",
|
||||
"type": "builtin",
|
||||
"icon": "cube",
|
||||
"description": "预览 3D 模型",
|
||||
"is_enabled": True,
|
||||
"extensions": ["gltf", "glb", "stl", "obj", "fbx", "ply", "3mf"],
|
||||
},
|
||||
{
|
||||
"name": "Font Viewer",
|
||||
"app_key": "font_viewer",
|
||||
"type": "builtin",
|
||||
"icon": "type",
|
||||
"description": "预览字体文件并显示元数据和文本样本",
|
||||
"is_enabled": True,
|
||||
"extensions": ["ttf", "otf", "woff", "woff2"],
|
||||
},
|
||||
{
|
||||
"name": "Office 在线预览",
|
||||
"app_key": "office_viewer",
|
||||
"type": "iframe",
|
||||
"icon": "file-word",
|
||||
"description": "使用 Microsoft Office Online 预览文档",
|
||||
"is_enabled": False,
|
||||
"is_enabled": True,
|
||||
"iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
|
||||
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
|
||||
},
|
||||
# WOPI 应用(默认禁用)
|
||||
{
|
||||
"name": "Collabora Online",
|
||||
"app_key": "collabora",
|
||||
"type": "wopi",
|
||||
"icon": "file-text",
|
||||
"description": "Collabora Online 文档编辑器(需自行部署)",
|
||||
"is_enabled": False,
|
||||
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx", "odt", "ods", "odp"],
|
||||
},
|
||||
{
|
||||
"name": "OnlyOffice",
|
||||
"app_key": "onlyoffice",
|
||||
"type": "wopi",
|
||||
"icon": "file-text",
|
||||
"description": "OnlyOffice 文档编辑器(需自行部署)",
|
||||
"is_enabled": False,
|
||||
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -493,7 +505,7 @@ async def init_default_file_apps() -> None:
|
||||
return
|
||||
|
||||
for app_data in _DEFAULT_FILE_APPS:
|
||||
extensions = app_data.pop("extensions")
|
||||
extensions = app_data["extensions"]
|
||||
|
||||
app = FileApp(
|
||||
name=app_data["name"],
|
||||
@@ -515,6 +527,6 @@ async def init_default_file_apps() -> None:
|
||||
extension=ext.lower(),
|
||||
priority=i,
|
||||
)
|
||||
await ext_record.save(session)
|
||||
ext_record = await ext_record.save(session)
|
||||
|
||||
log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用')
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship, text, Index
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .download import Download
|
||||
@@ -28,13 +28,13 @@ class NodeType(StrEnum):
|
||||
class Aria2ConfigurationBase(SQLModelBase):
|
||||
"""Aria2配置基础模型"""
|
||||
|
||||
rpc_url: str | None = Field(default=None, max_length=255)
|
||||
rpc_url: Str255 | None = None
|
||||
"""RPC地址"""
|
||||
|
||||
rpc_secret: str | None = None
|
||||
"""RPC密钥"""
|
||||
|
||||
temp_path: str | None = Field(default=None, max_length=255)
|
||||
temp_path: Str255 | None = None
|
||||
"""临时下载路径"""
|
||||
|
||||
max_concurrent: int = Field(default=5, ge=1, le=50)
|
||||
@@ -70,19 +70,19 @@ class Node(SQLModelBase, TableBaseMixin):
|
||||
status: NodeStatus = Field(default=NodeStatus.ONLINE)
|
||||
"""节点状态"""
|
||||
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
name: Str255 = Field(unique=True)
|
||||
"""节点名称"""
|
||||
|
||||
type: NodeType
|
||||
"""节点类型"""
|
||||
|
||||
server: str = Field(max_length=255)
|
||||
server: Str255
|
||||
"""节点地址(IP或域名)"""
|
||||
|
||||
slave_key: str | None = Field(default=None, max_length=255)
|
||||
slave_key: Str255 | None = None
|
||||
"""从机通讯密钥"""
|
||||
|
||||
master_key: str | None = Field(default=None, max_length=255)
|
||||
master_key: Str255 | None = None
|
||||
"""主机通讯密钥"""
|
||||
|
||||
aria2_enabled: bool = False
|
||||
|
||||
@@ -7,7 +7,9 @@ from enum import StrEnum
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, CheckConstraint, Index, text
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255, Str256
|
||||
|
||||
from .policy import PolicyType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -16,49 +18,21 @@ if TYPE_CHECKING:
|
||||
from .share import Share
|
||||
from .physical_file import PhysicalFile
|
||||
from .uri import DiskNextURI
|
||||
from .object_metadata import ObjectMetadata
|
||||
|
||||
|
||||
class ObjectType(StrEnum):
|
||||
"""对象类型枚举"""
|
||||
FILE = "file"
|
||||
FOLDER = "folder"
|
||||
|
||||
class StorageType(StrEnum):
|
||||
"""存储类型枚举"""
|
||||
LOCAL = "local"
|
||||
QINIU = "qiniu"
|
||||
TENCENT = "tencent"
|
||||
ALIYUN = "aliyun"
|
||||
ONEDRIVE = "onedrive"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
DROPBOX = "dropbox"
|
||||
WEBDAV = "webdav"
|
||||
REMOTE = "remote"
|
||||
|
||||
|
||||
class FileMetadataBase(SQLModelBase):
|
||||
"""文件元数据基础模型"""
|
||||
|
||||
width: int | None = Field(default=None)
|
||||
"""图片宽度(像素)"""
|
||||
|
||||
height: int | None = Field(default=None)
|
||||
"""图片高度(像素)"""
|
||||
|
||||
duration: float | None = Field(default=None)
|
||||
"""音视频时长(秒)"""
|
||||
|
||||
bitrate: int | None = Field(default=None)
|
||||
"""比特率(kbps)"""
|
||||
|
||||
mime_type: str | None = Field(default=None, max_length=127)
|
||||
"""MIME类型"""
|
||||
|
||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||
"""MD5校验和"""
|
||||
|
||||
checksum_sha256: str | None = Field(default=None, max_length=64)
|
||||
"""SHA256校验和"""
|
||||
class FileCategory(StrEnum):
|
||||
"""文件类型分类枚举,用于按类别筛选文件"""
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
@@ -75,9 +49,32 @@ class ObjectBase(SQLModelBase):
|
||||
size: int | None = None
|
||||
"""文件大小(字节),目录为 None"""
|
||||
|
||||
mime_type: str | None = Field(default=None, max_length=127)
|
||||
"""MIME类型(仅文件有效)"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class ObjectFileFinalize(SQLModelBase):
|
||||
"""文件上传完成后更新 Object 的 DTO"""
|
||||
|
||||
size: int
|
||||
"""文件大小(字节)"""
|
||||
|
||||
physical_file_id: UUID
|
||||
"""关联的物理文件UUID"""
|
||||
|
||||
|
||||
class ObjectMoveUpdate(SQLModelBase):
|
||||
"""移动/重命名 Object 的 DTO"""
|
||||
|
||||
parent_id: UUID
|
||||
"""新的父目录UUID"""
|
||||
|
||||
name: str
|
||||
"""新名称"""
|
||||
|
||||
|
||||
class DirectoryCreateRequest(SQLModelBase):
|
||||
"""创建目录请求 DTO"""
|
||||
|
||||
@@ -136,7 +133,7 @@ class PolicyResponse(SQLModelBase):
|
||||
name: str
|
||||
"""策略名称"""
|
||||
|
||||
type: StorageType
|
||||
type: PolicyType
|
||||
"""存储类型"""
|
||||
|
||||
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
||||
@@ -164,22 +161,6 @@ class DirectoryResponse(SQLModelBase):
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
|
||||
"""文件元数据模型(与Object一对一关联)"""
|
||||
|
||||
object_id: UUID = Field(
|
||||
foreign_key="object.id",
|
||||
unique=True,
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的对象UUID"""
|
||||
|
||||
# 反向关系
|
||||
object: "Object" = Relationship(back_populates="file_metadata")
|
||||
"""关联的对象"""
|
||||
|
||||
|
||||
class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
统一对象模型
|
||||
@@ -217,13 +198,13 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
|
||||
# ==================== 基础字段 ====================
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""对象名称(文件名或目录名)"""
|
||||
|
||||
type: ObjectType
|
||||
"""对象类型:file 或 folder"""
|
||||
|
||||
password: str | None = Field(default=None, max_length=255)
|
||||
password: Str255 | None = None
|
||||
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
|
||||
|
||||
# ==================== 文件专属字段 ====================
|
||||
@@ -231,7 +212,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||
"""文件大小(字节),目录为 0"""
|
||||
|
||||
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
||||
upload_session_id: Str255 | None = Field(default=None, unique=True, index=True)
|
||||
"""分块上传会话ID(仅文件有效)"""
|
||||
|
||||
physical_file_id: UUID | None = Field(
|
||||
@@ -334,11 +315,11 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
"""子对象(文件和子目录)"""
|
||||
|
||||
# 仅文件有效的关系
|
||||
file_metadata: FileMetadata | None = Relationship(
|
||||
metadata_entries: list["ObjectMetadata"] = Relationship(
|
||||
back_populates="object",
|
||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
|
||||
)
|
||||
"""文件元数据(仅文件有效)"""
|
||||
"""元数据键值对列表"""
|
||||
|
||||
source_links: list["SourceLink"] = Relationship(
|
||||
back_populates="object",
|
||||
@@ -496,6 +477,37 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
fetch_mode="all"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_by_category(
|
||||
cls,
|
||||
session: 'AsyncSession',
|
||||
user_id: UUID,
|
||||
extensions: list[str],
|
||||
table_view: 'TableViewRequest | None' = None,
|
||||
) -> 'ListResponse[Object]':
|
||||
"""
|
||||
按扩展名列表查询用户的所有文件(跨目录)
|
||||
|
||||
只查询未删除、未封禁的文件对象,使用 ILIKE 匹配文件名后缀。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param extensions: 扩展名列表(不含点号)
|
||||
:param table_view: 分页排序参数
|
||||
:return: 分页文件列表
|
||||
"""
|
||||
from sqlalchemy import or_
|
||||
|
||||
ext_conditions = [cls.name.ilike(f"%.{ext}") for ext in extensions]
|
||||
condition = (
|
||||
(cls.owner_id == user_id) &
|
||||
(cls.type == ObjectType.FILE) &
|
||||
(cls.deleted_at == None) &
|
||||
(cls.is_banned == False) &
|
||||
or_(*ext_conditions)
|
||||
)
|
||||
return await cls.get_with_count(session, condition, table_view=table_view)
|
||||
|
||||
@classmethod
|
||||
async def resolve_uri(
|
||||
cls,
|
||||
@@ -573,7 +585,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
class UploadSessionBase(SQLModelBase):
|
||||
"""上传会话基础字段"""
|
||||
|
||||
file_name: str = Field(max_length=255)
|
||||
file_name: Str255
|
||||
"""原始文件名"""
|
||||
|
||||
file_size: int = Field(ge=0, sa_type=BigInteger)
|
||||
@@ -604,6 +616,12 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
||||
storage_path: str | None = Field(default=None, max_length=512)
|
||||
"""文件存储路径"""
|
||||
|
||||
s3_upload_id: Str256 | None = None
|
||||
"""S3 Multipart Upload ID(仅 S3 策略使用)"""
|
||||
|
||||
s3_part_etags: str | None = None
|
||||
"""S3 已上传分片的 ETag 列表,JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
|
||||
|
||||
expires_at: datetime
|
||||
"""会话过期时间"""
|
||||
|
||||
@@ -645,7 +663,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
||||
class CreateUploadSessionRequest(SQLModelBase):
|
||||
"""创建上传会话请求 DTO"""
|
||||
|
||||
file_name: str = Field(max_length=255)
|
||||
file_name: Str255
|
||||
"""文件名"""
|
||||
|
||||
file_size: int = Field(ge=0)
|
||||
@@ -702,7 +720,7 @@ class UploadChunkResponse(SQLModelBase):
|
||||
class CreateFileRequest(SQLModelBase):
|
||||
"""创建空白文件请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""文件名"""
|
||||
|
||||
parent_id: UUID
|
||||
@@ -712,6 +730,16 @@ class CreateFileRequest(SQLModelBase):
|
||||
"""存储策略UUID,不指定则使用父目录的策略"""
|
||||
|
||||
|
||||
class ObjectSwitchPolicyRequest(SQLModelBase):
|
||||
"""切换对象存储策略请求"""
|
||||
|
||||
policy_id: UUID
|
||||
"""目标存储策略UUID"""
|
||||
|
||||
is_migrate_existing: bool = False
|
||||
"""(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
|
||||
|
||||
|
||||
# ==================== 对象操作相关 DTO ====================
|
||||
|
||||
class ObjectCopyRequest(SQLModelBase):
|
||||
@@ -730,7 +758,7 @@ class ObjectRenameRequest(SQLModelBase):
|
||||
id: UUID
|
||||
"""对象UUID"""
|
||||
|
||||
new_name: str = Field(max_length=255)
|
||||
new_name: Str255
|
||||
"""新名称"""
|
||||
|
||||
|
||||
@@ -749,6 +777,9 @@ class ObjectPropertyResponse(SQLModelBase):
|
||||
size: int
|
||||
"""文件大小(字节)"""
|
||||
|
||||
mime_type: str | None = None
|
||||
"""MIME类型"""
|
||||
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
@@ -762,22 +793,13 @@ class ObjectPropertyResponse(SQLModelBase):
|
||||
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
||||
"""对象详细属性响应 DTO(继承基本属性)"""
|
||||
|
||||
# 元数据信息
|
||||
mime_type: str | None = None
|
||||
"""MIME类型"""
|
||||
|
||||
width: int | None = None
|
||||
"""图片宽度(像素)"""
|
||||
|
||||
height: int | None = None
|
||||
"""图片高度(像素)"""
|
||||
|
||||
duration: float | None = None
|
||||
"""音视频时长(秒)"""
|
||||
|
||||
# 校验和(从 PhysicalFile 读取)
|
||||
checksum_md5: str | None = None
|
||||
"""MD5校验和"""
|
||||
|
||||
checksum_sha256: str | None = None
|
||||
"""SHA256校验和"""
|
||||
|
||||
# 分享统计
|
||||
share_count: int = 0
|
||||
"""分享次数"""
|
||||
@@ -795,6 +817,10 @@ class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
||||
reference_count: int = 1
|
||||
"""物理文件引用计数(仅文件有效)"""
|
||||
|
||||
# 元数据(KV 格式)
|
||||
metadatas: dict[str, str] = {}
|
||||
"""所有元数据条目(键名 → 值)"""
|
||||
|
||||
|
||||
# ==================== 管理员文件管理 DTO ====================
|
||||
|
||||
|
||||
127
sqlmodels/object_metadata.py
Normal file
127
sqlmodels/object_metadata.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
对象元数据 KV 模型
|
||||
|
||||
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
|
||||
如 exif:width, stream:duration, music:artist 等。
|
||||
|
||||
架构:
|
||||
ObjectMetadata (KV 表,与 Object 一对多关系)
|
||||
└── 每个 Object 可以有多条元数据记录
|
||||
└── (object_id, name) 组合唯一索引
|
||||
|
||||
命名空间:
|
||||
- exif: 图片 EXIF 信息(尺寸、相机参数、拍摄时间等)
|
||||
- stream: 音视频流信息(时长、比特率、视频尺寸、编解码等)
|
||||
- music: 音乐标签(标题、艺术家、专辑等)
|
||||
- geo: 地理位置(经纬度、地址)
|
||||
- apk: Android 安装包信息
|
||||
- custom: 用户自定义属性
|
||||
- sys: 系统内部元数据
|
||||
- thumb: 缩略图信息
|
||||
"""
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, UniqueConstraint, Index, Relationship
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
|
||||
|
||||
# ==================== 枚举 ====================
|
||||
|
||||
class MetadataNamespace(StrEnum):
|
||||
"""元数据命名空间枚举"""
|
||||
EXIF = "exif"
|
||||
"""图片 EXIF 信息(含尺寸、相机参数、拍摄时间等)"""
|
||||
MUSIC = "music"
|
||||
"""音乐标签(title/artist/album/genre 等)"""
|
||||
STREAM = "stream"
|
||||
"""音视频流信息(codec/duration/bitrate/resolution 等)"""
|
||||
GEO = "geo"
|
||||
"""地理位置(latitude/longitude/address)"""
|
||||
APK = "apk"
|
||||
"""Android 安装包信息(package_name/version 等)"""
|
||||
THUMB = "thumb"
|
||||
"""缩略图信息(内部使用)"""
|
||||
SYS = "sys"
|
||||
"""系统元数据(内部使用)"""
|
||||
CUSTOM = "custom"
|
||||
"""用户自定义属性"""
|
||||
|
||||
|
||||
# 对外不可见的命名空间(API 不返回给普通用户)
|
||||
INTERNAL_NAMESPACES: set[str] = {MetadataNamespace.SYS, MetadataNamespace.THUMB}
|
||||
|
||||
# 用户可写的命名空间
|
||||
USER_WRITABLE_NAMESPACES: set[str] = {MetadataNamespace.CUSTOM}
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class ObjectMetadataBase(SQLModelBase):
|
||||
"""对象元数据 KV 基础模型"""
|
||||
|
||||
name: Str255
|
||||
"""元数据键名,格式:namespace:key(如 exif:width, stream:duration)"""
|
||||
|
||||
value: str
|
||||
"""元数据值(统一为字符串存储)"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class ObjectMetadata(ObjectMetadataBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
对象元数据 KV 模型
|
||||
|
||||
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
|
||||
每个对象的每个键名唯一(通过唯一索引保证)。
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("object_id", "name", name="uq_object_metadata_object_name"),
|
||||
Index("ix_object_metadata_object_id", "object_id"),
|
||||
)
|
||||
|
||||
object_id: UUID = Field(
|
||||
foreign_key="object.id",
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
"""关联的对象UUID"""
|
||||
|
||||
is_public: bool = False
|
||||
"""是否对分享页面公开"""
|
||||
|
||||
# 关系
|
||||
object: "Object" = Relationship(back_populates="metadata_entries")
|
||||
"""关联的对象"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class MetadataResponse(SQLModelBase):
|
||||
"""元数据查询响应 DTO"""
|
||||
|
||||
metadatas: dict[str, str]
|
||||
"""元数据字典(键名 → 值)"""
|
||||
|
||||
|
||||
class MetadataPatchItem(SQLModelBase):
|
||||
"""单条元数据补丁 DTO"""
|
||||
|
||||
key: Str255
|
||||
"""元数据键名"""
|
||||
|
||||
value: str | None = None
|
||||
"""值,None 表示删除此条目"""
|
||||
|
||||
|
||||
class MetadataPatchRequest(SQLModelBase):
|
||||
"""元数据批量更新请求 DTO"""
|
||||
|
||||
patches: list[MetadataPatchItem]
|
||||
"""补丁列表"""
|
||||
@@ -1,58 +1,122 @@
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Numeric
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .product import Product
|
||||
from .user import User
|
||||
|
||||
|
||||
class OrderStatus(StrEnum):
|
||||
"""订单状态枚举"""
|
||||
|
||||
PENDING = "pending"
|
||||
"""待支付"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
"""已完成"""
|
||||
|
||||
CANCELLED = "cancelled"
|
||||
"""已取消"""
|
||||
|
||||
|
||||
class OrderType(StrEnum):
|
||||
"""订单类型枚举"""
|
||||
# [TODO] 补充具体订单类型
|
||||
pass
|
||||
|
||||
STORAGE_PACK = "storage_pack"
|
||||
"""容量包"""
|
||||
|
||||
GROUP_TIME = "group_time"
|
||||
"""用户组时长"""
|
||||
|
||||
SCORE = "score"
|
||||
"""积分充值"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class CreateOrderRequest(SQLModelBase):
|
||||
"""创建订单请求 DTO"""
|
||||
|
||||
product_id: UUID
|
||||
"""商品UUID"""
|
||||
|
||||
num: int = Field(default=1, ge=1)
|
||||
"""购买数量"""
|
||||
|
||||
method: str
|
||||
"""支付方式"""
|
||||
|
||||
|
||||
class OrderResponse(SQLModelBase):
|
||||
"""订单响应 DTO"""
|
||||
|
||||
id: int
|
||||
"""订单ID"""
|
||||
|
||||
order_no: str
|
||||
"""订单号"""
|
||||
|
||||
type: OrderType
|
||||
"""订单类型"""
|
||||
|
||||
method: str | None = None
|
||||
"""支付方式"""
|
||||
|
||||
product_id: UUID | None = None
|
||||
"""商品UUID"""
|
||||
|
||||
num: int
|
||||
"""购买数量"""
|
||||
|
||||
name: str
|
||||
"""商品名称"""
|
||||
|
||||
price: float
|
||||
"""订单价格(元)"""
|
||||
|
||||
status: OrderStatus
|
||||
"""订单状态"""
|
||||
|
||||
user_id: UUID
|
||||
"""所属用户UUID"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class Order(SQLModelBase, TableBaseMixin):
|
||||
"""订单模型"""
|
||||
|
||||
order_no: str = Field(max_length=255, unique=True, index=True)
|
||||
order_no: Str255 = Field(unique=True, index=True)
|
||||
"""订单号,唯一"""
|
||||
|
||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""订单类型 [TODO] 待定义枚举"""
|
||||
type: OrderType
|
||||
"""订单类型"""
|
||||
|
||||
method: str | None = Field(default=None, max_length=255)
|
||||
method: Str255 | None = None
|
||||
"""支付方式"""
|
||||
|
||||
product_id: int | None = Field(default=None)
|
||||
"""商品ID"""
|
||||
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
||||
"""关联商品UUID"""
|
||||
|
||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
||||
"""购买数量"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""商品名称"""
|
||||
|
||||
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""订单价格(分)"""
|
||||
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
|
||||
"""订单价格(元)"""
|
||||
|
||||
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
||||
"""订单状态"""
|
||||
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
@@ -60,6 +124,22 @@ class Order(SQLModelBase, TableBaseMixin):
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="orders")
|
||||
user: "User" = Relationship(back_populates="orders")
|
||||
product: "Product" = Relationship(back_populates="orders")
|
||||
|
||||
def to_response(self) -> OrderResponse:
|
||||
"""转换为响应 DTO"""
|
||||
return OrderResponse(
|
||||
id=self.id,
|
||||
order_no=self.order_no,
|
||||
type=self.type,
|
||||
method=self.method,
|
||||
product_id=self.product_id,
|
||||
num=self.num,
|
||||
name=self.name,
|
||||
price=float(self.price),
|
||||
status=self.status,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from uuid import UUID
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str32, Str64
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
@@ -31,9 +31,12 @@ class PhysicalFileBase(SQLModelBase):
|
||||
size: int = Field(default=0, sa_type=BigInteger)
|
||||
"""文件大小(字节)"""
|
||||
|
||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||
checksum_md5: Str32 | None = None
|
||||
"""MD5校验和(用于文件去重和完整性校验)"""
|
||||
|
||||
checksum_sha256: Str64 | None = None
|
||||
"""SHA256校验和"""
|
||||
|
||||
|
||||
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
from enum import StrEnum
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
@@ -37,22 +37,22 @@ class PolicyType(StrEnum):
|
||||
class PolicyBase(SQLModelBase):
|
||||
"""存储策略基础字段,供 DTO 和数据库模型共享"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""策略名称"""
|
||||
|
||||
type: PolicyType
|
||||
"""存储策略类型"""
|
||||
|
||||
server: str | None = Field(default=None, max_length=255)
|
||||
server: Str255 | None = None
|
||||
"""服务器地址(本地策略为绝对路径)"""
|
||||
|
||||
bucket_name: str | None = Field(default=None, max_length=255)
|
||||
bucket_name: Str255 | None = None
|
||||
"""存储桶名称"""
|
||||
|
||||
is_private: bool = True
|
||||
"""是否为私有空间"""
|
||||
|
||||
base_url: str | None = Field(default=None, max_length=255)
|
||||
base_url: Str255 | None = None
|
||||
"""访问文件的基础URL"""
|
||||
|
||||
access_key: str | None = None
|
||||
@@ -67,10 +67,10 @@ class PolicyBase(SQLModelBase):
|
||||
auto_rename: bool = False
|
||||
"""是否自动重命名"""
|
||||
|
||||
dir_name_rule: str | None = Field(default=None, max_length=255)
|
||||
dir_name_rule: Str255 | None = None
|
||||
"""目录命名规则"""
|
||||
|
||||
file_name_rule: str | None = Field(default=None, max_length=255)
|
||||
file_name_rule: Str255 | None = None
|
||||
"""文件命名规则"""
|
||||
|
||||
is_origin_link_enable: bool = False
|
||||
@@ -102,6 +102,94 @@ class PolicySummary(SQLModelBase):
|
||||
"""是否私有"""
|
||||
|
||||
|
||||
class PolicyCreateRequest(PolicyBase):
|
||||
"""创建存储策略请求 DTO,包含 PolicyOptions 扁平字段"""
|
||||
|
||||
# PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
|
||||
token: str | None = None
|
||||
"""访问令牌"""
|
||||
|
||||
file_type: str | None = None
|
||||
"""允许的文件类型"""
|
||||
|
||||
mimetype: str | None = Field(default=None, max_length=127)
|
||||
"""MIME类型"""
|
||||
|
||||
od_redirect: Str255 | None = None
|
||||
"""OneDrive重定向地址"""
|
||||
|
||||
chunk_size: int = Field(default=52428800, ge=1)
|
||||
"""分片上传大小(字节),默认50MB"""
|
||||
|
||||
s3_path_style: bool = False
|
||||
"""是否使用S3路径风格"""
|
||||
|
||||
s3_region: Str64 = 'us-east-1'
|
||||
"""S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
|
||||
|
||||
|
||||
class PolicyUpdateRequest(SQLModelBase):
|
||||
"""更新存储策略请求 DTO(所有字段可选)"""
|
||||
|
||||
name: Str255 | None = None
|
||||
"""策略名称"""
|
||||
|
||||
server: Str255 | None = None
|
||||
"""服务器地址"""
|
||||
|
||||
bucket_name: Str255 | None = None
|
||||
"""存储桶名称"""
|
||||
|
||||
is_private: bool | None = None
|
||||
"""是否为私有空间"""
|
||||
|
||||
base_url: Str255 | None = None
|
||||
"""访问文件的基础URL"""
|
||||
|
||||
access_key: str | None = None
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str | None = None
|
||||
"""Secret Key"""
|
||||
|
||||
max_size: int | None = Field(default=None, ge=0)
|
||||
"""允许上传的最大文件尺寸(字节)"""
|
||||
|
||||
auto_rename: bool | None = None
|
||||
"""是否自动重命名"""
|
||||
|
||||
dir_name_rule: Str255 | None = None
|
||||
"""目录命名规则"""
|
||||
|
||||
file_name_rule: Str255 | None = None
|
||||
"""文件命名规则"""
|
||||
|
||||
is_origin_link_enable: bool | None = None
|
||||
"""是否开启源链接访问"""
|
||||
|
||||
# PolicyOptions 字段
|
||||
token: str | None = None
|
||||
"""访问令牌"""
|
||||
|
||||
file_type: str | None = None
|
||||
"""允许的文件类型"""
|
||||
|
||||
mimetype: str | None = Field(default=None, max_length=127)
|
||||
"""MIME类型"""
|
||||
|
||||
od_redirect: Str255 | None = None
|
||||
"""OneDrive重定向地址"""
|
||||
|
||||
chunk_size: int | None = Field(default=None, ge=1)
|
||||
"""分片上传大小(字节)"""
|
||||
|
||||
s3_path_style: bool | None = None
|
||||
"""是否使用S3路径风格"""
|
||||
|
||||
s3_region: Str64 | None = None
|
||||
"""S3 区域"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
|
||||
@@ -117,7 +205,7 @@ class PolicyOptionsBase(SQLModelBase):
|
||||
mimetype: str | None = Field(default=None, max_length=127)
|
||||
"""MIME类型"""
|
||||
|
||||
od_redirect: str | None = Field(default=None, max_length=255)
|
||||
od_redirect: Str255 | None = None
|
||||
"""OneDrive重定向地址"""
|
||||
|
||||
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
|
||||
@@ -126,6 +214,9 @@ class PolicyOptionsBase(SQLModelBase):
|
||||
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否使用S3路径风格"""
|
||||
|
||||
s3_region: Str64 = Field(default='us-east-1', sa_column_kwargs={"server_default": "'us-east-1'"})
|
||||
"""S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
|
||||
|
||||
|
||||
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
|
||||
"""存储策略选项模型(与Policy一对一关联)"""
|
||||
@@ -146,7 +237,7 @@ class Policy(PolicyBase, UUIDTableBaseMixin):
|
||||
"""存储策略模型"""
|
||||
|
||||
# 覆盖基类字段以添加数据库专有配置
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
name: Str255 = Field(unique=True)
|
||||
"""策略名称"""
|
||||
|
||||
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||
|
||||
206
sqlmodels/product.py
Normal file
206
sqlmodels/product.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Numeric, BigInteger
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .order import Order
|
||||
from .redeem import Redeem
|
||||
|
||||
|
||||
class ProductType(StrEnum):
|
||||
"""商品类型枚举"""
|
||||
|
||||
STORAGE_PACK = "storage_pack"
|
||||
"""容量包"""
|
||||
|
||||
GROUP_TIME = "group_time"
|
||||
"""用户组时长"""
|
||||
|
||||
SCORE = "score"
|
||||
"""积分充值"""
|
||||
|
||||
|
||||
class PaymentMethod(StrEnum):
|
||||
"""支付方式枚举"""
|
||||
|
||||
ALIPAY = "alipay"
|
||||
"""支付宝"""
|
||||
|
||||
WECHAT = "wechat"
|
||||
"""微信支付"""
|
||||
|
||||
STRIPE = "stripe"
|
||||
"""Stripe"""
|
||||
|
||||
EASYPAY = "easypay"
|
||||
"""易支付"""
|
||||
|
||||
CUSTOM = "custom"
|
||||
"""自定义支付"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class ProductBase(SQLModelBase):
|
||||
"""商品基础字段"""
|
||||
|
||||
name: str
|
||||
"""商品名称"""
|
||||
|
||||
type: ProductType
|
||||
"""商品类型"""
|
||||
|
||||
description: str | None = None
|
||||
"""商品描述"""
|
||||
|
||||
|
||||
class ProductCreateRequest(ProductBase):
|
||||
"""创建商品请求 DTO"""
|
||||
|
||||
name: Str255
|
||||
"""商品名称"""
|
||||
|
||||
price: Decimal = Field(ge=0, decimal_places=2)
|
||||
"""商品价格(元)"""
|
||||
|
||||
is_active: bool = True
|
||||
"""是否上架"""
|
||||
|
||||
sort_order: int = Field(default=0, ge=0)
|
||||
"""排序权重(越大越靠前)"""
|
||||
|
||||
# storage_pack 专用
|
||||
size: int | None = Field(default=None, ge=0)
|
||||
"""容量大小(字节),type=storage_pack 时必填"""
|
||||
|
||||
duration_days: int | None = Field(default=None, ge=1)
|
||||
"""有效天数,type=storage_pack/group_time 时必填"""
|
||||
|
||||
# group_time 专用
|
||||
group_id: UUID | None = None
|
||||
"""目标用户组UUID,type=group_time 时必填"""
|
||||
|
||||
# score 专用
|
||||
score_amount: int | None = Field(default=None, ge=1)
|
||||
"""积分数量,type=score 时必填"""
|
||||
|
||||
|
||||
class ProductUpdateRequest(SQLModelBase):
|
||||
"""更新商品请求 DTO(所有字段可选)"""
|
||||
|
||||
name: Str255 | None = None
|
||||
"""商品名称"""
|
||||
|
||||
description: str | None = None
|
||||
"""商品描述"""
|
||||
|
||||
price: Decimal | None = Field(default=None, ge=0, decimal_places=2)
|
||||
"""商品价格(元)"""
|
||||
|
||||
is_active: bool | None = None
|
||||
"""是否上架"""
|
||||
|
||||
sort_order: int | None = Field(default=None, ge=0)
|
||||
"""排序权重"""
|
||||
|
||||
size: int | None = Field(default=None, ge=0)
|
||||
"""容量大小(字节)"""
|
||||
|
||||
duration_days: int | None = Field(default=None, ge=1)
|
||||
"""有效天数"""
|
||||
|
||||
group_id: UUID | None = None
|
||||
"""目标用户组UUID"""
|
||||
|
||||
score_amount: int | None = Field(default=None, ge=1)
|
||||
"""积分数量"""
|
||||
|
||||
|
||||
class ProductResponse(ProductBase):
|
||||
"""商品响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""商品UUID"""
|
||||
|
||||
price: float
|
||||
"""商品价格(元)"""
|
||||
|
||||
is_active: bool
|
||||
"""是否上架"""
|
||||
|
||||
sort_order: int
|
||||
"""排序权重"""
|
||||
|
||||
size: int | None = None
|
||||
"""容量大小(字节)"""
|
||||
|
||||
duration_days: int | None = None
|
||||
"""有效天数"""
|
||||
|
||||
group_id: UUID | None = None
|
||||
"""目标用户组UUID"""
|
||||
|
||||
score_amount: int | None = None
|
||||
"""积分数量"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class Product(ProductBase, UUIDTableBaseMixin):
|
||||
"""商品模型"""
|
||||
|
||||
name: Str255
|
||||
"""商品名称"""
|
||||
|
||||
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
|
||||
"""商品价格(元)"""
|
||||
|
||||
is_active: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||
"""是否上架"""
|
||||
|
||||
sort_order: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""排序权重(越大越靠前)"""
|
||||
|
||||
# storage_pack 专用
|
||||
size: int | None = Field(default=None, sa_type=BigInteger)
|
||||
"""容量大小(字节),type=storage_pack 时必填"""
|
||||
|
||||
duration_days: int | None = None
|
||||
"""有效天数,type=storage_pack/group_time 时必填"""
|
||||
|
||||
# group_time 专用
|
||||
group_id: UUID | None = Field(default=None, foreign_key="group.id", ondelete="SET NULL")
|
||||
"""目标用户组UUID,type=group_time 时必填"""
|
||||
|
||||
# score 专用
|
||||
score_amount: int | None = None
|
||||
"""积分数量,type=score 时必填"""
|
||||
|
||||
# 关系
|
||||
orders: list["Order"] = Relationship(back_populates="product")
|
||||
"""关联的订单列表"""
|
||||
|
||||
redeems: list["Redeem"] = Relationship(back_populates="product")
|
||||
"""关联的兑换码列表"""
|
||||
|
||||
def to_response(self) -> ProductResponse:
|
||||
"""转换为响应 DTO"""
|
||||
return ProductResponse(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
type=self.type,
|
||||
description=self.description,
|
||||
price=float(self.price),
|
||||
is_active=self.is_active,
|
||||
sort_order=self.sort_order,
|
||||
size=self.size,
|
||||
duration_days=self.duration_days,
|
||||
group_id=self.group_id,
|
||||
score_amount=self.score_amount,
|
||||
)
|
||||
@@ -1,22 +1,141 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, text
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .product import Product
|
||||
from .user import User
|
||||
|
||||
|
||||
class RedeemType(StrEnum):
|
||||
"""兑换码类型枚举"""
|
||||
# [TODO] 补充具体兑换码类型
|
||||
pass
|
||||
|
||||
STORAGE_PACK = "storage_pack"
|
||||
"""容量包"""
|
||||
|
||||
GROUP_TIME = "group_time"
|
||||
"""用户组时长"""
|
||||
|
||||
SCORE = "score"
|
||||
"""积分充值"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class RedeemCreateRequest(SQLModelBase):
|
||||
"""批量生成兑换码请求 DTO"""
|
||||
|
||||
product_id: UUID
|
||||
"""关联商品UUID"""
|
||||
|
||||
count: int = Field(default=1, ge=1, le=100)
|
||||
"""生成数量"""
|
||||
|
||||
|
||||
class RedeemUseRequest(SQLModelBase):
|
||||
"""使用兑换码请求 DTO"""
|
||||
|
||||
code: str
|
||||
"""兑换码"""
|
||||
|
||||
|
||||
class RedeemInfoResponse(SQLModelBase):
|
||||
"""兑换码信息响应 DTO(用户侧)"""
|
||||
|
||||
type: RedeemType
|
||||
"""兑换码类型"""
|
||||
|
||||
product_name: str | None = None
|
||||
"""关联商品名称"""
|
||||
|
||||
num: int
|
||||
"""可兑换数量"""
|
||||
|
||||
is_used: bool
|
||||
"""是否已使用"""
|
||||
|
||||
|
||||
class RedeemAdminResponse(SQLModelBase):
|
||||
"""兑换码管理响应 DTO(管理侧)"""
|
||||
|
||||
id: int
|
||||
"""兑换码ID"""
|
||||
|
||||
type: RedeemType
|
||||
"""兑换码类型"""
|
||||
|
||||
product_id: UUID | None = None
|
||||
"""关联商品UUID"""
|
||||
|
||||
num: int
|
||||
"""可兑换数量"""
|
||||
|
||||
code: str
|
||||
"""兑换码"""
|
||||
|
||||
is_used: bool
|
||||
"""是否已使用"""
|
||||
|
||||
used_at: datetime | None = None
|
||||
"""使用时间"""
|
||||
|
||||
used_by: UUID | None = None
|
||||
"""使用者UUID"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class Redeem(SQLModelBase, TableBaseMixin):
|
||||
"""兑换码模型"""
|
||||
|
||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""兑换码类型 [TODO] 待定义枚举"""
|
||||
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
|
||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
|
||||
code: str = Field(unique=True, index=True, description="兑换码,唯一")
|
||||
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
|
||||
type: RedeemType
|
||||
"""兑换码类型"""
|
||||
|
||||
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
||||
"""关联商品UUID"""
|
||||
|
||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
||||
"""可兑换数量/时长等"""
|
||||
|
||||
code: str = Field(unique=True, index=True)
|
||||
"""兑换码,唯一"""
|
||||
|
||||
is_used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否已使用"""
|
||||
|
||||
used_at: datetime | None = None
|
||||
"""使用时间"""
|
||||
|
||||
used_by: UUID | None = Field(default=None, foreign_key="user.id", ondelete="SET NULL")
|
||||
"""使用者UUID"""
|
||||
|
||||
# 关系
|
||||
product: "Product" = Relationship(back_populates="redeems")
|
||||
user: "User" = Relationship(back_populates="redeems")
|
||||
|
||||
def to_admin_response(self) -> RedeemAdminResponse:
|
||||
"""转换为管理侧响应 DTO"""
|
||||
return RedeemAdminResponse(
|
||||
id=self.id,
|
||||
type=self.type,
|
||||
product_id=self.product_id,
|
||||
num=self.num,
|
||||
code=self.code,
|
||||
is_used=self.is_used,
|
||||
used_at=self.used_at,
|
||||
used_by=self.used_by,
|
||||
)
|
||||
|
||||
def to_info_response(self, product_name: str | None = None) -> RedeemInfoResponse:
|
||||
"""转换为用户侧响应 DTO"""
|
||||
return RedeemInfoResponse(
|
||||
type=self.type,
|
||||
product_name=product_name,
|
||||
num=self.num,
|
||||
is_used=self.is_used,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .share import Share
|
||||
@@ -21,7 +21,7 @@ class Report(SQLModelBase, TableBaseMixin):
|
||||
|
||||
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""举报原因 [TODO] 待定义枚举"""
|
||||
description: str | None = Field(default=None, max_length=255, description="补充描述")
|
||||
description: Str255 | None = Field(default=None, description="补充描述")
|
||||
|
||||
# 外键
|
||||
share_id: UUID = Field(
|
||||
|
||||
@@ -76,6 +76,9 @@ class SiteConfigResponse(SQLModelBase):
|
||||
email_binding_required: bool = True
|
||||
"""是否强制绑定邮箱"""
|
||||
|
||||
avatar_max_size: int = 2097152
|
||||
"""头像文件最大字节数(默认 2MB)"""
|
||||
|
||||
footer_code: str | None = None
|
||||
"""自定义页脚代码"""
|
||||
|
||||
@@ -160,6 +163,7 @@ class SettingsType(StrEnum):
|
||||
VERSION = "version"
|
||||
VIEW = "view"
|
||||
WOPI = "wopi"
|
||||
FILE_CATEGORY = "file_category"
|
||||
|
||||
# 数据库模型
|
||||
class Setting(SettingItem, TableBaseMixin):
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
|
||||
|
||||
from .model_base import ResponseBase
|
||||
from .object import ObjectType
|
||||
@@ -52,10 +52,10 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
|
||||
Index("ix_share_object", "object_id"),
|
||||
)
|
||||
|
||||
code: str = Field(max_length=64, nullable=False, index=True)
|
||||
code: Str64 = Field(nullable=False, index=True)
|
||||
"""分享码"""
|
||||
|
||||
password: str | None = Field(default=None, max_length=255)
|
||||
password: Str255 | None = None
|
||||
"""分享密码(加密后)"""
|
||||
|
||||
object_id: UUID = Field(
|
||||
@@ -80,7 +80,7 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
|
||||
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||
"""是否允许预览"""
|
||||
|
||||
source_name: str | None = Field(default=None, max_length=255)
|
||||
source_name: Str255 | None = None
|
||||
"""源名称(冗余字段,便于展示)"""
|
||||
|
||||
score: int = Field(default=0, ge=0)
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
@@ -17,7 +17,7 @@ class SourceLink(SQLModelBase, TableBaseMixin):
|
||||
Index("ix_sourcelink_object_name", "object_id", "name"),
|
||||
)
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""链接名称"""
|
||||
|
||||
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
|
||||
@@ -1,23 +1,60 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class StoragePackResponse(SQLModelBase):
|
||||
"""容量包响应 DTO"""
|
||||
|
||||
id: int
|
||||
"""容量包ID"""
|
||||
|
||||
name: str
|
||||
"""容量包名称"""
|
||||
|
||||
size: int
|
||||
"""容量大小(字节)"""
|
||||
|
||||
active_time: datetime | None = None
|
||||
"""激活时间"""
|
||||
|
||||
expired_time: datetime | None = None
|
||||
"""过期时间"""
|
||||
|
||||
product_id: UUID | None = None
|
||||
"""来源商品UUID"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class StoragePack(SQLModelBase, TableBaseMixin):
|
||||
"""容量包模型"""
|
||||
|
||||
name: str = Field(max_length=255, description="容量包名称")
|
||||
active_time: datetime | None = Field(default=None, description="激活时间")
|
||||
expired_time: datetime | None = Field(default=None, index=True, description="过期时间")
|
||||
size: int = Field(description="容量包大小(字节)")
|
||||
|
||||
name: Str255
|
||||
"""容量包名称"""
|
||||
|
||||
active_time: datetime | None = None
|
||||
"""激活时间"""
|
||||
|
||||
expired_time: datetime | None = Field(default=None, index=True)
|
||||
"""过期时间"""
|
||||
|
||||
size: int = Field(sa_type=BigInteger)
|
||||
"""容量包大小(字节)"""
|
||||
|
||||
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
||||
"""来源商品UUID"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
@@ -25,6 +62,17 @@ class StoragePack(SQLModelBase, TableBaseMixin):
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="storage_packs")
|
||||
user: "User" = Relationship(back_populates="storage_packs")
|
||||
|
||||
def to_response(self) -> StoragePackResponse:
|
||||
"""转换为响应 DTO"""
|
||||
return StoragePackResponse(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
size=self.size,
|
||||
active_time=self.active_time,
|
||||
expired_time=self.expired_time,
|
||||
product_id=self.product_id,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -24,13 +24,13 @@ class Tag(SQLModelBase, TableBaseMixin):
|
||||
|
||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""标签名称"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
icon: Str255 | None = None
|
||||
"""标签图标"""
|
||||
|
||||
color: str | None = Field(default=None, max_length=255)
|
||||
color: Str255 | None = None
|
||||
"""标签颜色"""
|
||||
|
||||
type: TagType = Field(default=TagType.MANUAL)
|
||||
|
||||
@@ -26,8 +26,8 @@ class TaskStatus(StrEnum):
|
||||
|
||||
class TaskType(StrEnum):
|
||||
"""任务类型枚举"""
|
||||
# [TODO] 补充具体任务类型
|
||||
pass
|
||||
POLICY_MIGRATE = "policy_migrate"
|
||||
"""存储策略迁移"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
@@ -39,7 +39,7 @@ class TaskSummaryBase(SQLModelBase):
|
||||
id: int
|
||||
"""任务ID"""
|
||||
|
||||
type: int
|
||||
type: TaskType
|
||||
"""任务类型"""
|
||||
|
||||
status: TaskStatus
|
||||
@@ -91,7 +91,14 @@ class TaskPropsBase(SQLModelBase):
|
||||
file_ids: str | None = None
|
||||
"""文件ID列表(逗号分隔)"""
|
||||
|
||||
# [TODO] 根据业务需求补充更多字段
|
||||
source_policy_id: UUID | None = None
|
||||
"""源存储策略UUID"""
|
||||
|
||||
dest_policy_id: UUID | None = None
|
||||
"""目标存储策略UUID"""
|
||||
|
||||
object_id: UUID | None = None
|
||||
"""关联的对象UUID"""
|
||||
|
||||
|
||||
class TaskProps(TaskPropsBase, TableBaseMixin):
|
||||
@@ -99,7 +106,7 @@ class TaskProps(TaskPropsBase, TableBaseMixin):
|
||||
|
||||
task_id: int = Field(
|
||||
foreign_key="task.id",
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的任务ID"""
|
||||
@@ -121,8 +128,8 @@ class Task(SQLModelBase, TableBaseMixin):
|
||||
status: TaskStatus = Field(default=TaskStatus.QUEUED)
|
||||
"""任务状态"""
|
||||
|
||||
type: int = Field(default=0)
|
||||
"""任务类型 [TODO] 待定义枚举"""
|
||||
type: TaskType
|
||||
"""任务类型"""
|
||||
|
||||
progress: int = Field(default=0, ge=0, le=100)
|
||||
"""任务进度(0-100)"""
|
||||
|
||||
@@ -3,7 +3,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
|
||||
|
||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||
|
||||
@@ -11,7 +11,7 @@ from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||
class ThemePresetBase(SQLModelBase):
|
||||
"""主题预设基础字段"""
|
||||
|
||||
name: str = Field(max_length=100)
|
||||
name: Str100
|
||||
"""预设名称"""
|
||||
|
||||
is_default: bool = False
|
||||
@@ -42,7 +42,7 @@ class ThemePresetBase(SQLModelBase):
|
||||
class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
|
||||
"""主题预设表"""
|
||||
|
||||
name: str = Field(max_length=100, unique=True)
|
||||
name: Str100 = Field(unique=True)
|
||||
"""预设名称(唯一约束)"""
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
|
||||
class ThemePresetCreateRequest(SQLModelBase):
|
||||
"""创建主题预设请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=100)
|
||||
name: Str100
|
||||
"""预设名称"""
|
||||
|
||||
colors: ThemeColorsBase
|
||||
@@ -61,7 +61,7 @@ class ThemePresetCreateRequest(SQLModelBase):
|
||||
class ThemePresetUpdateRequest(SQLModelBase):
|
||||
"""更新主题预设请求 DTO"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
name: Str100 | None = None
|
||||
"""预设名称(可选)"""
|
||||
|
||||
colors: ThemeColorsBase | None = None
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import Literal, TYPE_CHECKING, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import BinaryExpression, ClauseElement, and_
|
||||
from sqlalchemy import BigInteger, BinaryExpression, ClauseElement, and_
|
||||
from sqlmodel import Field, Relationship
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse, Str255
|
||||
|
||||
from .auth_identity import AuthProviderType
|
||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
from .download import Download
|
||||
from .object import Object
|
||||
from .order import Order
|
||||
from .redeem import Redeem
|
||||
from .share import Share
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag
|
||||
@@ -473,10 +474,10 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
|
||||
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||
storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
avatar: str = Field(default="default", max_length=255)
|
||||
avatar: Str255 = Field(default="default")
|
||||
"""头像地址"""
|
||||
|
||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||
@@ -570,6 +571,14 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
redeems: list["Redeem"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={
|
||||
"cascade": "all, delete-orphan",
|
||||
"foreign_keys": "[Redeem.used_by]"
|
||||
}
|
||||
)
|
||||
"""用户使用过的兑换码列表"""
|
||||
shares: list["Share"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import UUID
|
||||
from sqlalchemy import Column, Text
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str32, Str100, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -51,7 +51,7 @@ class AuthnDetailResponse(SQLModelBase):
|
||||
class AuthnRenameRequest(SQLModelBase):
|
||||
"""WebAuthn 凭证重命名请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=100)
|
||||
name: Str100
|
||||
"""新的凭证名称"""
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class AuthnRenameRequest(SQLModelBase):
|
||||
class UserAuthn(SQLModelBase, TableBaseMixin):
|
||||
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
|
||||
|
||||
credential_id: str = Field(max_length=255, unique=True, index=True)
|
||||
credential_id: Str255 = Field(unique=True, index=True)
|
||||
"""凭证 ID,Base64URL 编码"""
|
||||
|
||||
credential_public_key: str = Field(sa_column=Column(Text))
|
||||
@@ -69,16 +69,16 @@ class UserAuthn(SQLModelBase, TableBaseMixin):
|
||||
sign_count: int = Field(default=0, ge=0)
|
||||
"""签名计数器,用于防重放攻击"""
|
||||
|
||||
credential_device_type: str = Field(max_length=32)
|
||||
credential_device_type: Str32
|
||||
"""凭证设备类型:'single_device' 或 'multi_device'"""
|
||||
|
||||
credential_backed_up: bool = Field(default=False)
|
||||
"""凭证是否已备份"""
|
||||
|
||||
transports: str | None = Field(default=None, max_length=255)
|
||||
transports: Str255 | None = None
|
||||
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
name: Str100 | None = None
|
||||
"""用户自定义的凭证名称,便于识别"""
|
||||
|
||||
# 外键
|
||||
|
||||
@@ -1,32 +1,117 @@
|
||||
"""
|
||||
WebDAV 账户模型
|
||||
|
||||
管理用户的 WebDAV 连接账户,每个账户对应一个挂载根路径。
|
||||
通过 HTTP Basic Auth 认证访问 DAV 协议端点。
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class WebDAV(SQLModelBase, TableBaseMixin):
|
||||
"""WebDAV账户模型"""
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class WebDAVBase(SQLModelBase):
|
||||
"""WebDAV 账户基础字段"""
|
||||
|
||||
name: Str255
|
||||
"""账户名称(同一用户下唯一)"""
|
||||
|
||||
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"})
|
||||
"""挂载根目录路径"""
|
||||
|
||||
readonly: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
|
||||
"""是否只读"""
|
||||
|
||||
use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
|
||||
"""是否使用代理下载"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class WebDAV(WebDAVBase, TableBaseMixin):
|
||||
"""WebDAV 账户模型"""
|
||||
|
||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
|
||||
|
||||
name: str = Field(max_length=255, description="WebDAV账户名")
|
||||
password: str = Field(max_length=255, description="WebDAV密码")
|
||||
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
|
||||
readonly: bool = Field(default=False, description="是否只读")
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理下载")
|
||||
|
||||
password: Str255
|
||||
"""密码(Argon2 哈希)"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="webdavs")
|
||||
user: "User" = Relationship(back_populates="webdavs")
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class WebDAVCreateRequest(SQLModelBase):
|
||||
"""创建 WebDAV 账户请求"""
|
||||
|
||||
name: Str255
|
||||
"""账户名称"""
|
||||
|
||||
password: Str255 = Field(min_length=1)
|
||||
"""账户密码(明文,服务端哈希后存储)"""
|
||||
|
||||
root: str = "/"
|
||||
"""挂载根目录路径"""
|
||||
|
||||
readonly: bool = False
|
||||
"""是否只读"""
|
||||
|
||||
use_proxy: bool = False
|
||||
"""是否使用代理下载"""
|
||||
|
||||
|
||||
class WebDAVUpdateRequest(SQLModelBase):
|
||||
"""更新 WebDAV 账户请求"""
|
||||
|
||||
password: Str255 | None = Field(default=None, min_length=1)
|
||||
"""新密码(为 None 时不修改)"""
|
||||
|
||||
root: str | None = None
|
||||
"""新挂载根目录路径(为 None 时不修改)"""
|
||||
|
||||
readonly: bool | None = None
|
||||
"""是否只读(为 None 时不修改)"""
|
||||
|
||||
use_proxy: bool | None = None
|
||||
"""是否使用代理下载(为 None 时不修改)"""
|
||||
|
||||
|
||||
class WebDAVAccountResponse(SQLModelBase):
|
||||
"""WebDAV 账户响应"""
|
||||
|
||||
id: int
|
||||
"""账户ID"""
|
||||
|
||||
name: str
|
||||
"""账户名称"""
|
||||
|
||||
root: str
|
||||
"""挂载根目录路径"""
|
||||
|
||||
readonly: bool
|
||||
"""是否只读"""
|
||||
|
||||
use_proxy: bool
|
||||
"""是否使用代理下载"""
|
||||
|
||||
created_at: str
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: str
|
||||
"""更新时间"""
|
||||
|
||||
2
tests/fixtures/objects.py
vendored
2
tests/fixtures/objects.py
vendored
@@ -92,9 +92,9 @@ class ObjectFactory:
|
||||
owner_id=owner_id,
|
||||
policy_id=policy_id,
|
||||
size=size,
|
||||
mime_type=kwargs.get("mime_type"),
|
||||
source_name=kwargs.get("source_name", name),
|
||||
upload_session_id=kwargs.get("upload_session_id"),
|
||||
file_metadata=kwargs.get("file_metadata"),
|
||||
password=kwargs.get("password"),
|
||||
)
|
||||
|
||||
|
||||
8
tests/fixtures/users.py
vendored
8
tests/fixtures/users.py
vendored
@@ -71,7 +71,7 @@ class UserFactory:
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
@@ -123,7 +123,7 @@ class UserFactory:
|
||||
is_verified=True,
|
||||
user_id=admin.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
return admin
|
||||
|
||||
@@ -170,7 +170,7 @@ class UserFactory:
|
||||
is_verified=True,
|
||||
user_id=banned_user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
return banned_user
|
||||
|
||||
@@ -219,6 +219,6 @@ class UserFactory:
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
identity = await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
219
tests/integration/api/test_custom_property.py
Normal file
219
tests/integration/api/test_custom_property.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
自定义属性定义端点集成测试
|
||||
"""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
# ==================== 获取属性定义列表测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_custom_properties_requires_auth(async_client: AsyncClient):
|
||||
"""测试获取属性定义需要认证"""
|
||||
response = await async_client.get("/api/v1/object/custom_property")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_custom_properties_empty(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试获取空的属性定义列表"""
|
||||
response = await async_client.get(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data == []
|
||||
|
||||
|
||||
# ==================== 创建属性定义测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_custom_property(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试创建自定义属性"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "评分",
|
||||
"type": "rating",
|
||||
"icon": "mdi:star",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 验证已创建
|
||||
list_response = await async_client.get(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
)
|
||||
data = list_response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["name"] == "评分"
|
||||
assert data[0]["type"] == "rating"
|
||||
assert data[0]["icon"] == "mdi:star"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_custom_property_with_options(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试创建带选项的自定义属性"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "分类",
|
||||
"type": "select",
|
||||
"options": ["工作", "个人", "归档"],
|
||||
"default_value": "个人",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
list_response = await async_client.get(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
)
|
||||
data = list_response.json()
|
||||
prop = next(p for p in data if p["name"] == "分类")
|
||||
assert prop["type"] == "select"
|
||||
assert prop["options"] == ["工作", "个人", "归档"]
|
||||
assert prop["default_value"] == "个人"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_custom_property_duplicate_name(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试创建同名属性返回 409"""
|
||||
# 先创建
|
||||
await async_client.post(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
json={"name": "标签", "type": "text"},
|
||||
)
|
||||
|
||||
# 再创建同名
|
||||
response = await async_client.post(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
json={"name": "标签", "type": "text"},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
# ==================== 更新属性定义测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_custom_property(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新自定义属性"""
|
||||
# 先创建
|
||||
await async_client.post(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
json={"name": "备注", "type": "text"},
|
||||
)
|
||||
|
||||
# 获取 ID
|
||||
list_response = await async_client.get(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
)
|
||||
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "备注")
|
||||
|
||||
# 更新
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/object/custom_property/{prop_id}",
|
||||
headers=auth_headers,
|
||||
json={"name": "详细备注", "icon": "mdi:note"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 验证已更新
|
||||
list_response = await async_client.get(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
)
|
||||
prop = next(p for p in list_response.json() if p["id"] == prop_id)
|
||||
assert prop["name"] == "详细备注"
|
||||
assert prop["icon"] == "mdi:note"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_custom_property_not_found(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新不存在的属性返回 404"""
|
||||
fake_id = str(uuid4())
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/object/custom_property/{fake_id}",
|
||||
headers=auth_headers,
|
||||
json={"name": "不存在"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ==================== 删除属性定义测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_custom_property(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试删除自定义属性"""
|
||||
# 先创建
|
||||
await async_client.post(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
json={"name": "待删除", "type": "text"},
|
||||
)
|
||||
|
||||
# 获取 ID
|
||||
list_response = await async_client.get(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
)
|
||||
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "待删除")
|
||||
|
||||
# 删除
|
||||
response = await async_client.delete(
|
||||
f"/api/v1/object/custom_property/{prop_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 验证已删除
|
||||
list_response = await async_client.get(
|
||||
"/api/v1/object/custom_property",
|
||||
headers=auth_headers,
|
||||
)
|
||||
prop_names = [p["name"] for p in list_response.json()]
|
||||
assert "待删除" not in prop_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_custom_property_not_found(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试删除不存在的属性返回 404"""
|
||||
fake_id = str(uuid4())
|
||||
response = await async_client.delete(
|
||||
f"/api/v1/object/custom_property/{fake_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
239
tests/integration/api/test_object_metadata.py
Normal file
239
tests/integration/api/test_object_metadata.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
对象元数据端点集成测试
|
||||
"""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodels import ObjectMetadata
|
||||
|
||||
|
||||
# ==================== 获取元数据测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_requires_auth(async_client: AsyncClient):
|
||||
"""测试获取元数据需要认证"""
|
||||
fake_id = str(uuid4())
|
||||
response = await async_client.get(f"/api/v1/object/{fake_id}/metadata")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_empty(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试获取无元数据的对象"""
|
||||
file_id = test_directory_structure["file_id"]
|
||||
response = await async_client.get(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["metadatas"] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_with_entries(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
initialized_db,
|
||||
):
|
||||
"""测试获取有元数据的对象"""
|
||||
file_id = test_directory_structure["file_id"]
|
||||
|
||||
# 直接写入元数据
|
||||
entries = [
|
||||
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
|
||||
ObjectMetadata(object_id=file_id, name="exif:height", value="1080", is_public=True),
|
||||
ObjectMetadata(object_id=file_id, name="sys:extract_status", value="done", is_public=False),
|
||||
]
|
||||
for entry in entries:
|
||||
initialized_db.add(entry)
|
||||
await initialized_db.commit()
|
||||
|
||||
response = await async_client.get(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# sys: 命名空间应被过滤
|
||||
assert "exif:width" in data["metadatas"]
|
||||
assert "exif:height" in data["metadatas"]
|
||||
assert "sys:extract_status" not in data["metadatas"]
|
||||
assert data["metadatas"]["exif:width"] == "1920"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_ns_filter(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
initialized_db,
|
||||
):
|
||||
"""测试按命名空间过滤元数据"""
|
||||
file_id = test_directory_structure["file_id"]
|
||||
|
||||
entries = [
|
||||
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
|
||||
ObjectMetadata(object_id=file_id, name="music:title", value="Test Song", is_public=True),
|
||||
]
|
||||
for entry in entries:
|
||||
initialized_db.add(entry)
|
||||
await initialized_db.commit()
|
||||
|
||||
# 只获取 exif 命名空间
|
||||
response = await async_client.get(
|
||||
f"/api/v1/object/{file_id}/metadata?ns=exif",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "exif:width" in data["metadatas"]
|
||||
assert "music:title" not in data["metadatas"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_nonexistent_object(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试获取不存在对象的元数据"""
|
||||
fake_id = str(uuid4())
|
||||
response = await async_client.get(
|
||||
f"/api/v1/object/{fake_id}/metadata",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ==================== 更新元数据测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_metadata_requires_auth(async_client: AsyncClient):
|
||||
"""测试更新元数据需要认证"""
|
||||
fake_id = str(uuid4())
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/object/{fake_id}/metadata",
|
||||
json={"patches": [{"key": "custom:tag", "value": "test"}]},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_metadata_set_custom(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试设置自定义元数据"""
|
||||
file_id = test_directory_structure["file_id"]
|
||||
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"patches": [
|
||||
{"key": "custom:tag1", "value": "旅游"},
|
||||
{"key": "custom:tag2", "value": "风景"},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 验证已写入
|
||||
get_response = await async_client.get(
|
||||
f"/api/v1/object/{file_id}/metadata?ns=custom",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert get_response.status_code == 200
|
||||
data = get_response.json()
|
||||
assert data["metadatas"]["custom:tag1"] == "旅游"
|
||||
assert data["metadatas"]["custom:tag2"] == "风景"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_metadata_update_existing(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试更新已有的元数据"""
|
||||
file_id = test_directory_structure["file_id"]
|
||||
|
||||
# 先创建
|
||||
await async_client.patch(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
json={"patches": [{"key": "custom:note", "value": "旧值"}]},
|
||||
)
|
||||
|
||||
# 再更新
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
json={"patches": [{"key": "custom:note", "value": "新值"}]},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 验证已更新
|
||||
get_response = await async_client.get(
|
||||
f"/api/v1/object/{file_id}/metadata?ns=custom",
|
||||
headers=auth_headers,
|
||||
)
|
||||
data = get_response.json()
|
||||
assert data["metadatas"]["custom:note"] == "新值"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_metadata_delete(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试删除元数据条目"""
|
||||
file_id = test_directory_structure["file_id"]
|
||||
|
||||
# 先创建
|
||||
await async_client.patch(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
json={"patches": [{"key": "custom:to_delete", "value": "temp"}]},
|
||||
)
|
||||
|
||||
# 删除(value 为 null)
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
json={"patches": [{"key": "custom:to_delete", "value": None}]},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 验证已删除
|
||||
get_response = await async_client.get(
|
||||
f"/api/v1/object/{file_id}/metadata?ns=custom",
|
||||
headers=auth_headers,
|
||||
)
|
||||
data = get_response.json()
|
||||
assert "custom:to_delete" not in data["metadatas"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_metadata_reject_non_custom_namespace(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试拒绝修改非 custom: 命名空间"""
|
||||
file_id = test_directory_structure["file_id"]
|
||||
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/object/{file_id}/metadata",
|
||||
headers=auth_headers,
|
||||
json={"patches": [{"key": "exif:width", "value": "1920"}]},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
591
tests/integration/api/test_webdav.py
Normal file
591
tests/integration/api/test_webdav.py
Normal file
@@ -0,0 +1,591 @@
|
||||
"""
|
||||
WebDAV 账户管理端点集成测试
|
||||
"""
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, User
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
|
||||
API_PREFIX = "/api/v1/webdav"
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def no_webdav_headers(initialized_db: AsyncSession) -> dict[str, str]:
|
||||
"""创建一个 WebDAV 被禁用的用户,返回其认证头"""
|
||||
group = Group(
|
||||
id=uuid4(),
|
||||
name="无WebDAV用户组",
|
||||
max_storage=1024 * 1024 * 1024,
|
||||
share_enabled=True,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
initialized_db.add(group)
|
||||
await initialized_db.commit()
|
||||
await initialized_db.refresh(group)
|
||||
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
source_batch=0,
|
||||
select_node=False,
|
||||
advance_delete=False,
|
||||
)
|
||||
initialized_db.add(group_options)
|
||||
await initialized_db.commit()
|
||||
await initialized_db.refresh(group_options)
|
||||
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
email="nowebdav@test.local",
|
||||
nickname="无WebDAV用户",
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=group.id,
|
||||
avatar="default",
|
||||
)
|
||||
initialized_db.add(user)
|
||||
await initialized_db.commit()
|
||||
await initialized_db.refresh(user)
|
||||
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="nowebdav@test.local",
|
||||
credential=Password.hash("nowebdav123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
initialized_db.add(identity)
|
||||
|
||||
from sqlmodels import Policy
|
||||
policy = await Policy.get(initialized_db, Policy.name == "本地存储")
|
||||
|
||||
root = Object(
|
||||
id=uuid4(),
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=user.id,
|
||||
parent_id=None,
|
||||
policy_id=policy.id,
|
||||
size=0,
|
||||
)
|
||||
initialized_db.add(root)
|
||||
await initialized_db.commit()
|
||||
|
||||
group.options = group_options
|
||||
group_claims = GroupClaims.from_group(group)
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
return {"Authorization": f"Bearer {result.access_token}"}
|
||||
|
||||
|
||||
# ==================== 认证测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_requires_auth(async_client: AsyncClient):
|
||||
"""测试获取账户列表需要认证"""
|
||||
response = await async_client.get(f"{API_PREFIX}/accounts")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_requires_auth(async_client: AsyncClient):
|
||||
"""测试创建账户需要认证"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
json={"name": "test", "password": "testpass"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_requires_auth(async_client: AsyncClient):
|
||||
"""测试更新账户需要认证"""
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/1",
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account_requires_auth(async_client: AsyncClient):
|
||||
"""测试删除账户需要认证"""
|
||||
response = await async_client.delete(f"{API_PREFIX}/accounts/1")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ==================== WebDAV 禁用测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_webdav_disabled(
|
||||
async_client: AsyncClient,
|
||||
no_webdav_headers: dict[str, str],
|
||||
):
|
||||
"""测试 WebDAV 被禁用时返回 403"""
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=no_webdav_headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_webdav_disabled(
|
||||
async_client: AsyncClient,
|
||||
no_webdav_headers: dict[str, str],
|
||||
):
|
||||
"""测试 WebDAV 被禁用时创建账户返回 403"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=no_webdav_headers,
|
||||
json={"name": "test", "password": "testpass"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ==================== 获取账户列表测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_empty(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试初始状态账户列表为空"""
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
# ==================== 创建账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_success(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试成功创建 WebDAV 账户"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "my-nas", "password": "secretpass"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
data = response.json()
|
||||
assert data["name"] == "my-nas"
|
||||
assert data["root"] == "/"
|
||||
assert data["readonly"] is False
|
||||
assert data["use_proxy"] is False
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert "updated_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_with_options(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试创建带选项的 WebDAV 账户"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "readonly-nas",
|
||||
"password": "secretpass",
|
||||
"readonly": True,
|
||||
"use_proxy": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
data = response.json()
|
||||
assert data["name"] == "readonly-nas"
|
||||
assert data["readonly"] is True
|
||||
assert data["use_proxy"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_duplicate_name(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试重名账户返回 409"""
|
||||
# 先创建一个
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "dup-test", "password": "pass1"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 再创建同名的
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "dup-test", "password": "pass2"},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_invalid_root(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试无效根目录路径返回 400"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "bad-root",
|
||||
"password": "secretpass",
|
||||
"root": "/nonexistent/path",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_with_valid_subdir(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试使用有效的子目录作为根路径"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "docs-only",
|
||||
"password": "secretpass",
|
||||
"root": "/docs",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["root"] == "/docs"
|
||||
|
||||
|
||||
# ==================== 列表包含已创建账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_after_create(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试创建后列表中包含该账户"""
|
||||
# 创建
|
||||
await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "list-test", "password": "pass"},
|
||||
)
|
||||
|
||||
# 列表
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
accounts = response.json()
|
||||
assert len(accounts) == 1
|
||||
assert accounts[0]["name"] == "list-test"
|
||||
|
||||
|
||||
# ==================== 更新账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_success(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试成功更新 WebDAV 账户"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "update-test", "password": "oldpass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["readonly"] is True
|
||||
assert data["name"] == "update-test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_password(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新密码"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "pwd-test", "password": "oldpass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新密码
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"password": "newpass123"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_root(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试更新根目录路径"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "root-update", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新 root 到有效子目录
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"root": "/docs"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["root"] == "/docs"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_invalid_root(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新为无效根目录返回 400"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "bad-root-update", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新到无效路径
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"root": "/nonexistent"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_not_found(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新不存在的账户返回 404"""
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/99999",
|
||||
headers=auth_headers,
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_other_user_account(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
admin_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新其他用户的账户返回 404"""
|
||||
# 管理员创建账户
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
json={"name": "admin-account", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 普通用户尝试更新
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ==================== 删除账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account_success(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试成功删除 WebDAV 账户"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "delete-test", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 删除
|
||||
response = await async_client.delete(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 确认列表中已不存在
|
||||
list_resp = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert list_resp.status_code == 200
|
||||
names = [a["name"] for a in list_resp.json()]
|
||||
assert "delete-test" not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account_not_found(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试删除不存在的账户返回 404"""
|
||||
response = await async_client.delete(
|
||||
f"{API_PREFIX}/accounts/99999",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_other_user_account(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
admin_headers: dict[str, str],
|
||||
):
|
||||
"""测试删除其他用户的账户返回 404"""
|
||||
# 管理员创建账户
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
json={"name": "admin-del-test", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 普通用户尝试删除
|
||||
response = await async_client.delete(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ==================== 多账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_accounts(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试同一用户可以创建多个账户"""
|
||||
for name in ["account-1", "account-2", "account-3"]:
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": name, "password": "pass"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 列表应有3个
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 3
|
||||
|
||||
|
||||
# ==================== 用户隔离测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accounts_user_isolation(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
admin_headers: dict[str, str],
|
||||
):
|
||||
"""测试不同用户的账户相互隔离"""
|
||||
# 普通用户创建
|
||||
await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "user-account", "password": "pass"},
|
||||
)
|
||||
|
||||
# 管理员创建
|
||||
await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
json={"name": "admin-account", "password": "pass"},
|
||||
)
|
||||
|
||||
# 普通用户只看到自己的
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
accounts = response.json()
|
||||
assert len(accounts) == 1
|
||||
assert accounts[0]["name"] == "user-account"
|
||||
|
||||
# 管理员只看到自己的
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
accounts = response.json()
|
||||
assert len(accounts) == 1
|
||||
assert accounts[0]["name"] == "admin-account"
|
||||
@@ -23,6 +23,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
|
||||
|
||||
from main import app
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels.policy import GroupPolicyLink
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
@@ -108,6 +109,12 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
|
||||
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
|
||||
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
|
||||
Setting(type=SettingsType.AVATAR, name="gravatar_server", value="https://www.gravatar.com/"),
|
||||
Setting(type=SettingsType.AVATAR, name="avatar_size", value="2097152"),
|
||||
Setting(type=SettingsType.AVATAR, name="avatar_size_l", value="200"),
|
||||
Setting(type=SettingsType.AVATAR, name="avatar_size_m", value="130"),
|
||||
Setting(type=SettingsType.AVATAR, name="avatar_size_s", value="50"),
|
||||
Setting(type=SettingsType.PATH, name="avatar_path", value="avatar"),
|
||||
]
|
||||
for setting in settings:
|
||||
test_session.add(setting)
|
||||
@@ -156,7 +163,11 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
await test_session.refresh(admin_group)
|
||||
await test_session.refresh(default_policy)
|
||||
|
||||
# 4. 创建用户组选项
|
||||
# 4. 关联用户组与存储策略
|
||||
test_session.add(GroupPolicyLink(group_id=default_group.id, policy_id=default_policy.id))
|
||||
test_session.add(GroupPolicyLink(group_id=admin_group.id, policy_id=default_policy.id))
|
||||
|
||||
# 5. 创建用户组选项
|
||||
default_group_options = GroupOptions(
|
||||
group_id=default_group.id,
|
||||
share_download=True,
|
||||
|
||||
@@ -37,6 +37,12 @@ async def load_secret_key() -> None:
|
||||
if setting:
|
||||
SECRET_KEY = setting.value
|
||||
|
||||
if not SECRET_KEY:
|
||||
raise RuntimeError(
|
||||
"JWT SECRET_KEY 未配置,拒绝启动。"
|
||||
"请在 Setting 表中添加 type='auth', name='secret_key' 的记录。"
|
||||
)
|
||||
|
||||
|
||||
def build_token_payload(
|
||||
data: dict,
|
||||
|
||||
@@ -62,6 +62,10 @@ def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
|
||||
"""Raises an HTTP 501 Not Implemented exception."""
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail)
|
||||
|
||||
def raise_bad_gateway(detail: str | None = None) -> NoReturn:
|
||||
"""Raises an HTTP 502 Bad Gateway exception."""
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=detail)
|
||||
|
||||
def raise_service_unavailable(detail: str | None = None) -> NoReturn:
|
||||
"""Raises an HTTP 503 Service Unavailable exception."""
|
||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
|
||||
|
||||
Reference in New Issue
Block a user