Compare commits
3 Commits
main
...
1ecc0fdc1c
| Author | SHA1 | Date | |
|---|---|---|---|
| 1ecc0fdc1c | |||
| 71883d32c0 | |||
| ccadfe57cd |
@@ -5,8 +5,7 @@
|
|||||||
"Bash(findstr:*)",
|
"Bash(findstr:*)",
|
||||||
"Bash(find:*)",
|
"Bash(find:*)",
|
||||||
"Bash(yarn tsc:*)",
|
"Bash(yarn tsc:*)",
|
||||||
"Bash(dir:*)",
|
"Bash(dir:*)"
|
||||||
"mcp__server-notify__notify"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,6 +1,8 @@
|
|||||||
# Python
|
# Python
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
*.so
|
*.so
|
||||||
*.egg
|
*.egg
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
@@ -77,6 +79,3 @@ statics/
|
|||||||
# 许可证密钥(保密)
|
# 许可证密钥(保密)
|
||||||
license_private.pem
|
license_private.pem
|
||||||
license.key
|
license.key
|
||||||
|
|
||||||
avatar/
|
|
||||||
.dev/
|
|
||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
|||||||
[submodule "ee"]
|
|
||||||
path = ee
|
|
||||||
url = https://git.yxqi.cn/Yuerchu/disknext-ee.git
|
|
||||||
1
ee
1
ee
Submodule ee deleted from cc32d8db91
42
ee/__init__.py
Normal file
42
ee/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
DiskNext Enterprise Edition (EE) 模块
|
||||||
|
|
||||||
|
通过 ``try: from ee import init_ee`` 检测是否存在。
|
||||||
|
CE 版本中此目录不存在,ImportError 被 main.py 捕获。
|
||||||
|
"""
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from utils.conf import appmeta
|
||||||
|
|
||||||
|
_ee_initialized: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_pro() -> bool:
|
||||||
|
"""当前实例是否以 Pro 版本运行。"""
|
||||||
|
return _ee_initialized
|
||||||
|
|
||||||
|
|
||||||
|
async def init_ee() -> None:
|
||||||
|
"""
|
||||||
|
初始化企业版功能。
|
||||||
|
|
||||||
|
1. 加载并验证许可证
|
||||||
|
2. 设置 appmeta.IsPro = True
|
||||||
|
3. 标记 EE 已初始化
|
||||||
|
|
||||||
|
许可证无效或缺失时抛出异常,阻止应用启动。
|
||||||
|
"""
|
||||||
|
global _ee_initialized
|
||||||
|
|
||||||
|
from ee.service.license_service import load_and_validate_license
|
||||||
|
|
||||||
|
payload = await load_and_validate_license()
|
||||||
|
|
||||||
|
appmeta.IsPro = True
|
||||||
|
_ee_initialized = True
|
||||||
|
|
||||||
|
l.info(
|
||||||
|
f"Pro 版本已激活 — 域名: {payload.domain}, "
|
||||||
|
f"过期: {payload.expires_at.isoformat()}, "
|
||||||
|
f"功能: {payload.features}"
|
||||||
|
)
|
||||||
86
ee/license.py
Normal file
86
ee/license.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
RSA-PSS 许可证验签核心(编译为 .so 后公钥藏入二进制)
|
||||||
|
|
||||||
|
此文件只包含纯函数和常量,不包含 SQLModel 类。
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding
|
||||||
|
|
||||||
|
|
||||||
|
_PUBLIC_KEY_PEM: bytes = b"""-----BEGIN PUBLIC KEY-----
|
||||||
|
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
|
||||||
|
oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
|
||||||
|
eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
|
||||||
|
VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
|
||||||
|
vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
|
||||||
|
TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
|
||||||
|
l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
|
||||||
|
7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
|
||||||
|
EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
|
||||||
|
CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
|
||||||
|
enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
|
||||||
|
R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
|
||||||
|
-----END PUBLIC KEY-----"""
|
||||||
|
|
||||||
|
|
||||||
|
class LicenseError(Exception):
|
||||||
|
"""许可证验证基础异常"""
|
||||||
|
|
||||||
|
|
||||||
|
class LicenseExpiredError(LicenseError):
|
||||||
|
"""许可证已过期"""
|
||||||
|
|
||||||
|
|
||||||
|
def verify_license(raw: str) -> dict:
|
||||||
|
"""
|
||||||
|
验证许可证字符串并返回载荷字典。
|
||||||
|
|
||||||
|
:param raw: 格式为 ``base64(json_payload).base64(signature)``
|
||||||
|
:returns: 解析后的载荷字典
|
||||||
|
:raises LicenseError: 格式无效或签名验证失败
|
||||||
|
:raises LicenseExpiredError: 许可证已过期
|
||||||
|
"""
|
||||||
|
parts = raw.strip().split(".")
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise LicenseError("许可证格式无效:需要 payload.signature")
|
||||||
|
|
||||||
|
payload_b64, signature_b64 = parts
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
||||||
|
signature = base64.urlsafe_b64decode(signature_b64)
|
||||||
|
except Exception as exc:
|
||||||
|
raise LicenseError(f"许可证 base64 解码失败: {exc}") from exc
|
||||||
|
|
||||||
|
public_key = serialization.load_pem_public_key(_PUBLIC_KEY_PEM)
|
||||||
|
|
||||||
|
try:
|
||||||
|
public_key.verify( # type: ignore[union-attr]
|
||||||
|
signature,
|
||||||
|
payload_bytes,
|
||||||
|
padding.PSS(
|
||||||
|
mgf=padding.MGF1(hashes.SHA256()),
|
||||||
|
salt_length=padding.PSS.MAX_LENGTH,
|
||||||
|
),
|
||||||
|
hashes.SHA256(),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise LicenseError(f"许可证签名验证失败: {exc}") from exc
|
||||||
|
|
||||||
|
data: dict = orjson.loads(payload_bytes)
|
||||||
|
|
||||||
|
expires_at_str: str | None = data.get('expires_at')
|
||||||
|
if not expires_at_str:
|
||||||
|
raise LicenseError("许可证缺少 expires_at 字段")
|
||||||
|
|
||||||
|
expires_at = datetime.fromisoformat(expires_at_str)
|
||||||
|
if expires_at < datetime.now(timezone.utc):
|
||||||
|
raise LicenseExpiredError(
|
||||||
|
f"许可证已过期: {expires_at.isoformat()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
6
ee/routers/__init__.py
Normal file
6
ee/routers/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .pro import router as pro_router
|
||||||
|
|
||||||
|
ee_router = APIRouter()
|
||||||
|
ee_router.include_router(pro_router)
|
||||||
69
ee/routers/pro/__init__.py
Normal file
69
ee/routers/pro/__init__.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""
|
||||||
|
Pro 版本状态端点
|
||||||
|
|
||||||
|
提供许可证状态查询,需管理员权限。
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from ee.service import LicensePayload
|
||||||
|
from ee.service.license_service import get_cached_license
|
||||||
|
from middleware.auth import admin_required
|
||||||
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/pro")
|
||||||
|
|
||||||
|
|
||||||
|
class ProStatusResponse(SQLModelBase):
|
||||||
|
"""Pro 版本状态响应"""
|
||||||
|
|
||||||
|
is_active: bool
|
||||||
|
"""许可证是否有效"""
|
||||||
|
|
||||||
|
domain: str
|
||||||
|
"""授权域名"""
|
||||||
|
|
||||||
|
expires_at: datetime
|
||||||
|
"""过期时间"""
|
||||||
|
|
||||||
|
max_users: int
|
||||||
|
"""最大用户数(0 = 无限制)"""
|
||||||
|
|
||||||
|
features: list[str]
|
||||||
|
"""已授权的功能列表"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
'/status',
|
||||||
|
dependencies=[Depends(admin_required)],
|
||||||
|
)
|
||||||
|
async def get_pro_status() -> ProStatusResponse:
|
||||||
|
"""
|
||||||
|
查询 Pro 版本许可证状态
|
||||||
|
|
||||||
|
认证:
|
||||||
|
- JWT token in Authorization header
|
||||||
|
- 需要管理员权限
|
||||||
|
|
||||||
|
响应:
|
||||||
|
- ProStatusResponse: 当前许可证信息
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- HTTPException 401: 未授权
|
||||||
|
- HTTPException 403: 非管理员
|
||||||
|
- HTTPException 500: 许可证缓存异常
|
||||||
|
"""
|
||||||
|
payload: LicensePayload | None = get_cached_license()
|
||||||
|
if not payload:
|
||||||
|
# init_ee 成功后 payload 一定存在,此处做防御性编程
|
||||||
|
from utils.http.http_exceptions import raise_internal_error
|
||||||
|
raise_internal_error()
|
||||||
|
|
||||||
|
return ProStatusResponse(
|
||||||
|
is_active=True,
|
||||||
|
domain=payload.domain,
|
||||||
|
expires_at=payload.expires_at,
|
||||||
|
max_users=payload.max_users,
|
||||||
|
features=payload.features,
|
||||||
|
)
|
||||||
30
ee/service/__init__.py
Normal file
30
ee/service/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
EE 许可证服务模块
|
||||||
|
|
||||||
|
提供许可证加载、验证和缓存功能。
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
|
||||||
|
class LicensePayload(SQLModelBase):
|
||||||
|
"""许可证载荷(RSA 验签后的明文数据)"""
|
||||||
|
|
||||||
|
domain: str
|
||||||
|
"""授权域名"""
|
||||||
|
|
||||||
|
expires_at: datetime
|
||||||
|
"""过期时间(UTC)"""
|
||||||
|
|
||||||
|
max_users: int
|
||||||
|
"""最大用户数(0 = 无限制)"""
|
||||||
|
|
||||||
|
features: list[str]
|
||||||
|
"""已授权的功能列表"""
|
||||||
|
|
||||||
|
issued_at: datetime
|
||||||
|
"""签发时间(UTC)"""
|
||||||
|
|
||||||
|
|
||||||
|
from .license_service import get_cached_license, load_and_validate_license
|
||||||
49
ee/service/license_service.py
Normal file
49
ee/service/license_service.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
许可证加载与缓存服务(编译为 .so)
|
||||||
|
|
||||||
|
从环境变量 LICENSE_KEY 或 license.key 文件加载许可证,
|
||||||
|
调用 verify_license() 验证后缓存结果。
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
|
||||||
|
from ee.license import LicenseError, verify_license
|
||||||
|
from ee.service import LicensePayload
|
||||||
|
|
||||||
|
_cached_payload: LicensePayload | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def load_and_validate_license() -> LicensePayload:
|
||||||
|
"""
|
||||||
|
加载并验证许可证,成功后缓存。
|
||||||
|
|
||||||
|
加载优先级:
|
||||||
|
1. 环境变量 ``LICENSE_KEY``
|
||||||
|
2. 项目根目录 ``license.key`` 文件
|
||||||
|
|
||||||
|
:returns: 验证通过的 LicensePayload
|
||||||
|
:raises LicenseError: 未找到许可证 / 验证失败 / 已过期
|
||||||
|
"""
|
||||||
|
global _cached_payload
|
||||||
|
|
||||||
|
raw: str | None = os.getenv("LICENSE_KEY")
|
||||||
|
|
||||||
|
if not raw:
|
||||||
|
key_path = Path("license.key")
|
||||||
|
if key_path.is_file():
|
||||||
|
async with aiofiles.open(key_path, 'r') as f:
|
||||||
|
raw = (await f.read()).strip()
|
||||||
|
|
||||||
|
if not raw:
|
||||||
|
raise LicenseError("未找到许可证:请设置 LICENSE_KEY 环境变量或提供 license.key 文件")
|
||||||
|
|
||||||
|
data = verify_license(raw)
|
||||||
|
_cached_payload = LicensePayload.model_validate(data)
|
||||||
|
return _cached_payload
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_license() -> LicensePayload | None:
|
||||||
|
"""获取已缓存的许可证载荷(未加载时返回 None)。"""
|
||||||
|
return _cached_payload
|
||||||
5
ee/sqlmodels/__init__.py
Normal file
5
ee/sqlmodels/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""
|
||||||
|
EE 版本数据库模型
|
||||||
|
|
||||||
|
后续 Pro 功能的 SQLModel 定义位置。
|
||||||
|
"""
|
||||||
31
main.py
31
main.py
@@ -5,10 +5,7 @@ from fastapi import FastAPI, Request
|
|||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
|
|
||||||
from routers import router
|
from routers import router
|
||||||
from routers.dav import dav_app
|
|
||||||
from routers.dav.provider import EventLoopRef
|
|
||||||
from service.redis import RedisManager
|
from service.redis import RedisManager
|
||||||
from service.storage import S3StorageService
|
|
||||||
from sqlmodels.database_connection import DatabaseManager
|
from sqlmodels.database_connection import DatabaseManager
|
||||||
from sqlmodels.migration import migration
|
from sqlmodels.migration import migration
|
||||||
from utils import JWT
|
from utils import JWT
|
||||||
@@ -17,26 +14,24 @@ from utils.http.http_exceptions import raise_internal_error
|
|||||||
from utils.lifespan import lifespan
|
from utils.lifespan import lifespan
|
||||||
|
|
||||||
# 尝试加载企业版功能
|
# 尝试加载企业版功能
|
||||||
_has_ee: bool = False
|
|
||||||
try:
|
try:
|
||||||
from ee import init_ee
|
from ee import init_ee
|
||||||
from ee.license import LicenseError
|
from ee.license import LicenseError
|
||||||
from ee.routers import ee_router
|
|
||||||
|
|
||||||
_has_ee = True
|
async def _init_ee_and_routes() -> None:
|
||||||
|
|
||||||
async def _init_ee() -> None:
|
|
||||||
"""启动时验证许可证,路由由 license_valid_required 依赖保护"""
|
|
||||||
try:
|
try:
|
||||||
await init_ee()
|
await init_ee()
|
||||||
except LicenseError as exc:
|
except LicenseError as exc:
|
||||||
l.critical(f"许可证验证失败: {exc}")
|
l.critical(f"许可证验证失败: {exc}")
|
||||||
raise SystemExit(1) from exc
|
raise SystemExit(1) from exc
|
||||||
|
|
||||||
lifespan.add_startup(_init_ee)
|
from ee.routers import ee_router
|
||||||
except ImportError as exc:
|
from routers.api.v1 import router as v1_router
|
||||||
ee_router = None
|
v1_router.include_router(ee_router)
|
||||||
l.info(f"以 Community 版本运行 (原因: {exc})")
|
|
||||||
|
lifespan.add_startup(_init_ee_and_routes)
|
||||||
|
except ImportError:
|
||||||
|
l.info("以 Community 版本运行")
|
||||||
|
|
||||||
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
|
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
|
||||||
"""前端静态文件目录(由 Docker 构建时复制)"""
|
"""前端静态文件目录(由 Docker 构建时复制)"""
|
||||||
@@ -45,18 +40,13 @@ async def _init_db() -> None:
|
|||||||
"""初始化数据库连接引擎"""
|
"""初始化数据库连接引擎"""
|
||||||
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
||||||
|
|
||||||
# 捕获事件循环引用(供 WSGI 线程桥接使用)
|
|
||||||
lifespan.add_startup(EventLoopRef.capture)
|
|
||||||
|
|
||||||
# 添加初始化数据库启动项
|
# 添加初始化数据库启动项
|
||||||
lifespan.add_startup(_init_db)
|
lifespan.add_startup(_init_db)
|
||||||
lifespan.add_startup(migration)
|
lifespan.add_startup(migration)
|
||||||
lifespan.add_startup(JWT.load_secret_key)
|
lifespan.add_startup(JWT.load_secret_key)
|
||||||
lifespan.add_startup(RedisManager.connect)
|
lifespan.add_startup(RedisManager.connect)
|
||||||
lifespan.add_startup(S3StorageService.initialize_session)
|
|
||||||
|
|
||||||
# 添加关闭项
|
# 添加关闭项
|
||||||
lifespan.add_shutdown(S3StorageService.close_session)
|
|
||||||
lifespan.add_shutdown(DatabaseManager.close)
|
lifespan.add_shutdown(DatabaseManager.close)
|
||||||
lifespan.add_shutdown(RedisManager.disconnect)
|
lifespan.add_shutdown(RedisManager.disconnect)
|
||||||
|
|
||||||
@@ -97,11 +87,6 @@ async def handle_unexpected_exceptions(
|
|||||||
|
|
||||||
# 挂载路由
|
# 挂载路由
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
if _has_ee:
|
|
||||||
app.include_router(ee_router, prefix="/api/v1")
|
|
||||||
|
|
||||||
# 挂载 WebDAV 协议端点(优先于 SPA catch-all)
|
|
||||||
app.mount("/dav", dav_app)
|
|
||||||
|
|
||||||
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
|
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
|
||||||
if STATICS_DIR.is_dir():
|
if STATICS_DIR.is_dir():
|
||||||
|
|||||||
@@ -33,8 +33,6 @@ dependencies = [
|
|||||||
"uvicorn>=0.38.0",
|
"uvicorn>=0.38.0",
|
||||||
"webauthn>=2.7.0",
|
"webauthn>=2.7.0",
|
||||||
"whatthepatch>=1.0.6",
|
"whatthepatch>=1.0.6",
|
||||||
"wsgidav>=4.3.0",
|
|
||||||
"a2wsgi>=1.10.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from utils.conf import appmeta
|
|||||||
from .admin import admin_router
|
from .admin import admin_router
|
||||||
|
|
||||||
from .callback import callback_router
|
from .callback import callback_router
|
||||||
from .category import category_router
|
|
||||||
from .directory import directory_router
|
from .directory import directory_router
|
||||||
from .download import download_router
|
from .download import download_router
|
||||||
from .file import router as file_router
|
from .file import router as file_router
|
||||||
@@ -15,6 +14,7 @@ from .trash import trash_router
|
|||||||
from .site import site_router
|
from .site import site_router
|
||||||
from .slave import slave_router
|
from .slave import slave_router
|
||||||
from .user import user_router
|
from .user import user_router
|
||||||
|
from .vas import vas_router
|
||||||
from .webdav import webdav_router
|
from .webdav import webdav_router
|
||||||
|
|
||||||
router = APIRouter(prefix="/v1")
|
router = APIRouter(prefix="/v1")
|
||||||
@@ -24,7 +24,6 @@ router = APIRouter(prefix="/v1")
|
|||||||
if appmeta.mode == "master":
|
if appmeta.mode == "master":
|
||||||
router.include_router(admin_router)
|
router.include_router(admin_router)
|
||||||
router.include_router(callback_router)
|
router.include_router(callback_router)
|
||||||
router.include_router(category_router)
|
|
||||||
router.include_router(directory_router)
|
router.include_router(directory_router)
|
||||||
router.include_router(download_router)
|
router.include_router(download_router)
|
||||||
router.include_router(file_router)
|
router.include_router(file_router)
|
||||||
@@ -33,6 +32,7 @@ if appmeta.mode == "master":
|
|||||||
router.include_router(site_router)
|
router.include_router(site_router)
|
||||||
router.include_router(trash_router)
|
router.include_router(trash_router)
|
||||||
router.include_router(user_router)
|
router.include_router(user_router)
|
||||||
|
router.include_router(vas_router)
|
||||||
router.include_router(webdav_router)
|
router.include_router(webdav_router)
|
||||||
elif appmeta.mode == "slave":
|
elif appmeta.mode == "slave":
|
||||||
router.include_router(slave_router)
|
router.include_router(slave_router)
|
||||||
|
|||||||
@@ -16,12 +16,6 @@ from sqlmodels.setting import (
|
|||||||
from sqlmodels.setting import SettingsType
|
from sqlmodels.setting import SettingsType
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
from utils.conf import appmeta
|
from utils.conf import appmeta
|
||||||
|
|
||||||
try:
|
|
||||||
from ee.service import get_cached_license
|
|
||||||
except ImportError:
|
|
||||||
get_cached_license = None
|
|
||||||
|
|
||||||
from .file import admin_file_router
|
from .file import admin_file_router
|
||||||
from .file_app import admin_file_app_router
|
from .file_app import admin_file_app_router
|
||||||
from .group import admin_group_router
|
from .group import admin_group_router
|
||||||
@@ -30,6 +24,7 @@ from .share import admin_share_router
|
|||||||
from .task import admin_task_router
|
from .task import admin_task_router
|
||||||
from .user import admin_user_router
|
from .user import admin_user_router
|
||||||
from .theme import admin_theme_router
|
from .theme import admin_theme_router
|
||||||
|
from .vas import admin_vas_router
|
||||||
|
|
||||||
|
|
||||||
class Aria2TestRequest(SQLModelBase):
|
class Aria2TestRequest(SQLModelBase):
|
||||||
@@ -55,6 +50,7 @@ admin_router.include_router(admin_policy_router)
|
|||||||
admin_router.include_router(admin_share_router)
|
admin_router.include_router(admin_share_router)
|
||||||
admin_router.include_router(admin_task_router)
|
admin_router.include_router(admin_task_router)
|
||||||
admin_router.include_router(admin_theme_router)
|
admin_router.include_router(admin_theme_router)
|
||||||
|
admin_router.include_router(admin_vas_router)
|
||||||
|
|
||||||
# 离线下载 /api/admin/aria2
|
# 离线下载 /api/admin/aria2
|
||||||
admin_aria2_router = APIRouter(
|
admin_aria2_router = APIRouter(
|
||||||
@@ -163,19 +159,9 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
|||||||
if site_url_setting and site_url_setting.value:
|
if site_url_setting and site_url_setting.value:
|
||||||
site_urls.append(site_url_setting.value)
|
site_urls.append(site_url_setting.value)
|
||||||
|
|
||||||
# 许可证信息(Pro 版本从缓存读取,CE 版本永不过期)
|
# 许可证信息(从设置读取或使用默认值)
|
||||||
if appmeta.IsPro and get_cached_license:
|
|
||||||
payload = get_cached_license()
|
|
||||||
license_info = LicenseInfo(
|
license_info = LicenseInfo(
|
||||||
expired_at=payload.expires_at,
|
expired_at=now + timedelta(days=365),
|
||||||
signed_at=payload.issued_at,
|
|
||||||
root_domains=[],
|
|
||||||
domains=[payload.domain],
|
|
||||||
vol_domains=[],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
license_info = LicenseInfo(
|
|
||||||
expired_at=datetime.max,
|
|
||||||
signed_at=now,
|
signed_at=now,
|
||||||
root_domains=[],
|
root_domains=[],
|
||||||
domains=[],
|
domains=[],
|
||||||
@@ -239,11 +225,11 @@ async def router_admin_update_settings(
|
|||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
existing.value = item.value
|
existing.value = item.value
|
||||||
existing = await existing.save(session)
|
await existing.save(session)
|
||||||
updated_count += 1
|
updated_count += 1
|
||||||
else:
|
else:
|
||||||
new_setting = Setting(type=item.type, name=item.name, value=item.value)
|
new_setting = Setting(type=item.type, name=item.name, value=item.value)
|
||||||
new_setting = await new_setting.save(session)
|
await new_setting.save(session)
|
||||||
created_count += 1
|
created_count += 1
|
||||||
|
|
||||||
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")
|
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ async def _set_ban_recursive(
|
|||||||
obj.banned_by = None
|
obj.banned_by = None
|
||||||
obj.ban_reason = None
|
obj.ban_reason = None
|
||||||
|
|
||||||
obj = await obj.save(session)
|
await obj.save(session)
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@@ -131,7 +131,9 @@ async def router_admin_preview_file(
|
|||||||
:param file_id: 文件UUID
|
:param file_id: 文件UUID
|
||||||
:return: 文件内容
|
:return: 文件内容
|
||||||
"""
|
"""
|
||||||
file_obj = await Object.get_exist_one(session, file_id)
|
file_obj = await Object.get(session, Object.id == file_id)
|
||||||
|
if not file_obj:
|
||||||
|
raise HTTPException(status_code=404, detail="文件不存在")
|
||||||
|
|
||||||
if not file_obj.is_file:
|
if not file_obj.is_file:
|
||||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||||
@@ -180,7 +182,9 @@ async def router_admin_ban_file(
|
|||||||
:param claims: 当前管理员 JWT claims
|
:param claims: 当前管理员 JWT claims
|
||||||
:return: 封禁结果
|
:return: 封禁结果
|
||||||
"""
|
"""
|
||||||
file_obj = await Object.get_exist_one(session, file_id)
|
file_obj = await Object.get(session, Object.id == file_id)
|
||||||
|
if not file_obj:
|
||||||
|
raise HTTPException(status_code=404, detail="文件不存在")
|
||||||
|
|
||||||
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
||||||
|
|
||||||
@@ -208,7 +212,9 @@ async def router_admin_delete_file(
|
|||||||
:param delete_physical: 是否同时删除物理文件
|
:param delete_physical: 是否同时删除物理文件
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
file_obj = await Object.get_exist_one(session, file_id)
|
file_obj = await Object.get(session, Object.id == file_id)
|
||||||
|
if not file_obj:
|
||||||
|
raise HTTPException(status_code=404, detail="文件不存在")
|
||||||
|
|
||||||
if not file_obj.is_file:
|
if not file_obj.is_file:
|
||||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||||
|
|||||||
@@ -1,18 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
管理员文件应用管理端点
|
管理员文件应用管理端点
|
||||||
|
|
||||||
提供文件查看器应用的 CRUD、扩展名管理、用户组权限管理和 WOPI Discovery。
|
提供文件查看器应用的 CRUD、扩展名管理和用户组权限管理。
|
||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends, status
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from service.wopi import parse_wopi_discovery_xml
|
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
FileApp,
|
FileApp,
|
||||||
FileAppCreateRequest,
|
FileAppCreateRequest,
|
||||||
@@ -23,10 +21,7 @@ from sqlmodels import (
|
|||||||
FileAppUpdateRequest,
|
FileAppUpdateRequest,
|
||||||
ExtensionUpdateRequest,
|
ExtensionUpdateRequest,
|
||||||
GroupAccessUpdateRequest,
|
GroupAccessUpdateRequest,
|
||||||
WopiDiscoveredExtension,
|
|
||||||
WopiDiscoveryResponse,
|
|
||||||
)
|
)
|
||||||
from sqlmodels.file_app import FileAppType
|
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
|
|
||||||
admin_file_app_router = APIRouter(
|
admin_file_app_router = APIRouter(
|
||||||
@@ -128,7 +123,6 @@ async def create_file_app(
|
|||||||
group_links.append(link)
|
group_links.append(link)
|
||||||
if group_links:
|
if group_links:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(app)
|
|
||||||
|
|
||||||
l.info(f"创建文件应用: {app.name} ({app.app_key})")
|
l.info(f"创建文件应用: {app.name} ({app.app_key})")
|
||||||
|
|
||||||
@@ -151,7 +145,9 @@ async def get_file_app(
|
|||||||
错误处理:
|
错误处理:
|
||||||
- 404: 应用不存在
|
- 404: 应用不存在
|
||||||
"""
|
"""
|
||||||
app = await FileApp.get_exist_one(session, app_id)
|
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||||
|
if not app:
|
||||||
|
http_exceptions.raise_not_found("应用不存在")
|
||||||
|
|
||||||
extensions = await FileAppExtension.get(
|
extensions = await FileAppExtension.get(
|
||||||
session,
|
session,
|
||||||
@@ -184,7 +180,9 @@ async def update_file_app(
|
|||||||
- 404: 应用不存在
|
- 404: 应用不存在
|
||||||
- 409: 新 app_key 已被其他应用使用
|
- 409: 新 app_key 已被其他应用使用
|
||||||
"""
|
"""
|
||||||
app = await FileApp.get_exist_one(session, app_id)
|
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||||
|
if not app:
|
||||||
|
http_exceptions.raise_not_found("应用不存在")
|
||||||
|
|
||||||
# 检查 app_key 唯一性
|
# 检查 app_key 唯一性
|
||||||
if request.app_key is not None and request.app_key != app.app_key:
|
if request.app_key is not None and request.app_key != app.app_key:
|
||||||
@@ -231,7 +229,9 @@ async def delete_file_app(
|
|||||||
错误处理:
|
错误处理:
|
||||||
- 404: 应用不存在
|
- 404: 应用不存在
|
||||||
"""
|
"""
|
||||||
app = await FileApp.get_exist_one(session, app_id)
|
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||||
|
if not app:
|
||||||
|
http_exceptions.raise_not_found("应用不存在")
|
||||||
|
|
||||||
app_name = app.app_key
|
app_name = app.app_key
|
||||||
await FileApp.delete(session, app)
|
await FileApp.delete(session, app)
|
||||||
@@ -257,24 +257,20 @@ async def update_extensions(
|
|||||||
错误处理:
|
错误处理:
|
||||||
- 404: 应用不存在
|
- 404: 应用不存在
|
||||||
"""
|
"""
|
||||||
app = await FileApp.get_exist_one(session, app_id)
|
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||||
|
if not app:
|
||||||
|
http_exceptions.raise_not_found("应用不存在")
|
||||||
|
|
||||||
# 保留旧扩展名的 wopi_action_url(Discovery 填充的值)
|
# 删除旧的扩展名
|
||||||
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||||
session,
|
session,
|
||||||
FileAppExtension.app_id == app_id,
|
FileAppExtension.app_id == app_id,
|
||||||
fetch_mode="all",
|
fetch_mode="all",
|
||||||
)
|
)
|
||||||
old_url_map: dict[str, str] = {
|
|
||||||
ext.extension: ext.wopi_action_url
|
|
||||||
for ext in old_extensions
|
|
||||||
if ext.wopi_action_url
|
|
||||||
}
|
|
||||||
for old_ext in old_extensions:
|
for old_ext in old_extensions:
|
||||||
await FileAppExtension.delete(session, old_ext, commit=False)
|
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
# 创建新的扩展名(保留已有的 wopi_action_url)
|
# 创建新的扩展名
|
||||||
new_extensions: list[FileAppExtension] = []
|
new_extensions: list[FileAppExtension] = []
|
||||||
for i, ext in enumerate(request.extensions):
|
for i, ext in enumerate(request.extensions):
|
||||||
normalized = ext.lower().strip().lstrip('.')
|
normalized = ext.lower().strip().lstrip('.')
|
||||||
@@ -282,14 +278,12 @@ async def update_extensions(
|
|||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
extension=normalized,
|
extension=normalized,
|
||||||
priority=i,
|
priority=i,
|
||||||
wopi_action_url=old_url_map.get(normalized),
|
|
||||||
)
|
)
|
||||||
session.add(ext_record)
|
session.add(ext_record)
|
||||||
new_extensions.append(ext_record)
|
new_extensions.append(ext_record)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
# refresh commit 后过期的对象
|
# refresh 新创建的记录
|
||||||
await session.refresh(app)
|
|
||||||
for ext_record in new_extensions:
|
for ext_record in new_extensions:
|
||||||
await session.refresh(ext_record)
|
await session.refresh(ext_record)
|
||||||
|
|
||||||
@@ -322,7 +316,9 @@ async def update_group_access(
|
|||||||
错误处理:
|
错误处理:
|
||||||
- 404: 应用不存在
|
- 404: 应用不存在
|
||||||
"""
|
"""
|
||||||
app = await FileApp.get_exist_one(session, app_id)
|
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||||
|
if not app:
|
||||||
|
http_exceptions.raise_not_found("应用不存在")
|
||||||
|
|
||||||
# 删除旧的用户组关联
|
# 删除旧的用户组关联
|
||||||
old_links_result = await session.exec(
|
old_links_result = await session.exec(
|
||||||
@@ -340,7 +336,6 @@ async def update_group_access(
|
|||||||
new_links.append(link)
|
new_links.append(link)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(app)
|
|
||||||
|
|
||||||
extensions = await FileAppExtension.get(
|
extensions = await FileAppExtension.get(
|
||||||
session,
|
session,
|
||||||
@@ -351,100 +346,3 @@ async def update_group_access(
|
|||||||
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
|
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
|
||||||
|
|
||||||
return FileAppResponse.from_app(app, extensions, new_links)
|
return FileAppResponse.from_app(app, extensions, new_links)
|
||||||
|
|
||||||
|
|
||||||
@admin_file_app_router.post(
|
|
||||||
path='/{app_id}/discover',
|
|
||||||
summary='执行 WOPI Discovery',
|
|
||||||
)
|
|
||||||
async def discover_wopi(
|
|
||||||
session: SessionDep,
|
|
||||||
app_id: UUID,
|
|
||||||
) -> WopiDiscoveryResponse:
|
|
||||||
"""
|
|
||||||
从 WOPI 服务端获取 Discovery XML 并自动配置扩展名和 URL 模板。
|
|
||||||
|
|
||||||
流程:
|
|
||||||
1. 验证 FileApp 存在且为 WOPI 类型
|
|
||||||
2. 使用 FileApp.wopi_discovery_url 获取 Discovery XML
|
|
||||||
3. 解析 XML,提取扩展名和动作 URL
|
|
||||||
4. 全量替换 FileAppExtension 记录(带 wopi_action_url)
|
|
||||||
|
|
||||||
认证:管理员权限
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 404: 应用不存在
|
|
||||||
- 400: 非 WOPI 类型 / discovery URL 未配置 / XML 解析失败
|
|
||||||
- 502: WOPI 服务端不可达或返回无效响应
|
|
||||||
"""
|
|
||||||
app = await FileApp.get_exist_one(session, app_id)
|
|
||||||
|
|
||||||
if app.type != FileAppType.WOPI:
|
|
||||||
http_exceptions.raise_bad_request("仅 WOPI 类型应用支持自动发现")
|
|
||||||
|
|
||||||
if not app.wopi_discovery_url:
|
|
||||||
http_exceptions.raise_bad_request("未配置 WOPI Discovery URL")
|
|
||||||
|
|
||||||
# commit 后对象会过期,先保存需要的值
|
|
||||||
discovery_url = app.wopi_discovery_url
|
|
||||||
app_key = app.app_key
|
|
||||||
|
|
||||||
# 获取 Discovery XML
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as client:
|
|
||||||
async with client.get(
|
|
||||||
discovery_url,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=15),
|
|
||||||
) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
http_exceptions.raise_bad_gateway(
|
|
||||||
f"WOPI 服务端返回 HTTP {resp.status}"
|
|
||||||
)
|
|
||||||
xml_content = await resp.text()
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
http_exceptions.raise_bad_gateway(f"无法连接 WOPI 服务端: {e}")
|
|
||||||
|
|
||||||
# 解析 XML
|
|
||||||
try:
|
|
||||||
action_urls, app_names = parse_wopi_discovery_xml(xml_content)
|
|
||||||
except ValueError as e:
|
|
||||||
http_exceptions.raise_bad_request(str(e))
|
|
||||||
|
|
||||||
if not action_urls:
|
|
||||||
return WopiDiscoveryResponse(app_names=app_names)
|
|
||||||
|
|
||||||
# 全量替换扩展名
|
|
||||||
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
|
||||||
session,
|
|
||||||
FileAppExtension.app_id == app_id,
|
|
||||||
fetch_mode="all",
|
|
||||||
)
|
|
||||||
for old_ext in old_extensions:
|
|
||||||
await FileAppExtension.delete(session, old_ext, commit=False)
|
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
new_extensions: list[FileAppExtension] = []
|
|
||||||
discovered: list[WopiDiscoveredExtension] = []
|
|
||||||
for i, (ext, action_url) in enumerate(sorted(action_urls.items())):
|
|
||||||
ext_record = FileAppExtension(
|
|
||||||
app_id=app_id,
|
|
||||||
extension=ext,
|
|
||||||
priority=i,
|
|
||||||
wopi_action_url=action_url,
|
|
||||||
)
|
|
||||||
session.add(ext_record)
|
|
||||||
new_extensions.append(ext_record)
|
|
||||||
discovered.append(WopiDiscoveredExtension(extension=ext, action_url=action_url))
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
l.info(
|
|
||||||
f"WOPI Discovery 完成: 应用 {app_key}, "
|
|
||||||
f"发现 {len(discovered)} 个扩展名"
|
|
||||||
)
|
|
||||||
|
|
||||||
return WopiDiscoveryResponse(
|
|
||||||
discovered_extensions=discovered,
|
|
||||||
app_names=app_names,
|
|
||||||
applied_count=len(discovered),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -63,7 +63,10 @@ async def router_admin_get_group(
|
|||||||
:param group_id: 用户组UUID
|
:param group_id: 用户组UUID
|
||||||
:return: 用户组详情
|
:return: 用户组详情
|
||||||
"""
|
"""
|
||||||
group = await Group.get_exist_one(session, group_id, load=[Group.options, Group.policies])
|
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
|
||||||
|
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||||
|
|
||||||
# 直接访问已加载的关系,无需额外查询
|
# 直接访问已加载的关系,无需额外查询
|
||||||
policies = group.policies
|
policies = group.policies
|
||||||
@@ -91,7 +94,9 @@ async def router_admin_get_group_members(
|
|||||||
:return: 分页成员列表
|
:return: 分页成员列表
|
||||||
"""
|
"""
|
||||||
# 验证组存在
|
# 验证组存在
|
||||||
await Group.get_exist_one(session, group_id)
|
group = await Group.get(session, Group.id == group_id)
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||||
|
|
||||||
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
|
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
|
||||||
|
|
||||||
@@ -133,11 +138,10 @@ async def router_admin_create_group(
|
|||||||
speed_limit=request.speed_limit,
|
speed_limit=request.speed_limit,
|
||||||
)
|
)
|
||||||
group = await group.save(session)
|
group = await group.save(session)
|
||||||
group_id_val: UUID = group.id
|
|
||||||
|
|
||||||
# 创建选项
|
# 创建选项
|
||||||
options = GroupOptions(
|
options = GroupOptions(
|
||||||
group_id=group_id_val,
|
group_id=group.id,
|
||||||
share_download=request.share_download,
|
share_download=request.share_download,
|
||||||
share_free=request.share_free,
|
share_free=request.share_free,
|
||||||
relocate=request.relocate,
|
relocate=request.relocate,
|
||||||
@@ -150,11 +154,11 @@ async def router_admin_create_group(
|
|||||||
aria2=request.aria2,
|
aria2=request.aria2,
|
||||||
redirected_source=request.redirected_source,
|
redirected_source=request.redirected_source,
|
||||||
)
|
)
|
||||||
options = await options.save(session)
|
await options.save(session)
|
||||||
|
|
||||||
# 关联存储策略
|
# 关联存储策略
|
||||||
for policy_id in request.policy_ids:
|
for policy_id in request.policy_ids:
|
||||||
link = GroupPolicyLink(group_id=group_id_val, policy_id=policy_id)
|
link = GroupPolicyLink(group_id=group.id, policy_id=policy_id)
|
||||||
session.add(link)
|
session.add(link)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -181,7 +185,9 @@ async def router_admin_update_group(
|
|||||||
:param request: 更新请求
|
:param request: 更新请求
|
||||||
:return: 更新结果
|
:return: 更新结果
|
||||||
"""
|
"""
|
||||||
group = await Group.get_exist_one(session, group_id, load=Group.options)
|
group = await Group.get(session, Group.id == group_id, load=Group.options)
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||||
|
|
||||||
# 检查名称唯一性(如果要更新名称)
|
# 检查名称唯一性(如果要更新名称)
|
||||||
if request.name and request.name != group.name:
|
if request.name and request.name != group.name:
|
||||||
@@ -211,7 +217,7 @@ async def router_admin_update_group(
|
|||||||
if options_data:
|
if options_data:
|
||||||
for key, value in options_data.items():
|
for key, value in options_data.items():
|
||||||
setattr(group.options, key, value)
|
setattr(group.options, key, value)
|
||||||
group.options = await group.options.save(session)
|
await group.options.save(session)
|
||||||
|
|
||||||
# 更新策略关联
|
# 更新策略关联
|
||||||
if request.policy_ids is not None:
|
if request.policy_ids is not None:
|
||||||
@@ -249,7 +255,9 @@ async def router_admin_delete_group(
|
|||||||
:param group_id: 用户组UUID
|
:param group_id: 用户组UUID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
group = await Group.get_exist_one(session, group_id)
|
group = await Group.get(session, Group.id == group_id)
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||||
|
|
||||||
# 检查是否有用户属于该组
|
# 检查是否有用户属于该组
|
||||||
user_count = await User.count(session, User.group_id == group_id)
|
user_count = await User.count(session, User.group_id == group_id)
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ from sqlmodel import Field
|
|||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
|
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||||
PolicyUpdateRequest, ResponseBase, ListResponse, Object,
|
ListResponse, Object,
|
||||||
)
|
)
|
||||||
from sqlmodel_ext import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
|
from service.storage import DirectoryCreationError, LocalStorageService
|
||||||
|
|
||||||
admin_policy_router = APIRouter(
|
admin_policy_router = APIRouter(
|
||||||
prefix='/policy',
|
prefix='/policy',
|
||||||
@@ -67,12 +67,6 @@ class PolicyDetailResponse(SQLModelBase):
|
|||||||
base_url: str | None
|
base_url: str | None
|
||||||
"""基础URL"""
|
"""基础URL"""
|
||||||
|
|
||||||
access_key: str | None
|
|
||||||
"""Access Key"""
|
|
||||||
|
|
||||||
secret_key: str | None
|
|
||||||
"""Secret Key"""
|
|
||||||
|
|
||||||
max_size: int
|
max_size: int
|
||||||
"""最大文件尺寸"""
|
"""最大文件尺寸"""
|
||||||
|
|
||||||
@@ -113,45 +107,9 @@ class PolicyTestSlaveRequest(SQLModelBase):
|
|||||||
secret: str
|
secret: str
|
||||||
"""从机通信密钥"""
|
"""从机通信密钥"""
|
||||||
|
|
||||||
class PolicyTestS3Request(SQLModelBase):
|
class PolicyCreateRequest(PolicyBase):
|
||||||
"""测试 S3 连接请求 DTO"""
|
"""创建存储策略请求 DTO,继承 PolicyBase 中的所有字段"""
|
||||||
|
pass
|
||||||
server: str = Field(max_length=255)
|
|
||||||
"""S3 端点地址"""
|
|
||||||
|
|
||||||
bucket_name: str = Field(max_length=255)
|
|
||||||
"""存储桶名称"""
|
|
||||||
|
|
||||||
access_key: str
|
|
||||||
"""Access Key"""
|
|
||||||
|
|
||||||
secret_key: str
|
|
||||||
"""Secret Key"""
|
|
||||||
|
|
||||||
s3_region: str = Field(default='us-east-1', max_length=64)
|
|
||||||
"""S3 区域"""
|
|
||||||
|
|
||||||
s3_path_style: bool = False
|
|
||||||
"""是否使用路径风格"""
|
|
||||||
|
|
||||||
|
|
||||||
class PolicyTestS3Response(SQLModelBase):
|
|
||||||
"""S3 连接测试响应"""
|
|
||||||
|
|
||||||
is_connected: bool
|
|
||||||
"""连接是否成功"""
|
|
||||||
|
|
||||||
message: str
|
|
||||||
"""测试结果消息"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
|
|
||||||
|
|
||||||
_OPTIONS_FIELDS: set[str] = {
|
|
||||||
'token', 'file_type', 'mimetype', 'od_redirect',
|
|
||||||
'chunk_size', 's3_path_style', 's3_region',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.get(
|
@admin_policy_router.get(
|
||||||
path='/list',
|
path='/list',
|
||||||
@@ -319,20 +277,7 @@ async def router_policy_add_policy(
|
|||||||
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
|
raise HTTPException(status_code=500, detail=f"创建存储目录失败: {e}")
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
policy = await policy.save(session)
|
await policy.save(session)
|
||||||
|
|
||||||
# 创建策略选项
|
|
||||||
options = PolicyOptions(
|
|
||||||
policy_id=policy.id,
|
|
||||||
token=request.token,
|
|
||||||
file_type=request.file_type,
|
|
||||||
mimetype=request.mimetype,
|
|
||||||
od_redirect=request.od_redirect,
|
|
||||||
chunk_size=request.chunk_size,
|
|
||||||
s3_path_style=request.s3_path_style,
|
|
||||||
s3_region=request.s3_region,
|
|
||||||
)
|
|
||||||
options = await options.save(session)
|
|
||||||
|
|
||||||
@admin_policy_router.post(
|
@admin_policy_router.post(
|
||||||
path='/cors',
|
path='/cors',
|
||||||
@@ -383,7 +328,9 @@ async def router_policy_onddrive_oauth(
|
|||||||
:param policy_id: 存储策略UUID
|
:param policy_id: 存储策略UUID
|
||||||
:return: OAuth URL
|
:return: OAuth URL
|
||||||
"""
|
"""
|
||||||
policy = await Policy.get_exist_one(session, policy_id)
|
policy = await Policy.get(session, Policy.id == policy_id)
|
||||||
|
if not policy:
|
||||||
|
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||||
|
|
||||||
# TODO: 实现OneDrive OAuth
|
# TODO: 实现OneDrive OAuth
|
||||||
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
|
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
|
||||||
@@ -406,7 +353,9 @@ async def router_policy_get_policy(
|
|||||||
:param policy_id: 存储策略UUID
|
:param policy_id: 存储策略UUID
|
||||||
:return: 策略详情
|
:return: 策略详情
|
||||||
"""
|
"""
|
||||||
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
|
||||||
|
if not policy:
|
||||||
|
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||||
|
|
||||||
# 获取使用此策略的用户组
|
# 获取使用此策略的用户组
|
||||||
groups = await policy.awaitable_attrs.groups
|
groups = await policy.awaitable_attrs.groups
|
||||||
@@ -422,8 +371,6 @@ async def router_policy_get_policy(
|
|||||||
bucket_name=policy.bucket_name,
|
bucket_name=policy.bucket_name,
|
||||||
is_private=policy.is_private,
|
is_private=policy.is_private,
|
||||||
base_url=policy.base_url,
|
base_url=policy.base_url,
|
||||||
access_key=policy.access_key,
|
|
||||||
secret_key=policy.secret_key,
|
|
||||||
max_size=policy.max_size,
|
max_size=policy.max_size,
|
||||||
auto_rename=policy.auto_rename,
|
auto_rename=policy.auto_rename,
|
||||||
dir_name_rule=policy.dir_name_rule,
|
dir_name_rule=policy.dir_name_rule,
|
||||||
@@ -455,7 +402,9 @@ async def router_policy_delete_policy(
|
|||||||
:param policy_id: 存储策略UUID
|
:param policy_id: 存储策略UUID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
policy = await Policy.get_exist_one(session, policy_id)
|
policy = await Policy.get(session, Policy.id == policy_id)
|
||||||
|
if not policy:
|
||||||
|
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||||
|
|
||||||
# 检查是否有文件使用此策略
|
# 检查是否有文件使用此策略
|
||||||
file_count = await Object.count(session, Object.policy_id == policy_id)
|
file_count = await Object.count(session, Object.policy_id == policy_id)
|
||||||
@@ -469,105 +418,3 @@ async def router_policy_delete_policy(
|
|||||||
await Policy.delete(session, policy)
|
await Policy.delete(session, policy)
|
||||||
|
|
||||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.patch(
|
|
||||||
path='/{policy_id}',
|
|
||||||
summary='更新存储策略',
|
|
||||||
description='更新存储策略配置。策略类型创建后不可更改。',
|
|
||||||
dependencies=[Depends(admin_required)],
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def router_policy_update_policy(
|
|
||||||
session: SessionDep,
|
|
||||||
policy_id: UUID,
|
|
||||||
request: PolicyUpdateRequest,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
更新存储策略端点
|
|
||||||
|
|
||||||
功能:
|
|
||||||
- 更新策略基础字段和扩展选项
|
|
||||||
- 策略类型(type)不可更改
|
|
||||||
|
|
||||||
认证:
|
|
||||||
- 需要管理员权限
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param policy_id: 存储策略UUID
|
|
||||||
:param request: 更新请求
|
|
||||||
"""
|
|
||||||
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
|
||||||
|
|
||||||
# 检查名称唯一性(如果要更新名称)
|
|
||||||
if request.name and request.name != policy.name:
|
|
||||||
existing = await Policy.get(session, Policy.name == request.name)
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(status_code=409, detail="策略名称已存在")
|
|
||||||
|
|
||||||
# 分离 Policy 字段和 Options 字段
|
|
||||||
all_data = request.model_dump(exclude_unset=True)
|
|
||||||
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
|
|
||||||
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
|
|
||||||
|
|
||||||
# 更新 Policy 基础字段
|
|
||||||
if policy_data:
|
|
||||||
for key, value in policy_data.items():
|
|
||||||
setattr(policy, key, value)
|
|
||||||
policy = await policy.save(session)
|
|
||||||
|
|
||||||
# 更新或创建 PolicyOptions
|
|
||||||
if options_data:
|
|
||||||
if policy.options:
|
|
||||||
for key, value in options_data.items():
|
|
||||||
setattr(policy.options, key, value)
|
|
||||||
policy.options = await policy.options.save(session)
|
|
||||||
else:
|
|
||||||
options = PolicyOptions(policy_id=policy.id, **options_data)
|
|
||||||
options = await options.save(session)
|
|
||||||
|
|
||||||
l.info(f"管理员更新了存储策略: {policy_id}")
|
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.post(
|
|
||||||
path='/test/s3',
|
|
||||||
summary='测试 S3 连接',
|
|
||||||
description='测试 S3 存储端点的连通性和凭据有效性。',
|
|
||||||
dependencies=[Depends(admin_required)],
|
|
||||||
)
|
|
||||||
async def router_policy_test_s3(
|
|
||||||
request: PolicyTestS3Request,
|
|
||||||
) -> PolicyTestS3Response:
|
|
||||||
"""
|
|
||||||
测试 S3 连接端点
|
|
||||||
|
|
||||||
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
|
|
||||||
|
|
||||||
:param request: 测试请求
|
|
||||||
:return: 测试结果
|
|
||||||
"""
|
|
||||||
from service.storage import S3APIError
|
|
||||||
|
|
||||||
# 构造临时 Policy 对象用于创建 S3StorageService
|
|
||||||
temp_policy = Policy(
|
|
||||||
name="__test__",
|
|
||||||
type=PolicyType.S3,
|
|
||||||
server=request.server,
|
|
||||||
bucket_name=request.bucket_name,
|
|
||||||
access_key=request.access_key,
|
|
||||||
secret_key=request.secret_key,
|
|
||||||
)
|
|
||||||
s3_service = S3StorageService(
|
|
||||||
temp_policy,
|
|
||||||
region=request.s3_region,
|
|
||||||
is_path_style=request.s3_path_style,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 使用 file_exists 发送 HEAD 请求来验证连通性
|
|
||||||
await s3_service.file_exists("__connection_test__")
|
|
||||||
return PolicyTestS3Response(is_connected=True, message="连接成功")
|
|
||||||
except S3APIError as e:
|
|
||||||
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")
|
|
||||||
@@ -155,7 +155,9 @@ async def router_admin_delete_share(
|
|||||||
:param share_id: 分享ID
|
:param share_id: 分享ID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
share = await Share.get_exist_one(session, share_id)
|
share = await Share.get(session, Share.id == share_id)
|
||||||
|
if not share:
|
||||||
|
raise HTTPException(status_code=404, detail="分享不存在")
|
||||||
|
|
||||||
await Share.delete(session, share)
|
await Share.delete(session, share)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from middleware.auth import admin_required
|
|||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
ListResponse,
|
ListResponse,
|
||||||
Task, TaskSummary, TaskStatus, TaskType,
|
Task, TaskSummary,
|
||||||
)
|
)
|
||||||
from sqlmodel_ext import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
@@ -19,10 +19,10 @@ class TaskDetailResponse(SQLModelBase):
|
|||||||
id: int
|
id: int
|
||||||
"""任务ID"""
|
"""任务ID"""
|
||||||
|
|
||||||
status: TaskStatus
|
status: int
|
||||||
"""任务状态"""
|
"""任务状态"""
|
||||||
|
|
||||||
type: TaskType
|
type: int
|
||||||
"""任务类型"""
|
"""任务类型"""
|
||||||
|
|
||||||
progress: int
|
progress: int
|
||||||
@@ -150,7 +150,9 @@ async def router_admin_delete_task(
|
|||||||
:param task_id: 任务ID
|
:param task_id: 任务ID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
task = await Task.get_exist_one(session, task_id)
|
task = await Task.get(session, Task.id == task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
await Task.delete(session, task)
|
await Task.delete(session, task)
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ async def router_admin_theme_create(
|
|||||||
name=request.name,
|
name=request.name,
|
||||||
**request.colors.model_dump(),
|
**request.colors.model_dump(),
|
||||||
)
|
)
|
||||||
preset = await preset.save(session)
|
await preset.save(session)
|
||||||
l.info(f"管理员创建了主题预设: {request.name}")
|
l.info(f"管理员创建了主题预设: {request.name}")
|
||||||
|
|
||||||
|
|
||||||
@@ -101,7 +101,11 @@ async def router_admin_theme_update(
|
|||||||
- 404: 预设不存在
|
- 404: 预设不存在
|
||||||
- 409: 名称已被其他预设使用
|
- 409: 名称已被其他预设使用
|
||||||
"""
|
"""
|
||||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
preset: ThemePreset | None = await ThemePreset.get(
|
||||||
|
session, ThemePreset.id == preset_id
|
||||||
|
)
|
||||||
|
if not preset:
|
||||||
|
http_exceptions.raise_not_found("主题预设不存在")
|
||||||
|
|
||||||
# 检查名称唯一性(排除自身)
|
# 检查名称唯一性(排除自身)
|
||||||
if request.name is not None and request.name != preset.name:
|
if request.name is not None and request.name != preset.name:
|
||||||
@@ -116,7 +120,7 @@ async def router_admin_theme_update(
|
|||||||
for key, value in color_data.items():
|
for key, value in color_data.items():
|
||||||
setattr(preset, key, value)
|
setattr(preset, key, value)
|
||||||
|
|
||||||
preset = await preset.save(session)
|
await preset.save(session)
|
||||||
l.info(f"管理员更新了主题预设: {preset.name}")
|
l.info(f"管理员更新了主题预设: {preset.name}")
|
||||||
|
|
||||||
|
|
||||||
@@ -143,7 +147,11 @@ async def router_admin_theme_delete(
|
|||||||
副作用:
|
副作用:
|
||||||
- 关联用户的 theme_preset_id 会被数据库 SET NULL
|
- 关联用户的 theme_preset_id 会被数据库 SET NULL
|
||||||
"""
|
"""
|
||||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
preset: ThemePreset | None = await ThemePreset.get(
|
||||||
|
session, ThemePreset.id == preset_id
|
||||||
|
)
|
||||||
|
if not preset:
|
||||||
|
http_exceptions.raise_not_found("主题预设不存在")
|
||||||
|
|
||||||
await preset.delete(session)
|
await preset.delete(session)
|
||||||
l.info(f"管理员删除了主题预设: {preset.name}")
|
l.info(f"管理员删除了主题预设: {preset.name}")
|
||||||
@@ -172,7 +180,11 @@ async def router_admin_theme_set_default(
|
|||||||
逻辑:
|
逻辑:
|
||||||
- 事务中先清除所有旧默认,再设新默认
|
- 事务中先清除所有旧默认,再设新默认
|
||||||
"""
|
"""
|
||||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
preset: ThemePreset | None = await ThemePreset.get(
|
||||||
|
session, ThemePreset.id == preset_id
|
||||||
|
)
|
||||||
|
if not preset:
|
||||||
|
http_exceptions.raise_not_found("主题预设不存在")
|
||||||
|
|
||||||
# 清除所有旧默认
|
# 清除所有旧默认
|
||||||
await session.execute(
|
await session.execute(
|
||||||
@@ -183,5 +195,5 @@ async def router_admin_theme_set_default(
|
|||||||
|
|
||||||
# 设新默认
|
# 设新默认
|
||||||
preset.is_default = True
|
preset.is_default = True
|
||||||
preset = await preset.save(session)
|
await preset.save(session)
|
||||||
l.info(f"管理员将主题预设 '{preset.name}' 设为默认")
|
l.info(f"管理员将主题预设 '{preset.name}' 设为默认")
|
||||||
|
|||||||
@@ -128,9 +128,8 @@ async def router_admin_create_user(
|
|||||||
is_verified=True,
|
is_verified=True,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
user = await User.get(session, User.id == user.id, load=User.group)
|
|
||||||
return user.to_public()
|
return user.to_public()
|
||||||
|
|
||||||
|
|
||||||
@@ -154,7 +153,9 @@ async def router_admin_update_user(
|
|||||||
:param request: 更新请求
|
:param request: 更新请求
|
||||||
:return: 更新结果
|
:return: 更新结果
|
||||||
"""
|
"""
|
||||||
user = await User.get_exist_one(session, user_id)
|
user = await User.get(session, User.id == user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="用户不存在")
|
||||||
|
|
||||||
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
||||||
default_admin_setting = await Setting.get(
|
default_admin_setting = await Setting.get(
|
||||||
@@ -251,7 +252,9 @@ async def router_admin_calibrate_storage(
|
|||||||
:param user_id: 用户UUID
|
:param user_id: 用户UUID
|
||||||
:return: 校准结果
|
:return: 校准结果
|
||||||
"""
|
"""
|
||||||
user = await User.get_exist_one(session, user_id)
|
user = await User.get(session, User.id == user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="用户不存在")
|
||||||
|
|
||||||
previous_storage = user.storage
|
previous_storage = user.storage
|
||||||
|
|
||||||
|
|||||||
81
routers/api/v1/admin/vas/__init__.py
Normal file
81
routers/api/v1/admin/vas/__init__.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
|
from middleware.auth import admin_required
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import (
|
||||||
|
ResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
admin_vas_router = APIRouter(
|
||||||
|
prefix='/vas',
|
||||||
|
tags=['admin', 'admin_vas']
|
||||||
|
)
|
||||||
|
|
||||||
|
@admin_vas_router.get(
|
||||||
|
path='/list',
|
||||||
|
summary='获取增值服务列表',
|
||||||
|
description='Get VAS list (orders and storage packs)',
|
||||||
|
dependencies=[Depends(admin_required)]
|
||||||
|
)
|
||||||
|
async def router_admin_get_vas_list(
|
||||||
|
session: SessionDep,
|
||||||
|
user_id: UUID | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
获取增值服务列表(订单和存储包)。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 按用户筛选
|
||||||
|
:param page: 页码
|
||||||
|
:param page_size: 每页数量
|
||||||
|
:return: 增值服务列表
|
||||||
|
"""
|
||||||
|
# TODO: 实现增值服务列表
|
||||||
|
# 需要查询 Order 和 StoragePack 模型
|
||||||
|
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_vas_router.get(
|
||||||
|
path='/{vas_id}',
|
||||||
|
summary='获取增值服务详情',
|
||||||
|
description='Get VAS detail by ID',
|
||||||
|
dependencies=[Depends(admin_required)]
|
||||||
|
)
|
||||||
|
async def router_admin_get_vas(
|
||||||
|
session: SessionDep,
|
||||||
|
vas_id: UUID,
|
||||||
|
) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
获取增值服务详情。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param vas_id: 增值服务UUID
|
||||||
|
:return: 增值服务详情
|
||||||
|
"""
|
||||||
|
# TODO: 实现增值服务详情
|
||||||
|
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_vas_router.delete(
|
||||||
|
path='/{vas_id}',
|
||||||
|
summary='删除增值服务',
|
||||||
|
description='Delete VAS by ID',
|
||||||
|
dependencies=[Depends(admin_required)]
|
||||||
|
)
|
||||||
|
async def router_admin_delete_vas(
|
||||||
|
session: SessionDep,
|
||||||
|
vas_id: UUID,
|
||||||
|
) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
删除增值服务。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param vas_id: 增值服务UUID
|
||||||
|
:return: 删除结果
|
||||||
|
"""
|
||||||
|
# TODO: 实现增值服务删除
|
||||||
|
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from loguru import logger as l
|
|
||||||
|
|
||||||
from sqlmodels import ResponseBase
|
from sqlmodels import ResponseBase
|
||||||
import service.oauth
|
import service.oauth
|
||||||
@@ -16,12 +15,18 @@ oauth_router = APIRouter(
|
|||||||
tags=["callback", "oauth"],
|
tags=["callback", "oauth"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pay_router = APIRouter(
|
||||||
|
prefix='/callback/pay',
|
||||||
|
tags=["callback", "pay"],
|
||||||
|
)
|
||||||
|
|
||||||
upload_router = APIRouter(
|
upload_router = APIRouter(
|
||||||
prefix='/callback/upload',
|
prefix='/callback/upload',
|
||||||
tags=["callback", "upload"],
|
tags=["callback", "upload"],
|
||||||
)
|
)
|
||||||
|
|
||||||
callback_router.include_router(oauth_router)
|
callback_router.include_router(oauth_router)
|
||||||
|
callback_router.include_router(pay_router)
|
||||||
callback_router.include_router(upload_router)
|
callback_router.include_router(upload_router)
|
||||||
|
|
||||||
@oauth_router.post(
|
@oauth_router.post(
|
||||||
@@ -59,17 +64,91 @@ async def router_callback_github(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
access_token = await service.oauth.github.get_access_token(code)
|
access_token = await service.oauth.github.get_access_token(code)
|
||||||
|
# [TODO] 把access_token写数据库里
|
||||||
if not access_token:
|
if not access_token:
|
||||||
return PlainTextResponse("GitHub 认证失败", status_code=400)
|
return PlainTextResponse("Failed to retrieve access token from GitHub.", status_code=400)
|
||||||
|
|
||||||
user_data = await service.oauth.github.get_user_info(access_token.access_token)
|
user_data = await service.oauth.github.get_user_info(access_token.access_token)
|
||||||
# [TODO] 把 access_token 和 user_data 写数据库,生成 JWT,重定向到前端
|
# [TODO] 把user_data写数据库里
|
||||||
l.info(f"GitHub OAuth 回调成功: user={user_data.user_data.login}")
|
|
||||||
|
|
||||||
return PlainTextResponse("认证成功,功能开发中", status_code=200)
|
return PlainTextResponse(f"User information processed successfully, code: {code}, user_data: {user_data.json_dump()}", status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.error(f"GitHub OAuth 回调异常: {e}")
|
return PlainTextResponse(f"An error occurred: {str(e)}", status_code=500)
|
||||||
return PlainTextResponse("认证过程中发生错误,请重试", status_code=500)
|
|
||||||
|
@pay_router.post(
|
||||||
|
path='/alipay',
|
||||||
|
summary='支付宝支付回调',
|
||||||
|
description='Handle Alipay payment callback and return payment status.',
|
||||||
|
)
|
||||||
|
def router_callback_alipay() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Handle Alipay payment callback and return payment status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the Alipay payment callback.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@pay_router.post(
|
||||||
|
path='/wechat',
|
||||||
|
summary='微信支付回调',
|
||||||
|
description='Handle WeChat Pay payment callback and return payment status.',
|
||||||
|
)
|
||||||
|
def router_callback_wechat() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Handle WeChat Pay payment callback and return payment status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@pay_router.post(
|
||||||
|
path='/stripe',
|
||||||
|
summary='Stripe支付回调',
|
||||||
|
description='Handle Stripe payment callback and return payment status.',
|
||||||
|
)
|
||||||
|
def router_callback_stripe() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Handle Stripe payment callback and return payment status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the Stripe payment callback.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@pay_router.get(
|
||||||
|
path='/easypay',
|
||||||
|
summary='易支付回调',
|
||||||
|
description='Handle EasyPay payment callback and return payment status.',
|
||||||
|
)
|
||||||
|
def router_callback_easypay() -> PlainTextResponse:
|
||||||
|
"""
|
||||||
|
Handle EasyPay payment callback and return payment status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
# return PlainTextResponse("success", status_code=200)
|
||||||
|
|
||||||
|
@pay_router.get(
|
||||||
|
path='/custom/{order_no}/{id}',
|
||||||
|
summary='自定义支付回调',
|
||||||
|
description='Handle custom payment callback and return payment status.',
|
||||||
|
)
|
||||||
|
def router_callback_custom(order_no: str, id: str) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Handle custom payment callback and return payment status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_no (str): The order number for the payment.
|
||||||
|
id (str): The ID associated with the payment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the custom payment callback.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
@upload_router.post(
|
@upload_router.post(
|
||||||
path='/remote/{session_id}/{key}',
|
path='/remote/{session_id}/{key}',
|
||||||
|
|||||||
@@ -1,100 +0,0 @@
|
|||||||
"""
|
|
||||||
文件分类筛选端点
|
|
||||||
|
|
||||||
按文件类型分类(图片/视频/音频/文档)查询用户的所有文件,
|
|
||||||
跨目录搜索,支持分页。扩展名映射从数据库 Setting 表读取。
|
|
||||||
"""
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from loguru import logger as l
|
|
||||||
|
|
||||||
from middleware.auth import auth_required
|
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
|
||||||
from sqlmodels import (
|
|
||||||
FileCategory,
|
|
||||||
ListResponse,
|
|
||||||
Object,
|
|
||||||
ObjectResponse,
|
|
||||||
ObjectType,
|
|
||||||
Setting,
|
|
||||||
SettingsType,
|
|
||||||
User,
|
|
||||||
)
|
|
||||||
|
|
||||||
category_router = APIRouter(
|
|
||||||
prefix="/category",
|
|
||||||
tags=["category"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@category_router.get(
|
|
||||||
path="/{category}",
|
|
||||||
summary="按分类获取文件列表",
|
|
||||||
)
|
|
||||||
async def router_category_list(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
category: FileCategory,
|
|
||||||
table_view: TableViewRequestDep,
|
|
||||||
) -> ListResponse[ObjectResponse]:
|
|
||||||
"""
|
|
||||||
按文件类型分类查询用户的所有文件
|
|
||||||
|
|
||||||
跨所有目录搜索,返回分页结果。
|
|
||||||
扩展名配置从数据库 Setting 表读取(type=file_category)。
|
|
||||||
|
|
||||||
认证:
|
|
||||||
- JWT token in Authorization header
|
|
||||||
|
|
||||||
路径参数:
|
|
||||||
- category: 文件分类(image / video / audio / document)
|
|
||||||
|
|
||||||
查询参数:
|
|
||||||
- offset: 分页偏移量(默认0)
|
|
||||||
- limit: 每页数量(默认20,最大100)
|
|
||||||
- desc: 是否降序(默认true)
|
|
||||||
- order: 排序字段(created_at / updated_at)
|
|
||||||
|
|
||||||
响应:
|
|
||||||
- ListResponse[ObjectResponse]: 分页文件列表
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- HTTPException 422: category 参数无效
|
|
||||||
- HTTPException 404: 该分类未配置扩展名
|
|
||||||
"""
|
|
||||||
# 从数据库读取该分类的扩展名配置
|
|
||||||
setting = await Setting.get(
|
|
||||||
session,
|
|
||||||
(Setting.type == SettingsType.FILE_CATEGORY) & (Setting.name == category.value),
|
|
||||||
)
|
|
||||||
if not setting or not setting.value:
|
|
||||||
raise HTTPException(status_code=404, detail=f"分类 {category.value} 未配置扩展名")
|
|
||||||
|
|
||||||
extensions = [ext.strip() for ext in setting.value.split(",") if ext.strip()]
|
|
||||||
if not extensions:
|
|
||||||
raise HTTPException(status_code=404, detail=f"分类 {category.value} 扩展名列表为空")
|
|
||||||
|
|
||||||
result = await Object.get_by_category(
|
|
||||||
session,
|
|
||||||
user.id,
|
|
||||||
extensions,
|
|
||||||
table_view=table_view,
|
|
||||||
)
|
|
||||||
|
|
||||||
items = [
|
|
||||||
ObjectResponse(
|
|
||||||
id=obj.id,
|
|
||||||
name=obj.name,
|
|
||||||
type=ObjectType.FILE,
|
|
||||||
size=obj.size,
|
|
||||||
mime_type=obj.mime_type,
|
|
||||||
thumb=False,
|
|
||||||
created_at=obj.created_at,
|
|
||||||
updated_at=obj.updated_at,
|
|
||||||
source_enabled=False,
|
|
||||||
)
|
|
||||||
for obj in result.items
|
|
||||||
]
|
|
||||||
|
|
||||||
return ListResponse(count=result.count, items=items)
|
|
||||||
@@ -57,7 +57,7 @@ async def _get_directory_response(
|
|||||||
policy_response = PolicyResponse(
|
policy_response = PolicyResponse(
|
||||||
id=policy.id,
|
id=policy.id,
|
||||||
name=policy.name,
|
name=policy.name,
|
||||||
type=policy.type,
|
type=policy.type.value,
|
||||||
max_size=policy.max_size,
|
max_size=policy.max_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -189,14 +189,6 @@ async def router_directory_create(
|
|||||||
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
||||||
|
|
||||||
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
||||||
|
|
||||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
|
||||||
if request.policy_id:
|
|
||||||
group = await user.awaitable_attrs.group
|
|
||||||
await session.refresh(group, ['policies'])
|
|
||||||
if request.policy_id not in {p.id for p in group.policies}:
|
|
||||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
|
||||||
|
|
||||||
parent_id = parent.id # 在 save 前保存
|
parent_id = parent.id # 在 save 前保存
|
||||||
|
|
||||||
new_folder = Object(
|
new_folder = Object(
|
||||||
@@ -206,4 +198,4 @@ async def router_directory_create(
|
|||||||
parent_id=parent_id,
|
parent_id=parent_id,
|
||||||
policy_id=policy_id,
|
policy_id=policy_id,
|
||||||
)
|
)
|
||||||
new_folder = await new_folder.save(session)
|
await new_folder.save(session)
|
||||||
|
|||||||
@@ -13,11 +13,9 @@ from datetime import datetime, timedelta
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import orjson
|
|
||||||
import whatthepatch
|
import whatthepatch
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
from starlette.responses import Response
|
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlmodel_ext import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
from whatthepatch.exceptions import HunkApplyException
|
from whatthepatch.exceptions import HunkApplyException
|
||||||
@@ -46,9 +44,7 @@ from sqlmodels import (
|
|||||||
User,
|
User,
|
||||||
WopiSessionResponse,
|
WopiSessionResponse,
|
||||||
)
|
)
|
||||||
import orjson
|
from service.storage import LocalStorageService, adjust_user_storage
|
||||||
|
|
||||||
from service.storage import LocalStorageService, S3StorageService, adjust_user_storage
|
|
||||||
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
|
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
|
||||||
from utils.JWT.wopi_token import create_wopi_token
|
from utils.JWT.wopi_token import create_wopi_token
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
@@ -184,14 +180,9 @@ async def create_upload_session(
|
|||||||
|
|
||||||
# 确定存储策略
|
# 确定存储策略
|
||||||
policy_id = request.policy_id or parent.policy_id
|
policy_id = request.policy_id or parent.policy_id
|
||||||
policy = await Policy.get_exist_one(session, policy_id)
|
policy = await Policy.get(session, Policy.id == policy_id)
|
||||||
|
if not policy:
|
||||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||||
if request.policy_id:
|
|
||||||
group = await user.awaitable_attrs.group
|
|
||||||
await session.refresh(group, ['policies'])
|
|
||||||
if request.policy_id not in {p.id for p in group.policies}:
|
|
||||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
|
||||||
|
|
||||||
# 验证文件大小限制
|
# 验证文件大小限制
|
||||||
_check_policy_size_limit(policy, request.file_size)
|
_check_policy_size_limit(policy, request.file_size)
|
||||||
@@ -219,7 +210,6 @@ async def create_upload_session(
|
|||||||
|
|
||||||
# 生成存储路径
|
# 生成存储路径
|
||||||
storage_path: str | None = None
|
storage_path: str | None = None
|
||||||
s3_upload_id: str | None = None
|
|
||||||
if policy.type == PolicyType.LOCAL:
|
if policy.type == PolicyType.LOCAL:
|
||||||
storage_service = LocalStorageService(policy)
|
storage_service = LocalStorageService(policy)
|
||||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||||
@@ -227,25 +217,8 @@ async def create_upload_session(
|
|||||||
original_filename=request.file_name,
|
original_filename=request.file_name,
|
||||||
)
|
)
|
||||||
storage_path = full_path
|
storage_path = full_path
|
||||||
elif policy.type == PolicyType.S3:
|
else:
|
||||||
s3_service = S3StorageService(
|
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||||
policy,
|
|
||||||
region=options.s3_region if options else 'us-east-1',
|
|
||||||
is_path_style=options.s3_path_style if options else False,
|
|
||||||
)
|
|
||||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
|
||||||
user_id=user.id,
|
|
||||||
original_filename=request.file_name,
|
|
||||||
)
|
|
||||||
# 多分片时创建 multipart upload
|
|
||||||
if total_chunks > 1:
|
|
||||||
s3_upload_id = await s3_service.create_multipart_upload(
|
|
||||||
storage_path, content_type='application/octet-stream',
|
|
||||||
)
|
|
||||||
|
|
||||||
# 预扣存储空间(与创建会话在同一事务中提交,防止并发绕过配额)
|
|
||||||
if request.file_size > 0:
|
|
||||||
await adjust_user_storage(session, user.id, request.file_size, commit=False)
|
|
||||||
|
|
||||||
# 创建上传会话
|
# 创建上传会话
|
||||||
upload_session = UploadSession(
|
upload_session = UploadSession(
|
||||||
@@ -254,7 +227,6 @@ async def create_upload_session(
|
|||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
total_chunks=total_chunks,
|
total_chunks=total_chunks,
|
||||||
storage_path=storage_path,
|
storage_path=storage_path,
|
||||||
s3_upload_id=s3_upload_id,
|
|
||||||
expires_at=datetime.now() + timedelta(hours=24),
|
expires_at=datetime.now() + timedelta(hours=24),
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
parent_id=request.parent_id,
|
parent_id=request.parent_id,
|
||||||
@@ -330,38 +302,8 @@ async def upload_chunk(
|
|||||||
content,
|
content,
|
||||||
offset,
|
offset,
|
||||||
)
|
)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
if not upload_session.storage_path:
|
|
||||||
raise HTTPException(status_code=500, detail="存储路径丢失")
|
|
||||||
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
|
|
||||||
if upload_session.total_chunks == 1:
|
|
||||||
# 单分片:直接 PUT 上传
|
|
||||||
await s3_service.upload_file(upload_session.storage_path, content)
|
|
||||||
else:
|
else:
|
||||||
# 多分片:UploadPart
|
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||||
if not upload_session.s3_upload_id:
|
|
||||||
raise HTTPException(status_code=500, detail="S3 分片上传 ID 丢失")
|
|
||||||
|
|
||||||
etag = await s3_service.upload_part(
|
|
||||||
upload_session.storage_path,
|
|
||||||
upload_session.s3_upload_id,
|
|
||||||
chunk_index + 1, # S3 part number 从 1 开始
|
|
||||||
content,
|
|
||||||
)
|
|
||||||
# 追加 ETag 到 s3_part_etags
|
|
||||||
etags: list[list[int | str]] = orjson.loads(upload_session.s3_part_etags or '[]')
|
|
||||||
etags.append([chunk_index + 1, etag])
|
|
||||||
upload_session.s3_part_etags = orjson.dumps(etags).decode()
|
|
||||||
|
|
||||||
# 在 save(commit)前缓存后续需要的属性(commit 后 ORM 对象会过期)
|
|
||||||
policy_type = policy.type
|
|
||||||
s3_upload_id = upload_session.s3_upload_id
|
|
||||||
s3_part_etags = upload_session.s3_part_etags
|
|
||||||
s3_service_for_complete: S3StorageService | None = None
|
|
||||||
if policy_type == PolicyType.S3:
|
|
||||||
s3_service_for_complete = await S3StorageService.from_policy(policy)
|
|
||||||
|
|
||||||
# 更新会话进度
|
# 更新会话进度
|
||||||
upload_session.uploaded_chunks += 1
|
upload_session.uploaded_chunks += 1
|
||||||
@@ -377,26 +319,12 @@ async def upload_chunk(
|
|||||||
if is_complete:
|
if is_complete:
|
||||||
# 保存 upload_session 属性(commit 后会过期)
|
# 保存 upload_session 属性(commit 后会过期)
|
||||||
file_name = upload_session.file_name
|
file_name = upload_session.file_name
|
||||||
file_size = upload_session.file_size
|
|
||||||
uploaded_size = upload_session.uploaded_size
|
uploaded_size = upload_session.uploaded_size
|
||||||
storage_path = upload_session.storage_path
|
storage_path = upload_session.storage_path
|
||||||
upload_session_id = upload_session.id
|
upload_session_id = upload_session.id
|
||||||
parent_id = upload_session.parent_id
|
parent_id = upload_session.parent_id
|
||||||
policy_id = upload_session.policy_id
|
policy_id = upload_session.policy_id
|
||||||
|
|
||||||
# S3 多分片上传完成:合并分片
|
|
||||||
if (
|
|
||||||
policy_type == PolicyType.S3
|
|
||||||
and s3_upload_id
|
|
||||||
and s3_part_etags
|
|
||||||
and s3_service_for_complete
|
|
||||||
):
|
|
||||||
parts_data: list[list[int | str]] = orjson.loads(s3_part_etags)
|
|
||||||
parts = [(int(pn), str(et)) for pn, et in parts_data]
|
|
||||||
await s3_service_for_complete.complete_multipart_upload(
|
|
||||||
storage_path, s3_upload_id, parts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 PhysicalFile 记录
|
# 创建 PhysicalFile 记录
|
||||||
physical_file = PhysicalFile(
|
physical_file = PhysicalFile(
|
||||||
storage_path=storage_path,
|
storage_path=storage_path,
|
||||||
@@ -427,10 +355,9 @@ async def upload_chunk(
|
|||||||
commit=False
|
commit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调整存储配额差值(创建会话时已预扣 file_size,这里只补差)
|
# 更新用户存储配额
|
||||||
size_diff = uploaded_size - file_size
|
if uploaded_size > 0:
|
||||||
if size_diff != 0:
|
await adjust_user_storage(session, user_id, uploaded_size, commit=False)
|
||||||
await adjust_user_storage(session, user_id, size_diff, commit=False)
|
|
||||||
|
|
||||||
# 统一提交所有更改
|
# 统一提交所有更改
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -463,25 +390,9 @@ async def delete_upload_session(
|
|||||||
|
|
||||||
# 删除临时文件
|
# 删除临时文件
|
||||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||||
if policy and upload_session.storage_path:
|
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
storage_service = LocalStorageService(policy)
|
storage_service = LocalStorageService(policy)
|
||||||
await storage_service.delete_file(upload_session.storage_path)
|
await storage_service.delete_file(upload_session.storage_path)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
# 如果有分片上传,先取消
|
|
||||||
if upload_session.s3_upload_id:
|
|
||||||
await s3_service.abort_multipart_upload(
|
|
||||||
upload_session.storage_path, upload_session.s3_upload_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 单分片上传已完成的话,删除已上传的文件
|
|
||||||
if upload_session.uploaded_chunks > 0:
|
|
||||||
await s3_service.delete_file(upload_session.storage_path)
|
|
||||||
|
|
||||||
# 释放预扣的存储空间
|
|
||||||
if upload_session.file_size > 0:
|
|
||||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
|
||||||
|
|
||||||
# 删除会话记录
|
# 删除会话记录
|
||||||
await UploadSession.delete(session, upload_session)
|
await UploadSession.delete(session, upload_session)
|
||||||
@@ -511,22 +422,9 @@ async def clear_upload_sessions(
|
|||||||
for upload_session in sessions:
|
for upload_session in sessions:
|
||||||
# 删除临时文件
|
# 删除临时文件
|
||||||
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
policy = await Policy.get(session, Policy.id == upload_session.policy_id)
|
||||||
if policy and upload_session.storage_path:
|
if policy and policy.type == PolicyType.LOCAL and upload_session.storage_path:
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
storage_service = LocalStorageService(policy)
|
storage_service = LocalStorageService(policy)
|
||||||
await storage_service.delete_file(upload_session.storage_path)
|
await storage_service.delete_file(upload_session.storage_path)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
if upload_session.s3_upload_id:
|
|
||||||
await s3_service.abort_multipart_upload(
|
|
||||||
upload_session.storage_path, upload_session.s3_upload_id,
|
|
||||||
)
|
|
||||||
elif upload_session.uploaded_chunks > 0:
|
|
||||||
await s3_service.delete_file(upload_session.storage_path)
|
|
||||||
|
|
||||||
# 释放预扣的存储空间
|
|
||||||
if upload_session.file_size > 0:
|
|
||||||
await adjust_user_storage(session, user.id, -upload_session.file_size)
|
|
||||||
|
|
||||||
await UploadSession.delete(session, upload_session)
|
await UploadSession.delete(session, upload_session)
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
@@ -588,12 +486,11 @@ async def create_download_token_endpoint(
|
|||||||
path='/{token}',
|
path='/{token}',
|
||||||
summary='下载文件',
|
summary='下载文件',
|
||||||
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
|
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
|
||||||
response_model=None,
|
|
||||||
)
|
)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
token: str,
|
token: str,
|
||||||
) -> Response:
|
) -> FileResponse:
|
||||||
"""
|
"""
|
||||||
下载文件端点
|
下载文件端点
|
||||||
|
|
||||||
@@ -643,15 +540,8 @@ async def download_file(
|
|||||||
filename=file_obj.name,
|
filename=file_obj.name,
|
||||||
media_type="application/octet-stream",
|
media_type="application/octet-stream",
|
||||||
)
|
)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
# 302 重定向到预签名 URL
|
|
||||||
presigned_url = s3_service.generate_presigned_url(
|
|
||||||
storage_path, method='GET', expires_in=3600, filename=file_obj.name,
|
|
||||||
)
|
|
||||||
return RedirectResponse(url=presigned_url, status_code=302)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="不支持的存储类型")
|
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||||
|
|
||||||
|
|
||||||
# ==================== 包含子路由 ====================
|
# ==================== 包含子路由 ====================
|
||||||
@@ -709,7 +599,9 @@ async def create_empty_file(
|
|||||||
|
|
||||||
# 确定存储策略
|
# 确定存储策略
|
||||||
policy_id = request.policy_id or parent.policy_id
|
policy_id = request.policy_id or parent.policy_id
|
||||||
policy = await Policy.get_exist_one(session, policy_id)
|
policy = await Policy.get(session, Policy.id == policy_id)
|
||||||
|
if not policy:
|
||||||
|
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||||
|
|
||||||
# 生成存储路径并创建空文件
|
# 生成存储路径并创建空文件
|
||||||
storage_path: str | None = None
|
storage_path: str | None = None
|
||||||
@@ -721,13 +613,8 @@ async def create_empty_file(
|
|||||||
)
|
)
|
||||||
await storage_service.create_empty_file(full_path)
|
await storage_service.create_empty_file(full_path)
|
||||||
storage_path = full_path
|
storage_path = full_path
|
||||||
elif policy.type == PolicyType.S3:
|
else:
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||||
dir_path, storage_name, storage_path = await s3_service.generate_file_path(
|
|
||||||
user_id=user_id,
|
|
||||||
original_filename=request.name,
|
|
||||||
)
|
|
||||||
await s3_service.upload_file(storage_path, b'')
|
|
||||||
|
|
||||||
# 创建 PhysicalFile 记录
|
# 创建 PhysicalFile 记录
|
||||||
physical_file = PhysicalFile(
|
physical_file = PhysicalFile(
|
||||||
@@ -808,7 +695,6 @@ async def create_wopi_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
wopi_app: FileApp | None = None
|
wopi_app: FileApp | None = None
|
||||||
matched_ext_record: FileAppExtension | None = None
|
|
||||||
for ext_record in ext_records:
|
for ext_record in ext_records:
|
||||||
app = ext_record.app
|
app = ext_record.app
|
||||||
if app.type == FileAppType.WOPI and app.is_enabled:
|
if app.type == FileAppType.WOPI and app.is_enabled:
|
||||||
@@ -824,20 +710,13 @@ async def create_wopi_session(
|
|||||||
if not result.first():
|
if not result.first():
|
||||||
continue
|
continue
|
||||||
wopi_app = app
|
wopi_app = app
|
||||||
matched_ext_record = ext_record
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if not wopi_app:
|
if not wopi_app:
|
||||||
http_exceptions.raise_not_found("无可用的 WOPI 查看器")
|
http_exceptions.raise_not_found("无可用的 WOPI 查看器")
|
||||||
|
|
||||||
# 优先使用 per-extension URL(Discovery 自动填充),回退到全局模板
|
if not wopi_app.wopi_editor_url_template:
|
||||||
editor_url_template: str | None = None
|
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板")
|
||||||
if matched_ext_record and matched_ext_record.wopi_action_url:
|
|
||||||
editor_url_template = matched_ext_record.wopi_action_url
|
|
||||||
if not editor_url_template:
|
|
||||||
editor_url_template = wopi_app.wopi_editor_url_template
|
|
||||||
if not editor_url_template:
|
|
||||||
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板,请先执行 Discovery 或手动配置")
|
|
||||||
|
|
||||||
# 获取站点 URL
|
# 获取站点 URL
|
||||||
site_url_setting: Setting | None = await Setting.get(
|
site_url_setting: Setting | None = await Setting.get(
|
||||||
@@ -853,8 +732,12 @@ async def create_wopi_session(
|
|||||||
# 构建 wopi_src
|
# 构建 wopi_src
|
||||||
wopi_src = f"{site_url}/wopi/files/{file_id}"
|
wopi_src = f"{site_url}/wopi/files/{file_id}"
|
||||||
|
|
||||||
# 构建 editor URL(只替换 wopi_src,token 通过 POST 表单传递)
|
# 构建 editor URL
|
||||||
editor_url = editor_url_template.format(wopi_src=wopi_src)
|
editor_url = wopi_app.wopi_editor_url_template.format(
|
||||||
|
wopi_src=wopi_src,
|
||||||
|
access_token=token,
|
||||||
|
access_token_ttl=access_token_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
return WopiSessionResponse(
|
return WopiSessionResponse(
|
||||||
wopi_src=wopi_src,
|
wopi_src=wopi_src,
|
||||||
@@ -915,13 +798,12 @@ async def _validate_source_link(
|
|||||||
path='/get/{file_id}/{name}',
|
path='/get/{file_id}/{name}',
|
||||||
summary='文件外链(直接输出文件数据)',
|
summary='文件外链(直接输出文件数据)',
|
||||||
description='通过外链直接获取文件内容,公开访问无需认证。',
|
description='通过外链直接获取文件内容,公开访问无需认证。',
|
||||||
response_model=None,
|
|
||||||
)
|
)
|
||||||
async def file_get(
|
async def file_get(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
file_id: UUID,
|
file_id: UUID,
|
||||||
name: str,
|
name: str,
|
||||||
) -> Response:
|
) -> FileResponse:
|
||||||
"""
|
"""
|
||||||
文件外链端点(直接输出)
|
文件外链端点(直接输出)
|
||||||
|
|
||||||
@@ -933,32 +815,25 @@ async def file_get(
|
|||||||
"""
|
"""
|
||||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||||
|
|
||||||
|
if policy.type != PolicyType.LOCAL:
|
||||||
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
|
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
if not await storage_service.file_exists(physical_file.storage_path):
|
||||||
|
http_exceptions.raise_not_found("物理文件不存在")
|
||||||
|
|
||||||
# 缓存物理路径(save 后对象属性会过期)
|
# 缓存物理路径(save 后对象属性会过期)
|
||||||
file_path = physical_file.storage_path
|
file_path = physical_file.storage_path
|
||||||
|
|
||||||
# 递增下载次数
|
# 递增下载次数
|
||||||
link.downloads += 1
|
link.downloads += 1
|
||||||
link = await link.save(session)
|
await link.save(session)
|
||||||
|
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
storage_service = LocalStorageService(policy)
|
|
||||||
if not await storage_service.file_exists(file_path):
|
|
||||||
http_exceptions.raise_not_found("物理文件不存在")
|
|
||||||
|
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
path=file_path,
|
path=file_path,
|
||||||
filename=name,
|
filename=name,
|
||||||
media_type="application/octet-stream",
|
media_type="application/octet-stream",
|
||||||
)
|
)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
# S3 外链直接输出:302 重定向到预签名 URL
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
presigned_url = s3_service.generate_presigned_url(
|
|
||||||
file_path, method='GET', expires_in=3600, filename=name,
|
|
||||||
)
|
|
||||||
return RedirectResponse(url=presigned_url, status_code=302)
|
|
||||||
else:
|
|
||||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -971,7 +846,7 @@ async def file_source_redirect(
|
|||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
file_id: UUID,
|
file_id: UUID,
|
||||||
name: str,
|
name: str,
|
||||||
) -> Response:
|
) -> FileResponse | RedirectResponse:
|
||||||
"""
|
"""
|
||||||
文件外链端点(重定向/直接输出)
|
文件外链端点(重定向/直接输出)
|
||||||
|
|
||||||
@@ -985,6 +860,13 @@ async def file_source_redirect(
|
|||||||
"""
|
"""
|
||||||
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
file_obj, link, physical_file, policy = await _validate_source_link(session, file_id)
|
||||||
|
|
||||||
|
if policy.type != PolicyType.LOCAL:
|
||||||
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
|
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
if not await storage_service.file_exists(physical_file.storage_path):
|
||||||
|
http_exceptions.raise_not_found("物理文件不存在")
|
||||||
|
|
||||||
# 缓存所有需要的值(save 后对象属性会过期)
|
# 缓存所有需要的值(save 后对象属性会过期)
|
||||||
file_path = physical_file.storage_path
|
file_path = physical_file.storage_path
|
||||||
is_private = policy.is_private
|
is_private = policy.is_private
|
||||||
@@ -992,12 +874,7 @@ async def file_source_redirect(
|
|||||||
|
|
||||||
# 递增下载次数
|
# 递增下载次数
|
||||||
link.downloads += 1
|
link.downloads += 1
|
||||||
link = await link.save(session)
|
await link.save(session)
|
||||||
|
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
storage_service = LocalStorageService(policy)
|
|
||||||
if not await storage_service.file_exists(file_path):
|
|
||||||
http_exceptions.raise_not_found("物理文件不存在")
|
|
||||||
|
|
||||||
# 公有存储:302 重定向到 base_url
|
# 公有存储:302 重定向到 base_url
|
||||||
if not is_private and base_url:
|
if not is_private and base_url:
|
||||||
@@ -1011,19 +888,6 @@ async def file_source_redirect(
|
|||||||
filename=name,
|
filename=name,
|
||||||
media_type="application/octet-stream",
|
media_type="application/octet-stream",
|
||||||
)
|
)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
# 公有存储且有 base_url:直接重定向到公开 URL
|
|
||||||
if not is_private and base_url:
|
|
||||||
redirect_url = f"{base_url.rstrip('/')}/{file_path}"
|
|
||||||
return RedirectResponse(url=redirect_url, status_code=302)
|
|
||||||
# 私有存储:生成预签名 URL 重定向
|
|
||||||
presigned_url = s3_service.generate_presigned_url(
|
|
||||||
file_path, method='GET', expires_in=3600, filename=name,
|
|
||||||
)
|
|
||||||
return RedirectResponse(url=presigned_url, status_code=302)
|
|
||||||
else:
|
|
||||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
@@ -1077,15 +941,11 @@ async def file_content(
|
|||||||
if not policy:
|
if not policy:
|
||||||
http_exceptions.raise_internal_error("存储策略不存在")
|
http_exceptions.raise_internal_error("存储策略不存在")
|
||||||
|
|
||||||
# 读取文件内容
|
if policy.type != PolicyType.LOCAL:
|
||||||
if policy.type == PolicyType.LOCAL:
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
|
|
||||||
storage_service = LocalStorageService(policy)
|
storage_service = LocalStorageService(policy)
|
||||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
raw_bytes = await s3_service.download_file(physical_file.storage_path)
|
|
||||||
else:
|
|
||||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = raw_bytes.decode('utf-8')
|
content = raw_bytes.decode('utf-8')
|
||||||
@@ -1151,15 +1011,11 @@ async def patch_file_content(
|
|||||||
if not policy:
|
if not policy:
|
||||||
http_exceptions.raise_internal_error("存储策略不存在")
|
http_exceptions.raise_internal_error("存储策略不存在")
|
||||||
|
|
||||||
# 读取文件内容
|
if policy.type != PolicyType.LOCAL:
|
||||||
if policy.type == PolicyType.LOCAL:
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
|
|
||||||
storage_service = LocalStorageService(policy)
|
storage_service = LocalStorageService(policy)
|
||||||
raw_bytes = await storage_service.read_file(storage_path)
|
raw_bytes = await storage_service.read_file(storage_path)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
raw_bytes = await s3_service.download_file(storage_path)
|
|
||||||
else:
|
|
||||||
http_exceptions.raise_internal_error("不支持的存储类型")
|
|
||||||
|
|
||||||
# 解码 + 规范化
|
# 解码 + 规范化
|
||||||
original_text = raw_bytes.decode('utf-8')
|
original_text = raw_bytes.decode('utf-8')
|
||||||
@@ -1193,10 +1049,7 @@ async def patch_file_content(
|
|||||||
_check_policy_size_limit(policy, len(new_bytes))
|
_check_policy_size_limit(policy, len(new_bytes))
|
||||||
|
|
||||||
# 写入文件
|
# 写入文件
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
await storage_service.write_file(storage_path, new_bytes)
|
await storage_service.write_file(storage_path, new_bytes)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
await s3_service.upload_file(storage_path, new_bytes)
|
|
||||||
|
|
||||||
# 更新数据库
|
# 更新数据库
|
||||||
owner_id = file_obj.owner_id
|
owner_id = file_obj.owner_id
|
||||||
|
|||||||
@@ -8,14 +8,13 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
|
|
||||||
from middleware.auth import auth_required
|
from middleware.auth import auth_required
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
Group,
|
|
||||||
Object,
|
Object,
|
||||||
ObjectCopyRequest,
|
ObjectCopyRequest,
|
||||||
ObjectDeleteRequest,
|
ObjectDeleteRequest,
|
||||||
@@ -23,42 +22,24 @@ from sqlmodels import (
|
|||||||
ObjectPropertyDetailResponse,
|
ObjectPropertyDetailResponse,
|
||||||
ObjectPropertyResponse,
|
ObjectPropertyResponse,
|
||||||
ObjectRenameRequest,
|
ObjectRenameRequest,
|
||||||
ObjectSwitchPolicyRequest,
|
|
||||||
ObjectType,
|
ObjectType,
|
||||||
PhysicalFile,
|
PhysicalFile,
|
||||||
Policy,
|
Policy,
|
||||||
PolicyType,
|
PolicyType,
|
||||||
Task,
|
|
||||||
TaskProps,
|
|
||||||
TaskStatus,
|
|
||||||
TaskSummaryBase,
|
|
||||||
TaskType,
|
|
||||||
User,
|
User,
|
||||||
# 元数据相关
|
|
||||||
ObjectMetadata,
|
|
||||||
MetadataResponse,
|
|
||||||
MetadataPatchRequest,
|
|
||||||
INTERNAL_NAMESPACES,
|
|
||||||
USER_WRITABLE_NAMESPACES,
|
|
||||||
)
|
)
|
||||||
from service.storage import (
|
from service.storage import (
|
||||||
LocalStorageService,
|
LocalStorageService,
|
||||||
adjust_user_storage,
|
adjust_user_storage,
|
||||||
copy_object_recursive,
|
copy_object_recursive,
|
||||||
migrate_file_with_task,
|
|
||||||
migrate_directory_files,
|
|
||||||
)
|
)
|
||||||
from service.storage.object import soft_delete_objects
|
from service.storage.object import soft_delete_objects
|
||||||
from sqlmodels.database_connection import DatabaseManager
|
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
|
|
||||||
from .custom_property import router as custom_property_router
|
|
||||||
|
|
||||||
object_router = APIRouter(
|
object_router = APIRouter(
|
||||||
prefix="/object",
|
prefix="/object",
|
||||||
tags=["object"]
|
tags=["object"]
|
||||||
)
|
)
|
||||||
object_router.include_router(custom_property_router)
|
|
||||||
|
|
||||||
@object_router.post(
|
@object_router.post(
|
||||||
path='/',
|
path='/',
|
||||||
@@ -112,7 +93,9 @@ async def router_object_create(
|
|||||||
|
|
||||||
# 确定存储策略
|
# 确定存储策略
|
||||||
policy_id = request.policy_id or parent.policy_id
|
policy_id = request.policy_id or parent.policy_id
|
||||||
policy = await Policy.get_exist_one(session, policy_id)
|
policy = await Policy.get(session, Policy.id == policy_id)
|
||||||
|
if not policy:
|
||||||
|
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||||
|
|
||||||
parent_id = parent.id
|
parent_id = parent.id
|
||||||
|
|
||||||
@@ -147,7 +130,7 @@ async def router_object_create(
|
|||||||
owner_id=user_id,
|
owner_id=user_id,
|
||||||
policy_id=policy_id,
|
policy_id=policy_id,
|
||||||
)
|
)
|
||||||
file_object = await file_object.save(session)
|
await file_object.save(session)
|
||||||
|
|
||||||
l.info(f"创建空白文件: {request.name}")
|
l.info(f"创建空白文件: {request.name}")
|
||||||
|
|
||||||
@@ -472,7 +455,7 @@ async def router_object_rename(
|
|||||||
|
|
||||||
# 更新名称
|
# 更新名称
|
||||||
obj.name = new_name
|
obj.name = new_name
|
||||||
obj = await obj.save(session)
|
await obj.save(session)
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
|
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
|
||||||
|
|
||||||
@@ -510,7 +493,6 @@ async def router_object_property(
|
|||||||
name=obj.name,
|
name=obj.name,
|
||||||
type=obj.type,
|
type=obj.type,
|
||||||
size=obj.size,
|
size=obj.size,
|
||||||
mime_type=obj.mime_type,
|
|
||||||
created_at=obj.created_at,
|
created_at=obj.created_at,
|
||||||
updated_at=obj.updated_at,
|
updated_at=obj.updated_at,
|
||||||
parent_id=obj.parent_id,
|
parent_id=obj.parent_id,
|
||||||
@@ -538,7 +520,7 @@ async def router_object_property_detail(
|
|||||||
obj = await Object.get(
|
obj = await Object.get(
|
||||||
session,
|
session,
|
||||||
(Object.id == id) & (Object.deleted_at == None),
|
(Object.id == id) & (Object.deleted_at == None),
|
||||||
load=Object.metadata_entries,
|
load=Object.file_metadata,
|
||||||
)
|
)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail="对象不存在")
|
raise HTTPException(status_code=404, detail="对象不存在")
|
||||||
@@ -561,301 +543,35 @@ async def router_object_property_detail(
|
|||||||
total_views = sum(s.views for s in shares)
|
total_views = sum(s.views for s in shares)
|
||||||
total_downloads = sum(s.downloads for s in shares)
|
total_downloads = sum(s.downloads for s in shares)
|
||||||
|
|
||||||
# 获取物理文件信息(引用计数、校验和)
|
# 获取物理文件引用计数
|
||||||
reference_count = 1
|
reference_count = 1
|
||||||
checksum_md5: str | None = None
|
|
||||||
checksum_sha256: str | None = None
|
|
||||||
if obj.physical_file_id:
|
if obj.physical_file_id:
|
||||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||||
if physical_file:
|
if physical_file:
|
||||||
reference_count = physical_file.reference_count
|
reference_count = physical_file.reference_count
|
||||||
checksum_md5 = physical_file.checksum_md5
|
|
||||||
checksum_sha256 = physical_file.checksum_sha256
|
|
||||||
|
|
||||||
# 构建元数据字典(排除内部命名空间)
|
# 构建响应
|
||||||
metadata: dict[str, str] = {}
|
response = ObjectPropertyDetailResponse(
|
||||||
for entry in obj.metadata_entries:
|
|
||||||
ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
|
||||||
if ns not in INTERNAL_NAMESPACES:
|
|
||||||
metadata[entry.name] = entry.value
|
|
||||||
|
|
||||||
return ObjectPropertyDetailResponse(
|
|
||||||
id=obj.id,
|
id=obj.id,
|
||||||
name=obj.name,
|
name=obj.name,
|
||||||
type=obj.type,
|
type=obj.type,
|
||||||
size=obj.size,
|
size=obj.size,
|
||||||
mime_type=obj.mime_type,
|
|
||||||
created_at=obj.created_at,
|
created_at=obj.created_at,
|
||||||
updated_at=obj.updated_at,
|
updated_at=obj.updated_at,
|
||||||
parent_id=obj.parent_id,
|
parent_id=obj.parent_id,
|
||||||
checksum_md5=checksum_md5,
|
|
||||||
checksum_sha256=checksum_sha256,
|
|
||||||
policy_name=policy_name,
|
policy_name=policy_name,
|
||||||
share_count=share_count,
|
share_count=share_count,
|
||||||
total_views=total_views,
|
total_views=total_views,
|
||||||
total_downloads=total_downloads,
|
total_downloads=total_downloads,
|
||||||
reference_count=reference_count,
|
reference_count=reference_count,
|
||||||
metadatas=metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 添加文件元数据
|
||||||
|
if obj.file_metadata:
|
||||||
|
response.mime_type = obj.file_metadata.mime_type
|
||||||
|
response.width = obj.file_metadata.width
|
||||||
|
response.height = obj.file_metadata.height
|
||||||
|
response.duration = obj.file_metadata.duration
|
||||||
|
response.checksum_md5 = obj.file_metadata.checksum_md5
|
||||||
|
|
||||||
@object_router.patch(
|
return response
|
||||||
path='/{object_id}/policy',
|
|
||||||
summary='切换对象存储策略',
|
|
||||||
)
|
|
||||||
async def router_object_switch_policy(
|
|
||||||
session: SessionDep,
|
|
||||||
background_tasks: BackgroundTasks,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
object_id: UUID,
|
|
||||||
request: ObjectSwitchPolicyRequest,
|
|
||||||
) -> TaskSummaryBase:
|
|
||||||
"""
|
|
||||||
切换对象的存储策略
|
|
||||||
|
|
||||||
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
|
|
||||||
目录:更新目录 policy_id(新文件使用新策略);
|
|
||||||
若 is_migrate_existing=True,额外创建后台任务迁移所有已有文件。
|
|
||||||
|
|
||||||
认证:JWT Bearer Token
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 404: 对象不存在
|
|
||||||
- 403: 无权操作此对象 / 用户组无权使用目标策略
|
|
||||||
- 400: 目标策略与当前相同 / 不能对根目录操作
|
|
||||||
"""
|
|
||||||
user_id = user.id
|
|
||||||
|
|
||||||
# 查找对象
|
|
||||||
obj = await Object.get(
|
|
||||||
session,
|
|
||||||
(Object.id == object_id) & (Object.deleted_at == None)
|
|
||||||
)
|
|
||||||
if not obj:
|
|
||||||
http_exceptions.raise_not_found("对象不存在")
|
|
||||||
if obj.owner_id != user_id:
|
|
||||||
http_exceptions.raise_forbidden("无权操作此对象")
|
|
||||||
if obj.is_banned:
|
|
||||||
http_exceptions.raise_banned()
|
|
||||||
|
|
||||||
# 根目录不能直接切换策略(应通过子对象或子目录操作)
|
|
||||||
if obj.parent_id is None:
|
|
||||||
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
|
|
||||||
|
|
||||||
# 校验目标策略存在
|
|
||||||
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
|
|
||||||
if not dest_policy:
|
|
||||||
http_exceptions.raise_not_found("目标存储策略不存在")
|
|
||||||
|
|
||||||
# 校验用户组权限
|
|
||||||
group: Group = await user.awaitable_attrs.group
|
|
||||||
await session.refresh(group, ['policies'])
|
|
||||||
allowed_ids = {p.id for p in group.policies}
|
|
||||||
if request.policy_id not in allowed_ids:
|
|
||||||
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
|
|
||||||
|
|
||||||
# 不能切换到相同策略
|
|
||||||
if obj.policy_id == request.policy_id:
|
|
||||||
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
|
|
||||||
|
|
||||||
# 保存必要的属性,避免 save 后对象过期
|
|
||||||
src_policy_id = obj.policy_id
|
|
||||||
obj_id = obj.id
|
|
||||||
obj_is_file = obj.type == ObjectType.FILE
|
|
||||||
dest_policy_id = request.policy_id
|
|
||||||
dest_policy_name = dest_policy.name
|
|
||||||
|
|
||||||
# 创建任务记录
|
|
||||||
task = Task(
|
|
||||||
type=TaskType.POLICY_MIGRATE,
|
|
||||||
status=TaskStatus.QUEUED,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
task = await task.save(session)
|
|
||||||
task_id = task.id
|
|
||||||
|
|
||||||
task_props = TaskProps(
|
|
||||||
task_id=task_id,
|
|
||||||
source_policy_id=src_policy_id,
|
|
||||||
dest_policy_id=dest_policy_id,
|
|
||||||
object_id=obj_id,
|
|
||||||
)
|
|
||||||
task_props = await task_props.save(session)
|
|
||||||
|
|
||||||
if obj_is_file:
|
|
||||||
# 文件:后台迁移
|
|
||||||
async def _run_file_migration() -> None:
|
|
||||||
async with DatabaseManager.session() as bg_session:
|
|
||||||
bg_obj = await Object.get(bg_session, Object.id == obj_id)
|
|
||||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
|
||||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
|
||||||
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
|
|
||||||
|
|
||||||
background_tasks.add_task(_run_file_migration)
|
|
||||||
else:
|
|
||||||
# 目录:先更新目录自身的 policy_id
|
|
||||||
obj = await Object.get(session, Object.id == obj_id)
|
|
||||||
obj.policy_id = dest_policy_id
|
|
||||||
obj = await obj.save(session)
|
|
||||||
|
|
||||||
if request.is_migrate_existing:
|
|
||||||
# 后台迁移所有已有文件
|
|
||||||
async def _run_dir_migration() -> None:
|
|
||||||
async with DatabaseManager.session() as bg_session:
|
|
||||||
bg_folder = await Object.get(bg_session, Object.id == obj_id)
|
|
||||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
|
||||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
|
||||||
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
|
|
||||||
|
|
||||||
background_tasks.add_task(_run_dir_migration)
|
|
||||||
else:
|
|
||||||
# 不迁移已有文件,直接完成任务
|
|
||||||
task = await Task.get(session, Task.id == task_id)
|
|
||||||
task.status = TaskStatus.COMPLETED
|
|
||||||
task.progress = 100
|
|
||||||
task = await task.save(session)
|
|
||||||
|
|
||||||
# 重新获取 task 以读取最新状态
|
|
||||||
task = await Task.get(session, Task.id == task_id)
|
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
|
|
||||||
|
|
||||||
return TaskSummaryBase(
|
|
||||||
id=task.id,
|
|
||||||
type=task.type,
|
|
||||||
status=task.status,
|
|
||||||
progress=task.progress,
|
|
||||||
error=task.error,
|
|
||||||
user_id=task.user_id,
|
|
||||||
created_at=task.created_at,
|
|
||||||
updated_at=task.updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 元数据端点 ====================
|
|
||||||
|
|
||||||
@object_router.get(
|
|
||||||
path='/{object_id}/metadata',
|
|
||||||
summary='获取对象元数据',
|
|
||||||
description='获取对象的元数据键值对,可按命名空间过滤。',
|
|
||||||
)
|
|
||||||
async def router_get_object_metadata(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
object_id: UUID,
|
|
||||||
ns: str | None = None,
|
|
||||||
) -> MetadataResponse:
|
|
||||||
"""
|
|
||||||
获取对象元数据端点
|
|
||||||
|
|
||||||
认证:JWT token 必填
|
|
||||||
|
|
||||||
查询参数:
|
|
||||||
- ns: 逗号分隔的命名空间列表(如 exif,stream),不传返回所有非内部命名空间
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 404: 对象不存在
|
|
||||||
- 403: 无权查看此对象
|
|
||||||
"""
|
|
||||||
obj = await Object.get(
|
|
||||||
session,
|
|
||||||
(Object.id == object_id) & (Object.deleted_at == None),
|
|
||||||
load=Object.metadata_entries,
|
|
||||||
)
|
|
||||||
if not obj:
|
|
||||||
raise HTTPException(status_code=404, detail="对象不存在")
|
|
||||||
|
|
||||||
if obj.owner_id != user.id:
|
|
||||||
raise HTTPException(status_code=403, detail="无权查看此对象")
|
|
||||||
|
|
||||||
# 解析命名空间过滤
|
|
||||||
ns_filter: set[str] | None = None
|
|
||||||
if ns:
|
|
||||||
ns_filter = {n.strip() for n in ns.split(",") if n.strip()}
|
|
||||||
# 不允许查看内部命名空间
|
|
||||||
ns_filter -= INTERNAL_NAMESPACES
|
|
||||||
|
|
||||||
# 构建元数据字典
|
|
||||||
metadata: dict[str, str] = {}
|
|
||||||
for entry in obj.metadata_entries:
|
|
||||||
entry_ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
|
||||||
if entry_ns in INTERNAL_NAMESPACES:
|
|
||||||
continue
|
|
||||||
if ns_filter is not None and entry_ns not in ns_filter:
|
|
||||||
continue
|
|
||||||
metadata[entry.name] = entry.value
|
|
||||||
|
|
||||||
return MetadataResponse(metadatas=metadata)
|
|
||||||
|
|
||||||
|
|
||||||
@object_router.patch(
|
|
||||||
path='/{object_id}/metadata',
|
|
||||||
summary='批量更新对象元数据',
|
|
||||||
description='批量设置或删除对象的元数据条目。仅允许修改 custom: 命名空间。',
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def router_patch_object_metadata(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
object_id: UUID,
|
|
||||||
request: MetadataPatchRequest,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
批量更新对象元数据端点
|
|
||||||
|
|
||||||
请求体中值为 None 的键将被删除,其余键将被设置/更新。
|
|
||||||
用户只能修改 custom: 命名空间的条目。
|
|
||||||
|
|
||||||
认证:JWT token 必填
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 400: 尝试修改非 custom: 命名空间的条目
|
|
||||||
- 404: 对象不存在
|
|
||||||
- 403: 无权操作此对象
|
|
||||||
"""
|
|
||||||
obj = await Object.get(
|
|
||||||
session,
|
|
||||||
(Object.id == object_id) & (Object.deleted_at == None),
|
|
||||||
)
|
|
||||||
if not obj:
|
|
||||||
raise HTTPException(status_code=404, detail="对象不存在")
|
|
||||||
|
|
||||||
if obj.owner_id != user.id:
|
|
||||||
raise HTTPException(status_code=403, detail="无权操作此对象")
|
|
||||||
|
|
||||||
for patch in request.patches:
|
|
||||||
# 验证命名空间
|
|
||||||
patch_ns = patch.key.split(":")[0] if ":" in patch.key else ""
|
|
||||||
if patch_ns not in USER_WRITABLE_NAMESPACES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"不允许修改命名空间 '{patch_ns}' 的元数据,仅允许 custom: 命名空间",
|
|
||||||
)
|
|
||||||
|
|
||||||
if patch.value is None:
|
|
||||||
# 删除元数据条目
|
|
||||||
existing = await ObjectMetadata.get(
|
|
||||||
session,
|
|
||||||
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
await ObjectMetadata.delete(session, instances=existing)
|
|
||||||
else:
|
|
||||||
# 设置/更新元数据条目
|
|
||||||
existing = await ObjectMetadata.get(
|
|
||||||
session,
|
|
||||||
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
existing.value = patch.value
|
|
||||||
existing = await existing.save(session)
|
|
||||||
else:
|
|
||||||
entry = ObjectMetadata(
|
|
||||||
object_id=object_id,
|
|
||||||
name=patch.key,
|
|
||||||
value=patch.value,
|
|
||||||
is_public=True,
|
|
||||||
)
|
|
||||||
entry = await entry.save(session)
|
|
||||||
|
|
||||||
l.info(f"用户 {user.id} 更新了对象 {object_id} 的 {len(request.patches)} 条元数据")
|
|
||||||
|
|||||||
@@ -1,168 +0,0 @@
|
|||||||
"""
|
|
||||||
用户自定义属性定义路由
|
|
||||||
|
|
||||||
提供自定义属性模板的增删改查功能。
|
|
||||||
用户可以定义类型化的属性模板(如标签、评分、分类等),
|
|
||||||
然后通过元数据 PATCH 端点为对象设置属性值。
|
|
||||||
|
|
||||||
路由前缀:/custom_property
|
|
||||||
"""
|
|
||||||
from typing import Annotated
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from loguru import logger as l
|
|
||||||
|
|
||||||
from middleware.auth import auth_required
|
|
||||||
from middleware.dependencies import SessionDep
|
|
||||||
from sqlmodels import (
|
|
||||||
CustomPropertyDefinition,
|
|
||||||
CustomPropertyCreateRequest,
|
|
||||||
CustomPropertyUpdateRequest,
|
|
||||||
CustomPropertyResponse,
|
|
||||||
User,
|
|
||||||
)
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/custom_property",
|
|
||||||
tags=["custom_property"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
path='',
|
|
||||||
summary='获取自定义属性定义列表',
|
|
||||||
description='获取当前用户的所有自定义属性定义,按 sort_order 排序。',
|
|
||||||
)
|
|
||||||
async def router_list_custom_properties(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
) -> list[CustomPropertyResponse]:
|
|
||||||
"""
|
|
||||||
获取自定义属性定义列表端点
|
|
||||||
|
|
||||||
认证:JWT token 必填
|
|
||||||
|
|
||||||
返回当前用户定义的所有自定义属性模板。
|
|
||||||
"""
|
|
||||||
definitions = await CustomPropertyDefinition.get(
|
|
||||||
session,
|
|
||||||
CustomPropertyDefinition.owner_id == user.id,
|
|
||||||
fetch_mode="all",
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
CustomPropertyResponse(
|
|
||||||
id=d.id,
|
|
||||||
name=d.name,
|
|
||||||
type=d.type,
|
|
||||||
icon=d.icon,
|
|
||||||
options=d.options,
|
|
||||||
default_value=d.default_value,
|
|
||||||
sort_order=d.sort_order,
|
|
||||||
)
|
|
||||||
for d in sorted(definitions, key=lambda x: x.sort_order)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
path='',
|
|
||||||
summary='创建自定义属性定义',
|
|
||||||
description='创建一个新的自定义属性模板。',
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def router_create_custom_property(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
request: CustomPropertyCreateRequest,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
创建自定义属性定义端点
|
|
||||||
|
|
||||||
认证:JWT token 必填
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 400: 请求数据无效
|
|
||||||
- 409: 同名属性已存在
|
|
||||||
"""
|
|
||||||
# 检查同名属性
|
|
||||||
existing = await CustomPropertyDefinition.get(
|
|
||||||
session,
|
|
||||||
(CustomPropertyDefinition.owner_id == user.id) &
|
|
||||||
(CustomPropertyDefinition.name == request.name),
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(status_code=409, detail="同名自定义属性已存在")
|
|
||||||
|
|
||||||
definition = CustomPropertyDefinition(
|
|
||||||
owner_id=user.id,
|
|
||||||
name=request.name,
|
|
||||||
type=request.type,
|
|
||||||
icon=request.icon,
|
|
||||||
options=request.options,
|
|
||||||
default_value=request.default_value,
|
|
||||||
)
|
|
||||||
definition = await definition.save(session)
|
|
||||||
|
|
||||||
l.info(f"用户 {user.id} 创建了自定义属性: {request.name}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
path='/{id}',
|
|
||||||
summary='更新自定义属性定义',
|
|
||||||
description='更新自定义属性模板的名称、图标、选项等。',
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def router_update_custom_property(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
id: UUID,
|
|
||||||
request: CustomPropertyUpdateRequest,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
更新自定义属性定义端点
|
|
||||||
|
|
||||||
认证:JWT token 必填
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 404: 属性定义不存在
|
|
||||||
- 403: 无权操作此属性
|
|
||||||
"""
|
|
||||||
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
|
||||||
|
|
||||||
if definition.owner_id != user.id:
|
|
||||||
raise HTTPException(status_code=403, detail="无权操作此属性")
|
|
||||||
|
|
||||||
definition = await definition.update(session, request)
|
|
||||||
|
|
||||||
l.info(f"用户 {user.id} 更新了自定义属性: {id}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
path='/{id}',
|
|
||||||
summary='删除自定义属性定义',
|
|
||||||
description='删除自定义属性模板。注意:不会自动清理已使用该属性的元数据条目。',
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def router_delete_custom_property(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
id: UUID,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
删除自定义属性定义端点
|
|
||||||
|
|
||||||
认证:JWT token 必填
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 404: 属性定义不存在
|
|
||||||
- 403: 无权操作此属性
|
|
||||||
"""
|
|
||||||
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
|
||||||
|
|
||||||
if definition.owner_id != user.id:
|
|
||||||
raise HTTPException(status_code=403, detail="无权操作此属性")
|
|
||||||
|
|
||||||
await CustomPropertyDefinition.delete(session, instances=definition)
|
|
||||||
|
|
||||||
l.info(f"用户 {user.id} 删除了自定义属性: {id}")
|
|
||||||
@@ -45,7 +45,12 @@ async def router_share_get(
|
|||||||
4. 返回分享详情(含文件树和分享者信息)
|
4. 返回分享详情(含文件树和分享者信息)
|
||||||
"""
|
"""
|
||||||
# 1. 查询分享(预加载 user 和 object)
|
# 1. 查询分享(预加载 user 和 object)
|
||||||
share = await Share.get_exist_one(session, id, load=[Share.user, Share.object])
|
share = await Share.get(
|
||||||
|
session, Share.id == id,
|
||||||
|
load=[Share.user, Share.object],
|
||||||
|
)
|
||||||
|
if not share:
|
||||||
|
http_exceptions.raise_not_found(detail="分享不存在或已被取消")
|
||||||
|
|
||||||
# 2. 检查过期
|
# 2. 检查过期
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
@@ -469,29 +474,16 @@ def router_share_update(id: str) -> ResponseBase:
|
|||||||
path='/{id}',
|
path='/{id}',
|
||||||
summary='删除分享',
|
summary='删除分享',
|
||||||
description='Delete a share by ID.',
|
description='Delete a share by ID.',
|
||||||
status_code=204,
|
dependencies=[Depends(auth_required)]
|
||||||
)
|
)
|
||||||
async def router_share_delete(
|
def router_share_delete(id: str) -> ResponseBase:
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
id: UUID,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
删除分享
|
Delete a share by ID.
|
||||||
|
|
||||||
认证:需要 JWT token
|
Args:
|
||||||
|
id (str): The ID of the share to be deleted.
|
||||||
|
|
||||||
流程:
|
Returns:
|
||||||
1. 通过分享ID查找分享
|
ResponseBase: A model containing the response data for the deleted share.
|
||||||
2. 验证分享属于当前用户
|
|
||||||
3. 删除分享记录
|
|
||||||
"""
|
"""
|
||||||
share = await Share.get_exist_one(session, id)
|
http_exceptions.raise_not_implemented()
|
||||||
if share.user_id != user.id:
|
|
||||||
http_exceptions.raise_forbidden(detail="无权删除此分享")
|
|
||||||
|
|
||||||
user_id = user.id
|
|
||||||
share_code = share.code
|
|
||||||
await Share.delete(session, share)
|
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 删除了分享: {share_code}")
|
|
||||||
@@ -82,8 +82,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
|||||||
(Setting.type == SettingsType.REGISTER) |
|
(Setting.type == SettingsType.REGISTER) |
|
||||||
(Setting.type == SettingsType.CAPTCHA) |
|
(Setting.type == SettingsType.CAPTCHA) |
|
||||||
(Setting.type == SettingsType.AUTH) |
|
(Setting.type == SettingsType.AUTH) |
|
||||||
(Setting.type == SettingsType.OAUTH) |
|
(Setting.type == SettingsType.OAUTH),
|
||||||
(Setting.type == SettingsType.AVATAR),
|
|
||||||
fetch_mode="all",
|
fetch_mode="all",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -123,7 +122,6 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
|||||||
password_required=s.get("auth_password_required") == "1",
|
password_required=s.get("auth_password_required") == "1",
|
||||||
phone_binding_required=s.get("auth_phone_binding_required") == "1",
|
phone_binding_required=s.get("auth_phone_binding_required") == "1",
|
||||||
email_binding_required=s.get("auth_email_binding_required") == "1",
|
email_binding_required=s.get("auth_email_binding_required") == "1",
|
||||||
avatar_max_size=int(s["avatar_size"]),
|
|
||||||
footer_code=s.get("footer_code"),
|
footer_code=s.get("footer_code"),
|
||||||
tos_url=s.get("tos_url"),
|
tos_url=s.get("tos_url"),
|
||||||
privacy_url=s.get("privacy_url"),
|
privacy_url=s.get("privacy_url"),
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import json
|
|||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
|
||||||
from itsdangerous import URLSafeTimedSerializer
|
from itsdangerous import URLSafeTimedSerializer
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from webauthn import (
|
from webauthn import (
|
||||||
@@ -234,7 +233,7 @@ async def router_user_register(
|
|||||||
group_id=default_group.id,
|
group_id=default_group.id,
|
||||||
)
|
)
|
||||||
new_user_id = new_user.id
|
new_user_id = new_user.id
|
||||||
new_user = await new_user.save(session)
|
await new_user.save(session)
|
||||||
|
|
||||||
# 7. 创建 AuthIdentity
|
# 7. 创建 AuthIdentity
|
||||||
hashed_password = Password.hash(request.credential) if request.credential else None
|
hashed_password = Password.hash(request.credential) if request.credential else None
|
||||||
@@ -246,14 +245,13 @@ async def router_user_register(
|
|||||||
is_verified=False,
|
is_verified=False,
|
||||||
user_id=new_user_id,
|
user_id=new_user_id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
# 8. 创建用户根目录(使用用户组关联的第一个存储策略)
|
# 8. 创建用户根目录
|
||||||
await session.refresh(default_group, ['policies'])
|
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||||
if not default_group.policies:
|
if not default_policy:
|
||||||
logger.error("默认用户组未关联任何存储策略")
|
logger.error("默认存储策略不存在")
|
||||||
http_exceptions.raise_internal_error()
|
http_exceptions.raise_internal_error()
|
||||||
default_policy = default_group.policies[0]
|
|
||||||
|
|
||||||
await sqlmodels.Object(
|
await sqlmodels.Object(
|
||||||
name="/",
|
name="/",
|
||||||
@@ -320,7 +318,7 @@ async def router_user_magic_link(
|
|||||||
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
||||||
|
|
||||||
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token})
|
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token})
|
||||||
logger.info(f"Magic Link token 已为 {request.email} 生成 (邮件发送待实现)")
|
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)")
|
||||||
|
|
||||||
|
|
||||||
@user_router.post(
|
@user_router.post(
|
||||||
@@ -359,78 +357,20 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
|||||||
@user_router.get(
|
@user_router.get(
|
||||||
path='/avatar/{id}/{size}',
|
path='/avatar/{id}/{size}',
|
||||||
summary='获取用户头像',
|
summary='获取用户头像',
|
||||||
response_model=None,
|
description='Get user avatar by ID and size.',
|
||||||
)
|
)
|
||||||
async def router_user_avatar(
|
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||||
session: SessionDep,
|
|
||||||
id: UUID,
|
|
||||||
size: int = 128,
|
|
||||||
) -> FileResponse | RedirectResponse:
|
|
||||||
"""
|
"""
|
||||||
获取指定用户指定尺寸的头像(公开端点,无需认证)
|
Get user avatar by ID and size.
|
||||||
|
|
||||||
路径参数:
|
Args:
|
||||||
- id: 用户 UUID
|
id (str): The user ID.
|
||||||
- size: 请求的头像尺寸(px),默认 128
|
size (int): The size of the avatar image.
|
||||||
|
|
||||||
行为:
|
Returns:
|
||||||
- default: 302 重定向到 Gravatar identicon
|
str: A Base64 encoded string of the user avatar image.
|
||||||
- gravatar: 302 重定向到 Gravatar(使用用户邮箱 MD5)
|
|
||||||
- file: 返回本地 WebP 文件
|
|
||||||
|
|
||||||
响应:
|
|
||||||
- 200: image/webp(file 模式)
|
|
||||||
- 302: 重定向到外部 URL(default/gravatar 模式)
|
|
||||||
- 404: 用户不存在
|
|
||||||
|
|
||||||
缓存:Cache-Control: public, max-age=3600
|
|
||||||
"""
|
"""
|
||||||
import aiofiles.os
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
from service.avatar import (
|
|
||||||
get_avatar_file_path,
|
|
||||||
get_avatar_settings,
|
|
||||||
gravatar_url,
|
|
||||||
resolve_avatar_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
user = await sqlmodels.User.get(session, sqlmodels.User.id == id)
|
|
||||||
if not user:
|
|
||||||
http_exceptions.raise_not_found("用户不存在")
|
|
||||||
|
|
||||||
avatar_path, _, size_l, size_m, size_s = await get_avatar_settings(session)
|
|
||||||
|
|
||||||
if user.avatar == "file":
|
|
||||||
size_label = resolve_avatar_size(size, size_l, size_m, size_s)
|
|
||||||
file_path = get_avatar_file_path(avatar_path, user.id, size_label)
|
|
||||||
|
|
||||||
if not await aiofiles.os.path.exists(file_path):
|
|
||||||
# 文件丢失,降级为 identicon
|
|
||||||
fallback_url = gravatar_url(str(user.id), size, "https://www.gravatar.com/")
|
|
||||||
return RedirectResponse(url=fallback_url, status_code=302)
|
|
||||||
|
|
||||||
return FileResponse(
|
|
||||||
path=file_path,
|
|
||||||
media_type="image/webp",
|
|
||||||
headers={"Cache-Control": "public, max-age=3600"},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif user.avatar == "gravatar":
|
|
||||||
gravatar_setting = await sqlmodels.Setting.get(
|
|
||||||
session,
|
|
||||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AVATAR)
|
|
||||||
& (sqlmodels.Setting.name == "gravatar_server"),
|
|
||||||
)
|
|
||||||
server = gravatar_setting.value if gravatar_setting else "https://www.gravatar.com/"
|
|
||||||
email = user.email or str(user.id)
|
|
||||||
url = gravatar_url(email, size, server)
|
|
||||||
return RedirectResponse(url=url, status_code=302)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# default: identicon
|
|
||||||
email_or_id = user.email or str(user.id)
|
|
||||||
url = gravatar_url(email_or_id, size, "https://www.gravatar.com/")
|
|
||||||
return RedirectResponse(url=url, status_code=302)
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# 需要登录的接口
|
# 需要登录的接口
|
||||||
@@ -494,24 +434,9 @@ async def router_user_storage(
|
|||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||||
|
|
||||||
# 查询用户所有未过期容量包的 size 总和
|
# [TODO] 总空间加上用户购买的额外空间
|
||||||
from datetime import datetime
|
|
||||||
from sqlalchemy import func, select, and_, or_
|
|
||||||
|
|
||||||
now = datetime.now()
|
total: int = group.max_storage
|
||||||
stmt = select(func.coalesce(func.sum(sqlmodels.StoragePack.size), 0)).where(
|
|
||||||
and_(
|
|
||||||
sqlmodels.StoragePack.user_id == user.id,
|
|
||||||
or_(
|
|
||||||
sqlmodels.StoragePack.expired_time.is_(None),
|
|
||||||
sqlmodels.StoragePack.expired_time > now,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await session.exec(stmt)
|
|
||||||
active_packs_total: int = result.scalar_one()
|
|
||||||
|
|
||||||
total: int = group.max_storage + active_packs_total
|
|
||||||
used: int = user.storage
|
used: int = user.storage
|
||||||
free: int = max(0, total - used)
|
free: int = max(0, total - used)
|
||||||
|
|
||||||
@@ -653,7 +578,7 @@ async def router_user_authn_finish(
|
|||||||
is_verified=True,
|
is_verified=True,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
return authn.to_detail_response()
|
return authn.to_detail_response()
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
|
|
||||||
import sqlmodels
|
import sqlmodels
|
||||||
@@ -13,7 +13,6 @@ from sqlmodels import (
|
|||||||
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
||||||
ChangePasswordRequest,
|
ChangePasswordRequest,
|
||||||
AuthnDetailResponse, AuthnRenameRequest,
|
AuthnDetailResponse, AuthnRenameRequest,
|
||||||
PolicySummary,
|
|
||||||
)
|
)
|
||||||
from sqlmodels.color import ThemeColorsBase
|
from sqlmodels.color import ThemeColorsBase
|
||||||
from sqlmodels.user_authn import UserAuthn
|
from sqlmodels.user_authn import UserAuthn
|
||||||
@@ -32,25 +31,16 @@ user_settings_router.include_router(file_viewers_router)
|
|||||||
@user_settings_router.get(
|
@user_settings_router.get(
|
||||||
path='/policies',
|
path='/policies',
|
||||||
summary='获取用户可选存储策略',
|
summary='获取用户可选存储策略',
|
||||||
|
description='Get user selectable storage policies.',
|
||||||
)
|
)
|
||||||
async def router_user_settings_policies(
|
def router_user_settings_policies() -> sqlmodels.ResponseBase:
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
|
||||||
) -> list[PolicySummary]:
|
|
||||||
"""
|
"""
|
||||||
获取当前用户所在组可选的存储策略列表
|
Get user selectable storage policies.
|
||||||
|
|
||||||
返回用户组关联的所有存储策略的摘要信息。
|
Returns:
|
||||||
|
dict: A dictionary containing available storage policies for the user.
|
||||||
"""
|
"""
|
||||||
group = await user.awaitable_attrs.group
|
http_exceptions.raise_not_implemented()
|
||||||
await session.refresh(group, ['policies'])
|
|
||||||
return [
|
|
||||||
PolicySummary(
|
|
||||||
id=p.id, name=p.name, type=p.type,
|
|
||||||
server=p.server, max_size=p.max_size, is_private=p.is_private,
|
|
||||||
)
|
|
||||||
for p in group.policies
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.get(
|
@user_settings_router.get(
|
||||||
@@ -165,121 +155,34 @@ async def router_user_settings(
|
|||||||
@user_settings_router.post(
|
@user_settings_router.post(
|
||||||
path='/avatar',
|
path='/avatar',
|
||||||
summary='从文件上传头像',
|
summary='从文件上传头像',
|
||||||
status_code=204,
|
description='Upload user avatar from file.',
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
)
|
)
|
||||||
async def router_user_settings_avatar(
|
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
|
||||||
file: UploadFile = File(...),
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
上传头像文件
|
Upload user avatar from file.
|
||||||
|
|
||||||
认证:JWT token
|
Returns:
|
||||||
请求体:multipart/form-data,file 字段
|
dict: A dictionary containing the result of the avatar upload.
|
||||||
|
|
||||||
流程:
|
|
||||||
1. 验证文件 MIME 类型(JPEG/PNG/GIF/WebP)
|
|
||||||
2. 验证文件大小 <= avatar_size 设置(默认 2MB)
|
|
||||||
3. 调用 Pillow 验证图片有效性并处理(居中裁剪、缩放 L/M/S)
|
|
||||||
4. 保存三种尺寸的 WebP 文件
|
|
||||||
5. 更新 User.avatar = "file"
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 400: 文件类型不支持 / 图片无法解析
|
|
||||||
- 413: 文件过大
|
|
||||||
"""
|
"""
|
||||||
from service.avatar import (
|
http_exceptions.raise_not_implemented()
|
||||||
ALLOWED_CONTENT_TYPES,
|
|
||||||
get_avatar_settings,
|
|
||||||
process_and_save_avatar,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 验证 MIME 类型
|
|
||||||
if file.content_type not in ALLOWED_CONTENT_TYPES:
|
|
||||||
http_exceptions.raise_bad_request(
|
|
||||||
f"不支持的图片格式,允许: {', '.join(ALLOWED_CONTENT_TYPES)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 读取并验证大小
|
|
||||||
_, max_upload_size, _, _, _ = await get_avatar_settings(session)
|
|
||||||
raw_bytes = await file.read()
|
|
||||||
if len(raw_bytes) > max_upload_size:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail=f"文件过大,最大允许 {max_upload_size} 字节",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理并保存(内部会验证图片有效性,无效抛出 ValueError)
|
|
||||||
try:
|
|
||||||
await process_and_save_avatar(session, user.id, raw_bytes)
|
|
||||||
except ValueError as e:
|
|
||||||
http_exceptions.raise_bad_request(str(e))
|
|
||||||
|
|
||||||
# 更新用户头像字段
|
|
||||||
user.avatar = "file"
|
|
||||||
user = await user.save(session)
|
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.put(
|
@user_settings_router.put(
|
||||||
path='/avatar',
|
path='/avatar',
|
||||||
summary='设定为 Gravatar 头像',
|
summary='设定为Gravatar头像',
|
||||||
|
description='Set user avatar to Gravatar.',
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
status_code=204,
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_user_settings_avatar_gravatar(
|
def router_user_settings_avatar_gravatar() -> None:
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
将头像切换为 Gravatar
|
Set user avatar to Gravatar.
|
||||||
|
|
||||||
认证:JWT token
|
Returns:
|
||||||
|
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||||
流程:
|
|
||||||
1. 验证用户有邮箱(Gravatar 基于邮箱 MD5)
|
|
||||||
2. 如果当前是 FILE 头像,删除本地文件
|
|
||||||
3. 更新 User.avatar = "gravatar"
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 400: 用户没有邮箱
|
|
||||||
"""
|
"""
|
||||||
from service.avatar import delete_avatar_files
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
if not user.email:
|
|
||||||
http_exceptions.raise_bad_request("Gravatar 需要邮箱,请先绑定邮箱")
|
|
||||||
|
|
||||||
if user.avatar == "file":
|
|
||||||
await delete_avatar_files(session, user.id)
|
|
||||||
|
|
||||||
user.avatar = "gravatar"
|
|
||||||
user = await user.save(session)
|
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.delete(
|
|
||||||
path='/avatar',
|
|
||||||
summary='重置头像为默认',
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def router_user_settings_avatar_delete(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
重置头像为默认
|
|
||||||
|
|
||||||
认证:JWT token
|
|
||||||
|
|
||||||
流程:
|
|
||||||
1. 如果当前是 FILE 头像,删除本地文件
|
|
||||||
2. 更新 User.avatar = "default"
|
|
||||||
"""
|
|
||||||
from service.avatar import delete_avatar_files
|
|
||||||
|
|
||||||
if user.avatar == "file":
|
|
||||||
await delete_avatar_files(session, user.id)
|
|
||||||
|
|
||||||
user.avatar = "default"
|
|
||||||
user = await user.save(session)
|
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.patch(
|
@user_settings_router.patch(
|
||||||
@@ -321,7 +224,7 @@ async def router_user_settings_theme(
|
|||||||
user.color_error = request.theme_colors.error
|
user.color_error = request.theme_colors.error
|
||||||
user.color_neutral = request.theme_colors.neutral
|
user.color_neutral = request.theme_colors.neutral
|
||||||
|
|
||||||
user = await user.save(session)
|
await user.save(session)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.patch(
|
@user_settings_router.patch(
|
||||||
@@ -358,7 +261,7 @@ async def router_user_settings_change_password(
|
|||||||
http_exceptions.raise_forbidden("当前密码错误")
|
http_exceptions.raise_forbidden("当前密码错误")
|
||||||
|
|
||||||
email_identity.credential = Password.hash(request.new_password)
|
email_identity.credential = Password.hash(request.new_password)
|
||||||
email_identity = await email_identity.save(session)
|
await email_identity.save(session)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.patch(
|
@user_settings_router.patch(
|
||||||
@@ -392,7 +295,7 @@ async def router_user_settings_patch(
|
|||||||
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
|
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
|
||||||
|
|
||||||
setattr(user, option.value, value)
|
setattr(user, option.value, value)
|
||||||
user = await user.save(session)
|
await user.save(session)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.get(
|
@user_settings_router.get(
|
||||||
@@ -454,7 +357,7 @@ async def router_user_settings_2fa_enable(
|
|||||||
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
|
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
|
||||||
extra["two_factor"] = secret
|
extra["two_factor"] = secret
|
||||||
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
|
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
|
||||||
email_identity = await email_identity.save(session)
|
await email_identity.save(session)
|
||||||
|
|
||||||
|
|
||||||
# ==================== 认证身份管理 ====================
|
# ==================== 认证身份管理 ====================
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ async def set_default_viewer(
|
|||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
existing.app_id = request.app_id
|
existing.app_id = request.app_id
|
||||||
existing = await existing.save(session, load=UserFileAppDefault.app)
|
existing = await existing.save(session)
|
||||||
|
# 重新加载 app 关系
|
||||||
|
await session.refresh(existing, attribute_names=["app"])
|
||||||
return existing.to_response()
|
return existing.to_response()
|
||||||
else:
|
else:
|
||||||
new_default = UserFileAppDefault(
|
new_default = UserFileAppDefault(
|
||||||
@@ -87,7 +89,9 @@ async def set_default_viewer(
|
|||||||
extension=normalized_ext,
|
extension=normalized_ext,
|
||||||
app_id=request.app_id,
|
app_id=request.app_id,
|
||||||
)
|
)
|
||||||
new_default = await new_default.save(session, load=UserFileAppDefault.app)
|
new_default = await new_default.save(session)
|
||||||
|
# 重新加载 app 关系
|
||||||
|
await session.refresh(new_default, attribute_names=["app"])
|
||||||
return new_default.to_response()
|
return new_default.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
106
routers/api/v1/vas/__init__.py
Normal file
106
routers/api/v1/vas/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from middleware.auth import auth_required
|
||||||
|
from sqlmodels import ResponseBase
|
||||||
|
from utils import http_exceptions
|
||||||
|
|
||||||
|
vas_router = APIRouter(
|
||||||
|
prefix="/vas",
|
||||||
|
tags=["vas"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@vas_router.get(
|
||||||
|
path='/pack',
|
||||||
|
summary='获取容量包及配额信息',
|
||||||
|
description='Get information about storage packs and quotas.',
|
||||||
|
dependencies=[Depends(auth_required)]
|
||||||
|
)
|
||||||
|
def router_vas_pack() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Get information about storage packs and quotas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for storage packs and quotas.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@vas_router.get(
|
||||||
|
path='/product',
|
||||||
|
summary='获取商品信息,同时返回支付信息',
|
||||||
|
description='Get product information along with payment details.',
|
||||||
|
dependencies=[Depends(auth_required)]
|
||||||
|
)
|
||||||
|
def router_vas_product() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Get product information along with payment details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for products and payment information.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@vas_router.post(
|
||||||
|
path='/order',
|
||||||
|
summary='新建支付订单',
|
||||||
|
description='Create an order for a product.',
|
||||||
|
dependencies=[Depends(auth_required)]
|
||||||
|
)
|
||||||
|
def router_vas_order() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Create an order for a product.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the created order.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@vas_router.get(
|
||||||
|
path='/order/{id}',
|
||||||
|
summary='查询订单状态',
|
||||||
|
description='Get information about a specific payment order by ID.',
|
||||||
|
dependencies=[Depends(auth_required)]
|
||||||
|
)
|
||||||
|
def router_vas_order_get(id: str) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Get information about a specific payment order by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (str): The ID of the order to retrieve information for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the specified order.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@vas_router.get(
|
||||||
|
path='/redeem',
|
||||||
|
summary='获取兑换码信息',
|
||||||
|
description='Get information about a specific redemption code.',
|
||||||
|
dependencies=[Depends(auth_required)]
|
||||||
|
)
|
||||||
|
def router_vas_redeem(code: str) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Get information about a specific redemption code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code (str): The redemption code to retrieve information for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the specified redemption code.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@vas_router.post(
|
||||||
|
path='/redeem',
|
||||||
|
summary='执行兑换',
|
||||||
|
description='Redeem a redemption code for a product or service.',
|
||||||
|
dependencies=[Depends(auth_required)]
|
||||||
|
)
|
||||||
|
def router_vas_redeem_post() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Redeem a redemption code for a product or service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the redeemed code.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
@@ -1,207 +1,110 @@
|
|||||||
from typing import Annotated
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from loguru import logger as l
|
|
||||||
|
|
||||||
from middleware.auth import auth_required
|
from middleware.auth import auth_required
|
||||||
from middleware.dependencies import SessionDep
|
from sqlmodels import ResponseBase
|
||||||
from sqlmodels import (
|
|
||||||
Object,
|
|
||||||
User,
|
|
||||||
WebDAV,
|
|
||||||
WebDAVAccountResponse,
|
|
||||||
WebDAVCreateRequest,
|
|
||||||
WebDAVUpdateRequest,
|
|
||||||
)
|
|
||||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
from utils.password.pwd import Password
|
|
||||||
|
|
||||||
|
# WebDAV 管理路由
|
||||||
webdav_router = APIRouter(
|
webdav_router = APIRouter(
|
||||||
prefix='/webdav',
|
prefix='/webdav',
|
||||||
tags=["webdav"],
|
tags=["webdav"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_webdav_enabled(user: User) -> None:
|
|
||||||
"""检查用户组是否启用了 WebDAV 功能"""
|
|
||||||
if not user.group.web_dav_enabled:
|
|
||||||
http_exceptions.raise_forbidden("WebDAV 功能未启用")
|
|
||||||
|
|
||||||
|
|
||||||
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
|
|
||||||
"""将 WebDAV 数据库模型转换为响应 DTO"""
|
|
||||||
return WebDAVAccountResponse(
|
|
||||||
id=account.id,
|
|
||||||
name=account.name,
|
|
||||||
root=account.root,
|
|
||||||
readonly=account.readonly,
|
|
||||||
use_proxy=account.use_proxy,
|
|
||||||
created_at=str(account.created_at),
|
|
||||||
updated_at=str(account.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@webdav_router.get(
|
@webdav_router.get(
|
||||||
path='/accounts',
|
path='/accounts',
|
||||||
summary='获取账号列表',
|
summary='获取账号信息',
|
||||||
|
description='Get account information for WebDAV.',
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
)
|
)
|
||||||
async def list_accounts(
|
def router_webdav_accounts() -> ResponseBase:
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
) -> list[WebDAVAccountResponse]:
|
|
||||||
"""
|
"""
|
||||||
列出当前用户所有 WebDAV 账户
|
Get account information for WebDAV.
|
||||||
|
|
||||||
认证:JWT Bearer Token
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the account information.
|
||||||
"""
|
"""
|
||||||
_check_webdav_enabled(user)
|
http_exceptions.raise_not_implemented()
|
||||||
user_id: UUID = user.id
|
|
||||||
|
|
||||||
accounts: list[WebDAV] = await WebDAV.get(
|
|
||||||
session,
|
|
||||||
WebDAV.user_id == user_id,
|
|
||||||
fetch_mode="all",
|
|
||||||
)
|
|
||||||
return [_to_response(a) for a in accounts]
|
|
||||||
|
|
||||||
|
|
||||||
@webdav_router.post(
|
@webdav_router.post(
|
||||||
path='/accounts',
|
path='/accounts',
|
||||||
summary='创建账号',
|
summary='新建账号',
|
||||||
status_code=201,
|
description='Create a new WebDAV account.',
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
)
|
)
|
||||||
async def create_account(
|
def router_webdav_create_account() -> ResponseBase:
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
request: WebDAVCreateRequest,
|
|
||||||
) -> WebDAVAccountResponse:
|
|
||||||
"""
|
"""
|
||||||
创建 WebDAV 账户
|
Create a new WebDAV account.
|
||||||
|
|
||||||
认证:JWT Bearer Token
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the created account.
|
||||||
错误处理:
|
|
||||||
- 403: WebDAV 功能未启用
|
|
||||||
- 400: 根目录路径不存在或不是目录
|
|
||||||
- 409: 账户名已存在
|
|
||||||
"""
|
"""
|
||||||
_check_webdav_enabled(user)
|
http_exceptions.raise_not_implemented()
|
||||||
user_id: UUID = user.id
|
|
||||||
|
|
||||||
# 验证账户名唯一
|
|
||||||
existing = await WebDAV.get(
|
|
||||||
session,
|
|
||||||
(WebDAV.name == request.name) & (WebDAV.user_id == user_id),
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
http_exceptions.raise_conflict("账户名已存在")
|
|
||||||
|
|
||||||
# 验证 root 路径存在且为目录
|
|
||||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
|
||||||
if not root_obj or not root_obj.is_folder:
|
|
||||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
|
||||||
|
|
||||||
# 创建账户
|
|
||||||
account = WebDAV(
|
|
||||||
name=request.name,
|
|
||||||
password=Password.hash(request.password),
|
|
||||||
root=request.root,
|
|
||||||
readonly=request.readonly,
|
|
||||||
use_proxy=request.use_proxy,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
account = await account.save(session)
|
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
|
|
||||||
return _to_response(account)
|
|
||||||
|
|
||||||
|
|
||||||
@webdav_router.patch(
|
|
||||||
path='/accounts/{account_id}',
|
|
||||||
summary='更新账号',
|
|
||||||
)
|
|
||||||
async def update_account(
|
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
account_id: int,
|
|
||||||
request: WebDAVUpdateRequest,
|
|
||||||
) -> WebDAVAccountResponse:
|
|
||||||
"""
|
|
||||||
更新 WebDAV 账户
|
|
||||||
|
|
||||||
认证:JWT Bearer Token
|
|
||||||
|
|
||||||
错误处理:
|
|
||||||
- 403: WebDAV 功能未启用
|
|
||||||
- 404: 账户不存在
|
|
||||||
- 400: 根目录路径不存在或不是目录
|
|
||||||
"""
|
|
||||||
_check_webdav_enabled(user)
|
|
||||||
user_id: UUID = user.id
|
|
||||||
|
|
||||||
account = await WebDAV.get(
|
|
||||||
session,
|
|
||||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
|
||||||
)
|
|
||||||
if not account:
|
|
||||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
|
||||||
|
|
||||||
# 验证 root 路径
|
|
||||||
if request.root is not None:
|
|
||||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
|
||||||
if not root_obj or not root_obj.is_folder:
|
|
||||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
|
||||||
|
|
||||||
# 密码哈希后原地替换,update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
|
|
||||||
is_password_changed = request.password is not None
|
|
||||||
if is_password_changed:
|
|
||||||
request.password = Password.hash(request.password)
|
|
||||||
|
|
||||||
account = await account.update(session, request)
|
|
||||||
|
|
||||||
# 密码变更时清除认证缓存
|
|
||||||
if is_password_changed:
|
|
||||||
await WebDAVAuthCache.invalidate_account(user_id, account.name)
|
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
|
|
||||||
return _to_response(account)
|
|
||||||
|
|
||||||
|
|
||||||
@webdav_router.delete(
|
@webdav_router.delete(
|
||||||
path='/accounts/{account_id}',
|
path='/accounts/{id}',
|
||||||
summary='删除账号',
|
summary='删除账号',
|
||||||
status_code=204,
|
description='Delete a WebDAV account by its ID.',
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
)
|
)
|
||||||
async def delete_account(
|
def router_webdav_delete_account(id: str) -> ResponseBase:
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[User, Depends(auth_required)],
|
|
||||||
account_id: int,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
删除 WebDAV 账户
|
Delete a WebDAV account by its ID.
|
||||||
|
|
||||||
认证:JWT Bearer Token
|
Args:
|
||||||
|
id (str): The ID of the account to be deleted.
|
||||||
|
|
||||||
错误处理:
|
Returns:
|
||||||
- 403: WebDAV 功能未启用
|
ResponseBase: A model containing the response data for the deletion operation.
|
||||||
- 404: 账户不存在
|
|
||||||
"""
|
"""
|
||||||
_check_webdav_enabled(user)
|
http_exceptions.raise_not_implemented()
|
||||||
user_id: UUID = user.id
|
|
||||||
|
|
||||||
account = await WebDAV.get(
|
@webdav_router.post(
|
||||||
session,
|
path='/mount',
|
||||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
summary='新建目录挂载',
|
||||||
)
|
description='Create a new WebDAV mount point.',
|
||||||
if not account:
|
dependencies=[Depends(auth_required)],
|
||||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
)
|
||||||
|
def router_webdav_create_mount() -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Create a new WebDAV mount point.
|
||||||
|
|
||||||
account_name = account.name
|
Returns:
|
||||||
await WebDAV.delete(session, account)
|
ResponseBase: A model containing the response data for the created mount point.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
# 清除认证缓存
|
@webdav_router.delete(
|
||||||
await WebDAVAuthCache.invalidate_account(user_id, account_name)
|
path='/mount/{id}',
|
||||||
|
summary='删除目录挂载',
|
||||||
|
description='Delete a WebDAV mount point by its ID.',
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
|
)
|
||||||
|
def router_webdav_delete_mount(id: str) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Delete a WebDAV mount point by its ID.
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")
|
Args:
|
||||||
|
id (str): The ID of the mount point to be deleted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the deletion operation.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@webdav_router.patch(
|
||||||
|
path='accounts/{id}',
|
||||||
|
summary='更新账号信息',
|
||||||
|
description='Update WebDAV account information by ID.',
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
|
)
|
||||||
|
def router_webdav_update_account(id: str) -> ResponseBase:
|
||||||
|
"""
|
||||||
|
Update WebDAV account information by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (str): The ID of the account to be updated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseBase: A model containing the response data for the updated account.
|
||||||
|
"""
|
||||||
|
http_exceptions.raise_not_implemented()
|
||||||
1
routers/dav/README.md
Normal file
1
routers/dav/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# WebDAV 操作路由
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
"""
|
|
||||||
WebDAV 协议入口
|
|
||||||
|
|
||||||
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
|
|
||||||
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
|
|
||||||
"""
|
|
||||||
from a2wsgi import WSGIMiddleware
|
|
||||||
from wsgidav.wsgidav_app import WsgiDAVApp
|
|
||||||
|
|
||||||
from .domain_controller import DiskNextDomainController
|
|
||||||
from .provider import DiskNextDAVProvider
|
|
||||||
|
|
||||||
_wsgidav_config: dict[str, object] = {
|
|
||||||
"provider_mapping": {
|
|
||||||
"/": DiskNextDAVProvider(),
|
|
||||||
},
|
|
||||||
"http_authenticator": {
|
|
||||||
"domain_controller": DiskNextDomainController,
|
|
||||||
"accept_basic": True,
|
|
||||||
"accept_digest": False,
|
|
||||||
"default_to_digest": False,
|
|
||||||
},
|
|
||||||
"verbose": 1,
|
|
||||||
# 使用 WsgiDAV 内置的内存锁管理器
|
|
||||||
"lock_storage": True,
|
|
||||||
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
|
|
||||||
"dir_browser": {
|
|
||||||
"enable": False,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
|
|
||||||
|
|
||||||
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
|
|
||||||
"""ASGI 应用,挂载到 /dav 路径"""
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
"""
|
|
||||||
WebDAV 认证控制器
|
|
||||||
|
|
||||||
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
|
|
||||||
通过 DiskNext 的 WebDAV 账户模型进行认证。
|
|
||||||
|
|
||||||
用户名格式: {email}/{webdav_account_name}
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from loguru import logger as l
|
|
||||||
from wsgidav.dc.base_dc import BaseDomainController
|
|
||||||
|
|
||||||
from routers.dav.provider import EventLoopRef, _get_session
|
|
||||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
|
||||||
from sqlmodels.user import User, UserStatus
|
|
||||||
from sqlmodels.webdav import WebDAV
|
|
||||||
from utils.password.pwd import Password, PasswordStatus
|
|
||||||
|
|
||||||
|
|
||||||
async def _authenticate(
|
|
||||||
email: str,
|
|
||||||
account_name: str,
|
|
||||||
password: str,
|
|
||||||
) -> tuple[UUID, int] | None:
|
|
||||||
"""
|
|
||||||
异步认证 WebDAV 用户。
|
|
||||||
|
|
||||||
:param email: 用户邮箱
|
|
||||||
:param account_name: WebDAV 账户名
|
|
||||||
:param password: 明文密码
|
|
||||||
:return: (user_id, webdav_id) 或 None
|
|
||||||
"""
|
|
||||||
# 1. 查缓存
|
|
||||||
cached = await WebDAVAuthCache.get(email, account_name, password)
|
|
||||||
if cached is not None:
|
|
||||||
return cached
|
|
||||||
|
|
||||||
# 2. 缓存未命中,查库验证
|
|
||||||
async with _get_session() as session:
|
|
||||||
user = await User.get(session, User.email == email, load=User.group)
|
|
||||||
if not user:
|
|
||||||
return None
|
|
||||||
if user.status != UserStatus.ACTIVE:
|
|
||||||
return None
|
|
||||||
if not user.group.web_dav_enabled:
|
|
||||||
return None
|
|
||||||
|
|
||||||
account = await WebDAV.get(
|
|
||||||
session,
|
|
||||||
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
|
|
||||||
)
|
|
||||||
if not account:
|
|
||||||
return None
|
|
||||||
|
|
||||||
status = Password.verify(account.password, password)
|
|
||||||
if status == PasswordStatus.INVALID:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_id: UUID = user.id
|
|
||||||
webdav_id: int = account.id
|
|
||||||
|
|
||||||
# 3. 写入缓存
|
|
||||||
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
|
|
||||||
|
|
||||||
return user_id, webdav_id
|
|
||||||
|
|
||||||
|
|
||||||
class DiskNextDomainController(BaseDomainController):
|
|
||||||
"""
|
|
||||||
DiskNext WebDAV 认证控制器
|
|
||||||
|
|
||||||
用户名格式: {email}/{webdav_account_name}
|
|
||||||
密码: WebDAV 账户密码(创建账户时设置)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
|
|
||||||
super().__init__(wsgidav_app, config)
|
|
||||||
|
|
||||||
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
|
|
||||||
"""返回 realm 名称"""
|
|
||||||
return "DiskNext WebDAV"
|
|
||||||
|
|
||||||
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
|
|
||||||
"""所有请求都需要认证"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_share_anonymous(self, path_info: str) -> bool:
|
|
||||||
"""不支持匿名访问"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def supports_http_digest_auth(self) -> bool:
|
|
||||||
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def basic_auth_user(
|
|
||||||
self,
|
|
||||||
realm: str,
|
|
||||||
user_name: str,
|
|
||||||
password: str,
|
|
||||||
environ: dict[str, object],
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
HTTP Basic Auth 认证。
|
|
||||||
|
|
||||||
用户名格式: {email}/{webdav_account_name}
|
|
||||||
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
|
|
||||||
"""
|
|
||||||
# 解析用户名
|
|
||||||
if "/" not in user_name:
|
|
||||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
|
||||||
return False
|
|
||||||
|
|
||||||
email, account_name = user_name.split("/", 1)
|
|
||||||
if not email or not account_name:
|
|
||||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 在 WSGI 线程中调用异步认证
|
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
|
||||||
_authenticate(email, account_name, password),
|
|
||||||
EventLoopRef.get(),
|
|
||||||
)
|
|
||||||
result = future.result()
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
user_id, webdav_id = result
|
|
||||||
|
|
||||||
# 将认证信息存入 environ,供 Provider 使用
|
|
||||||
environ["disknext.user_id"] = user_id
|
|
||||||
environ["disknext.webdav_id"] = webdav_id
|
|
||||||
environ["disknext.email"] = email
|
|
||||||
environ["disknext.account_name"] = account_name
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def digest_auth_user(
|
|
||||||
self,
|
|
||||||
realm: str,
|
|
||||||
user_name: str,
|
|
||||||
environ: dict[str, object],
|
|
||||||
) -> bool:
|
|
||||||
"""不支持 Digest 认证"""
|
|
||||||
return False
|
|
||||||
@@ -1,645 +0,0 @@
|
|||||||
"""
|
|
||||||
DiskNext WebDAV 存储 Provider
|
|
||||||
|
|
||||||
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
|
|
||||||
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
import io
|
|
||||||
import mimetypes
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import ClassVar
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from loguru import logger as l
|
|
||||||
from wsgidav.dav_error import (
|
|
||||||
DAVError,
|
|
||||||
HTTP_FORBIDDEN,
|
|
||||||
HTTP_INSUFFICIENT_STORAGE,
|
|
||||||
HTTP_NOT_FOUND,
|
|
||||||
)
|
|
||||||
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
|
|
||||||
|
|
||||||
from service.storage import LocalStorageService, adjust_user_storage
|
|
||||||
from sqlmodels.database_connection import DatabaseManager
|
|
||||||
from sqlmodels.object import Object, ObjectType
|
|
||||||
from sqlmodels.physical_file import PhysicalFile
|
|
||||||
from sqlmodels.policy import Policy
|
|
||||||
from sqlmodels.user import User
|
|
||||||
from sqlmodels.webdav import WebDAV
|
|
||||||
|
|
||||||
|
|
||||||
class EventLoopRef:
|
|
||||||
"""持有主线程事件循环引用,供 WSGI 线程使用"""
|
|
||||||
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def capture(cls) -> None:
|
|
||||||
"""在 async 上下文中调用,捕获当前事件循环"""
|
|
||||||
cls._loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls) -> asyncio.AbstractEventLoop:
|
|
||||||
if cls._loop is None:
|
|
||||||
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
|
|
||||||
return cls._loop
|
|
||||||
|
|
||||||
|
|
||||||
def _run_async(coro): # type: ignore[no-untyped-def]
|
|
||||||
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
|
|
||||||
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
|
|
||||||
return future.result()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_session(): # type: ignore[no-untyped-def]
|
|
||||||
"""获取数据库会话上下文管理器"""
|
|
||||||
return DatabaseManager._async_session_factory()
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 异步辅助函数 ====================
|
|
||||||
|
|
||||||
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
|
|
||||||
"""获取 WebDAV 账户"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
return await WebDAV.get(session, WebDAV.id == webdav_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
|
|
||||||
"""根据路径获取对象"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
return await Object.get_by_path(session, user_id, path)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
|
|
||||||
"""获取目录子对象"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
return await Object.get_children(session, user_id, parent_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_object_by_id(object_id: UUID) -> Object | None:
|
|
||||||
"""根据ID获取对象"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_user(user_id: UUID) -> User | None:
|
|
||||||
"""获取用户(含 group 关系)"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
return await User.get(session, User.id == user_id, load=User.group)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_policy(policy_id: UUID) -> Policy | None:
|
|
||||||
"""获取存储策略"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
return await Policy.get(session, Policy.id == policy_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_folder(
|
|
||||||
name: str,
|
|
||||||
parent_id: UUID,
|
|
||||||
owner_id: UUID,
|
|
||||||
policy_id: UUID,
|
|
||||||
) -> Object:
|
|
||||||
"""创建目录对象"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
obj = Object(
|
|
||||||
name=name,
|
|
||||||
type=ObjectType.FOLDER,
|
|
||||||
size=0,
|
|
||||||
parent_id=parent_id,
|
|
||||||
owner_id=owner_id,
|
|
||||||
policy_id=policy_id,
|
|
||||||
)
|
|
||||||
obj = await obj.save(session)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_file(
|
|
||||||
name: str,
|
|
||||||
parent_id: UUID,
|
|
||||||
owner_id: UUID,
|
|
||||||
policy_id: UUID,
|
|
||||||
) -> Object:
|
|
||||||
"""创建空文件对象"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
obj = Object(
|
|
||||||
name=name,
|
|
||||||
type=ObjectType.FILE,
|
|
||||||
size=0,
|
|
||||||
parent_id=parent_id,
|
|
||||||
owner_id=owner_id,
|
|
||||||
policy_id=policy_id,
|
|
||||||
)
|
|
||||||
obj = await obj.save(session)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
async def _soft_delete_object(object_id: UUID) -> None:
|
|
||||||
"""软删除对象(移入回收站)"""
|
|
||||||
from service.storage import soft_delete_objects
|
|
||||||
|
|
||||||
async with _get_session() as session:
|
|
||||||
obj = await Object.get(session, Object.id == object_id)
|
|
||||||
if obj:
|
|
||||||
await soft_delete_objects(session, [obj])
|
|
||||||
|
|
||||||
|
|
||||||
async def _finalize_upload(
|
|
||||||
object_id: UUID,
|
|
||||||
physical_path: str,
|
|
||||||
size: int,
|
|
||||||
owner_id: UUID,
|
|
||||||
policy_id: UUID,
|
|
||||||
) -> None:
|
|
||||||
"""上传完成后更新对象元数据和物理文件记录"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
# 获取存储路径(相对路径)
|
|
||||||
policy = await Policy.get(session, Policy.id == policy_id)
|
|
||||||
if not policy or not policy.server:
|
|
||||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
|
||||||
|
|
||||||
base_path = Path(policy.server).resolve()
|
|
||||||
full_path = Path(physical_path).resolve()
|
|
||||||
storage_path = str(full_path.relative_to(base_path))
|
|
||||||
|
|
||||||
# 创建 PhysicalFile 记录
|
|
||||||
pf = PhysicalFile(
|
|
||||||
storage_path=storage_path,
|
|
||||||
size=size,
|
|
||||||
policy_id=policy_id,
|
|
||||||
reference_count=1,
|
|
||||||
)
|
|
||||||
pf = await pf.save(session)
|
|
||||||
|
|
||||||
# 更新 Object
|
|
||||||
obj = await Object.get(session, Object.id == object_id)
|
|
||||||
if obj:
|
|
||||||
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
|
|
||||||
obj = await obj.save(session)
|
|
||||||
|
|
||||||
# 更新用户存储用量
|
|
||||||
if size > 0:
|
|
||||||
await adjust_user_storage(session, owner_id, size)
|
|
||||||
|
|
||||||
|
|
||||||
async def _move_object(
|
|
||||||
object_id: UUID,
|
|
||||||
new_parent_id: UUID,
|
|
||||||
new_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""移动/重命名对象"""
|
|
||||||
async with _get_session() as session:
|
|
||||||
obj = await Object.get(session, Object.id == object_id)
|
|
||||||
if obj:
|
|
||||||
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
|
|
||||||
obj = await obj.save(session)
|
|
||||||
|
|
||||||
|
|
||||||
async def _copy_object_recursive(
|
|
||||||
src_id: UUID,
|
|
||||||
dst_parent_id: UUID,
|
|
||||||
dst_name: str,
|
|
||||||
owner_id: UUID,
|
|
||||||
) -> None:
|
|
||||||
"""递归复制对象"""
|
|
||||||
from service.storage import copy_object_recursive
|
|
||||||
|
|
||||||
async with _get_session() as session:
|
|
||||||
src = await Object.get(session, Object.id == src_id)
|
|
||||||
if not src:
|
|
||||||
return
|
|
||||||
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 辅助工具 ====================
|
|
||||||
|
|
||||||
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
|
|
||||||
"""从 environ 中提取认证信息"""
|
|
||||||
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
|
|
||||||
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
|
|
||||||
return user_id, webdav_id
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
|
|
||||||
"""
|
|
||||||
将 DAV 相对路径映射到 DiskNext 绝对路径。
|
|
||||||
|
|
||||||
:param account_root: 账户挂载根路径,如 "/" 或 "/docs"
|
|
||||||
:param dav_path: DAV 请求路径,如 "/" 或 "/photos/cat.jpg"
|
|
||||||
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
|
|
||||||
"""
|
|
||||||
# 规范化根路径
|
|
||||||
root = account_root.rstrip("/")
|
|
||||||
if not root:
|
|
||||||
root = ""
|
|
||||||
|
|
||||||
# 规范化 DAV 路径
|
|
||||||
if not dav_path or dav_path == "/":
|
|
||||||
return root + "/" if root else "/"
|
|
||||||
|
|
||||||
if not dav_path.startswith("/"):
|
|
||||||
dav_path = "/" + dav_path
|
|
||||||
|
|
||||||
full = root + dav_path
|
|
||||||
return full if full else "/"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_readonly(environ: dict[str, object]) -> None:
|
|
||||||
"""检查账户是否只读,只读则抛出 403"""
|
|
||||||
account = environ.get("disknext.webdav_account")
|
|
||||||
if account and getattr(account, 'readonly', False):
|
|
||||||
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_storage_quota(user: User, additional_bytes: int) -> None:
|
|
||||||
"""检查存储配额"""
|
|
||||||
max_storage = user.group.max_storage
|
|
||||||
if max_storage > 0 and user.storage + additional_bytes > max_storage:
|
|
||||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaLimitedWriter(io.RawIOBase):
|
|
||||||
"""带配额限制的写入流包装器"""
|
|
||||||
|
|
||||||
def __init__(self, stream: io.BufferedWriter, max_bytes: int) -> None:
|
|
||||||
self._stream = stream
|
|
||||||
self._max_bytes = max_bytes
|
|
||||||
self._bytes_written = 0
|
|
||||||
|
|
||||||
def writable(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def write(self, b: bytes | bytearray) -> int:
|
|
||||||
if self._bytes_written + len(b) > self._max_bytes:
|
|
||||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
|
||||||
written = self._stream.write(b)
|
|
||||||
self._bytes_written += written
|
|
||||||
return written
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._stream.close()
|
|
||||||
super().close()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bytes_written(self) -> int:
|
|
||||||
return self._bytes_written
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Provider ====================
|
|
||||||
|
|
||||||
class DiskNextDAVProvider(DAVProvider):
|
|
||||||
"""DiskNext WebDAV 存储 Provider"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_resource_inst(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
environ: dict[str, object],
|
|
||||||
) -> 'DiskNextCollection | DiskNextFile | None':
|
|
||||||
"""
|
|
||||||
将 WebDAV 路径映射到资源对象。
|
|
||||||
|
|
||||||
首次调用时加载 WebDAV 账户信息并缓存到 environ。
|
|
||||||
"""
|
|
||||||
user_id, webdav_id = _get_environ_info(environ)
|
|
||||||
|
|
||||||
# 首次请求时加载账户信息
|
|
||||||
if "disknext.webdav_account" not in environ:
|
|
||||||
account = _run_async(_get_webdav_account(webdav_id))
|
|
||||||
if not account:
|
|
||||||
return None
|
|
||||||
environ["disknext.webdav_account"] = account
|
|
||||||
|
|
||||||
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
|
|
||||||
disknext_path = _resolve_dav_path(account.root, path)
|
|
||||||
|
|
||||||
obj = _run_async(_get_object_by_path(user_id, disknext_path))
|
|
||||||
if not obj:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if obj.is_folder:
|
|
||||||
return DiskNextCollection(path, environ, obj, user_id, account)
|
|
||||||
else:
|
|
||||||
return DiskNextFile(path, environ, obj, user_id, account)
|
|
||||||
|
|
||||||
def is_readonly(self) -> bool:
|
|
||||||
"""只读由账户级别控制,不在 provider 级别限制"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Collection(目录) ====================
|
|
||||||
|
|
||||||
class DiskNextCollection(DAVCollection):
|
|
||||||
"""DiskNext 目录资源"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
environ: dict[str, object],
|
|
||||||
obj: Object,
|
|
||||||
user_id: UUID,
|
|
||||||
account: WebDAV,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(path, environ)
|
|
||||||
self._obj = obj
|
|
||||||
self._user_id = user_id
|
|
||||||
self._account = account
|
|
||||||
|
|
||||||
def get_display_info(self) -> dict[str, str]:
|
|
||||||
return {"type": "Directory"}
|
|
||||||
|
|
||||||
def get_member_names(self) -> list[str]:
|
|
||||||
"""获取子对象名称列表"""
|
|
||||||
children = _run_async(_get_children(self._user_id, self._obj.id))
|
|
||||||
return [c.name for c in children]
|
|
||||||
|
|
||||||
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
|
|
||||||
"""获取指定名称的子资源"""
|
|
||||||
member_path = self.path.rstrip("/") + "/" + name
|
|
||||||
account_root = self._account.root
|
|
||||||
disknext_path = _resolve_dav_path(account_root, member_path)
|
|
||||||
|
|
||||||
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
|
|
||||||
if not obj:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if obj.is_folder:
|
|
||||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
|
||||||
else:
|
|
||||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
|
||||||
|
|
||||||
def get_creation_date(self) -> float | None:
|
|
||||||
if self._obj.created_at:
|
|
||||||
return self._obj.created_at.timestamp()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_last_modified(self) -> float | None:
|
|
||||||
if self._obj.updated_at:
|
|
||||||
return self._obj.updated_at.timestamp()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def create_empty_resource(self, name: str) -> 'DiskNextFile':
|
|
||||||
"""创建空文件(PUT 操作的第一步)"""
|
|
||||||
_check_readonly(self.environ)
|
|
||||||
|
|
||||||
obj = _run_async(_create_file(
|
|
||||||
name=name,
|
|
||||||
parent_id=self._obj.id,
|
|
||||||
owner_id=self._user_id,
|
|
||||||
policy_id=self._obj.policy_id,
|
|
||||||
))
|
|
||||||
|
|
||||||
member_path = self.path.rstrip("/") + "/" + name
|
|
||||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
|
||||||
|
|
||||||
def create_collection(self, name: str) -> 'DiskNextCollection':
|
|
||||||
"""创建子目录(MKCOL)"""
|
|
||||||
_check_readonly(self.environ)
|
|
||||||
|
|
||||||
obj = _run_async(_create_folder(
|
|
||||||
name=name,
|
|
||||||
parent_id=self._obj.id,
|
|
||||||
owner_id=self._user_id,
|
|
||||||
policy_id=self._obj.policy_id,
|
|
||||||
))
|
|
||||||
|
|
||||||
member_path = self.path.rstrip("/") + "/" + name
|
|
||||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
"""软删除目录"""
|
|
||||||
_check_readonly(self.environ)
|
|
||||||
_run_async(_soft_delete_object(self._obj.id))
|
|
||||||
|
|
||||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
|
||||||
"""复制或移动目录"""
|
|
||||||
_check_readonly(self.environ)
|
|
||||||
|
|
||||||
account_root = self._account.root
|
|
||||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
|
||||||
|
|
||||||
# 解析目标父路径和新名称
|
|
||||||
if "/" in dest_disknext.rstrip("/"):
|
|
||||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
|
||||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
|
||||||
else:
|
|
||||||
parent_path = "/"
|
|
||||||
new_name = dest_disknext.lstrip("/")
|
|
||||||
|
|
||||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
|
||||||
if not dest_parent:
|
|
||||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
|
||||||
|
|
||||||
if is_move:
|
|
||||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
|
||||||
else:
|
|
||||||
_run_async(_copy_object_recursive(
|
|
||||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
|
||||||
))
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def support_recursive_delete(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def support_recursive_move(self, dest_path: str) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== NonCollection(文件) ====================
|
|
||||||
|
|
||||||
class DiskNextFile(DAVNonCollection):
|
|
||||||
"""DiskNext 文件资源"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
environ: dict[str, object],
|
|
||||||
obj: Object,
|
|
||||||
user_id: UUID,
|
|
||||||
account: WebDAV,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(path, environ)
|
|
||||||
self._obj = obj
|
|
||||||
self._user_id = user_id
|
|
||||||
self._account = account
|
|
||||||
self._write_path: str | None = None
|
|
||||||
self._write_stream: io.BufferedWriter | QuotaLimitedWriter | None = None
|
|
||||||
|
|
||||||
def get_content_length(self) -> int | None:
|
|
||||||
return self._obj.size if self._obj.size else 0
|
|
||||||
|
|
||||||
def get_content_type(self) -> str | None:
|
|
||||||
# 尝试从文件名推断 MIME 类型
|
|
||||||
mime, _ = mimetypes.guess_type(self._obj.name)
|
|
||||||
return mime or "application/octet-stream"
|
|
||||||
|
|
||||||
def get_creation_date(self) -> float | None:
|
|
||||||
if self._obj.created_at:
|
|
||||||
return self._obj.created_at.timestamp()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_last_modified(self) -> float | None:
|
|
||||||
if self._obj.updated_at:
|
|
||||||
return self._obj.updated_at.timestamp()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_display_info(self) -> dict[str, str]:
|
|
||||||
return {"type": "File"}
|
|
||||||
|
|
||||||
def get_content(self) -> io.BufferedReader | None:
|
|
||||||
"""
|
|
||||||
返回文件内容的可读流。
|
|
||||||
|
|
||||||
WsgiDAV 在线程中运行,可安全使用同步 open()。
|
|
||||||
"""
|
|
||||||
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
|
|
||||||
if not obj_with_file or not obj_with_file.physical_file:
|
|
||||||
return None
|
|
||||||
|
|
||||||
pf = obj_with_file.physical_file
|
|
||||||
policy = _run_async(_get_policy(obj_with_file.policy_id))
|
|
||||||
if not policy or not policy.server:
|
|
||||||
return None
|
|
||||||
|
|
||||||
full_path = Path(policy.server).resolve() / pf.storage_path
|
|
||||||
if not full_path.is_file():
|
|
||||||
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return open(full_path, "rb") # noqa: SIM115
|
|
||||||
|
|
||||||
def begin_write(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
content_type: str | None = None,
|
|
||||||
) -> io.BufferedWriter | QuotaLimitedWriter:
|
|
||||||
"""
|
|
||||||
开始写入文件(PUT 操作)。
|
|
||||||
|
|
||||||
返回一个可写的文件流,WsgiDAV 将向其中写入请求体数据。
|
|
||||||
当用户有配额限制时,返回 QuotaLimitedWriter 在写入过程中实时检查配额。
|
|
||||||
"""
|
|
||||||
_check_readonly(self.environ)
|
|
||||||
|
|
||||||
# 检查配额
|
|
||||||
remaining_quota: int = 0
|
|
||||||
user = _run_async(_get_user(self._user_id))
|
|
||||||
if user:
|
|
||||||
max_storage = user.group.max_storage
|
|
||||||
if max_storage > 0:
|
|
||||||
remaining_quota = max_storage - user.storage
|
|
||||||
if remaining_quota <= 0:
|
|
||||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
|
||||||
# Content-Length 预检(如果有的话)
|
|
||||||
content_length = self.environ.get("CONTENT_LENGTH")
|
|
||||||
if content_length and int(content_length) > remaining_quota:
|
|
||||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
|
||||||
|
|
||||||
# 获取策略以确定存储路径
|
|
||||||
policy = _run_async(_get_policy(self._obj.policy_id))
|
|
||||||
if not policy or not policy.server:
|
|
||||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
|
||||||
|
|
||||||
storage_service = LocalStorageService(policy)
|
|
||||||
dir_path, storage_name, full_path = _run_async(
|
|
||||||
storage_service.generate_file_path(
|
|
||||||
user_id=self._user_id,
|
|
||||||
original_filename=self._obj.name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._write_path = full_path
|
|
||||||
raw_stream = open(full_path, "wb") # noqa: SIM115
|
|
||||||
|
|
||||||
# 有配额限制时使用包装流,实时检查写入量
|
|
||||||
if remaining_quota > 0:
|
|
||||||
self._write_stream = QuotaLimitedWriter(raw_stream, remaining_quota)
|
|
||||||
else:
|
|
||||||
self._write_stream = raw_stream
|
|
||||||
|
|
||||||
return self._write_stream
|
|
||||||
|
|
||||||
def end_write(self, *, with_errors: bool) -> None:
|
|
||||||
"""写入完成后的收尾工作"""
|
|
||||||
if self._write_stream:
|
|
||||||
self._write_stream.close()
|
|
||||||
self._write_stream = None
|
|
||||||
|
|
||||||
if with_errors:
|
|
||||||
if self._write_path:
|
|
||||||
file_path = Path(self._write_path)
|
|
||||||
if file_path.exists():
|
|
||||||
file_path.unlink()
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self._write_path:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取文件大小
|
|
||||||
file_path = Path(self._write_path)
|
|
||||||
if not file_path.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
size = file_path.stat().st_size
|
|
||||||
|
|
||||||
# 更新数据库记录
|
|
||||||
_run_async(_finalize_upload(
|
|
||||||
object_id=self._obj.id,
|
|
||||||
physical_path=self._write_path,
|
|
||||||
size=size,
|
|
||||||
owner_id=self._user_id,
|
|
||||||
policy_id=self._obj.policy_id,
|
|
||||||
))
|
|
||||||
|
|
||||||
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
"""软删除文件"""
|
|
||||||
_check_readonly(self.environ)
|
|
||||||
_run_async(_soft_delete_object(self._obj.id))
|
|
||||||
|
|
||||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
|
||||||
"""复制或移动文件"""
|
|
||||||
_check_readonly(self.environ)
|
|
||||||
|
|
||||||
account_root = self._account.root
|
|
||||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
|
||||||
|
|
||||||
# 解析目标父路径和新名称
|
|
||||||
if "/" in dest_disknext.rstrip("/"):
|
|
||||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
|
||||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
|
||||||
else:
|
|
||||||
parent_path = "/"
|
|
||||||
new_name = dest_disknext.lstrip("/")
|
|
||||||
|
|
||||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
|
||||||
if not dest_parent:
|
|
||||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
|
||||||
|
|
||||||
if is_move:
|
|
||||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
|
||||||
else:
|
|
||||||
_run_async(_copy_object_recursive(
|
|
||||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
|
||||||
))
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def support_content_length(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_etag(self) -> str | None:
|
|
||||||
"""返回 ETag(基于ID和更新时间),WsgiDAV 会自动加双引号"""
|
|
||||||
if self._obj.updated_at:
|
|
||||||
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
|
|
||||||
return None
|
|
||||||
|
|
||||||
def support_etag(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def support_ranges(self) -> bool:
|
|
||||||
return True
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
"""
|
|
||||||
WebDAV 认证缓存
|
|
||||||
|
|
||||||
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
|
|
||||||
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
|
||||||
"""
|
|
||||||
import hashlib
|
|
||||||
from typing import ClassVar
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from cachetools import TTLCache
|
|
||||||
from loguru import logger as l
|
|
||||||
|
|
||||||
from . import RedisManager
|
|
||||||
|
|
||||||
_AUTH_TTL: int = 300
|
|
||||||
"""认证缓存 TTL(秒),5 分钟"""
|
|
||||||
|
|
||||||
|
|
||||||
class WebDAVAuthCache:
|
|
||||||
"""
|
|
||||||
WebDAV 认证结果缓存
|
|
||||||
|
|
||||||
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
|
|
||||||
缓存值格式: {user_id}:{webdav_id}
|
|
||||||
|
|
||||||
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
|
|
||||||
"""
|
|
||||||
|
|
||||||
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
|
|
||||||
"""内存缓存降级方案"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _build_key(cls, email: str, account_name: str, password: str) -> str:
|
|
||||||
"""构建缓存键"""
|
|
||||||
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
|
|
||||||
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get(
|
|
||||||
cls,
|
|
||||||
email: str,
|
|
||||||
account_name: str,
|
|
||||||
password: str,
|
|
||||||
) -> tuple[UUID, int] | None:
|
|
||||||
"""
|
|
||||||
查询缓存中的认证结果。
|
|
||||||
|
|
||||||
:param email: 用户邮箱
|
|
||||||
:param account_name: WebDAV 账户名
|
|
||||||
:param password: 用户提供的明文密码
|
|
||||||
:return: (user_id, webdav_id) 或 None(缓存未命中)
|
|
||||||
"""
|
|
||||||
key = cls._build_key(email, account_name, password)
|
|
||||||
|
|
||||||
client = RedisManager.get_client()
|
|
||||||
if client is not None:
|
|
||||||
value = await client.get(key)
|
|
||||||
if value is not None:
|
|
||||||
raw = value.decode() if isinstance(value, bytes) else value
|
|
||||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
|
||||||
return UUID(user_id_str), int(webdav_id_str)
|
|
||||||
else:
|
|
||||||
raw = cls._memory_cache.get(key)
|
|
||||||
if raw is not None:
|
|
||||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
|
||||||
return UUID(user_id_str), int(webdav_id_str)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def set(
|
|
||||||
cls,
|
|
||||||
email: str,
|
|
||||||
account_name: str,
|
|
||||||
password: str,
|
|
||||||
user_id: UUID,
|
|
||||||
webdav_id: int,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
写入认证结果到缓存。
|
|
||||||
|
|
||||||
:param email: 用户邮箱
|
|
||||||
:param account_name: WebDAV 账户名
|
|
||||||
:param password: 用户提供的明文密码
|
|
||||||
:param user_id: 用户UUID
|
|
||||||
:param webdav_id: WebDAV 账户ID
|
|
||||||
"""
|
|
||||||
key = cls._build_key(email, account_name, password)
|
|
||||||
value = f"{user_id}:{webdav_id}"
|
|
||||||
|
|
||||||
client = RedisManager.get_client()
|
|
||||||
if client is not None:
|
|
||||||
await client.set(key, value, ex=_AUTH_TTL)
|
|
||||||
else:
|
|
||||||
cls._memory_cache[key] = value
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
|
|
||||||
"""
|
|
||||||
失效指定账户的所有缓存。
|
|
||||||
|
|
||||||
由于缓存键包含 password hash,无法精确删除,
|
|
||||||
Redis 端使用 pattern scan 删除,内存端清空全部。
|
|
||||||
|
|
||||||
:param user_id: 用户UUID
|
|
||||||
:param account_name: WebDAV 账户名
|
|
||||||
"""
|
|
||||||
client = RedisManager.get_client()
|
|
||||||
if client is not None:
|
|
||||||
pattern = f"webdav_auth:*/{account_name}:*"
|
|
||||||
cursor: int = 0
|
|
||||||
while True:
|
|
||||||
cursor, keys = await client.scan(cursor, match=pattern, count=100)
|
|
||||||
if keys:
|
|
||||||
await client.delete(*keys)
|
|
||||||
if cursor == 0:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
|
|
||||||
keys_to_delete = [
|
|
||||||
k for k in cls._memory_cache
|
|
||||||
if f"/{account_name}:" in k
|
|
||||||
]
|
|
||||||
for k in keys_to_delete:
|
|
||||||
cls._memory_cache.pop(k, None)
|
|
||||||
|
|
||||||
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")
|
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
提供文件存储相关的服务,包括:
|
提供文件存储相关的服务,包括:
|
||||||
- 本地存储服务
|
- 本地存储服务
|
||||||
- S3 存储服务
|
|
||||||
- 命名规则解析器
|
- 命名规则解析器
|
||||||
- 存储异常定义
|
- 存储异常定义
|
||||||
"""
|
"""
|
||||||
@@ -12,8 +11,6 @@ from .exceptions import (
|
|||||||
FileReadError,
|
FileReadError,
|
||||||
FileWriteError,
|
FileWriteError,
|
||||||
InvalidPathError,
|
InvalidPathError,
|
||||||
S3APIError,
|
|
||||||
S3MultipartUploadError,
|
|
||||||
StorageException,
|
StorageException,
|
||||||
StorageFileNotFoundError,
|
StorageFileNotFoundError,
|
||||||
UploadSessionExpiredError,
|
UploadSessionExpiredError,
|
||||||
@@ -29,5 +26,3 @@ from .object import (
|
|||||||
restore_objects,
|
restore_objects,
|
||||||
soft_delete_objects,
|
soft_delete_objects,
|
||||||
)
|
)
|
||||||
from .migrate import migrate_file_with_task, migrate_directory_files
|
|
||||||
from .s3_storage import S3StorageService
|
|
||||||
@@ -43,13 +43,3 @@ class UploadSessionExpiredError(StorageException):
|
|||||||
class InvalidPathError(StorageException):
|
class InvalidPathError(StorageException):
|
||||||
"""无效的路径"""
|
"""无效的路径"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class S3APIError(StorageException):
|
|
||||||
"""S3 API 请求错误"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class S3MultipartUploadError(S3APIError):
|
|
||||||
"""S3 分片上传错误"""
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -263,49 +263,15 @@ class LocalStorageService:
|
|||||||
"""
|
"""
|
||||||
删除文件(物理删除)
|
删除文件(物理删除)
|
||||||
|
|
||||||
删除文件后会尝试清理因此变空的父目录。
|
|
||||||
|
|
||||||
:param path: 完整文件路径
|
:param path: 完整文件路径
|
||||||
"""
|
"""
|
||||||
if await self.file_exists(path):
|
if await self.file_exists(path):
|
||||||
try:
|
try:
|
||||||
await aiofiles.os.remove(path)
|
await aiofiles.os.remove(path)
|
||||||
l.debug(f"已删除文件: {path}")
|
l.debug(f"已删除文件: {path}")
|
||||||
await self._cleanup_empty_parents(path)
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
l.warning(f"删除文件失败 {path}: {e}")
|
l.warning(f"删除文件失败 {path}: {e}")
|
||||||
|
|
||||||
async def _cleanup_empty_parents(self, file_path: str) -> None:
|
|
||||||
"""
|
|
||||||
从被删文件的父目录开始,向上逐级删除空目录
|
|
||||||
|
|
||||||
在以下情况停止:
|
|
||||||
|
|
||||||
- 到达存储根目录(_base_path)
|
|
||||||
- 遇到非空目录
|
|
||||||
- 遇到 .trash 目录
|
|
||||||
- 删除失败(权限、并发等)
|
|
||||||
|
|
||||||
:param file_path: 被删文件的完整路径
|
|
||||||
"""
|
|
||||||
current = Path(file_path).parent
|
|
||||||
|
|
||||||
while current != self._base_path and str(current).startswith(str(self._base_path)):
|
|
||||||
if current.name == '.trash':
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
entries = await aiofiles.os.listdir(str(current))
|
|
||||||
if entries:
|
|
||||||
break
|
|
||||||
|
|
||||||
await aiofiles.os.rmdir(str(current))
|
|
||||||
l.debug(f"已清理空目录: {current}")
|
|
||||||
current = current.parent
|
|
||||||
except OSError as e:
|
|
||||||
l.debug(f"清理空目录失败(忽略): {current}: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
async def move_to_trash(
|
async def move_to_trash(
|
||||||
self,
|
self,
|
||||||
source_path: str,
|
source_path: str,
|
||||||
@@ -338,7 +304,6 @@ class LocalStorageService:
|
|||||||
try:
|
try:
|
||||||
await aiofiles.os.rename(source_path, str(trash_path))
|
await aiofiles.os.rename(source_path, str(trash_path))
|
||||||
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
|
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
|
||||||
await self._cleanup_empty_parents(source_path)
|
|
||||||
return str(trash_path)
|
return str(trash_path)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise StorageException(f"移动文件到回收站失败: {e}")
|
raise StorageException(f"移动文件到回收站失败: {e}")
|
||||||
|
|||||||
@@ -1,291 +0,0 @@
|
|||||||
"""
|
|
||||||
存储策略迁移服务
|
|
||||||
|
|
||||||
提供跨存储策略的文件迁移功能:
|
|
||||||
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
|
|
||||||
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
|
|
||||||
"""
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from loguru import logger as l
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from sqlmodels.object import Object, ObjectType
|
|
||||||
from sqlmodels.physical_file import PhysicalFile
|
|
||||||
from sqlmodels.policy import Policy, PolicyType
|
|
||||||
from sqlmodels.task import Task, TaskStatus
|
|
||||||
|
|
||||||
from .local_storage import LocalStorageService
|
|
||||||
from .s3_storage import S3StorageService
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_storage_service(
|
|
||||||
policy: Policy,
|
|
||||||
) -> LocalStorageService | S3StorageService:
|
|
||||||
"""
|
|
||||||
根据策略类型创建对应的存储服务实例
|
|
||||||
|
|
||||||
:param policy: 存储策略
|
|
||||||
:return: 存储服务实例
|
|
||||||
"""
|
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
return LocalStorageService(policy)
|
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
return await S3StorageService.from_policy(policy)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的存储策略类型: {policy.type}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _read_file_from_storage(
|
|
||||||
service: LocalStorageService | S3StorageService,
|
|
||||||
storage_path: str,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
从存储服务读取文件内容
|
|
||||||
|
|
||||||
:param service: 存储服务实例
|
|
||||||
:param storage_path: 文件存储路径
|
|
||||||
:return: 文件二进制内容
|
|
||||||
"""
|
|
||||||
if isinstance(service, LocalStorageService):
|
|
||||||
return await service.read_file(storage_path)
|
|
||||||
else:
|
|
||||||
return await service.download_file(storage_path)
|
|
||||||
|
|
||||||
|
|
||||||
async def _write_file_to_storage(
|
|
||||||
service: LocalStorageService | S3StorageService,
|
|
||||||
storage_path: str,
|
|
||||||
data: bytes,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
将文件内容写入存储服务
|
|
||||||
|
|
||||||
:param service: 存储服务实例
|
|
||||||
:param storage_path: 文件存储路径
|
|
||||||
:param data: 文件二进制内容
|
|
||||||
"""
|
|
||||||
if isinstance(service, LocalStorageService):
|
|
||||||
await service.write_file(storage_path, data)
|
|
||||||
else:
|
|
||||||
await service.upload_file(storage_path, data)
|
|
||||||
|
|
||||||
|
|
||||||
async def _delete_file_from_storage(
|
|
||||||
service: LocalStorageService | S3StorageService,
|
|
||||||
storage_path: str,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
从存储服务删除文件
|
|
||||||
|
|
||||||
:param service: 存储服务实例
|
|
||||||
:param storage_path: 文件存储路径
|
|
||||||
"""
|
|
||||||
if isinstance(service, LocalStorageService):
|
|
||||||
await service.delete_file(storage_path)
|
|
||||||
else:
|
|
||||||
await service.delete_file(storage_path)
|
|
||||||
|
|
||||||
|
|
||||||
async def migrate_single_file(
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: Object,
|
|
||||||
dest_policy: Policy,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
将单个文件对象从当前存储策略迁移到目标策略
|
|
||||||
|
|
||||||
流程:
|
|
||||||
1. 获取源物理文件和存储服务
|
|
||||||
2. 读取源文件内容
|
|
||||||
3. 在目标存储中生成新路径并写入
|
|
||||||
4. 创建新的 PhysicalFile 记录
|
|
||||||
5. 更新 Object 的 policy_id 和 physical_file_id
|
|
||||||
6. 旧 PhysicalFile 引用计数 -1,如为 0 则删除源物理文件
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param obj: 待迁移的文件对象(必须为文件类型)
|
|
||||||
:param dest_policy: 目标存储策略
|
|
||||||
"""
|
|
||||||
if obj.type != ObjectType.FILE:
|
|
||||||
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
|
|
||||||
|
|
||||||
# 获取源策略和物理文件
|
|
||||||
src_policy: Policy = await obj.awaitable_attrs.policy
|
|
||||||
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
|
|
||||||
|
|
||||||
if not old_physical:
|
|
||||||
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
|
|
||||||
return
|
|
||||||
|
|
||||||
if src_policy.id == dest_policy.id:
|
|
||||||
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 1. 从源存储读取文件
|
|
||||||
src_service = await _get_storage_service(src_policy)
|
|
||||||
data = await _read_file_from_storage(src_service, old_physical.storage_path)
|
|
||||||
|
|
||||||
# 2. 在目标存储生成新路径并写入
|
|
||||||
dest_service = await _get_storage_service(dest_policy)
|
|
||||||
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
|
|
||||||
user_id=obj.owner_id,
|
|
||||||
original_filename=obj.name,
|
|
||||||
)
|
|
||||||
await _write_file_to_storage(dest_service, new_storage_path, data)
|
|
||||||
|
|
||||||
# 3. 创建新的 PhysicalFile
|
|
||||||
new_physical = PhysicalFile(
|
|
||||||
storage_path=new_storage_path,
|
|
||||||
size=old_physical.size,
|
|
||||||
checksum_md5=old_physical.checksum_md5,
|
|
||||||
policy_id=dest_policy.id,
|
|
||||||
reference_count=1,
|
|
||||||
)
|
|
||||||
new_physical = await new_physical.save(session)
|
|
||||||
|
|
||||||
# 4. 更新 Object
|
|
||||||
obj.policy_id = dest_policy.id
|
|
||||||
obj.physical_file_id = new_physical.id
|
|
||||||
obj = await obj.save(session)
|
|
||||||
|
|
||||||
# 5. 旧 PhysicalFile 引用计数 -1
|
|
||||||
old_physical.decrement_reference()
|
|
||||||
if old_physical.can_be_deleted:
|
|
||||||
# 删除源存储中的物理文件
|
|
||||||
try:
|
|
||||||
await _delete_file_from_storage(src_service, old_physical.storage_path)
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
|
|
||||||
await PhysicalFile.delete(session, old_physical)
|
|
||||||
else:
|
|
||||||
old_physical = await old_physical.save(session)
|
|
||||||
|
|
||||||
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name} → {dest_policy.name}")
|
|
||||||
|
|
||||||
|
|
||||||
async def migrate_file_with_task(
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: Object,
|
|
||||||
dest_policy: Policy,
|
|
||||||
task: Task,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
迁移单个文件并更新任务状态
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param obj: 待迁移的文件对象
|
|
||||||
:param dest_policy: 目标存储策略
|
|
||||||
:param task: 关联的任务记录
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
task.status = TaskStatus.RUNNING
|
|
||||||
task.progress = 0
|
|
||||||
task = await task.save(session)
|
|
||||||
|
|
||||||
await migrate_single_file(session, obj, dest_policy)
|
|
||||||
|
|
||||||
task.status = TaskStatus.COMPLETED
|
|
||||||
task.progress = 100
|
|
||||||
task = await task.save(session)
|
|
||||||
except Exception as e:
|
|
||||||
l.error(f"文件迁移任务失败: {obj.id}: {e}")
|
|
||||||
task.status = TaskStatus.ERROR
|
|
||||||
task.error = str(e)[:500]
|
|
||||||
task = await task.save(session)
|
|
||||||
|
|
||||||
|
|
||||||
async def migrate_directory_files(
|
|
||||||
session: AsyncSession,
|
|
||||||
folder: Object,
|
|
||||||
dest_policy: Policy,
|
|
||||||
task: Task,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
迁移目录下所有文件到目标存储策略
|
|
||||||
|
|
||||||
递归遍历目录树,将所有文件迁移到目标策略。
|
|
||||||
子目录的 policy_id 同步更新。
|
|
||||||
任务进度按文件数比例更新。
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param folder: 目录对象
|
|
||||||
:param dest_policy: 目标存储策略
|
|
||||||
:param task: 关联的任务记录
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
task.status = TaskStatus.RUNNING
|
|
||||||
task.progress = 0
|
|
||||||
task = await task.save(session)
|
|
||||||
|
|
||||||
# 收集所有需要迁移的文件
|
|
||||||
files_to_migrate: list[Object] = []
|
|
||||||
folders_to_update: list[Object] = []
|
|
||||||
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
|
|
||||||
|
|
||||||
total = len(files_to_migrate)
|
|
||||||
migrated = 0
|
|
||||||
errors: list[str] = []
|
|
||||||
|
|
||||||
for file_obj in files_to_migrate:
|
|
||||||
try:
|
|
||||||
await migrate_single_file(session, file_obj, dest_policy)
|
|
||||||
migrated += 1
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"{file_obj.name}: {e}"
|
|
||||||
l.error(f"迁移文件失败: {error_msg}")
|
|
||||||
errors.append(error_msg)
|
|
||||||
|
|
||||||
# 更新进度
|
|
||||||
if total > 0:
|
|
||||||
task.progress = min(99, int(migrated / total * 100))
|
|
||||||
task = await task.save(session)
|
|
||||||
|
|
||||||
# 更新所有子目录的 policy_id
|
|
||||||
for sub_folder in folders_to_update:
|
|
||||||
sub_folder.policy_id = dest_policy.id
|
|
||||||
sub_folder = await sub_folder.save(session)
|
|
||||||
|
|
||||||
# 完成任务
|
|
||||||
if errors:
|
|
||||||
task.status = TaskStatus.ERROR
|
|
||||||
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
|
|
||||||
else:
|
|
||||||
task.status = TaskStatus.COMPLETED
|
|
||||||
|
|
||||||
task.progress = 100
|
|
||||||
task = await task.save(session)
|
|
||||||
|
|
||||||
l.info(
|
|
||||||
f"目录迁移完成: {folder.name} ({folder.id}), "
|
|
||||||
f"成功 {migrated}/{total}, 错误 {len(errors)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
l.error(f"目录迁移任务失败: {folder.id}: {e}")
|
|
||||||
task.status = TaskStatus.ERROR
|
|
||||||
task.error = str(e)[:500]
|
|
||||||
task = await task.save(session)
|
|
||||||
|
|
||||||
|
|
||||||
async def _collect_objects_recursive(
|
|
||||||
session: AsyncSession,
|
|
||||||
folder: Object,
|
|
||||||
files: list[Object],
|
|
||||||
folders: list[Object],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
递归收集目录下所有文件和子目录
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param folder: 当前目录
|
|
||||||
:param files: 文件列表(输出)
|
|
||||||
:param folders: 子目录列表(输出)
|
|
||||||
"""
|
|
||||||
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
|
|
||||||
|
|
||||||
for child in children:
|
|
||||||
if child.type == ObjectType.FILE:
|
|
||||||
files.append(child)
|
|
||||||
elif child.type == ObjectType.FOLDER:
|
|
||||||
folders.append(child)
|
|
||||||
await _collect_objects_recursive(session, child, files, folders)
|
|
||||||
@@ -6,8 +6,7 @@ from sqlalchemy import update as sql_update
|
|||||||
from sqlalchemy.sql.functions import func
|
from sqlalchemy.sql.functions import func
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
|
|
||||||
from .local_storage import LocalStorageService
|
from service.storage import LocalStorageService
|
||||||
from .s3_storage import S3StorageService
|
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
Object,
|
Object,
|
||||||
PhysicalFile,
|
PhysicalFile,
|
||||||
@@ -272,14 +271,10 @@ async def permanently_delete_objects(
|
|||||||
if physical_file.can_be_deleted:
|
if physical_file.can_be_deleted:
|
||||||
# 物理删除文件
|
# 物理删除文件
|
||||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||||
if policy:
|
if policy and policy.type == PolicyType.LOCAL:
|
||||||
try:
|
try:
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
storage_service = LocalStorageService(policy)
|
storage_service = LocalStorageService(policy)
|
||||||
await storage_service.delete_file(physical_file.storage_path)
|
await storage_service.delete_file(physical_file.storage_path)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
s3_service = await S3StorageService.from_policy(policy)
|
|
||||||
await s3_service.delete_file(physical_file.storage_path)
|
|
||||||
l.debug(f"物理文件已删除: {obj_name}")
|
l.debug(f"物理文件已删除: {obj_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||||
@@ -287,7 +282,7 @@ async def permanently_delete_objects(
|
|||||||
await PhysicalFile.delete(session, physical_file, commit=False)
|
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||||
else:
|
else:
|
||||||
physical_file = await physical_file.save(session, commit=False)
|
await physical_file.save(session, commit=False)
|
||||||
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||||
|
|
||||||
# 更新用户存储配额
|
# 更新用户存储配额
|
||||||
@@ -379,19 +374,10 @@ async def delete_object_recursive(
|
|||||||
if physical_file.can_be_deleted:
|
if physical_file.can_be_deleted:
|
||||||
# 物理删除文件
|
# 物理删除文件
|
||||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||||
if policy:
|
if policy and policy.type == PolicyType.LOCAL:
|
||||||
try:
|
try:
|
||||||
if policy.type == PolicyType.LOCAL:
|
|
||||||
storage_service = LocalStorageService(policy)
|
storage_service = LocalStorageService(policy)
|
||||||
await storage_service.delete_file(physical_file.storage_path)
|
await storage_service.delete_file(physical_file.storage_path)
|
||||||
elif policy.type == PolicyType.S3:
|
|
||||||
options = await policy.awaitable_attrs.options
|
|
||||||
s3_service = S3StorageService(
|
|
||||||
policy,
|
|
||||||
region=options.s3_region if options else 'us-east-1',
|
|
||||||
is_path_style=options.s3_path_style if options else False,
|
|
||||||
)
|
|
||||||
await s3_service.delete_file(physical_file.storage_path)
|
|
||||||
l.debug(f"物理文件已删除: {obj_name}")
|
l.debug(f"物理文件已删除: {obj_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||||
@@ -399,7 +385,7 @@ async def delete_object_recursive(
|
|||||||
await PhysicalFile.delete(session, physical_file, commit=False)
|
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||||
else:
|
else:
|
||||||
physical_file = await physical_file.save(session, commit=False)
|
await physical_file.save(session, commit=False)
|
||||||
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||||
|
|
||||||
# 阶段三:更新用户存储配额(与删除在同一事务中)
|
# 阶段三:更新用户存储配额(与删除在同一事务中)
|
||||||
@@ -458,7 +444,7 @@ async def _copy_object_recursive(
|
|||||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
|
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
|
||||||
if physical_file:
|
if physical_file:
|
||||||
physical_file.increment_reference()
|
physical_file.increment_reference()
|
||||||
physical_file = await physical_file.save(session)
|
await physical_file.save(session)
|
||||||
total_copied_size += src_size
|
total_copied_size += src_size
|
||||||
|
|
||||||
new_obj = await new_obj.save(session)
|
new_obj = await new_obj.save(session)
|
||||||
|
|||||||
@@ -1,709 +0,0 @@
|
|||||||
"""
|
|
||||||
S3 存储服务
|
|
||||||
|
|
||||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
|
||||||
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
|
|
||||||
|
|
||||||
移植自 foxline-pro-backend-server 项目的 S3APIClient,
|
|
||||||
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
|
|
||||||
"""
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import ClassVar, Literal
|
|
||||||
from urllib.parse import quote, urlencode
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from yarl import URL
|
|
||||||
from loguru import logger as l
|
|
||||||
|
|
||||||
from sqlmodels.policy import Policy
|
|
||||||
from .exceptions import S3APIError, S3MultipartUploadError
|
|
||||||
from .naming_rule import NamingContext, NamingRuleParser
|
|
||||||
|
|
||||||
|
|
||||||
def _sign(key: bytes, msg: str) -> bytes:
|
|
||||||
"""HMAC-SHA256 签名"""
|
|
||||||
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
|
||||||
|
|
||||||
|
|
||||||
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
|
||||||
|
|
||||||
|
|
||||||
class S3StorageService:
|
|
||||||
"""
|
|
||||||
S3 存储服务
|
|
||||||
|
|
||||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
|
||||||
从 Policy 配置中读取 S3 连接信息。
|
|
||||||
|
|
||||||
使用示例::
|
|
||||||
|
|
||||||
service = S3StorageService(policy, region='us-east-1')
|
|
||||||
await service.upload_file('path/to/file.txt', b'content')
|
|
||||||
data = await service.download_file('path/to/file.txt')
|
|
||||||
"""
|
|
||||||
|
|
||||||
_http_session: ClassVar[aiohttp.ClientSession | None] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
policy: Policy,
|
|
||||||
region: str = 'us-east-1',
|
|
||||||
is_path_style: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
:param policy: 存储策略(server=endpoint_url, bucket_name, access_key, secret_key)
|
|
||||||
:param region: S3 区域
|
|
||||||
:param is_path_style: 是否使用路径风格 URL
|
|
||||||
"""
|
|
||||||
if not policy.server:
|
|
||||||
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
|
|
||||||
if not policy.bucket_name:
|
|
||||||
raise S3APIError("S3 策略必须指定 bucket_name")
|
|
||||||
if not policy.access_key:
|
|
||||||
raise S3APIError("S3 策略必须指定 access_key")
|
|
||||||
if not policy.secret_key:
|
|
||||||
raise S3APIError("S3 策略必须指定 secret_key")
|
|
||||||
|
|
||||||
self._policy = policy
|
|
||||||
self._endpoint_url = policy.server.rstrip("/")
|
|
||||||
self._bucket_name = policy.bucket_name
|
|
||||||
self._access_key = policy.access_key
|
|
||||||
self._secret_key = policy.secret_key
|
|
||||||
self._region = region
|
|
||||||
self._is_path_style = is_path_style
|
|
||||||
self._base_url = policy.base_url
|
|
||||||
|
|
||||||
# 从 endpoint_url 提取 host
|
|
||||||
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
|
|
||||||
|
|
||||||
# ==================== 工厂方法 ====================
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
|
|
||||||
"""
|
|
||||||
根据 Policy 异步创建 S3StorageService(自动加载 options)
|
|
||||||
|
|
||||||
:param policy: 存储策略
|
|
||||||
:return: S3StorageService 实例
|
|
||||||
"""
|
|
||||||
options = await policy.awaitable_attrs.options
|
|
||||||
region = options.s3_region if options else 'us-east-1'
|
|
||||||
is_path_style = options.s3_path_style if options else False
|
|
||||||
return cls(policy, region=region, is_path_style=is_path_style)
|
|
||||||
|
|
||||||
# ==================== HTTP Session 管理 ====================
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def initialize_session(cls) -> None:
|
|
||||||
"""初始化全局 aiohttp ClientSession"""
|
|
||||||
if cls._http_session is None or cls._http_session.closed:
|
|
||||||
cls._http_session = aiohttp.ClientSession()
|
|
||||||
l.info("S3StorageService HTTP session 已初始化")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def close_session(cls) -> None:
|
|
||||||
"""关闭全局 aiohttp ClientSession"""
|
|
||||||
if cls._http_session and not cls._http_session.closed:
|
|
||||||
await cls._http_session.close()
|
|
||||||
cls._http_session = None
|
|
||||||
l.info("S3StorageService HTTP session 已关闭")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_session(cls) -> aiohttp.ClientSession:
|
|
||||||
"""获取 HTTP session"""
|
|
||||||
if cls._http_session is None or cls._http_session.closed:
|
|
||||||
# 懒初始化,以防 initialize_session 未被调用
|
|
||||||
cls._http_session = aiohttp.ClientSession()
|
|
||||||
return cls._http_session
|
|
||||||
|
|
||||||
# ==================== AWS Signature V4 签名 ====================
|
|
||||||
|
|
||||||
def _get_signature_key(self, date_stamp: str) -> bytes:
|
|
||||||
"""生成 AWS Signature V4 签名密钥"""
|
|
||||||
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
|
|
||||||
k_region = _sign(k_date, self._region)
|
|
||||||
k_service = _sign(k_region, "s3")
|
|
||||||
return _sign(k_service, "aws4_request")
|
|
||||||
|
|
||||||
def _create_authorization_header(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
uri: str,
|
|
||||||
query_string: str,
|
|
||||||
headers: dict[str, str],
|
|
||||||
payload_hash: str,
|
|
||||||
amz_date: str,
|
|
||||||
date_stamp: str,
|
|
||||||
) -> str:
|
|
||||||
"""创建 AWS Signature V4 授权头"""
|
|
||||||
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
|
|
||||||
canonical_headers = "".join(
|
|
||||||
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
|
|
||||||
)
|
|
||||||
canonical_request = (
|
|
||||||
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
|
|
||||||
f"{signed_headers}\n{payload_hash}"
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm = "AWS4-HMAC-SHA256"
|
|
||||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
|
||||||
string_to_sign = (
|
|
||||||
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
|
|
||||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
signing_key = self._get_signature_key(date_stamp)
|
|
||||||
signature = hmac.new(
|
|
||||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
|
|
||||||
f"SignedHeaders={signed_headers}, Signature={signature}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_headers(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
uri: str,
|
|
||||||
query_string: str = "",
|
|
||||||
payload: bytes = b"",
|
|
||||||
content_type: str | None = None,
|
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
payload_hash: str | None = None,
|
|
||||||
host: str | None = None,
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
构建包含 AWS V4 签名的完整请求头
|
|
||||||
|
|
||||||
:param method: HTTP 方法
|
|
||||||
:param uri: 请求 URI
|
|
||||||
:param query_string: 查询字符串
|
|
||||||
:param payload: 请求体字节(用于计算哈希)
|
|
||||||
:param content_type: Content-Type
|
|
||||||
:param extra_headers: 额外请求头
|
|
||||||
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
|
|
||||||
:param host: Host 头(默认使用 self._host)
|
|
||||||
"""
|
|
||||||
now_utc = datetime.now(timezone.utc)
|
|
||||||
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
|
|
||||||
date_stamp = now_utc.strftime("%Y%m%d")
|
|
||||||
|
|
||||||
if payload_hash is None:
|
|
||||||
payload_hash = hashlib.sha256(payload).hexdigest()
|
|
||||||
|
|
||||||
effective_host = host or self._host
|
|
||||||
|
|
||||||
headers: dict[str, str] = {
|
|
||||||
"Host": effective_host,
|
|
||||||
"X-Amz-Date": amz_date,
|
|
||||||
"X-Amz-Content-Sha256": payload_hash,
|
|
||||||
}
|
|
||||||
if content_type:
|
|
||||||
headers["Content-Type"] = content_type
|
|
||||||
if extra_headers:
|
|
||||||
headers.update(extra_headers)
|
|
||||||
|
|
||||||
authorization = self._create_authorization_header(
|
|
||||||
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
|
|
||||||
)
|
|
||||||
headers["Authorization"] = authorization
|
|
||||||
return headers
|
|
||||||
|
|
||||||
# ==================== 内部请求方法 ====================
|
|
||||||
|
|
||||||
def _build_uri(self, key: str | None = None) -> str:
|
|
||||||
"""
|
|
||||||
构建请求 URI
|
|
||||||
|
|
||||||
按 AWS S3 Signature V4 规范对路径进行 URI 编码(S3 仅需一次)。
|
|
||||||
斜杠作为路径分隔符保留不编码。
|
|
||||||
"""
|
|
||||||
if self._is_path_style:
|
|
||||||
if key:
|
|
||||||
return f"/{self._bucket_name}/{quote(key, safe='/')}"
|
|
||||||
return f"/{self._bucket_name}"
|
|
||||||
else:
|
|
||||||
if key:
|
|
||||||
return f"/{quote(key, safe='/')}"
|
|
||||||
return "/"
|
|
||||||
|
|
||||||
def _build_url(self, uri: str, query_string: str = "") -> str:
|
|
||||||
"""构建完整请求 URL"""
|
|
||||||
if self._is_path_style:
|
|
||||||
base = self._endpoint_url
|
|
||||||
else:
|
|
||||||
# 虚拟主机风格:bucket.endpoint
|
|
||||||
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
|
|
||||||
base = f"{protocol}{self._bucket_name}.{self._host}"
|
|
||||||
|
|
||||||
url = f"{base}{uri}"
|
|
||||||
if query_string:
|
|
||||||
url = f"{url}?{query_string}"
|
|
||||||
return url
|
|
||||||
|
|
||||||
def _get_effective_host(self) -> str:
|
|
||||||
"""获取实际请求的 Host 头"""
|
|
||||||
if self._is_path_style:
|
|
||||||
return self._host
|
|
||||||
return f"{self._bucket_name}.{self._host}"
|
|
||||||
|
|
||||||
async def _request(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
key: str | None = None,
|
|
||||||
query_params: dict[str, str] | None = None,
|
|
||||||
payload: bytes = b"",
|
|
||||||
content_type: str | None = None,
|
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
) -> aiohttp.ClientResponse:
|
|
||||||
"""发送签名请求"""
|
|
||||||
uri = self._build_uri(key)
|
|
||||||
query_string = urlencode(sorted(query_params.items())) if query_params else ""
|
|
||||||
effective_host = self._get_effective_host()
|
|
||||||
|
|
||||||
headers = self._build_headers(
|
|
||||||
method, uri, query_string, payload, content_type,
|
|
||||||
extra_headers, host=effective_host,
|
|
||||||
)
|
|
||||||
|
|
||||||
url = self._build_url(uri, query_string)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self._get_session().request(
|
|
||||||
method, URL(url, encoded=True),
|
|
||||||
headers=headers, data=payload if payload else None,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
|
|
||||||
|
|
||||||
async def _request_streaming(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
key: str,
|
|
||||||
data_stream: AsyncIterator[bytes],
|
|
||||||
content_length: int,
|
|
||||||
content_type: str | None = None,
|
|
||||||
) -> aiohttp.ClientResponse:
|
|
||||||
"""
|
|
||||||
发送流式签名请求(大文件上传)
|
|
||||||
|
|
||||||
使用 UNSIGNED-PAYLOAD 作为 payload hash。
|
|
||||||
"""
|
|
||||||
uri = self._build_uri(key)
|
|
||||||
effective_host = self._get_effective_host()
|
|
||||||
|
|
||||||
headers = self._build_headers(
|
|
||||||
method,
|
|
||||||
uri,
|
|
||||||
query_string="",
|
|
||||||
content_type=content_type,
|
|
||||||
extra_headers={"Content-Length": str(content_length)},
|
|
||||||
payload_hash="UNSIGNED-PAYLOAD",
|
|
||||||
host=effective_host,
|
|
||||||
)
|
|
||||||
|
|
||||||
url = self._build_url(uri)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self._get_session().request(
|
|
||||||
method, URL(url, encoded=True),
|
|
||||||
headers=headers, data=data_stream,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
|
|
||||||
|
|
||||||
# ==================== 文件操作 ====================
|
|
||||||
|
|
||||||
async def upload_file(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
data: bytes,
|
|
||||||
content_type: str = 'application/octet-stream',
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
上传文件
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:param data: 文件内容
|
|
||||||
:param content_type: MIME 类型
|
|
||||||
"""
|
|
||||||
async with await self._request(
|
|
||||||
"PUT", key=key, payload=data, content_type=content_type,
|
|
||||||
) as response:
|
|
||||||
if response.status not in (200, 201):
|
|
||||||
body = await response.text()
|
|
||||||
raise S3APIError(
|
|
||||||
f"上传失败: {self._bucket_name}/{key}, "
|
|
||||||
f"状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
|
|
||||||
|
|
||||||
async def upload_file_streaming(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
data_stream: AsyncIterator[bytes],
|
|
||||||
content_length: int,
|
|
||||||
content_type: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
流式上传文件(大文件,避免全部加载到内存)
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:param data_stream: 异步字节流迭代器
|
|
||||||
:param content_length: 数据总长度(必须准确)
|
|
||||||
:param content_type: MIME 类型
|
|
||||||
"""
|
|
||||||
async with await self._request_streaming(
|
|
||||||
"PUT", key=key, data_stream=data_stream,
|
|
||||||
content_length=content_length, content_type=content_type,
|
|
||||||
) as response:
|
|
||||||
if response.status not in (200, 201):
|
|
||||||
body = await response.text()
|
|
||||||
raise S3APIError(
|
|
||||||
f"流式上传失败: {self._bucket_name}/{key}, "
|
|
||||||
f"状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
|
|
||||||
|
|
||||||
async def download_file(self, key: str) -> bytes:
|
|
||||||
"""
|
|
||||||
下载文件
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:return: 文件内容
|
|
||||||
"""
|
|
||||||
async with await self._request("GET", key=key) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
body = await response.text()
|
|
||||||
raise S3APIError(
|
|
||||||
f"下载失败: {self._bucket_name}/{key}, "
|
|
||||||
f"状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
data = await response.read()
|
|
||||||
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
|
|
||||||
return data
|
|
||||||
|
|
||||||
async def delete_file(self, key: str) -> None:
|
|
||||||
"""
|
|
||||||
删除文件
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
"""
|
|
||||||
async with await self._request("DELETE", key=key) as response:
|
|
||||||
if response.status in (200, 204):
|
|
||||||
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
|
|
||||||
else:
|
|
||||||
body = await response.text()
|
|
||||||
raise S3APIError(
|
|
||||||
f"删除失败: {self._bucket_name}/{key}, "
|
|
||||||
f"状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def file_exists(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
检查文件是否存在
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:return: 是否存在
|
|
||||||
"""
|
|
||||||
async with await self._request("HEAD", key=key) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return True
|
|
||||||
elif response.status == 404:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise S3APIError(
|
|
||||||
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_file_size(self, key: str) -> int:
|
|
||||||
"""
|
|
||||||
获取文件大小
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:return: 文件大小(字节)
|
|
||||||
"""
|
|
||||||
async with await self._request("HEAD", key=key) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise S3APIError(
|
|
||||||
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
|
||||||
)
|
|
||||||
return int(response.headers.get("Content-Length", 0))
|
|
||||||
|
|
||||||
# ==================== Multipart Upload ====================
|
|
||||||
|
|
||||||
async def create_multipart_upload(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
content_type: str = 'application/octet-stream',
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
创建分片上传任务
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:param content_type: MIME 类型
|
|
||||||
:return: Upload ID
|
|
||||||
"""
|
|
||||||
async with await self._request(
|
|
||||||
"POST",
|
|
||||||
key=key,
|
|
||||||
query_params={"uploads": ""},
|
|
||||||
content_type=content_type,
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
body = await response.text()
|
|
||||||
raise S3MultipartUploadError(
|
|
||||||
f"创建分片上传失败: {self._bucket_name}/{key}, "
|
|
||||||
f"状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
|
|
||||||
body = await response.text()
|
|
||||||
root = ET.fromstring(body)
|
|
||||||
|
|
||||||
# 查找 UploadId 元素(支持命名空间)
|
|
||||||
upload_id_elem = root.find("UploadId")
|
|
||||||
if upload_id_elem is None:
|
|
||||||
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
|
|
||||||
if upload_id_elem is None or not upload_id_elem.text:
|
|
||||||
raise S3MultipartUploadError(
|
|
||||||
f"创建分片上传响应中未找到 UploadId: {body}"
|
|
||||||
)
|
|
||||||
|
|
||||||
upload_id = upload_id_elem.text
|
|
||||||
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
|
|
||||||
return upload_id
|
|
||||||
|
|
||||||
async def upload_part(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
upload_id: str,
|
|
||||||
part_number: int,
|
|
||||||
data: bytes,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
上传单个分片
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:param upload_id: 分片上传 ID
|
|
||||||
:param part_number: 分片编号(从 1 开始)
|
|
||||||
:param data: 分片数据
|
|
||||||
:return: ETag
|
|
||||||
"""
|
|
||||||
async with await self._request(
|
|
||||||
"PUT",
|
|
||||||
key=key,
|
|
||||||
query_params={
|
|
||||||
"partNumber": str(part_number),
|
|
||||||
"uploadId": upload_id,
|
|
||||||
},
|
|
||||||
payload=data,
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
body = await response.text()
|
|
||||||
raise S3MultipartUploadError(
|
|
||||||
f"上传分片失败: {self._bucket_name}/{key}, "
|
|
||||||
f"part={part_number}, 状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
|
|
||||||
etag = response.headers.get("ETag", "").strip('"')
|
|
||||||
l.debug(
|
|
||||||
f"S3 分片上传成功: {self._bucket_name}/{key}, "
|
|
||||||
f"part={part_number}, etag={etag}"
|
|
||||||
)
|
|
||||||
return etag
|
|
||||||
|
|
||||||
async def complete_multipart_upload(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
upload_id: str,
|
|
||||||
parts: list[tuple[int, str]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
完成分片上传
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:param upload_id: 分片上传 ID
|
|
||||||
:param parts: 分片列表 [(part_number, etag)]
|
|
||||||
"""
|
|
||||||
# 按 part_number 排序
|
|
||||||
parts_sorted = sorted(parts, key=lambda p: p[0])
|
|
||||||
|
|
||||||
# 构建 CompleteMultipartUpload XML
|
|
||||||
xml_parts = ''.join(
|
|
||||||
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
|
|
||||||
for pn, etag in parts_sorted
|
|
||||||
)
|
|
||||||
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
|
|
||||||
payload_bytes = payload.encode('utf-8')
|
|
||||||
|
|
||||||
async with await self._request(
|
|
||||||
"POST",
|
|
||||||
key=key,
|
|
||||||
query_params={"uploadId": upload_id},
|
|
||||||
payload=payload_bytes,
|
|
||||||
content_type="application/xml",
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
body = await response.text()
|
|
||||||
raise S3MultipartUploadError(
|
|
||||||
f"完成分片上传失败: {self._bucket_name}/{key}, "
|
|
||||||
f"状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
l.info(
|
|
||||||
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
|
|
||||||
f"共 {len(parts)} 个分片"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
|
|
||||||
"""
|
|
||||||
取消分片上传
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:param upload_id: 分片上传 ID
|
|
||||||
"""
|
|
||||||
async with await self._request(
|
|
||||||
"DELETE",
|
|
||||||
key=key,
|
|
||||||
query_params={"uploadId": upload_id},
|
|
||||||
) as response:
|
|
||||||
if response.status in (200, 204):
|
|
||||||
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
|
|
||||||
else:
|
|
||||||
body = await response.text()
|
|
||||||
l.warning(
|
|
||||||
f"取消分片上传失败: {self._bucket_name}/{key}, "
|
|
||||||
f"状态: {response.status}, {body}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==================== 预签名 URL ====================
|
|
||||||
|
|
||||||
def generate_presigned_url(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
method: Literal['GET', 'PUT'] = 'GET',
|
|
||||||
expires_in: int = 3600,
|
|
||||||
filename: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
生成 S3 预签名 URL(AWS Signature V4 Query String)
|
|
||||||
|
|
||||||
:param key: S3 对象键
|
|
||||||
:param method: HTTP 方法(GET 下载,PUT 上传)
|
|
||||||
:param expires_in: URL 有效期(秒)
|
|
||||||
:param filename: 文件名(GET 请求时设置 Content-Disposition)
|
|
||||||
:return: 预签名 URL
|
|
||||||
"""
|
|
||||||
current_time = datetime.now(timezone.utc)
|
|
||||||
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
|
|
||||||
date_stamp = current_time.strftime("%Y%m%d")
|
|
||||||
|
|
||||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
|
||||||
credential = f"{self._access_key}/{credential_scope}"
|
|
||||||
|
|
||||||
uri = self._build_uri(key)
|
|
||||||
effective_host = self._get_effective_host()
|
|
||||||
|
|
||||||
query_params: dict[str, str] = {
|
|
||||||
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
|
|
||||||
'X-Amz-Credential': credential,
|
|
||||||
'X-Amz-Date': amz_date,
|
|
||||||
'X-Amz-Expires': str(expires_in),
|
|
||||||
'X-Amz-SignedHeaders': 'host',
|
|
||||||
}
|
|
||||||
|
|
||||||
# GET 请求时添加 Content-Disposition
|
|
||||||
if method == "GET" and filename:
|
|
||||||
encoded_filename = quote(filename, safe='')
|
|
||||||
query_params['response-content-disposition'] = (
|
|
||||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
|
||||||
)
|
|
||||||
|
|
||||||
canonical_query_string = "&".join(
|
|
||||||
f"{quote(k, safe='')}={quote(v, safe='')}"
|
|
||||||
for k, v in sorted(query_params.items())
|
|
||||||
)
|
|
||||||
|
|
||||||
canonical_headers = f"host:{effective_host}\n"
|
|
||||||
signed_headers = "host"
|
|
||||||
payload_hash = "UNSIGNED-PAYLOAD"
|
|
||||||
|
|
||||||
canonical_request = (
|
|
||||||
f"{method}\n"
|
|
||||||
f"{uri}\n"
|
|
||||||
f"{canonical_query_string}\n"
|
|
||||||
f"{canonical_headers}\n"
|
|
||||||
f"{signed_headers}\n"
|
|
||||||
f"{payload_hash}"
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm = "AWS4-HMAC-SHA256"
|
|
||||||
string_to_sign = (
|
|
||||||
f"{algorithm}\n"
|
|
||||||
f"{amz_date}\n"
|
|
||||||
f"{credential_scope}\n"
|
|
||||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
signing_key = self._get_signature_key(date_stamp)
|
|
||||||
signature = hmac.new(
|
|
||||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
base_url = self._build_url(uri)
|
|
||||||
return (
|
|
||||||
f"{base_url}?"
|
|
||||||
f"{canonical_query_string}&"
|
|
||||||
f"X-Amz-Signature={signature}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==================== 路径生成 ====================
|
|
||||||
|
|
||||||
async def generate_file_path(
|
|
||||||
self,
|
|
||||||
user_id: UUID,
|
|
||||||
original_filename: str,
|
|
||||||
) -> tuple[str, str, str]:
|
|
||||||
"""
|
|
||||||
根据命名规则生成 S3 文件存储路径
|
|
||||||
|
|
||||||
与 LocalStorageService.generate_file_path 接口一致。
|
|
||||||
|
|
||||||
:param user_id: 用户UUID
|
|
||||||
:param original_filename: 原始文件名
|
|
||||||
:return: (相对目录路径, 存储文件名, 完整存储路径)
|
|
||||||
"""
|
|
||||||
context = NamingContext(
|
|
||||||
user_id=user_id,
|
|
||||||
original_filename=original_filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 解析目录规则
|
|
||||||
dir_path = ""
|
|
||||||
if self._policy.dir_name_rule:
|
|
||||||
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
|
|
||||||
|
|
||||||
# 解析文件名规则
|
|
||||||
if self._policy.auto_rename and self._policy.file_name_rule:
|
|
||||||
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
|
|
||||||
# 确保有扩展名
|
|
||||||
if '.' in original_filename and '.' not in storage_name:
|
|
||||||
ext = original_filename.rsplit('.', 1)[1]
|
|
||||||
storage_name = f"{storage_name}.{ext}"
|
|
||||||
else:
|
|
||||||
storage_name = original_filename
|
|
||||||
|
|
||||||
# S3 不需要创建目录,直接拼接路径
|
|
||||||
if dir_path:
|
|
||||||
storage_path = f"{dir_path}/{storage_name}"
|
|
||||||
else:
|
|
||||||
storage_path = storage_name
|
|
||||||
|
|
||||||
return dir_path, storage_name, storage_path
|
|
||||||
@@ -3,14 +3,12 @@
|
|||||||
|
|
||||||
支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。
|
支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。
|
||||||
"""
|
"""
|
||||||
import hashlib
|
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from service.redis.token_store import TokenStore
|
|
||||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||||
from sqlmodels.group import GroupClaims, GroupOptions
|
from sqlmodels.group import GroupClaims, GroupOptions
|
||||||
from sqlmodels.object import Object, ObjectType
|
from sqlmodels.object import Object, ObjectType
|
||||||
@@ -192,7 +190,7 @@ async def _login_oauth(
|
|||||||
# 已绑定 → 更新 OAuth 信息并返回关联用户
|
# 已绑定 → 更新 OAuth 信息并返回关联用户
|
||||||
identity.display_name = nickname
|
identity.display_name = nickname
|
||||||
identity.avatar_url = avatar_url
|
identity.avatar_url = avatar_url
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||||
if not user:
|
if not user:
|
||||||
@@ -254,7 +252,7 @@ async def _auto_register_oauth_user(
|
|||||||
is_verified=True,
|
is_verified=True,
|
||||||
user_id=new_user_id,
|
user_id=new_user_id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
# 创建用户根目录
|
# 创建用户根目录
|
||||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||||
@@ -335,7 +333,7 @@ async def _login_passkey(
|
|||||||
|
|
||||||
# 更新签名计数
|
# 更新签名计数
|
||||||
authn.sign_count = verification.new_sign_count
|
authn.sign_count = verification.new_sign_count
|
||||||
authn = await authn.save(session)
|
await authn.save(session)
|
||||||
|
|
||||||
# 加载用户
|
# 加载用户
|
||||||
user: User = await User.get(session, User.id == authn.user_id, load=User.group)
|
user: User = await User.get(session, User.id == authn.user_id, load=User.group)
|
||||||
@@ -365,12 +363,6 @@ async def _login_magic_link(
|
|||||||
except BadSignature:
|
except BadSignature:
|
||||||
http_exceptions.raise_unauthorized("Magic Link 无效")
|
http_exceptions.raise_unauthorized("Magic Link 无效")
|
||||||
|
|
||||||
# 防重放:使用 token 哈希作为标识符
|
|
||||||
token_hash = hashlib.sha256(request.identifier.encode()).hexdigest()
|
|
||||||
is_first_use = await TokenStore.mark_used(f"magic_link:{token_hash}", ttl=600)
|
|
||||||
if not is_first_use:
|
|
||||||
http_exceptions.raise_unauthorized("Magic Link 已被使用")
|
|
||||||
|
|
||||||
# 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link)
|
# 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link)
|
||||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
session,
|
session,
|
||||||
@@ -392,7 +384,7 @@ async def _login_magic_link(
|
|||||||
# 标记邮箱已验证
|
# 标记邮箱已验证
|
||||||
if not identity.is_verified:
|
if not identity.is_verified:
|
||||||
identity.is_verified = True
|
identity.is_verified = True
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|||||||
@@ -1,185 +0,0 @@
|
|||||||
"""
|
|
||||||
WOPI Discovery 服务模块
|
|
||||||
|
|
||||||
解析 WOPI 服务端(Collabora / OnlyOffice 等)的 Discovery XML,
|
|
||||||
提取支持的文件扩展名及对应的编辑器 URL 模板。
|
|
||||||
|
|
||||||
参考:Cloudreve pkg/wopi/discovery.go 和 pkg/wopi/wopi.go
|
|
||||||
"""
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
|
||||||
|
|
||||||
from loguru import logger as l
|
|
||||||
|
|
||||||
# WOPI URL 模板中已知的查询参数占位符及其替换值
|
|
||||||
# 值为 None 表示删除该参数,非 None 表示替换为该值
|
|
||||||
# 参考 Cloudreve pkg/wopi/wopi.go queryPlaceholders
|
|
||||||
_WOPI_QUERY_PLACEHOLDERS: dict[str, str | None] = {
|
|
||||||
'BUSINESS_USER': None,
|
|
||||||
'DC_LLCC': 'lng',
|
|
||||||
'DISABLE_ASYNC': None,
|
|
||||||
'DISABLE_CHAT': None,
|
|
||||||
'EMBEDDED': 'true',
|
|
||||||
'FULLSCREEN': 'true',
|
|
||||||
'HOST_SESSION_ID': None,
|
|
||||||
'SESSION_CONTEXT': None,
|
|
||||||
'RECORDING': None,
|
|
||||||
'THEME_ID': 'darkmode',
|
|
||||||
'UI_LLCC': 'lng',
|
|
||||||
'VALIDATOR_TEST_CATEGORY': None,
|
|
||||||
}
|
|
||||||
|
|
||||||
_WOPI_SRC_PLACEHOLDER = 'WOPI_SOURCE'
|
|
||||||
|
|
||||||
|
|
||||||
def process_wopi_action_url(raw_urlsrc: str) -> str:
|
|
||||||
"""
|
|
||||||
将 WOPI Discovery 中的原始 urlsrc 转换为 DiskNext 可用的 URL 模板。
|
|
||||||
|
|
||||||
处理流程(参考 Cloudreve generateActionUrl):
|
|
||||||
1. 去除 ``<>`` 占位符标记
|
|
||||||
2. 解析查询参数,替换/删除已知占位符
|
|
||||||
3. ``WOPI_SOURCE`` → ``{wopi_src}``
|
|
||||||
|
|
||||||
注意:access_token 和 access_token_ttl 不放在 URL 中,
|
|
||||||
根据 WOPI 规范它们通过 POST 表单字段传递给编辑器。
|
|
||||||
|
|
||||||
:param raw_urlsrc: WOPI Discovery XML 中的 urlsrc 原始值
|
|
||||||
:return: 处理后的 URL 模板字符串,包含 {wopi_src} 占位符
|
|
||||||
"""
|
|
||||||
# 去除 <> 标记
|
|
||||||
cleaned = raw_urlsrc.replace('<', '').replace('>', '')
|
|
||||||
parsed = urlparse(cleaned)
|
|
||||||
raw_params = parse_qs(parsed.query, keep_blank_values=True)
|
|
||||||
|
|
||||||
new_params: list[tuple[str, str]] = []
|
|
||||||
is_src_replaced = False
|
|
||||||
|
|
||||||
for key, values in raw_params.items():
|
|
||||||
value = values[0] if values else ''
|
|
||||||
|
|
||||||
# WOPI_SOURCE 占位符 → {wopi_src}
|
|
||||||
if value == _WOPI_SRC_PLACEHOLDER:
|
|
||||||
new_params.append((key, '{wopi_src}'))
|
|
||||||
is_src_replaced = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 已知占位符
|
|
||||||
if value in _WOPI_QUERY_PLACEHOLDERS:
|
|
||||||
replacement = _WOPI_QUERY_PLACEHOLDERS[value]
|
|
||||||
if replacement is not None:
|
|
||||||
new_params.append((key, replacement))
|
|
||||||
# replacement 为 None 时删除该参数
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 其他参数保留原值
|
|
||||||
new_params.append((key, value))
|
|
||||||
|
|
||||||
# 如果没有找到 WOPI_SOURCE 占位符,手动添加 WOPISrc
|
|
||||||
if not is_src_replaced:
|
|
||||||
new_params.append(('WOPISrc', '{wopi_src}'))
|
|
||||||
|
|
||||||
# LibreOffice/Collabora 需要 lang 参数(避免重复添加)
|
|
||||||
existing_keys = {k for k, _ in new_params}
|
|
||||||
if 'lang' not in existing_keys:
|
|
||||||
new_params.append(('lang', 'lng'))
|
|
||||||
|
|
||||||
# 注意:access_token 和 access_token_ttl 不放在 URL 中
|
|
||||||
# 根据 WOPI 规范,它们通过 POST 表单字段传递给编辑器
|
|
||||||
|
|
||||||
# 重建 URL
|
|
||||||
new_query = urlencode(new_params, safe='{}')
|
|
||||||
result = urlunparse((
|
|
||||||
parsed.scheme,
|
|
||||||
parsed.netloc,
|
|
||||||
parsed.path,
|
|
||||||
parsed.params,
|
|
||||||
new_query,
|
|
||||||
'',
|
|
||||||
))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def parse_wopi_discovery_xml(xml_content: str) -> tuple[dict[str, str], list[str]]:
|
|
||||||
"""
|
|
||||||
解析 WOPI Discovery XML,提取扩展名到 URL 模板的映射。
|
|
||||||
|
|
||||||
XML 结构::
|
|
||||||
|
|
||||||
<wopi-discovery>
|
|
||||||
<net-zone name="external-https">
|
|
||||||
<app name="Writer" favIconUrl="...">
|
|
||||||
<action name="edit" ext="docx" urlsrc="https://..."/>
|
|
||||||
<action name="view" ext="docx" urlsrc="https://..."/>
|
|
||||||
</app>
|
|
||||||
</net-zone>
|
|
||||||
</wopi-discovery>
|
|
||||||
|
|
||||||
动作优先级:edit > embedview > view(参考 Cloudreve discovery.go)
|
|
||||||
|
|
||||||
:param xml_content: WOPI Discovery 端点返回的 XML 字符串
|
|
||||||
:return: (action_urls, app_names) 元组
|
|
||||||
action_urls: {extension: processed_url_template}
|
|
||||||
app_names: 发现的应用名称列表
|
|
||||||
:raises ValueError: XML 解析失败或格式无效
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
root = ET.fromstring(xml_content)
|
|
||||||
except ET.ParseError as e:
|
|
||||||
raise ValueError(f"WOPI Discovery XML 解析失败: {e}")
|
|
||||||
|
|
||||||
# 查找 net-zone(可能有多个,取第一个非空的)
|
|
||||||
net_zones = root.findall('net-zone')
|
|
||||||
if not net_zones:
|
|
||||||
raise ValueError("WOPI Discovery XML 缺少 net-zone 节点")
|
|
||||||
|
|
||||||
# ext_actions: {extension: {action_name: urlsrc}}
|
|
||||||
ext_actions: dict[str, dict[str, str]] = {}
|
|
||||||
app_names: list[str] = []
|
|
||||||
|
|
||||||
for net_zone in net_zones:
|
|
||||||
for app_elem in net_zone.findall('app'):
|
|
||||||
app_name = app_elem.get('name', '')
|
|
||||||
if app_name:
|
|
||||||
app_names.append(app_name)
|
|
||||||
|
|
||||||
for action_elem in app_elem.findall('action'):
|
|
||||||
action_name = action_elem.get('name', '')
|
|
||||||
ext = action_elem.get('ext', '')
|
|
||||||
urlsrc = action_elem.get('urlsrc', '')
|
|
||||||
|
|
||||||
if not ext or not urlsrc:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 只关注 edit / embedview / view 三种动作
|
|
||||||
if action_name not in ('edit', 'embedview', 'view'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ext not in ext_actions:
|
|
||||||
ext_actions[ext] = {}
|
|
||||||
ext_actions[ext][action_name] = urlsrc
|
|
||||||
|
|
||||||
# 为每个扩展名选择最佳 URL: edit > embedview > view
|
|
||||||
action_urls: dict[str, str] = {}
|
|
||||||
for ext, actions_map in ext_actions.items():
|
|
||||||
selected_urlsrc: str | None = None
|
|
||||||
for preferred in ('edit', 'embedview', 'view'):
|
|
||||||
if preferred in actions_map:
|
|
||||||
selected_urlsrc = actions_map[preferred]
|
|
||||||
break
|
|
||||||
|
|
||||||
if selected_urlsrc:
|
|
||||||
action_urls[ext] = process_wopi_action_url(selected_urlsrc)
|
|
||||||
|
|
||||||
# 去重 app_names
|
|
||||||
seen: set[str] = set()
|
|
||||||
unique_names: list[str] = []
|
|
||||||
for name in app_names:
|
|
||||||
if name not in seen:
|
|
||||||
seen.add(name)
|
|
||||||
unique_names.append(name)
|
|
||||||
|
|
||||||
l.info(f"WOPI Discovery 解析完成: {len(action_urls)} 个扩展名, 应用: {unique_names}")
|
|
||||||
|
|
||||||
return action_urls, unique_names
|
|
||||||
@@ -84,7 +84,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="disknext-ee",
|
name="disknext-ee",
|
||||||
packages=[],
|
|
||||||
ext_modules=cythonize(
|
ext_modules=cythonize(
|
||||||
extensions,
|
extensions,
|
||||||
compiler_directives={'language_level': "3"},
|
compiler_directives={'language_level': "3"},
|
||||||
|
|||||||
@@ -954,11 +954,18 @@ class PolicyType(StrEnum):
|
|||||||
S3 = "s3" # S3 兼容存储
|
S3 = "s3" # S3 兼容存储
|
||||||
```
|
```
|
||||||
|
|
||||||
### PolicyType
|
### StorageType
|
||||||
```python
|
```python
|
||||||
class PolicyType(StrEnum):
|
class StorageType(StrEnum):
|
||||||
LOCAL = "local" # 本地存储
|
LOCAL = "local" # 本地存储
|
||||||
S3 = "s3" # S3 兼容存储
|
QINIU = "qiniu" # 七牛云
|
||||||
|
TENCENT = "tencent" # 腾讯云
|
||||||
|
ALIYUN = "aliyun" # 阿里云
|
||||||
|
ONEDRIVE = "onedrive" # OneDrive
|
||||||
|
GOOGLE_DRIVE = "google_drive" # Google Drive
|
||||||
|
DROPBOX = "dropbox" # Dropbox
|
||||||
|
WEBDAV = "webdav" # WebDAV
|
||||||
|
REMOTE = "remote" # 远程存储
|
||||||
```
|
```
|
||||||
|
|
||||||
### UserStatus
|
### UserStatus
|
||||||
|
|||||||
@@ -69,20 +69,18 @@ from .object import (
|
|||||||
CreateUploadSessionRequest,
|
CreateUploadSessionRequest,
|
||||||
DirectoryCreateRequest,
|
DirectoryCreateRequest,
|
||||||
DirectoryResponse,
|
DirectoryResponse,
|
||||||
|
FileMetadata,
|
||||||
|
FileMetadataBase,
|
||||||
Object,
|
Object,
|
||||||
ObjectBase,
|
ObjectBase,
|
||||||
ObjectCopyRequest,
|
ObjectCopyRequest,
|
||||||
ObjectDeleteRequest,
|
ObjectDeleteRequest,
|
||||||
ObjectFileFinalize,
|
|
||||||
ObjectMoveRequest,
|
ObjectMoveRequest,
|
||||||
ObjectMoveUpdate,
|
|
||||||
ObjectPropertyDetailResponse,
|
ObjectPropertyDetailResponse,
|
||||||
ObjectPropertyResponse,
|
ObjectPropertyResponse,
|
||||||
ObjectRenameRequest,
|
ObjectRenameRequest,
|
||||||
ObjectResponse,
|
ObjectResponse,
|
||||||
ObjectSwitchPolicyRequest,
|
|
||||||
ObjectType,
|
ObjectType,
|
||||||
FileCategory,
|
|
||||||
PolicyResponse,
|
PolicyResponse,
|
||||||
UploadChunkResponse,
|
UploadChunkResponse,
|
||||||
UploadSession,
|
UploadSession,
|
||||||
@@ -97,42 +95,11 @@ from .object import (
|
|||||||
TrashRestoreRequest,
|
TrashRestoreRequest,
|
||||||
TrashDeleteRequest,
|
TrashDeleteRequest,
|
||||||
)
|
)
|
||||||
from .object_metadata import (
|
|
||||||
ObjectMetadata,
|
|
||||||
ObjectMetadataBase,
|
|
||||||
MetadataNamespace,
|
|
||||||
MetadataResponse,
|
|
||||||
MetadataPatchItem,
|
|
||||||
MetadataPatchRequest,
|
|
||||||
INTERNAL_NAMESPACES,
|
|
||||||
USER_WRITABLE_NAMESPACES,
|
|
||||||
)
|
|
||||||
from .custom_property import (
|
|
||||||
CustomPropertyDefinition,
|
|
||||||
CustomPropertyDefinitionBase,
|
|
||||||
CustomPropertyType,
|
|
||||||
CustomPropertyCreateRequest,
|
|
||||||
CustomPropertyUpdateRequest,
|
|
||||||
CustomPropertyResponse,
|
|
||||||
)
|
|
||||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||||
from .uri import DiskNextURI, FileSystemNamespace
|
from .uri import DiskNextURI, FileSystemNamespace
|
||||||
from .order import (
|
from .order import Order, OrderStatus, OrderType
|
||||||
Order, OrderStatus, OrderType,
|
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
|
||||||
CreateOrderRequest, OrderResponse,
|
from .redeem import Redeem, RedeemType
|
||||||
)
|
|
||||||
from .policy import (
|
|
||||||
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
|
|
||||||
PolicyType, PolicySummary, PolicyUpdateRequest,
|
|
||||||
)
|
|
||||||
from .product import (
|
|
||||||
Product, ProductBase, ProductType, PaymentMethod,
|
|
||||||
ProductCreateRequest, ProductUpdateRequest, ProductResponse,
|
|
||||||
)
|
|
||||||
from .redeem import (
|
|
||||||
Redeem, RedeemType,
|
|
||||||
RedeemCreateRequest, RedeemUseRequest, RedeemInfoResponse, RedeemAdminResponse,
|
|
||||||
)
|
|
||||||
from .report import Report, ReportReason
|
from .report import Report, ReportReason
|
||||||
from .setting import (
|
from .setting import (
|
||||||
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
|
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
|
||||||
@@ -145,20 +112,16 @@ from .share import (
|
|||||||
AdminShareListItem,
|
AdminShareListItem,
|
||||||
)
|
)
|
||||||
from .source_link import SourceLink
|
from .source_link import SourceLink
|
||||||
from .storage_pack import StoragePack, StoragePackResponse
|
from .storage_pack import StoragePack
|
||||||
from .tag import Tag, TagType
|
from .tag import Tag, TagType
|
||||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
|
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||||
from .webdav import (
|
from .webdav import WebDAV
|
||||||
WebDAV, WebDAVBase,
|
|
||||||
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
|
|
||||||
)
|
|
||||||
from .file_app import (
|
from .file_app import (
|
||||||
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
|
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
|
||||||
# DTO
|
# DTO
|
||||||
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
|
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
|
||||||
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
|
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
|
||||||
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
|
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
|
||||||
WopiDiscoveredExtension, WopiDiscoveryResponse,
|
|
||||||
)
|
)
|
||||||
from .wopi import WopiFileInfo, WopiAccessTokenPayload
|
from .wopi import WopiFileInfo, WopiAccessTokenPayload
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100, Str128, Str255, Text1024
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -87,7 +87,7 @@ class ChangePasswordRequest(SQLModelBase):
|
|||||||
old_password: str = Field(min_length=1)
|
old_password: str = Field(min_length=1)
|
||||||
"""当前密码"""
|
"""当前密码"""
|
||||||
|
|
||||||
new_password: Str128 = Field(min_length=8)
|
new_password: str = Field(min_length=8, max_length=128)
|
||||||
"""新密码(至少 8 位)"""
|
"""新密码(至少 8 位)"""
|
||||||
|
|
||||||
|
|
||||||
@@ -103,13 +103,13 @@ class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
|
|||||||
provider: AuthProviderType = Field(index=True)
|
provider: AuthProviderType = Field(index=True)
|
||||||
"""提供者类型"""
|
"""提供者类型"""
|
||||||
|
|
||||||
identifier: Str255 = Field(index=True)
|
identifier: str = Field(max_length=255, index=True)
|
||||||
"""标识符(邮箱/手机号/OAuth openid)"""
|
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||||
|
|
||||||
credential: Text1024 | None = None
|
credential: str | None = Field(default=None, max_length=1024)
|
||||||
"""凭证(Argon2 哈希密码 / null)"""
|
"""凭证(Argon2 哈希密码 / null)"""
|
||||||
|
|
||||||
display_name: Str100 | None = None
|
display_name: str | None = Field(default=None, max_length=100)
|
||||||
"""OAuth 昵称"""
|
"""OAuth 昵称"""
|
||||||
|
|
||||||
avatar_url: str | None = Field(default=None, max_length=512)
|
avatar_url: str | None = Field(default=None, max_length=512)
|
||||||
|
|||||||
@@ -1,135 +0,0 @@
|
|||||||
"""
|
|
||||||
用户自定义属性定义模型
|
|
||||||
|
|
||||||
允许用户定义类型化的自定义属性模板(如标签、评分、分类等),
|
|
||||||
实际值通过 ObjectMetadata KV 表存储,键名格式:custom:{property_definition_id}。
|
|
||||||
|
|
||||||
支持的属性类型:text, number, boolean, select, multi_select, rating, link
|
|
||||||
"""
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import JSON
|
|
||||||
from sqlmodel import Field, Relationship
|
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .user import User
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 枚举 ====================
|
|
||||||
|
|
||||||
class CustomPropertyType(StrEnum):
|
|
||||||
"""自定义属性值类型枚举"""
|
|
||||||
TEXT = "text"
|
|
||||||
"""文本"""
|
|
||||||
NUMBER = "number"
|
|
||||||
"""数字"""
|
|
||||||
BOOLEAN = "boolean"
|
|
||||||
"""布尔值"""
|
|
||||||
SELECT = "select"
|
|
||||||
"""单选"""
|
|
||||||
MULTI_SELECT = "multi_select"
|
|
||||||
"""多选"""
|
|
||||||
RATING = "rating"
|
|
||||||
"""评分(1-5)"""
|
|
||||||
LINK = "link"
|
|
||||||
"""链接"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Base 模型 ====================
|
|
||||||
|
|
||||||
class CustomPropertyDefinitionBase(SQLModelBase):
|
|
||||||
"""自定义属性定义基础模型"""
|
|
||||||
|
|
||||||
name: Str100
|
|
||||||
"""属性显示名称"""
|
|
||||||
|
|
||||||
type: CustomPropertyType
|
|
||||||
"""属性值类型"""
|
|
||||||
|
|
||||||
icon: Str100 | None = None
|
|
||||||
"""图标标识(iconify 名称)"""
|
|
||||||
|
|
||||||
options: list[str] | None = Field(default=None, sa_type=JSON)
|
|
||||||
"""可选值列表(仅 select/multi_select 类型)"""
|
|
||||||
|
|
||||||
default_value: str | None = Field(default=None, max_length=500)
|
|
||||||
"""默认值"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
|
||||||
|
|
||||||
class CustomPropertyDefinition(CustomPropertyDefinitionBase, UUIDTableBaseMixin):
|
|
||||||
"""
|
|
||||||
用户自定义属性定义
|
|
||||||
|
|
||||||
每个用户独立管理自己的属性模板。
|
|
||||||
实际属性值存储在 ObjectMetadata 表中,键名格式:custom:{id}。
|
|
||||||
"""
|
|
||||||
|
|
||||||
owner_id: UUID = Field(
|
|
||||||
foreign_key="user.id",
|
|
||||||
ondelete="CASCADE",
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
"""所有者用户UUID"""
|
|
||||||
|
|
||||||
sort_order: int = 0
|
|
||||||
"""排序顺序"""
|
|
||||||
|
|
||||||
# 关系
|
|
||||||
owner: "User" = Relationship()
|
|
||||||
"""所有者"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
|
||||||
|
|
||||||
class CustomPropertyCreateRequest(SQLModelBase):
|
|
||||||
"""创建自定义属性请求 DTO"""
|
|
||||||
|
|
||||||
name: Str100
|
|
||||||
"""属性显示名称"""
|
|
||||||
|
|
||||||
type: CustomPropertyType
|
|
||||||
"""属性值类型"""
|
|
||||||
|
|
||||||
icon: str | None = None
|
|
||||||
"""图标标识"""
|
|
||||||
|
|
||||||
options: list[str] | None = None
|
|
||||||
"""可选值列表(仅 select/multi_select 类型)"""
|
|
||||||
|
|
||||||
default_value: str | None = None
|
|
||||||
"""默认值"""
|
|
||||||
|
|
||||||
|
|
||||||
class CustomPropertyUpdateRequest(SQLModelBase):
|
|
||||||
"""更新自定义属性请求 DTO"""
|
|
||||||
|
|
||||||
name: str | None = None
|
|
||||||
"""属性显示名称"""
|
|
||||||
|
|
||||||
icon: str | None = None
|
|
||||||
"""图标标识"""
|
|
||||||
|
|
||||||
options: list[str] | None = None
|
|
||||||
"""可选值列表"""
|
|
||||||
|
|
||||||
default_value: str | None = None
|
|
||||||
"""默认值"""
|
|
||||||
|
|
||||||
sort_order: int | None = None
|
|
||||||
"""排序顺序"""
|
|
||||||
|
|
||||||
|
|
||||||
class CustomPropertyResponse(CustomPropertyDefinitionBase):
|
|
||||||
"""自定义属性响应 DTO"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
"""属性定义UUID"""
|
|
||||||
|
|
||||||
sort_order: int
|
|
||||||
"""排序顺序"""
|
|
||||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -141,7 +141,7 @@ class Download(DownloadBase, UUIDTableBaseMixin):
|
|||||||
speed: int = Field(default=0)
|
speed: int = Field(default=0)
|
||||||
"""下载速度(bytes/s)"""
|
"""下载速度(bytes/s)"""
|
||||||
|
|
||||||
parent: Str255 | None = None
|
parent: str | None = Field(default=None, max_length=255)
|
||||||
"""父任务标识"""
|
"""父任务标识"""
|
||||||
|
|
||||||
error: str | None = Field(default=None)
|
error: str | None = Field(default=None)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str100, Str255, Text1024
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .group import Group
|
from .group import Group
|
||||||
@@ -119,7 +119,7 @@ class UserFileAppDefaultResponse(SQLModelBase):
|
|||||||
class FileAppCreateRequest(SQLModelBase):
|
class FileAppCreateRequest(SQLModelBase):
|
||||||
"""管理员创建应用请求 DTO"""
|
"""管理员创建应用请求 DTO"""
|
||||||
|
|
||||||
name: Str100
|
name: str = Field(max_length=100)
|
||||||
"""应用名称"""
|
"""应用名称"""
|
||||||
|
|
||||||
app_key: str = Field(max_length=50)
|
app_key: str = Field(max_length=50)
|
||||||
@@ -128,7 +128,7 @@ class FileAppCreateRequest(SQLModelBase):
|
|||||||
type: FileAppType
|
type: FileAppType
|
||||||
"""应用类型"""
|
"""应用类型"""
|
||||||
|
|
||||||
icon: Str255 | None = None
|
icon: str | None = Field(default=None, max_length=255)
|
||||||
"""图标名称/URL"""
|
"""图标名称/URL"""
|
||||||
|
|
||||||
description: str | None = Field(default=None, max_length=500)
|
description: str | None = Field(default=None, max_length=500)
|
||||||
@@ -140,13 +140,13 @@ class FileAppCreateRequest(SQLModelBase):
|
|||||||
is_restricted: bool = False
|
is_restricted: bool = False
|
||||||
"""是否限制用户组访问"""
|
"""是否限制用户组访问"""
|
||||||
|
|
||||||
iframe_url_template: Text1024 | None = None
|
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||||
"""iframe URL 模板"""
|
"""iframe URL 模板"""
|
||||||
|
|
||||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||||
"""WOPI 发现端点 URL"""
|
"""WOPI 发现端点 URL"""
|
||||||
|
|
||||||
wopi_editor_url_template: Text1024 | None = None
|
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||||
"""WOPI 编辑器 URL 模板"""
|
"""WOPI 编辑器 URL 模板"""
|
||||||
|
|
||||||
extensions: list[str] = []
|
extensions: list[str] = []
|
||||||
@@ -159,7 +159,7 @@ class FileAppCreateRequest(SQLModelBase):
|
|||||||
class FileAppUpdateRequest(SQLModelBase):
|
class FileAppUpdateRequest(SQLModelBase):
|
||||||
"""管理员更新应用请求 DTO(所有字段可选)"""
|
"""管理员更新应用请求 DTO(所有字段可选)"""
|
||||||
|
|
||||||
name: Str100 | None = None
|
name: str | None = Field(default=None, max_length=100)
|
||||||
"""应用名称"""
|
"""应用名称"""
|
||||||
|
|
||||||
app_key: str | None = Field(default=None, max_length=50)
|
app_key: str | None = Field(default=None, max_length=50)
|
||||||
@@ -168,7 +168,7 @@ class FileAppUpdateRequest(SQLModelBase):
|
|||||||
type: FileAppType | None = None
|
type: FileAppType | None = None
|
||||||
"""应用类型"""
|
"""应用类型"""
|
||||||
|
|
||||||
icon: Str255 | None = None
|
icon: str | None = Field(default=None, max_length=255)
|
||||||
"""图标名称/URL"""
|
"""图标名称/URL"""
|
||||||
|
|
||||||
description: str | None = Field(default=None, max_length=500)
|
description: str | None = Field(default=None, max_length=500)
|
||||||
@@ -180,13 +180,13 @@ class FileAppUpdateRequest(SQLModelBase):
|
|||||||
is_restricted: bool | None = None
|
is_restricted: bool | None = None
|
||||||
"""是否限制用户组访问"""
|
"""是否限制用户组访问"""
|
||||||
|
|
||||||
iframe_url_template: Text1024 | None = None
|
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||||
"""iframe URL 模板"""
|
"""iframe URL 模板"""
|
||||||
|
|
||||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||||
"""WOPI 发现端点 URL"""
|
"""WOPI 发现端点 URL"""
|
||||||
|
|
||||||
wopi_editor_url_template: Text1024 | None = None
|
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||||
"""WOPI 编辑器 URL 模板"""
|
"""WOPI 编辑器 URL 模板"""
|
||||||
|
|
||||||
|
|
||||||
@@ -297,35 +297,12 @@ class WopiSessionResponse(SQLModelBase):
|
|||||||
"""完整的编辑器 URL"""
|
"""完整的编辑器 URL"""
|
||||||
|
|
||||||
|
|
||||||
class WopiDiscoveredExtension(SQLModelBase):
|
|
||||||
"""单个 WOPI Discovery 发现的扩展名"""
|
|
||||||
|
|
||||||
extension: str
|
|
||||||
"""文件扩展名"""
|
|
||||||
|
|
||||||
action_url: str
|
|
||||||
"""处理后的动作 URL 模板"""
|
|
||||||
|
|
||||||
|
|
||||||
class WopiDiscoveryResponse(SQLModelBase):
|
|
||||||
"""WOPI Discovery 结果响应 DTO"""
|
|
||||||
|
|
||||||
discovered_extensions: list[WopiDiscoveredExtension] = []
|
|
||||||
"""发现的扩展名及其 URL 模板"""
|
|
||||||
|
|
||||||
app_names: list[str] = []
|
|
||||||
"""WOPI 服务端报告的应用名称(如 Writer、Calc、Impress)"""
|
|
||||||
|
|
||||||
applied_count: int = 0
|
|
||||||
"""已应用到 FileAppExtension 的数量"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
||||||
"""文件查看器应用注册表"""
|
"""文件查看器应用注册表"""
|
||||||
|
|
||||||
name: Str100
|
name: str = Field(max_length=100)
|
||||||
"""应用名称"""
|
"""应用名称"""
|
||||||
|
|
||||||
app_key: str = Field(max_length=50, unique=True, index=True)
|
app_key: str = Field(max_length=50, unique=True, index=True)
|
||||||
@@ -334,7 +311,7 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
|||||||
type: FileAppType
|
type: FileAppType
|
||||||
"""应用类型"""
|
"""应用类型"""
|
||||||
|
|
||||||
icon: Str255 | None = None
|
icon: str | None = Field(default=None, max_length=255)
|
||||||
"""图标名称/URL"""
|
"""图标名称/URL"""
|
||||||
|
|
||||||
description: str | None = Field(default=None, max_length=500)
|
description: str | None = Field(default=None, max_length=500)
|
||||||
@@ -346,13 +323,13 @@ class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
|||||||
is_restricted: bool = False
|
is_restricted: bool = False
|
||||||
"""是否限制用户组访问"""
|
"""是否限制用户组访问"""
|
||||||
|
|
||||||
iframe_url_template: Text1024 | None = None
|
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||||
"""iframe URL 模板,支持 {file_url} 占位符"""
|
"""iframe URL 模板,支持 {file_url} 占位符"""
|
||||||
|
|
||||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||||
"""WOPI 客户端发现端点 URL"""
|
"""WOPI 客户端发现端点 URL"""
|
||||||
|
|
||||||
wopi_editor_url_template: Text1024 | None = None
|
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||||
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
|
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
@@ -400,9 +377,6 @@ class FileAppExtension(SQLModelBase, TableBaseMixin):
|
|||||||
priority: int = Field(default=0, ge=0)
|
priority: int = Field(default=0, ge=0)
|
||||||
"""排序优先级(越小越优先)"""
|
"""排序优先级(越小越优先)"""
|
||||||
|
|
||||||
wopi_action_url: str | None = Field(default=None, max_length=2048)
|
|
||||||
"""WOPI 动作 URL 模板(Discovery 自动填充),支持 {wopi_src} {access_token} {access_token_ttl}"""
|
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
app: FileApp = Relationship(back_populates="extensions")
|
app: FileApp = Relationship(back_populates="extensions")
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import BigInteger
|
|
||||||
from sqlmodel import Field, Relationship, text
|
from sqlmodel import Field, Relationship, text
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -67,7 +66,7 @@ class GroupAllOptionsBase(GroupOptionsBase):
|
|||||||
class GroupCreateRequest(GroupAllOptionsBase):
|
class GroupCreateRequest(GroupAllOptionsBase):
|
||||||
"""创建用户组请求 DTO"""
|
"""创建用户组请求 DTO"""
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255)
|
||||||
"""用户组名称"""
|
"""用户组名称"""
|
||||||
|
|
||||||
max_storage: int = Field(default=0, ge=0)
|
max_storage: int = Field(default=0, ge=0)
|
||||||
@@ -92,7 +91,7 @@ class GroupCreateRequest(GroupAllOptionsBase):
|
|||||||
class GroupUpdateRequest(SQLModelBase):
|
class GroupUpdateRequest(SQLModelBase):
|
||||||
"""更新用户组请求 DTO(所有字段可选)"""
|
"""更新用户组请求 DTO(所有字段可选)"""
|
||||||
|
|
||||||
name: Str255 | None = None
|
name: str | None = Field(default=None, max_length=255)
|
||||||
"""用户组名称"""
|
"""用户组名称"""
|
||||||
|
|
||||||
max_storage: int | None = Field(default=None, ge=0)
|
max_storage: int | None = Field(default=None, ge=0)
|
||||||
@@ -258,10 +257,10 @@ class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
|
|||||||
class Group(GroupBase, UUIDTableBaseMixin):
|
class Group(GroupBase, UUIDTableBaseMixin):
|
||||||
"""用户组模型"""
|
"""用户组模型"""
|
||||||
|
|
||||||
name: Str255 = Field(unique=True)
|
name: str = Field(max_length=255, unique=True)
|
||||||
"""用户组名"""
|
"""用户组名"""
|
||||||
|
|
||||||
max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
"""最大存储空间(字节)"""
|
"""最大存储空间(字节)"""
|
||||||
|
|
||||||
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||||
|
|||||||
@@ -130,11 +130,6 @@ default_settings: list[Setting] = [
|
|||||||
Setting(name="sms_provider", value="", type=SettingsType.MOBILE),
|
Setting(name="sms_provider", value="", type=SettingsType.MOBILE),
|
||||||
Setting(name="sms_access_key", value="", type=SettingsType.MOBILE),
|
Setting(name="sms_access_key", value="", type=SettingsType.MOBILE),
|
||||||
Setting(name="sms_secret_key", value="", type=SettingsType.MOBILE),
|
Setting(name="sms_secret_key", value="", type=SettingsType.MOBILE),
|
||||||
# ==================== 文件分类扩展名配置 ====================
|
|
||||||
Setting(name="image", value="jpg,jpeg,png,gif,bmp,webp,svg,ico,tiff,tif,avif,heic,heif,psd,raw", type=SettingsType.FILE_CATEGORY),
|
|
||||||
Setting(name="video", value="mp4,mkv,avi,mov,wmv,flv,webm,m4v,ts,3gp,mpg,mpeg", type=SettingsType.FILE_CATEGORY),
|
|
||||||
Setting(name="audio", value="mp3,wav,flac,aac,ogg,wma,m4a,opus,ape,aiff,mid,midi", type=SettingsType.FILE_CATEGORY),
|
|
||||||
Setting(name="document", value="pdf,doc,docx,odt,rtf,txt,tex,epub,pages,ppt,pptx,odp,key,xls,xlsx,csv,ods,numbers,tsv,md,markdown,mdx", type=SettingsType.FILE_CATEGORY),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
async def init_default_settings() -> None:
|
async def init_default_settings() -> None:
|
||||||
@@ -178,7 +173,7 @@ async def init_default_group() -> None:
|
|||||||
admin=True,
|
admin=True,
|
||||||
)
|
)
|
||||||
admin_group_id = admin_group.id # 在 save 前保存 UUID
|
admin_group_id = admin_group.id # 在 save 前保存 UUID
|
||||||
admin_group = await admin_group.save(session)
|
await admin_group.save(session)
|
||||||
|
|
||||||
await GroupOptions(
|
await GroupOptions(
|
||||||
group_id=admin_group_id,
|
group_id=admin_group_id,
|
||||||
@@ -208,7 +203,7 @@ async def init_default_group() -> None:
|
|||||||
web_dav_enabled=True,
|
web_dav_enabled=True,
|
||||||
)
|
)
|
||||||
member_group_id = member_group.id # 在 save 前保存 UUID
|
member_group_id = member_group.id # 在 save 前保存 UUID
|
||||||
member_group = await member_group.save(session)
|
await member_group.save(session)
|
||||||
|
|
||||||
await GroupOptions(
|
await GroupOptions(
|
||||||
group_id=member_group_id,
|
group_id=member_group_id,
|
||||||
@@ -227,7 +222,7 @@ async def init_default_group() -> None:
|
|||||||
default_group_setting = await Setting.get(session, Setting.name == "default_group")
|
default_group_setting = await Setting.get(session, Setting.name == "default_group")
|
||||||
if default_group_setting:
|
if default_group_setting:
|
||||||
default_group_setting.value = str(member_group_id)
|
default_group_setting.value = str(member_group_id)
|
||||||
default_group_setting = await default_group_setting.save(session)
|
await default_group_setting.save(session)
|
||||||
|
|
||||||
# 未找到初始游客组时,则创建
|
# 未找到初始游客组时,则创建
|
||||||
if not await Group.get(session, Group.name == "游客"):
|
if not await Group.get(session, Group.name == "游客"):
|
||||||
@@ -237,7 +232,7 @@ async def init_default_group() -> None:
|
|||||||
web_dav_enabled=False,
|
web_dav_enabled=False,
|
||||||
)
|
)
|
||||||
guest_group_id = guest_group.id # 在 save 前保存 UUID
|
guest_group_id = guest_group.id # 在 save 前保存 UUID
|
||||||
guest_group = await guest_group.save(session)
|
await guest_group.save(session)
|
||||||
|
|
||||||
await GroupOptions(
|
await GroupOptions(
|
||||||
group_id=guest_group_id,
|
group_id=guest_group_id,
|
||||||
@@ -289,7 +284,7 @@ async def init_default_user() -> None:
|
|||||||
group_id=admin_group.id,
|
group_id=admin_group.id,
|
||||||
)
|
)
|
||||||
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
||||||
admin_user = await admin_user.save(session)
|
await admin_user.save(session)
|
||||||
|
|
||||||
# 创建 AuthIdentity(邮箱密码身份)
|
# 创建 AuthIdentity(邮箱密码身份)
|
||||||
await AuthIdentity(
|
await AuthIdentity(
|
||||||
@@ -378,7 +373,7 @@ async def init_default_theme_presets() -> None:
|
|||||||
error=ChromaticColor.RED,
|
error=ChromaticColor.RED,
|
||||||
neutral=NeutralColor.ZINC,
|
neutral=NeutralColor.ZINC,
|
||||||
)
|
)
|
||||||
default_preset = await default_preset.save(session)
|
await default_preset.save(session)
|
||||||
log.info('已创建默认主题预设')
|
log.info('已创建默认主题预设')
|
||||||
|
|
||||||
|
|
||||||
@@ -451,43 +446,36 @@ _DEFAULT_FILE_APPS: list[dict] = [
|
|||||||
"is_enabled": True,
|
"is_enabled": True,
|
||||||
"extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"],
|
"extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"],
|
||||||
},
|
},
|
||||||
{
|
# iframe 应用(默认禁用)
|
||||||
"name": "EPUB 阅读器",
|
|
||||||
"app_key": "epub_reader",
|
|
||||||
"type": "builtin",
|
|
||||||
"icon": "book-open",
|
|
||||||
"description": "阅读 EPUB 电子书",
|
|
||||||
"is_enabled": True,
|
|
||||||
"extensions": ["epub"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "3D 模型预览",
|
|
||||||
"app_key": "model_viewer",
|
|
||||||
"type": "builtin",
|
|
||||||
"icon": "cube",
|
|
||||||
"description": "预览 3D 模型",
|
|
||||||
"is_enabled": True,
|
|
||||||
"extensions": ["gltf", "glb", "stl", "obj", "fbx", "ply", "3mf"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Font Viewer",
|
|
||||||
"app_key": "font_viewer",
|
|
||||||
"type": "builtin",
|
|
||||||
"icon": "type",
|
|
||||||
"description": "预览字体文件并显示元数据和文本样本",
|
|
||||||
"is_enabled": True,
|
|
||||||
"extensions": ["ttf", "otf", "woff", "woff2"],
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "Office 在线预览",
|
"name": "Office 在线预览",
|
||||||
"app_key": "office_viewer",
|
"app_key": "office_viewer",
|
||||||
"type": "iframe",
|
"type": "iframe",
|
||||||
"icon": "file-word",
|
"icon": "file-word",
|
||||||
"description": "使用 Microsoft Office Online 预览文档",
|
"description": "使用 Microsoft Office Online 预览文档",
|
||||||
"is_enabled": True,
|
"is_enabled": False,
|
||||||
"iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
|
"iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
|
||||||
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
|
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
|
||||||
},
|
},
|
||||||
|
# WOPI 应用(默认禁用)
|
||||||
|
{
|
||||||
|
"name": "Collabora Online",
|
||||||
|
"app_key": "collabora",
|
||||||
|
"type": "wopi",
|
||||||
|
"icon": "file-text",
|
||||||
|
"description": "Collabora Online 文档编辑器(需自行部署)",
|
||||||
|
"is_enabled": False,
|
||||||
|
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx", "odt", "ods", "odp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "OnlyOffice",
|
||||||
|
"app_key": "onlyoffice",
|
||||||
|
"type": "wopi",
|
||||||
|
"icon": "file-text",
|
||||||
|
"description": "OnlyOffice 文档编辑器(需自行部署)",
|
||||||
|
"is_enabled": False,
|
||||||
|
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -505,7 +493,7 @@ async def init_default_file_apps() -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for app_data in _DEFAULT_FILE_APPS:
|
for app_data in _DEFAULT_FILE_APPS:
|
||||||
extensions = app_data["extensions"]
|
extensions = app_data.pop("extensions")
|
||||||
|
|
||||||
app = FileApp(
|
app = FileApp(
|
||||||
name=app_data["name"],
|
name=app_data["name"],
|
||||||
@@ -527,6 +515,6 @@ async def init_default_file_apps() -> None:
|
|||||||
extension=ext.lower(),
|
extension=ext.lower(),
|
||||||
priority=i,
|
priority=i,
|
||||||
)
|
)
|
||||||
ext_record = await ext_record.save(session)
|
await ext_record.save(session)
|
||||||
|
|
||||||
log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用')
|
log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用')
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, text, Index
|
from sqlmodel import Field, Relationship, text, Index
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .download import Download
|
from .download import Download
|
||||||
@@ -28,13 +28,13 @@ class NodeType(StrEnum):
|
|||||||
class Aria2ConfigurationBase(SQLModelBase):
|
class Aria2ConfigurationBase(SQLModelBase):
|
||||||
"""Aria2配置基础模型"""
|
"""Aria2配置基础模型"""
|
||||||
|
|
||||||
rpc_url: Str255 | None = None
|
rpc_url: str | None = Field(default=None, max_length=255)
|
||||||
"""RPC地址"""
|
"""RPC地址"""
|
||||||
|
|
||||||
rpc_secret: str | None = None
|
rpc_secret: str | None = None
|
||||||
"""RPC密钥"""
|
"""RPC密钥"""
|
||||||
|
|
||||||
temp_path: Str255 | None = None
|
temp_path: str | None = Field(default=None, max_length=255)
|
||||||
"""临时下载路径"""
|
"""临时下载路径"""
|
||||||
|
|
||||||
max_concurrent: int = Field(default=5, ge=1, le=50)
|
max_concurrent: int = Field(default=5, ge=1, le=50)
|
||||||
@@ -70,19 +70,19 @@ class Node(SQLModelBase, TableBaseMixin):
|
|||||||
status: NodeStatus = Field(default=NodeStatus.ONLINE)
|
status: NodeStatus = Field(default=NodeStatus.ONLINE)
|
||||||
"""节点状态"""
|
"""节点状态"""
|
||||||
|
|
||||||
name: Str255 = Field(unique=True)
|
name: str = Field(max_length=255, unique=True)
|
||||||
"""节点名称"""
|
"""节点名称"""
|
||||||
|
|
||||||
type: NodeType
|
type: NodeType
|
||||||
"""节点类型"""
|
"""节点类型"""
|
||||||
|
|
||||||
server: Str255
|
server: str = Field(max_length=255)
|
||||||
"""节点地址(IP或域名)"""
|
"""节点地址(IP或域名)"""
|
||||||
|
|
||||||
slave_key: Str255 | None = None
|
slave_key: str | None = Field(default=None, max_length=255)
|
||||||
"""从机通讯密钥"""
|
"""从机通讯密钥"""
|
||||||
|
|
||||||
master_key: Str255 | None = None
|
master_key: str | None = Field(default=None, max_length=255)
|
||||||
"""主机通讯密钥"""
|
"""主机通讯密钥"""
|
||||||
|
|
||||||
aria2_enabled: bool = False
|
aria2_enabled: bool = False
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from enum import StrEnum
|
|||||||
from sqlalchemy import BigInteger
|
from sqlalchemy import BigInteger
|
||||||
from sqlmodel import Field, Relationship, CheckConstraint, Index, text
|
from sqlmodel import Field, Relationship, CheckConstraint, Index, text
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255, Str256
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
from .policy import PolicyType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -18,7 +16,6 @@ if TYPE_CHECKING:
|
|||||||
from .share import Share
|
from .share import Share
|
||||||
from .physical_file import PhysicalFile
|
from .physical_file import PhysicalFile
|
||||||
from .uri import DiskNextURI
|
from .uri import DiskNextURI
|
||||||
from .object_metadata import ObjectMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectType(StrEnum):
|
class ObjectType(StrEnum):
|
||||||
@@ -26,13 +23,42 @@ class ObjectType(StrEnum):
|
|||||||
FILE = "file"
|
FILE = "file"
|
||||||
FOLDER = "folder"
|
FOLDER = "folder"
|
||||||
|
|
||||||
|
class StorageType(StrEnum):
|
||||||
|
"""存储类型枚举"""
|
||||||
|
LOCAL = "local"
|
||||||
|
QINIU = "qiniu"
|
||||||
|
TENCENT = "tencent"
|
||||||
|
ALIYUN = "aliyun"
|
||||||
|
ONEDRIVE = "onedrive"
|
||||||
|
GOOGLE_DRIVE = "google_drive"
|
||||||
|
DROPBOX = "dropbox"
|
||||||
|
WEBDAV = "webdav"
|
||||||
|
REMOTE = "remote"
|
||||||
|
|
||||||
class FileCategory(StrEnum):
|
|
||||||
"""文件类型分类枚举,用于按类别筛选文件"""
|
class FileMetadataBase(SQLModelBase):
|
||||||
IMAGE = "image"
|
"""文件元数据基础模型"""
|
||||||
VIDEO = "video"
|
|
||||||
AUDIO = "audio"
|
width: int | None = Field(default=None)
|
||||||
DOCUMENT = "document"
|
"""图片宽度(像素)"""
|
||||||
|
|
||||||
|
height: int | None = Field(default=None)
|
||||||
|
"""图片高度(像素)"""
|
||||||
|
|
||||||
|
duration: float | None = Field(default=None)
|
||||||
|
"""音视频时长(秒)"""
|
||||||
|
|
||||||
|
bitrate: int | None = Field(default=None)
|
||||||
|
"""比特率(kbps)"""
|
||||||
|
|
||||||
|
mime_type: str | None = Field(default=None, max_length=127)
|
||||||
|
"""MIME类型"""
|
||||||
|
|
||||||
|
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||||
|
"""MD5校验和"""
|
||||||
|
|
||||||
|
checksum_sha256: str | None = Field(default=None, max_length=64)
|
||||||
|
"""SHA256校验和"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== Base 模型 ====================
|
# ==================== Base 模型 ====================
|
||||||
@@ -49,32 +75,9 @@ class ObjectBase(SQLModelBase):
|
|||||||
size: int | None = None
|
size: int | None = None
|
||||||
"""文件大小(字节),目录为 None"""
|
"""文件大小(字节),目录为 None"""
|
||||||
|
|
||||||
mime_type: str | None = Field(default=None, max_length=127)
|
|
||||||
"""MIME类型(仅文件有效)"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
class ObjectFileFinalize(SQLModelBase):
|
|
||||||
"""文件上传完成后更新 Object 的 DTO"""
|
|
||||||
|
|
||||||
size: int
|
|
||||||
"""文件大小(字节)"""
|
|
||||||
|
|
||||||
physical_file_id: UUID
|
|
||||||
"""关联的物理文件UUID"""
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectMoveUpdate(SQLModelBase):
|
|
||||||
"""移动/重命名 Object 的 DTO"""
|
|
||||||
|
|
||||||
parent_id: UUID
|
|
||||||
"""新的父目录UUID"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""新名称"""
|
|
||||||
|
|
||||||
|
|
||||||
class DirectoryCreateRequest(SQLModelBase):
|
class DirectoryCreateRequest(SQLModelBase):
|
||||||
"""创建目录请求 DTO"""
|
"""创建目录请求 DTO"""
|
||||||
|
|
||||||
@@ -133,7 +136,7 @@ class PolicyResponse(SQLModelBase):
|
|||||||
name: str
|
name: str
|
||||||
"""策略名称"""
|
"""策略名称"""
|
||||||
|
|
||||||
type: PolicyType
|
type: StorageType
|
||||||
"""存储类型"""
|
"""存储类型"""
|
||||||
|
|
||||||
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
||||||
@@ -161,6 +164,22 @@ class DirectoryResponse(SQLModelBase):
|
|||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
|
||||||
|
"""文件元数据模型(与Object一对一关联)"""
|
||||||
|
|
||||||
|
object_id: UUID = Field(
|
||||||
|
foreign_key="object.id",
|
||||||
|
unique=True,
|
||||||
|
index=True,
|
||||||
|
ondelete="CASCADE"
|
||||||
|
)
|
||||||
|
"""关联的对象UUID"""
|
||||||
|
|
||||||
|
# 反向关系
|
||||||
|
object: "Object" = Relationship(back_populates="file_metadata")
|
||||||
|
"""关联的对象"""
|
||||||
|
|
||||||
|
|
||||||
class Object(ObjectBase, UUIDTableBaseMixin):
|
class Object(ObjectBase, UUIDTableBaseMixin):
|
||||||
"""
|
"""
|
||||||
统一对象模型
|
统一对象模型
|
||||||
@@ -198,13 +217,13 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
|
|
||||||
# ==================== 基础字段 ====================
|
# ==================== 基础字段 ====================
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255)
|
||||||
"""对象名称(文件名或目录名)"""
|
"""对象名称(文件名或目录名)"""
|
||||||
|
|
||||||
type: ObjectType
|
type: ObjectType
|
||||||
"""对象类型:file 或 folder"""
|
"""对象类型:file 或 folder"""
|
||||||
|
|
||||||
password: Str255 | None = None
|
password: str | None = Field(default=None, max_length=255)
|
||||||
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
|
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
|
||||||
|
|
||||||
# ==================== 文件专属字段 ====================
|
# ==================== 文件专属字段 ====================
|
||||||
@@ -212,7 +231,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||||
"""文件大小(字节),目录为 0"""
|
"""文件大小(字节),目录为 0"""
|
||||||
|
|
||||||
upload_session_id: Str255 | None = Field(default=None, unique=True, index=True)
|
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
||||||
"""分块上传会话ID(仅文件有效)"""
|
"""分块上传会话ID(仅文件有效)"""
|
||||||
|
|
||||||
physical_file_id: UUID | None = Field(
|
physical_file_id: UUID | None = Field(
|
||||||
@@ -315,11 +334,11 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
"""子对象(文件和子目录)"""
|
"""子对象(文件和子目录)"""
|
||||||
|
|
||||||
# 仅文件有效的关系
|
# 仅文件有效的关系
|
||||||
metadata_entries: list["ObjectMetadata"] = Relationship(
|
file_metadata: FileMetadata | None = Relationship(
|
||||||
back_populates="object",
|
back_populates="object",
|
||||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
|
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
||||||
)
|
)
|
||||||
"""元数据键值对列表"""
|
"""文件元数据(仅文件有效)"""
|
||||||
|
|
||||||
source_links: list["SourceLink"] = Relationship(
|
source_links: list["SourceLink"] = Relationship(
|
||||||
back_populates="object",
|
back_populates="object",
|
||||||
@@ -477,37 +496,6 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
fetch_mode="all"
|
fetch_mode="all"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_by_category(
|
|
||||||
cls,
|
|
||||||
session: 'AsyncSession',
|
|
||||||
user_id: UUID,
|
|
||||||
extensions: list[str],
|
|
||||||
table_view: 'TableViewRequest | None' = None,
|
|
||||||
) -> 'ListResponse[Object]':
|
|
||||||
"""
|
|
||||||
按扩展名列表查询用户的所有文件(跨目录)
|
|
||||||
|
|
||||||
只查询未删除、未封禁的文件对象,使用 ILIKE 匹配文件名后缀。
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param user_id: 用户UUID
|
|
||||||
:param extensions: 扩展名列表(不含点号)
|
|
||||||
:param table_view: 分页排序参数
|
|
||||||
:return: 分页文件列表
|
|
||||||
"""
|
|
||||||
from sqlalchemy import or_
|
|
||||||
|
|
||||||
ext_conditions = [cls.name.ilike(f"%.{ext}") for ext in extensions]
|
|
||||||
condition = (
|
|
||||||
(cls.owner_id == user_id) &
|
|
||||||
(cls.type == ObjectType.FILE) &
|
|
||||||
(cls.deleted_at == None) &
|
|
||||||
(cls.is_banned == False) &
|
|
||||||
or_(*ext_conditions)
|
|
||||||
)
|
|
||||||
return await cls.get_with_count(session, condition, table_view=table_view)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def resolve_uri(
|
async def resolve_uri(
|
||||||
cls,
|
cls,
|
||||||
@@ -585,7 +573,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
class UploadSessionBase(SQLModelBase):
|
class UploadSessionBase(SQLModelBase):
|
||||||
"""上传会话基础字段"""
|
"""上传会话基础字段"""
|
||||||
|
|
||||||
file_name: Str255
|
file_name: str = Field(max_length=255)
|
||||||
"""原始文件名"""
|
"""原始文件名"""
|
||||||
|
|
||||||
file_size: int = Field(ge=0, sa_type=BigInteger)
|
file_size: int = Field(ge=0, sa_type=BigInteger)
|
||||||
@@ -616,12 +604,6 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
|||||||
storage_path: str | None = Field(default=None, max_length=512)
|
storage_path: str | None = Field(default=None, max_length=512)
|
||||||
"""文件存储路径"""
|
"""文件存储路径"""
|
||||||
|
|
||||||
s3_upload_id: Str256 | None = None
|
|
||||||
"""S3 Multipart Upload ID(仅 S3 策略使用)"""
|
|
||||||
|
|
||||||
s3_part_etags: str | None = None
|
|
||||||
"""S3 已上传分片的 ETag 列表,JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
|
|
||||||
|
|
||||||
expires_at: datetime
|
expires_at: datetime
|
||||||
"""会话过期时间"""
|
"""会话过期时间"""
|
||||||
|
|
||||||
@@ -663,7 +645,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
|||||||
class CreateUploadSessionRequest(SQLModelBase):
|
class CreateUploadSessionRequest(SQLModelBase):
|
||||||
"""创建上传会话请求 DTO"""
|
"""创建上传会话请求 DTO"""
|
||||||
|
|
||||||
file_name: Str255
|
file_name: str = Field(max_length=255)
|
||||||
"""文件名"""
|
"""文件名"""
|
||||||
|
|
||||||
file_size: int = Field(ge=0)
|
file_size: int = Field(ge=0)
|
||||||
@@ -720,7 +702,7 @@ class UploadChunkResponse(SQLModelBase):
|
|||||||
class CreateFileRequest(SQLModelBase):
|
class CreateFileRequest(SQLModelBase):
|
||||||
"""创建空白文件请求 DTO"""
|
"""创建空白文件请求 DTO"""
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255)
|
||||||
"""文件名"""
|
"""文件名"""
|
||||||
|
|
||||||
parent_id: UUID
|
parent_id: UUID
|
||||||
@@ -730,16 +712,6 @@ class CreateFileRequest(SQLModelBase):
|
|||||||
"""存储策略UUID,不指定则使用父目录的策略"""
|
"""存储策略UUID,不指定则使用父目录的策略"""
|
||||||
|
|
||||||
|
|
||||||
class ObjectSwitchPolicyRequest(SQLModelBase):
|
|
||||||
"""切换对象存储策略请求"""
|
|
||||||
|
|
||||||
policy_id: UUID
|
|
||||||
"""目标存储策略UUID"""
|
|
||||||
|
|
||||||
is_migrate_existing: bool = False
|
|
||||||
"""(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 对象操作相关 DTO ====================
|
# ==================== 对象操作相关 DTO ====================
|
||||||
|
|
||||||
class ObjectCopyRequest(SQLModelBase):
|
class ObjectCopyRequest(SQLModelBase):
|
||||||
@@ -758,7 +730,7 @@ class ObjectRenameRequest(SQLModelBase):
|
|||||||
id: UUID
|
id: UUID
|
||||||
"""对象UUID"""
|
"""对象UUID"""
|
||||||
|
|
||||||
new_name: Str255
|
new_name: str = Field(max_length=255)
|
||||||
"""新名称"""
|
"""新名称"""
|
||||||
|
|
||||||
|
|
||||||
@@ -777,9 +749,6 @@ class ObjectPropertyResponse(SQLModelBase):
|
|||||||
size: int
|
size: int
|
||||||
"""文件大小(字节)"""
|
"""文件大小(字节)"""
|
||||||
|
|
||||||
mime_type: str | None = None
|
|
||||||
"""MIME类型"""
|
|
||||||
|
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
"""创建时间"""
|
"""创建时间"""
|
||||||
|
|
||||||
@@ -793,13 +762,22 @@ class ObjectPropertyResponse(SQLModelBase):
|
|||||||
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
||||||
"""对象详细属性响应 DTO(继承基本属性)"""
|
"""对象详细属性响应 DTO(继承基本属性)"""
|
||||||
|
|
||||||
# 校验和(从 PhysicalFile 读取)
|
# 元数据信息
|
||||||
|
mime_type: str | None = None
|
||||||
|
"""MIME类型"""
|
||||||
|
|
||||||
|
width: int | None = None
|
||||||
|
"""图片宽度(像素)"""
|
||||||
|
|
||||||
|
height: int | None = None
|
||||||
|
"""图片高度(像素)"""
|
||||||
|
|
||||||
|
duration: float | None = None
|
||||||
|
"""音视频时长(秒)"""
|
||||||
|
|
||||||
checksum_md5: str | None = None
|
checksum_md5: str | None = None
|
||||||
"""MD5校验和"""
|
"""MD5校验和"""
|
||||||
|
|
||||||
checksum_sha256: str | None = None
|
|
||||||
"""SHA256校验和"""
|
|
||||||
|
|
||||||
# 分享统计
|
# 分享统计
|
||||||
share_count: int = 0
|
share_count: int = 0
|
||||||
"""分享次数"""
|
"""分享次数"""
|
||||||
@@ -817,10 +795,6 @@ class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
|||||||
reference_count: int = 1
|
reference_count: int = 1
|
||||||
"""物理文件引用计数(仅文件有效)"""
|
"""物理文件引用计数(仅文件有效)"""
|
||||||
|
|
||||||
# 元数据(KV 格式)
|
|
||||||
metadatas: dict[str, str] = {}
|
|
||||||
"""所有元数据条目(键名 → 值)"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 管理员文件管理 DTO ====================
|
# ==================== 管理员文件管理 DTO ====================
|
||||||
|
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
"""
|
|
||||||
对象元数据 KV 模型
|
|
||||||
|
|
||||||
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
|
|
||||||
如 exif:width, stream:duration, music:artist 等。
|
|
||||||
|
|
||||||
架构:
|
|
||||||
ObjectMetadata (KV 表,与 Object 一对多关系)
|
|
||||||
└── 每个 Object 可以有多条元数据记录
|
|
||||||
└── (object_id, name) 组合唯一索引
|
|
||||||
|
|
||||||
命名空间:
|
|
||||||
- exif: 图片 EXIF 信息(尺寸、相机参数、拍摄时间等)
|
|
||||||
- stream: 音视频流信息(时长、比特率、视频尺寸、编解码等)
|
|
||||||
- music: 音乐标签(标题、艺术家、专辑等)
|
|
||||||
- geo: 地理位置(经纬度、地址)
|
|
||||||
- apk: Android 安装包信息
|
|
||||||
- custom: 用户自定义属性
|
|
||||||
- sys: 系统内部元数据
|
|
||||||
- thumb: 缩略图信息
|
|
||||||
"""
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlmodel import Field, UniqueConstraint, Index, Relationship
|
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .object import Object
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 枚举 ====================
|
|
||||||
|
|
||||||
class MetadataNamespace(StrEnum):
|
|
||||||
"""元数据命名空间枚举"""
|
|
||||||
EXIF = "exif"
|
|
||||||
"""图片 EXIF 信息(含尺寸、相机参数、拍摄时间等)"""
|
|
||||||
MUSIC = "music"
|
|
||||||
"""音乐标签(title/artist/album/genre 等)"""
|
|
||||||
STREAM = "stream"
|
|
||||||
"""音视频流信息(codec/duration/bitrate/resolution 等)"""
|
|
||||||
GEO = "geo"
|
|
||||||
"""地理位置(latitude/longitude/address)"""
|
|
||||||
APK = "apk"
|
|
||||||
"""Android 安装包信息(package_name/version 等)"""
|
|
||||||
THUMB = "thumb"
|
|
||||||
"""缩略图信息(内部使用)"""
|
|
||||||
SYS = "sys"
|
|
||||||
"""系统元数据(内部使用)"""
|
|
||||||
CUSTOM = "custom"
|
|
||||||
"""用户自定义属性"""
|
|
||||||
|
|
||||||
|
|
||||||
# 对外不可见的命名空间(API 不返回给普通用户)
|
|
||||||
INTERNAL_NAMESPACES: set[str] = {MetadataNamespace.SYS, MetadataNamespace.THUMB}
|
|
||||||
|
|
||||||
# 用户可写的命名空间
|
|
||||||
USER_WRITABLE_NAMESPACES: set[str] = {MetadataNamespace.CUSTOM}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Base 模型 ====================
|
|
||||||
|
|
||||||
class ObjectMetadataBase(SQLModelBase):
|
|
||||||
"""对象元数据 KV 基础模型"""
|
|
||||||
|
|
||||||
name: Str255
|
|
||||||
"""元数据键名,格式:namespace:key(如 exif:width, stream:duration)"""
|
|
||||||
|
|
||||||
value: str
|
|
||||||
"""元数据值(统一为字符串存储)"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
|
||||||
|
|
||||||
class ObjectMetadata(ObjectMetadataBase, UUIDTableBaseMixin):
|
|
||||||
"""
|
|
||||||
对象元数据 KV 模型
|
|
||||||
|
|
||||||
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
|
|
||||||
每个对象的每个键名唯一(通过唯一索引保证)。
|
|
||||||
"""
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint("object_id", "name", name="uq_object_metadata_object_name"),
|
|
||||||
Index("ix_object_metadata_object_id", "object_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
object_id: UUID = Field(
|
|
||||||
foreign_key="object.id",
|
|
||||||
ondelete="CASCADE",
|
|
||||||
)
|
|
||||||
"""关联的对象UUID"""
|
|
||||||
|
|
||||||
is_public: bool = False
|
|
||||||
"""是否对分享页面公开"""
|
|
||||||
|
|
||||||
# 关系
|
|
||||||
object: "Object" = Relationship(back_populates="metadata_entries")
|
|
||||||
"""关联的对象"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
|
||||||
|
|
||||||
class MetadataResponse(SQLModelBase):
|
|
||||||
"""元数据查询响应 DTO"""
|
|
||||||
|
|
||||||
metadatas: dict[str, str]
|
|
||||||
"""元数据字典(键名 → 值)"""
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataPatchItem(SQLModelBase):
|
|
||||||
"""单条元数据补丁 DTO"""
|
|
||||||
|
|
||||||
key: Str255
|
|
||||||
"""元数据键名"""
|
|
||||||
|
|
||||||
value: str | None = None
|
|
||||||
"""值,None 表示删除此条目"""
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataPatchRequest(SQLModelBase):
|
|
||||||
"""元数据批量更新请求 DTO"""
|
|
||||||
|
|
||||||
patches: list[MetadataPatchItem]
|
|
||||||
"""补丁列表"""
|
|
||||||
@@ -1,118 +1,54 @@
|
|||||||
from decimal import Decimal
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import Numeric
|
|
||||||
from sqlmodel import Field, Relationship
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .product import Product
|
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
class OrderStatus(StrEnum):
|
class OrderStatus(StrEnum):
|
||||||
"""订单状态枚举"""
|
"""订单状态枚举"""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
"""待支付"""
|
"""待支付"""
|
||||||
|
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
"""已完成"""
|
"""已完成"""
|
||||||
|
|
||||||
CANCELLED = "cancelled"
|
CANCELLED = "cancelled"
|
||||||
"""已取消"""
|
"""已取消"""
|
||||||
|
|
||||||
|
|
||||||
class OrderType(StrEnum):
|
class OrderType(StrEnum):
|
||||||
"""订单类型枚举"""
|
"""订单类型枚举"""
|
||||||
|
# [TODO] 补充具体订单类型
|
||||||
|
pass
|
||||||
|
|
||||||
STORAGE_PACK = "storage_pack"
|
|
||||||
"""容量包"""
|
|
||||||
|
|
||||||
GROUP_TIME = "group_time"
|
|
||||||
"""用户组时长"""
|
|
||||||
|
|
||||||
SCORE = "score"
|
|
||||||
"""积分充值"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
|
||||||
|
|
||||||
class CreateOrderRequest(SQLModelBase):
|
|
||||||
"""创建订单请求 DTO"""
|
|
||||||
|
|
||||||
product_id: UUID
|
|
||||||
"""商品UUID"""
|
|
||||||
|
|
||||||
num: int = Field(default=1, ge=1)
|
|
||||||
"""购买数量"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
"""支付方式"""
|
|
||||||
|
|
||||||
|
|
||||||
class OrderResponse(SQLModelBase):
|
|
||||||
"""订单响应 DTO"""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
"""订单ID"""
|
|
||||||
|
|
||||||
order_no: str
|
|
||||||
"""订单号"""
|
|
||||||
|
|
||||||
type: OrderType
|
|
||||||
"""订单类型"""
|
|
||||||
|
|
||||||
method: str | None = None
|
|
||||||
"""支付方式"""
|
|
||||||
|
|
||||||
product_id: UUID | None = None
|
|
||||||
"""商品UUID"""
|
|
||||||
|
|
||||||
num: int
|
|
||||||
"""购买数量"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""商品名称"""
|
|
||||||
|
|
||||||
price: float
|
|
||||||
"""订单价格(元)"""
|
|
||||||
|
|
||||||
status: OrderStatus
|
|
||||||
"""订单状态"""
|
|
||||||
|
|
||||||
user_id: UUID
|
|
||||||
"""所属用户UUID"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
|
||||||
|
|
||||||
class Order(SQLModelBase, TableBaseMixin):
|
class Order(SQLModelBase, TableBaseMixin):
|
||||||
"""订单模型"""
|
"""订单模型"""
|
||||||
|
|
||||||
order_no: Str255 = Field(unique=True, index=True)
|
order_no: str = Field(max_length=255, unique=True, index=True)
|
||||||
"""订单号,唯一"""
|
"""订单号,唯一"""
|
||||||
|
|
||||||
type: OrderType
|
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
"""订单类型"""
|
"""订单类型 [TODO] 待定义枚举"""
|
||||||
|
|
||||||
method: Str255 | None = None
|
method: str | None = Field(default=None, max_length=255)
|
||||||
"""支付方式"""
|
"""支付方式"""
|
||||||
|
|
||||||
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
product_id: int | None = Field(default=None)
|
||||||
"""关联商品UUID"""
|
"""商品ID"""
|
||||||
|
|
||||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
||||||
"""购买数量"""
|
"""购买数量"""
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255)
|
||||||
"""商品名称"""
|
"""商品名称"""
|
||||||
|
|
||||||
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
|
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
"""订单价格(元)"""
|
"""订单价格(分)"""
|
||||||
|
|
||||||
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
||||||
"""订单状态"""
|
"""订单状态"""
|
||||||
@@ -127,19 +63,3 @@ class Order(SQLModelBase, TableBaseMixin):
|
|||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
user: "User" = Relationship(back_populates="orders")
|
user: "User" = Relationship(back_populates="orders")
|
||||||
product: "Product" = Relationship(back_populates="orders")
|
|
||||||
|
|
||||||
def to_response(self) -> OrderResponse:
|
|
||||||
"""转换为响应 DTO"""
|
|
||||||
return OrderResponse(
|
|
||||||
id=self.id,
|
|
||||||
order_no=self.order_no,
|
|
||||||
type=self.type,
|
|
||||||
method=self.method,
|
|
||||||
product_id=self.product_id,
|
|
||||||
num=self.num,
|
|
||||||
name=self.name,
|
|
||||||
price=float(self.price),
|
|
||||||
status=self.status,
|
|
||||||
user_id=self.user_id,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from uuid import UUID
|
|||||||
from sqlalchemy import BigInteger
|
from sqlalchemy import BigInteger
|
||||||
from sqlmodel import Field, Relationship, Index
|
from sqlmodel import Field, Relationship, Index
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str32, Str64
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .object import Object
|
from .object import Object
|
||||||
@@ -31,12 +31,9 @@ class PhysicalFileBase(SQLModelBase):
|
|||||||
size: int = Field(default=0, sa_type=BigInteger)
|
size: int = Field(default=0, sa_type=BigInteger)
|
||||||
"""文件大小(字节)"""
|
"""文件大小(字节)"""
|
||||||
|
|
||||||
checksum_md5: Str32 | None = None
|
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||||
"""MD5校验和(用于文件去重和完整性校验)"""
|
"""MD5校验和(用于文件去重和完整性校验)"""
|
||||||
|
|
||||||
checksum_sha256: Str64 | None = None
|
|
||||||
"""SHA256校验和"""
|
|
||||||
|
|
||||||
|
|
||||||
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
|
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from sqlmodel import Field, Relationship, text
|
from sqlmodel import Field, Relationship, text
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .object import Object
|
from .object import Object
|
||||||
@@ -37,22 +37,22 @@ class PolicyType(StrEnum):
|
|||||||
class PolicyBase(SQLModelBase):
|
class PolicyBase(SQLModelBase):
|
||||||
"""存储策略基础字段,供 DTO 和数据库模型共享"""
|
"""存储策略基础字段,供 DTO 和数据库模型共享"""
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255)
|
||||||
"""策略名称"""
|
"""策略名称"""
|
||||||
|
|
||||||
type: PolicyType
|
type: PolicyType
|
||||||
"""存储策略类型"""
|
"""存储策略类型"""
|
||||||
|
|
||||||
server: Str255 | None = None
|
server: str | None = Field(default=None, max_length=255)
|
||||||
"""服务器地址(本地策略为绝对路径)"""
|
"""服务器地址(本地策略为绝对路径)"""
|
||||||
|
|
||||||
bucket_name: Str255 | None = None
|
bucket_name: str | None = Field(default=None, max_length=255)
|
||||||
"""存储桶名称"""
|
"""存储桶名称"""
|
||||||
|
|
||||||
is_private: bool = True
|
is_private: bool = True
|
||||||
"""是否为私有空间"""
|
"""是否为私有空间"""
|
||||||
|
|
||||||
base_url: Str255 | None = None
|
base_url: str | None = Field(default=None, max_length=255)
|
||||||
"""访问文件的基础URL"""
|
"""访问文件的基础URL"""
|
||||||
|
|
||||||
access_key: str | None = None
|
access_key: str | None = None
|
||||||
@@ -67,10 +67,10 @@ class PolicyBase(SQLModelBase):
|
|||||||
auto_rename: bool = False
|
auto_rename: bool = False
|
||||||
"""是否自动重命名"""
|
"""是否自动重命名"""
|
||||||
|
|
||||||
dir_name_rule: Str255 | None = None
|
dir_name_rule: str | None = Field(default=None, max_length=255)
|
||||||
"""目录命名规则"""
|
"""目录命名规则"""
|
||||||
|
|
||||||
file_name_rule: Str255 | None = None
|
file_name_rule: str | None = Field(default=None, max_length=255)
|
||||||
"""文件命名规则"""
|
"""文件命名规则"""
|
||||||
|
|
||||||
is_origin_link_enable: bool = False
|
is_origin_link_enable: bool = False
|
||||||
@@ -102,94 +102,6 @@ class PolicySummary(SQLModelBase):
|
|||||||
"""是否私有"""
|
"""是否私有"""
|
||||||
|
|
||||||
|
|
||||||
class PolicyCreateRequest(PolicyBase):
|
|
||||||
"""创建存储策略请求 DTO,包含 PolicyOptions 扁平字段"""
|
|
||||||
|
|
||||||
# PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
|
|
||||||
token: str | None = None
|
|
||||||
"""访问令牌"""
|
|
||||||
|
|
||||||
file_type: str | None = None
|
|
||||||
"""允许的文件类型"""
|
|
||||||
|
|
||||||
mimetype: str | None = Field(default=None, max_length=127)
|
|
||||||
"""MIME类型"""
|
|
||||||
|
|
||||||
od_redirect: Str255 | None = None
|
|
||||||
"""OneDrive重定向地址"""
|
|
||||||
|
|
||||||
chunk_size: int = Field(default=52428800, ge=1)
|
|
||||||
"""分片上传大小(字节),默认50MB"""
|
|
||||||
|
|
||||||
s3_path_style: bool = False
|
|
||||||
"""是否使用S3路径风格"""
|
|
||||||
|
|
||||||
s3_region: Str64 = 'us-east-1'
|
|
||||||
"""S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
|
|
||||||
|
|
||||||
|
|
||||||
class PolicyUpdateRequest(SQLModelBase):
|
|
||||||
"""更新存储策略请求 DTO(所有字段可选)"""
|
|
||||||
|
|
||||||
name: Str255 | None = None
|
|
||||||
"""策略名称"""
|
|
||||||
|
|
||||||
server: Str255 | None = None
|
|
||||||
"""服务器地址"""
|
|
||||||
|
|
||||||
bucket_name: Str255 | None = None
|
|
||||||
"""存储桶名称"""
|
|
||||||
|
|
||||||
is_private: bool | None = None
|
|
||||||
"""是否为私有空间"""
|
|
||||||
|
|
||||||
base_url: Str255 | None = None
|
|
||||||
"""访问文件的基础URL"""
|
|
||||||
|
|
||||||
access_key: str | None = None
|
|
||||||
"""Access Key"""
|
|
||||||
|
|
||||||
secret_key: str | None = None
|
|
||||||
"""Secret Key"""
|
|
||||||
|
|
||||||
max_size: int | None = Field(default=None, ge=0)
|
|
||||||
"""允许上传的最大文件尺寸(字节)"""
|
|
||||||
|
|
||||||
auto_rename: bool | None = None
|
|
||||||
"""是否自动重命名"""
|
|
||||||
|
|
||||||
dir_name_rule: Str255 | None = None
|
|
||||||
"""目录命名规则"""
|
|
||||||
|
|
||||||
file_name_rule: Str255 | None = None
|
|
||||||
"""文件命名规则"""
|
|
||||||
|
|
||||||
is_origin_link_enable: bool | None = None
|
|
||||||
"""是否开启源链接访问"""
|
|
||||||
|
|
||||||
# PolicyOptions 字段
|
|
||||||
token: str | None = None
|
|
||||||
"""访问令牌"""
|
|
||||||
|
|
||||||
file_type: str | None = None
|
|
||||||
"""允许的文件类型"""
|
|
||||||
|
|
||||||
mimetype: str | None = Field(default=None, max_length=127)
|
|
||||||
"""MIME类型"""
|
|
||||||
|
|
||||||
od_redirect: Str255 | None = None
|
|
||||||
"""OneDrive重定向地址"""
|
|
||||||
|
|
||||||
chunk_size: int | None = Field(default=None, ge=1)
|
|
||||||
"""分片上传大小(字节)"""
|
|
||||||
|
|
||||||
s3_path_style: bool | None = None
|
|
||||||
"""是否使用S3路径风格"""
|
|
||||||
|
|
||||||
s3_region: Str64 | None = None
|
|
||||||
"""S3 区域"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
|
||||||
@@ -205,7 +117,7 @@ class PolicyOptionsBase(SQLModelBase):
|
|||||||
mimetype: str | None = Field(default=None, max_length=127)
|
mimetype: str | None = Field(default=None, max_length=127)
|
||||||
"""MIME类型"""
|
"""MIME类型"""
|
||||||
|
|
||||||
od_redirect: Str255 | None = None
|
od_redirect: str | None = Field(default=None, max_length=255)
|
||||||
"""OneDrive重定向地址"""
|
"""OneDrive重定向地址"""
|
||||||
|
|
||||||
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
|
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
|
||||||
@@ -214,9 +126,6 @@ class PolicyOptionsBase(SQLModelBase):
|
|||||||
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||||
"""是否使用S3路径风格"""
|
"""是否使用S3路径风格"""
|
||||||
|
|
||||||
s3_region: Str64 = Field(default='us-east-1', sa_column_kwargs={"server_default": "'us-east-1'"})
|
|
||||||
"""S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
|
|
||||||
|
|
||||||
|
|
||||||
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
|
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
|
||||||
"""存储策略选项模型(与Policy一对一关联)"""
|
"""存储策略选项模型(与Policy一对一关联)"""
|
||||||
@@ -237,7 +146,7 @@ class Policy(PolicyBase, UUIDTableBaseMixin):
|
|||||||
"""存储策略模型"""
|
"""存储策略模型"""
|
||||||
|
|
||||||
# 覆盖基类字段以添加数据库专有配置
|
# 覆盖基类字段以添加数据库专有配置
|
||||||
name: Str255 = Field(unique=True)
|
name: str = Field(max_length=255, unique=True)
|
||||||
"""策略名称"""
|
"""策略名称"""
|
||||||
|
|
||||||
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||||
|
|||||||
@@ -1,206 +0,0 @@
|
|||||||
from decimal import Decimal
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import Numeric, BigInteger
|
|
||||||
from sqlmodel import Field, Relationship, text
|
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .order import Order
|
|
||||||
from .redeem import Redeem
|
|
||||||
|
|
||||||
|
|
||||||
class ProductType(StrEnum):
|
|
||||||
"""商品类型枚举"""
|
|
||||||
|
|
||||||
STORAGE_PACK = "storage_pack"
|
|
||||||
"""容量包"""
|
|
||||||
|
|
||||||
GROUP_TIME = "group_time"
|
|
||||||
"""用户组时长"""
|
|
||||||
|
|
||||||
SCORE = "score"
|
|
||||||
"""积分充值"""
|
|
||||||
|
|
||||||
|
|
||||||
class PaymentMethod(StrEnum):
|
|
||||||
"""支付方式枚举"""
|
|
||||||
|
|
||||||
ALIPAY = "alipay"
|
|
||||||
"""支付宝"""
|
|
||||||
|
|
||||||
WECHAT = "wechat"
|
|
||||||
"""微信支付"""
|
|
||||||
|
|
||||||
STRIPE = "stripe"
|
|
||||||
"""Stripe"""
|
|
||||||
|
|
||||||
EASYPAY = "easypay"
|
|
||||||
"""易支付"""
|
|
||||||
|
|
||||||
CUSTOM = "custom"
|
|
||||||
"""自定义支付"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
|
||||||
|
|
||||||
class ProductBase(SQLModelBase):
|
|
||||||
"""商品基础字段"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""商品名称"""
|
|
||||||
|
|
||||||
type: ProductType
|
|
||||||
"""商品类型"""
|
|
||||||
|
|
||||||
description: str | None = None
|
|
||||||
"""商品描述"""
|
|
||||||
|
|
||||||
|
|
||||||
class ProductCreateRequest(ProductBase):
|
|
||||||
"""创建商品请求 DTO"""
|
|
||||||
|
|
||||||
name: Str255
|
|
||||||
"""商品名称"""
|
|
||||||
|
|
||||||
price: Decimal = Field(ge=0, decimal_places=2)
|
|
||||||
"""商品价格(元)"""
|
|
||||||
|
|
||||||
is_active: bool = True
|
|
||||||
"""是否上架"""
|
|
||||||
|
|
||||||
sort_order: int = Field(default=0, ge=0)
|
|
||||||
"""排序权重(越大越靠前)"""
|
|
||||||
|
|
||||||
# storage_pack 专用
|
|
||||||
size: int | None = Field(default=None, ge=0)
|
|
||||||
"""容量大小(字节),type=storage_pack 时必填"""
|
|
||||||
|
|
||||||
duration_days: int | None = Field(default=None, ge=1)
|
|
||||||
"""有效天数,type=storage_pack/group_time 时必填"""
|
|
||||||
|
|
||||||
# group_time 专用
|
|
||||||
group_id: UUID | None = None
|
|
||||||
"""目标用户组UUID,type=group_time 时必填"""
|
|
||||||
|
|
||||||
# score 专用
|
|
||||||
score_amount: int | None = Field(default=None, ge=1)
|
|
||||||
"""积分数量,type=score 时必填"""
|
|
||||||
|
|
||||||
|
|
||||||
class ProductUpdateRequest(SQLModelBase):
|
|
||||||
"""更新商品请求 DTO(所有字段可选)"""
|
|
||||||
|
|
||||||
name: Str255 | None = None
|
|
||||||
"""商品名称"""
|
|
||||||
|
|
||||||
description: str | None = None
|
|
||||||
"""商品描述"""
|
|
||||||
|
|
||||||
price: Decimal | None = Field(default=None, ge=0, decimal_places=2)
|
|
||||||
"""商品价格(元)"""
|
|
||||||
|
|
||||||
is_active: bool | None = None
|
|
||||||
"""是否上架"""
|
|
||||||
|
|
||||||
sort_order: int | None = Field(default=None, ge=0)
|
|
||||||
"""排序权重"""
|
|
||||||
|
|
||||||
size: int | None = Field(default=None, ge=0)
|
|
||||||
"""容量大小(字节)"""
|
|
||||||
|
|
||||||
duration_days: int | None = Field(default=None, ge=1)
|
|
||||||
"""有效天数"""
|
|
||||||
|
|
||||||
group_id: UUID | None = None
|
|
||||||
"""目标用户组UUID"""
|
|
||||||
|
|
||||||
score_amount: int | None = Field(default=None, ge=1)
|
|
||||||
"""积分数量"""
|
|
||||||
|
|
||||||
|
|
||||||
class ProductResponse(ProductBase):
|
|
||||||
"""商品响应 DTO"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
"""商品UUID"""
|
|
||||||
|
|
||||||
price: float
|
|
||||||
"""商品价格(元)"""
|
|
||||||
|
|
||||||
is_active: bool
|
|
||||||
"""是否上架"""
|
|
||||||
|
|
||||||
sort_order: int
|
|
||||||
"""排序权重"""
|
|
||||||
|
|
||||||
size: int | None = None
|
|
||||||
"""容量大小(字节)"""
|
|
||||||
|
|
||||||
duration_days: int | None = None
|
|
||||||
"""有效天数"""
|
|
||||||
|
|
||||||
group_id: UUID | None = None
|
|
||||||
"""目标用户组UUID"""
|
|
||||||
|
|
||||||
score_amount: int | None = None
|
|
||||||
"""积分数量"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
|
||||||
|
|
||||||
class Product(ProductBase, UUIDTableBaseMixin):
|
|
||||||
"""商品模型"""
|
|
||||||
|
|
||||||
name: Str255
|
|
||||||
"""商品名称"""
|
|
||||||
|
|
||||||
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
|
|
||||||
"""商品价格(元)"""
|
|
||||||
|
|
||||||
is_active: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
|
||||||
"""是否上架"""
|
|
||||||
|
|
||||||
sort_order: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
|
||||||
"""排序权重(越大越靠前)"""
|
|
||||||
|
|
||||||
# storage_pack 专用
|
|
||||||
size: int | None = Field(default=None, sa_type=BigInteger)
|
|
||||||
"""容量大小(字节),type=storage_pack 时必填"""
|
|
||||||
|
|
||||||
duration_days: int | None = None
|
|
||||||
"""有效天数,type=storage_pack/group_time 时必填"""
|
|
||||||
|
|
||||||
# group_time 专用
|
|
||||||
group_id: UUID | None = Field(default=None, foreign_key="group.id", ondelete="SET NULL")
|
|
||||||
"""目标用户组UUID,type=group_time 时必填"""
|
|
||||||
|
|
||||||
# score 专用
|
|
||||||
score_amount: int | None = None
|
|
||||||
"""积分数量,type=score 时必填"""
|
|
||||||
|
|
||||||
# 关系
|
|
||||||
orders: list["Order"] = Relationship(back_populates="product")
|
|
||||||
"""关联的订单列表"""
|
|
||||||
|
|
||||||
redeems: list["Redeem"] = Relationship(back_populates="product")
|
|
||||||
"""关联的兑换码列表"""
|
|
||||||
|
|
||||||
def to_response(self) -> ProductResponse:
|
|
||||||
"""转换为响应 DTO"""
|
|
||||||
return ProductResponse(
|
|
||||||
id=self.id,
|
|
||||||
name=self.name,
|
|
||||||
type=self.type,
|
|
||||||
description=self.description,
|
|
||||||
price=float(self.price),
|
|
||||||
is_active=self.is_active,
|
|
||||||
sort_order=self.sort_order,
|
|
||||||
size=self.size,
|
|
||||||
duration_days=self.duration_days,
|
|
||||||
group_id=self.group_id,
|
|
||||||
score_amount=self.score_amount,
|
|
||||||
)
|
|
||||||
@@ -1,141 +1,22 @@
|
|||||||
from datetime import datetime
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlmodel import Field, Relationship, text
|
from sqlmodel import Field, text
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .product import Product
|
|
||||||
from .user import User
|
|
||||||
|
|
||||||
|
|
||||||
class RedeemType(StrEnum):
|
class RedeemType(StrEnum):
|
||||||
"""兑换码类型枚举"""
|
"""兑换码类型枚举"""
|
||||||
|
# [TODO] 补充具体兑换码类型
|
||||||
|
pass
|
||||||
|
|
||||||
STORAGE_PACK = "storage_pack"
|
|
||||||
"""容量包"""
|
|
||||||
|
|
||||||
GROUP_TIME = "group_time"
|
|
||||||
"""用户组时长"""
|
|
||||||
|
|
||||||
SCORE = "score"
|
|
||||||
"""积分充值"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
|
||||||
|
|
||||||
class RedeemCreateRequest(SQLModelBase):
|
|
||||||
"""批量生成兑换码请求 DTO"""
|
|
||||||
|
|
||||||
product_id: UUID
|
|
||||||
"""关联商品UUID"""
|
|
||||||
|
|
||||||
count: int = Field(default=1, ge=1, le=100)
|
|
||||||
"""生成数量"""
|
|
||||||
|
|
||||||
|
|
||||||
class RedeemUseRequest(SQLModelBase):
|
|
||||||
"""使用兑换码请求 DTO"""
|
|
||||||
|
|
||||||
code: str
|
|
||||||
"""兑换码"""
|
|
||||||
|
|
||||||
|
|
||||||
class RedeemInfoResponse(SQLModelBase):
|
|
||||||
"""兑换码信息响应 DTO(用户侧)"""
|
|
||||||
|
|
||||||
type: RedeemType
|
|
||||||
"""兑换码类型"""
|
|
||||||
|
|
||||||
product_name: str | None = None
|
|
||||||
"""关联商品名称"""
|
|
||||||
|
|
||||||
num: int
|
|
||||||
"""可兑换数量"""
|
|
||||||
|
|
||||||
is_used: bool
|
|
||||||
"""是否已使用"""
|
|
||||||
|
|
||||||
|
|
||||||
class RedeemAdminResponse(SQLModelBase):
|
|
||||||
"""兑换码管理响应 DTO(管理侧)"""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
"""兑换码ID"""
|
|
||||||
|
|
||||||
type: RedeemType
|
|
||||||
"""兑换码类型"""
|
|
||||||
|
|
||||||
product_id: UUID | None = None
|
|
||||||
"""关联商品UUID"""
|
|
||||||
|
|
||||||
num: int
|
|
||||||
"""可兑换数量"""
|
|
||||||
|
|
||||||
code: str
|
|
||||||
"""兑换码"""
|
|
||||||
|
|
||||||
is_used: bool
|
|
||||||
"""是否已使用"""
|
|
||||||
|
|
||||||
used_at: datetime | None = None
|
|
||||||
"""使用时间"""
|
|
||||||
|
|
||||||
used_by: UUID | None = None
|
|
||||||
"""使用者UUID"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
|
||||||
|
|
||||||
class Redeem(SQLModelBase, TableBaseMixin):
|
class Redeem(SQLModelBase, TableBaseMixin):
|
||||||
"""兑换码模型"""
|
"""兑换码模型"""
|
||||||
|
|
||||||
type: RedeemType
|
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
"""兑换码类型"""
|
"""兑换码类型 [TODO] 待定义枚举"""
|
||||||
|
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
|
||||||
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
|
||||||
"""关联商品UUID"""
|
code: str = Field(unique=True, index=True, description="兑换码,唯一")
|
||||||
|
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
|
||||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
|
||||||
"""可兑换数量/时长等"""
|
|
||||||
|
|
||||||
code: str = Field(unique=True, index=True)
|
|
||||||
"""兑换码,唯一"""
|
|
||||||
|
|
||||||
is_used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
|
||||||
"""是否已使用"""
|
|
||||||
|
|
||||||
used_at: datetime | None = None
|
|
||||||
"""使用时间"""
|
|
||||||
|
|
||||||
used_by: UUID | None = Field(default=None, foreign_key="user.id", ondelete="SET NULL")
|
|
||||||
"""使用者UUID"""
|
|
||||||
|
|
||||||
# 关系
|
|
||||||
product: "Product" = Relationship(back_populates="redeems")
|
|
||||||
user: "User" = Relationship(back_populates="redeems")
|
|
||||||
|
|
||||||
def to_admin_response(self) -> RedeemAdminResponse:
|
|
||||||
"""转换为管理侧响应 DTO"""
|
|
||||||
return RedeemAdminResponse(
|
|
||||||
id=self.id,
|
|
||||||
type=self.type,
|
|
||||||
product_id=self.product_id,
|
|
||||||
num=self.num,
|
|
||||||
code=self.code,
|
|
||||||
is_used=self.is_used,
|
|
||||||
used_at=self.used_at,
|
|
||||||
used_by=self.used_by,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_info_response(self, product_name: str | None = None) -> RedeemInfoResponse:
|
|
||||||
"""转换为用户侧响应 DTO"""
|
|
||||||
return RedeemInfoResponse(
|
|
||||||
type=self.type,
|
|
||||||
product_name=product_name,
|
|
||||||
num=self.num,
|
|
||||||
is_used=self.is_used,
|
|
||||||
)
|
|
||||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .share import Share
|
from .share import Share
|
||||||
@@ -21,7 +21,7 @@ class Report(SQLModelBase, TableBaseMixin):
|
|||||||
|
|
||||||
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
"""举报原因 [TODO] 待定义枚举"""
|
"""举报原因 [TODO] 待定义枚举"""
|
||||||
description: Str255 | None = Field(default=None, description="补充描述")
|
description: str | None = Field(default=None, max_length=255, description="补充描述")
|
||||||
|
|
||||||
# 外键
|
# 外键
|
||||||
share_id: UUID = Field(
|
share_id: UUID = Field(
|
||||||
|
|||||||
@@ -76,9 +76,6 @@ class SiteConfigResponse(SQLModelBase):
|
|||||||
email_binding_required: bool = True
|
email_binding_required: bool = True
|
||||||
"""是否强制绑定邮箱"""
|
"""是否强制绑定邮箱"""
|
||||||
|
|
||||||
avatar_max_size: int = 2097152
|
|
||||||
"""头像文件最大字节数(默认 2MB)"""
|
|
||||||
|
|
||||||
footer_code: str | None = None
|
footer_code: str | None = None
|
||||||
"""自定义页脚代码"""
|
"""自定义页脚代码"""
|
||||||
|
|
||||||
@@ -163,7 +160,6 @@ class SettingsType(StrEnum):
|
|||||||
VERSION = "version"
|
VERSION = "version"
|
||||||
VIEW = "view"
|
VIEW = "view"
|
||||||
WOPI = "wopi"
|
WOPI = "wopi"
|
||||||
FILE_CATEGORY = "file_category"
|
|
||||||
|
|
||||||
# 数据库模型
|
# 数据库模型
|
||||||
class Setting(SettingItem, TableBaseMixin):
|
class Setting(SettingItem, TableBaseMixin):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
from .model_base import ResponseBase
|
from .model_base import ResponseBase
|
||||||
from .object import ObjectType
|
from .object import ObjectType
|
||||||
@@ -52,10 +52,10 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
|
|||||||
Index("ix_share_object", "object_id"),
|
Index("ix_share_object", "object_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
code: Str64 = Field(nullable=False, index=True)
|
code: str = Field(max_length=64, nullable=False, index=True)
|
||||||
"""分享码"""
|
"""分享码"""
|
||||||
|
|
||||||
password: Str255 | None = None
|
password: str | None = Field(default=None, max_length=255)
|
||||||
"""分享密码(加密后)"""
|
"""分享密码(加密后)"""
|
||||||
|
|
||||||
object_id: UUID = Field(
|
object_id: UUID = Field(
|
||||||
@@ -80,7 +80,7 @@ class Share(SQLModelBase, UUIDTableBaseMixin):
|
|||||||
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||||
"""是否允许预览"""
|
"""是否允许预览"""
|
||||||
|
|
||||||
source_name: Str255 | None = None
|
source_name: str | None = Field(default=None, max_length=255)
|
||||||
"""源名称(冗余字段,便于展示)"""
|
"""源名称(冗余字段,便于展示)"""
|
||||||
|
|
||||||
score: int = Field(default=0, ge=0)
|
score: int = Field(default=0, ge=0)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, Index
|
from sqlmodel import Field, Relationship, Index
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .object import Object
|
from .object import Object
|
||||||
@@ -17,7 +17,7 @@ class SourceLink(SQLModelBase, TableBaseMixin):
|
|||||||
Index("ix_sourcelink_object_name", "object_id", "name"),
|
Index("ix_sourcelink_object_name", "object_id", "name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255)
|
||||||
"""链接名称"""
|
"""链接名称"""
|
||||||
|
|
||||||
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
|
|||||||
@@ -1,59 +1,22 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import BigInteger
|
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||||
from sqlmodel import Field, Relationship
|
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
|
||||||
|
|
||||||
class StoragePackResponse(SQLModelBase):
|
|
||||||
"""容量包响应 DTO"""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
"""容量包ID"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""容量包名称"""
|
|
||||||
|
|
||||||
size: int
|
|
||||||
"""容量大小(字节)"""
|
|
||||||
|
|
||||||
active_time: datetime | None = None
|
|
||||||
"""激活时间"""
|
|
||||||
|
|
||||||
expired_time: datetime | None = None
|
|
||||||
"""过期时间"""
|
|
||||||
|
|
||||||
product_id: UUID | None = None
|
|
||||||
"""来源商品UUID"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
|
||||||
|
|
||||||
class StoragePack(SQLModelBase, TableBaseMixin):
|
class StoragePack(SQLModelBase, TableBaseMixin):
|
||||||
"""容量包模型"""
|
"""容量包模型"""
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255, description="容量包名称")
|
||||||
"""容量包名称"""
|
active_time: datetime | None = Field(default=None, description="激活时间")
|
||||||
|
expired_time: datetime | None = Field(default=None, index=True, description="过期时间")
|
||||||
active_time: datetime | None = None
|
size: int = Field(description="容量包大小(字节)")
|
||||||
"""激活时间"""
|
|
||||||
|
|
||||||
expired_time: datetime | None = Field(default=None, index=True)
|
|
||||||
"""过期时间"""
|
|
||||||
|
|
||||||
size: int = Field(sa_type=BigInteger)
|
|
||||||
"""容量包大小(字节)"""
|
|
||||||
|
|
||||||
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
|
||||||
"""来源商品UUID"""
|
|
||||||
|
|
||||||
# 外键
|
# 外键
|
||||||
user_id: UUID = Field(
|
user_id: UUID = Field(
|
||||||
@@ -65,14 +28,3 @@ class StoragePack(SQLModelBase, TableBaseMixin):
|
|||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
user: "User" = Relationship(back_populates="storage_packs")
|
user: "User" = Relationship(back_populates="storage_packs")
|
||||||
|
|
||||||
def to_response(self) -> StoragePackResponse:
|
|
||||||
"""转换为响应 DTO"""
|
|
||||||
return StoragePackResponse(
|
|
||||||
id=self.id,
|
|
||||||
name=self.name,
|
|
||||||
size=self.size,
|
|
||||||
active_time=self.active_time,
|
|
||||||
expired_time=self.expired_time,
|
|
||||||
product_id=self.product_id,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -24,13 +24,13 @@ class Tag(SQLModelBase, TableBaseMixin):
|
|||||||
|
|
||||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
|
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
|
||||||
|
|
||||||
name: Str255
|
name: str = Field(max_length=255)
|
||||||
"""标签名称"""
|
"""标签名称"""
|
||||||
|
|
||||||
icon: Str255 | None = None
|
icon: str | None = Field(default=None, max_length=255)
|
||||||
"""标签图标"""
|
"""标签图标"""
|
||||||
|
|
||||||
color: Str255 | None = None
|
color: str | None = Field(default=None, max_length=255)
|
||||||
"""标签颜色"""
|
"""标签颜色"""
|
||||||
|
|
||||||
type: TagType = Field(default=TagType.MANUAL)
|
type: TagType = Field(default=TagType.MANUAL)
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ class TaskStatus(StrEnum):
|
|||||||
|
|
||||||
class TaskType(StrEnum):
|
class TaskType(StrEnum):
|
||||||
"""任务类型枚举"""
|
"""任务类型枚举"""
|
||||||
POLICY_MIGRATE = "policy_migrate"
|
# [TODO] 补充具体任务类型
|
||||||
"""存储策略迁移"""
|
pass
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
# ==================== DTO 模型 ====================
|
||||||
@@ -39,7 +39,7 @@ class TaskSummaryBase(SQLModelBase):
|
|||||||
id: int
|
id: int
|
||||||
"""任务ID"""
|
"""任务ID"""
|
||||||
|
|
||||||
type: TaskType
|
type: int
|
||||||
"""任务类型"""
|
"""任务类型"""
|
||||||
|
|
||||||
status: TaskStatus
|
status: TaskStatus
|
||||||
@@ -91,14 +91,7 @@ class TaskPropsBase(SQLModelBase):
|
|||||||
file_ids: str | None = None
|
file_ids: str | None = None
|
||||||
"""文件ID列表(逗号分隔)"""
|
"""文件ID列表(逗号分隔)"""
|
||||||
|
|
||||||
source_policy_id: UUID | None = None
|
# [TODO] 根据业务需求补充更多字段
|
||||||
"""源存储策略UUID"""
|
|
||||||
|
|
||||||
dest_policy_id: UUID | None = None
|
|
||||||
"""目标存储策略UUID"""
|
|
||||||
|
|
||||||
object_id: UUID | None = None
|
|
||||||
"""关联的对象UUID"""
|
|
||||||
|
|
||||||
|
|
||||||
class TaskProps(TaskPropsBase, TableBaseMixin):
|
class TaskProps(TaskPropsBase, TableBaseMixin):
|
||||||
@@ -106,7 +99,7 @@ class TaskProps(TaskPropsBase, TableBaseMixin):
|
|||||||
|
|
||||||
task_id: int = Field(
|
task_id: int = Field(
|
||||||
foreign_key="task.id",
|
foreign_key="task.id",
|
||||||
unique=True,
|
primary_key=True,
|
||||||
ondelete="CASCADE"
|
ondelete="CASCADE"
|
||||||
)
|
)
|
||||||
"""关联的任务ID"""
|
"""关联的任务ID"""
|
||||||
@@ -128,8 +121,8 @@ class Task(SQLModelBase, TableBaseMixin):
|
|||||||
status: TaskStatus = Field(default=TaskStatus.QUEUED)
|
status: TaskStatus = Field(default=TaskStatus.QUEUED)
|
||||||
"""任务状态"""
|
"""任务状态"""
|
||||||
|
|
||||||
type: TaskType
|
type: int = Field(default=0)
|
||||||
"""任务类型"""
|
"""任务类型 [TODO] 待定义枚举"""
|
||||||
|
|
||||||
progress: int = Field(default=0, ge=0, le=100)
|
progress: int = Field(default=0, ge=0, le=100)
|
||||||
"""任务进度(0-100)"""
|
"""任务进度(0-100)"""
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||||
|
|
||||||
@@ -11,7 +11,7 @@ from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
|||||||
class ThemePresetBase(SQLModelBase):
|
class ThemePresetBase(SQLModelBase):
|
||||||
"""主题预设基础字段"""
|
"""主题预设基础字段"""
|
||||||
|
|
||||||
name: Str100
|
name: str = Field(max_length=100)
|
||||||
"""预设名称"""
|
"""预设名称"""
|
||||||
|
|
||||||
is_default: bool = False
|
is_default: bool = False
|
||||||
@@ -42,7 +42,7 @@ class ThemePresetBase(SQLModelBase):
|
|||||||
class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
|
class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
|
||||||
"""主题预设表"""
|
"""主题预设表"""
|
||||||
|
|
||||||
name: Str100 = Field(unique=True)
|
name: str = Field(max_length=100, unique=True)
|
||||||
"""预设名称(唯一约束)"""
|
"""预设名称(唯一约束)"""
|
||||||
|
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ class ThemePreset(ThemePresetBase, UUIDTableBaseMixin):
|
|||||||
class ThemePresetCreateRequest(SQLModelBase):
|
class ThemePresetCreateRequest(SQLModelBase):
|
||||||
"""创建主题预设请求 DTO"""
|
"""创建主题预设请求 DTO"""
|
||||||
|
|
||||||
name: Str100
|
name: str = Field(max_length=100)
|
||||||
"""预设名称"""
|
"""预设名称"""
|
||||||
|
|
||||||
colors: ThemeColorsBase
|
colors: ThemeColorsBase
|
||||||
@@ -61,7 +61,7 @@ class ThemePresetCreateRequest(SQLModelBase):
|
|||||||
class ThemePresetUpdateRequest(SQLModelBase):
|
class ThemePresetUpdateRequest(SQLModelBase):
|
||||||
"""更新主题预设请求 DTO"""
|
"""更新主题预设请求 DTO"""
|
||||||
|
|
||||||
name: Str100 | None = None
|
name: str | None = Field(default=None, max_length=100)
|
||||||
"""预设名称(可选)"""
|
"""预设名称(可选)"""
|
||||||
|
|
||||||
colors: ThemeColorsBase | None = None
|
colors: ThemeColorsBase | None = None
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ from typing import Literal, TYPE_CHECKING, TypeVar
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import BigInteger, BinaryExpression, ClauseElement, and_
|
from sqlalchemy import BinaryExpression, ClauseElement, and_
|
||||||
from sqlmodel import Field, Relationship
|
from sqlmodel import Field, Relationship
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.main import RelationshipInfo
|
from sqlmodel.main import RelationshipInfo
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse, Str255
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse
|
||||||
|
|
||||||
from .auth_identity import AuthProviderType
|
from .auth_identity import AuthProviderType
|
||||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||||
@@ -23,7 +23,6 @@ if TYPE_CHECKING:
|
|||||||
from .download import Download
|
from .download import Download
|
||||||
from .object import Object
|
from .object import Object
|
||||||
from .order import Order
|
from .order import Order
|
||||||
from .redeem import Redeem
|
|
||||||
from .share import Share
|
from .share import Share
|
||||||
from .storage_pack import StoragePack
|
from .storage_pack import StoragePack
|
||||||
from .tag import Tag
|
from .tag import Tag
|
||||||
@@ -474,10 +473,10 @@ class User(UserBase, UUIDTableBaseMixin):
|
|||||||
status: UserStatus = UserStatus.ACTIVE
|
status: UserStatus = UserStatus.ACTIVE
|
||||||
"""用户状态"""
|
"""用户状态"""
|
||||||
|
|
||||||
storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}, ge=0)
|
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||||
"""已用存储空间(字节)"""
|
"""已用存储空间(字节)"""
|
||||||
|
|
||||||
avatar: Str255 = Field(default="default")
|
avatar: str = Field(default="default", max_length=255)
|
||||||
"""头像地址"""
|
"""头像地址"""
|
||||||
|
|
||||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||||
@@ -571,14 +570,6 @@ class User(UserBase, UUIDTableBaseMixin):
|
|||||||
back_populates="user",
|
back_populates="user",
|
||||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||||
)
|
)
|
||||||
redeems: list["Redeem"] = Relationship(
|
|
||||||
back_populates="user",
|
|
||||||
sa_relationship_kwargs={
|
|
||||||
"cascade": "all, delete-orphan",
|
|
||||||
"foreign_keys": "[Redeem.used_by]"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
"""用户使用过的兑换码列表"""
|
|
||||||
shares: list["Share"] = Relationship(
|
shares: list["Share"] = Relationship(
|
||||||
back_populates="user",
|
back_populates="user",
|
||||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from uuid import UUID
|
|||||||
from sqlalchemy import Column, Text
|
from sqlalchemy import Column, Text
|
||||||
from sqlmodel import Field, Relationship
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str32, Str100, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -51,7 +51,7 @@ class AuthnDetailResponse(SQLModelBase):
|
|||||||
class AuthnRenameRequest(SQLModelBase):
|
class AuthnRenameRequest(SQLModelBase):
|
||||||
"""WebAuthn 凭证重命名请求 DTO"""
|
"""WebAuthn 凭证重命名请求 DTO"""
|
||||||
|
|
||||||
name: Str100
|
name: str = Field(max_length=100)
|
||||||
"""新的凭证名称"""
|
"""新的凭证名称"""
|
||||||
|
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class AuthnRenameRequest(SQLModelBase):
|
|||||||
class UserAuthn(SQLModelBase, TableBaseMixin):
|
class UserAuthn(SQLModelBase, TableBaseMixin):
|
||||||
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
|
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
|
||||||
|
|
||||||
credential_id: Str255 = Field(unique=True, index=True)
|
credential_id: str = Field(max_length=255, unique=True, index=True)
|
||||||
"""凭证 ID,Base64URL 编码"""
|
"""凭证 ID,Base64URL 编码"""
|
||||||
|
|
||||||
credential_public_key: str = Field(sa_column=Column(Text))
|
credential_public_key: str = Field(sa_column=Column(Text))
|
||||||
@@ -69,16 +69,16 @@ class UserAuthn(SQLModelBase, TableBaseMixin):
|
|||||||
sign_count: int = Field(default=0, ge=0)
|
sign_count: int = Field(default=0, ge=0)
|
||||||
"""签名计数器,用于防重放攻击"""
|
"""签名计数器,用于防重放攻击"""
|
||||||
|
|
||||||
credential_device_type: Str32
|
credential_device_type: str = Field(max_length=32)
|
||||||
"""凭证设备类型:'single_device' 或 'multi_device'"""
|
"""凭证设备类型:'single_device' 或 'multi_device'"""
|
||||||
|
|
||||||
credential_backed_up: bool = Field(default=False)
|
credential_backed_up: bool = Field(default=False)
|
||||||
"""凭证是否已备份"""
|
"""凭证是否已备份"""
|
||||||
|
|
||||||
transports: Str255 | None = None
|
transports: str | None = Field(default=None, max_length=255)
|
||||||
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
|
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
|
||||||
|
|
||||||
name: Str100 | None = None
|
name: str | None = Field(default=None, max_length=100)
|
||||||
"""用户自定义的凭证名称,便于识别"""
|
"""用户自定义的凭证名称,便于识别"""
|
||||||
|
|
||||||
# 外键
|
# 外键
|
||||||
|
|||||||
@@ -1,117 +1,32 @@
|
|||||||
"""
|
|
||||||
WebDAV 账户模型
|
|
||||||
|
|
||||||
管理用户的 WebDAV 连接账户,每个账户对应一个挂载根路径。
|
|
||||||
通过 HTTP Basic Auth 认证访问 DAV 协议端点。
|
|
||||||
"""
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||||
|
|
||||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
|
class WebDAV(SQLModelBase, TableBaseMixin):
|
||||||
# ==================== Base 模型 ====================
|
"""WebDAV账户模型"""
|
||||||
|
|
||||||
class WebDAVBase(SQLModelBase):
|
|
||||||
"""WebDAV 账户基础字段"""
|
|
||||||
|
|
||||||
name: Str255
|
|
||||||
"""账户名称(同一用户下唯一)"""
|
|
||||||
|
|
||||||
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"})
|
|
||||||
"""挂载根目录路径"""
|
|
||||||
|
|
||||||
readonly: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
|
|
||||||
"""是否只读"""
|
|
||||||
|
|
||||||
use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
|
|
||||||
"""是否使用代理下载"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
|
||||||
|
|
||||||
class WebDAV(WebDAVBase, TableBaseMixin):
|
|
||||||
"""WebDAV 账户模型"""
|
|
||||||
|
|
||||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
|
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
|
||||||
|
|
||||||
password: Str255
|
name: str = Field(max_length=255, description="WebDAV账户名")
|
||||||
"""密码(Argon2 哈希)"""
|
password: str = Field(max_length=255, description="WebDAV密码")
|
||||||
|
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
|
||||||
|
readonly: bool = Field(default=False, description="是否只读")
|
||||||
|
use_proxy: bool = Field(default=False, description="是否使用代理下载")
|
||||||
|
|
||||||
# 外键
|
# 外键
|
||||||
user_id: UUID = Field(
|
user_id: UUID = Field(
|
||||||
foreign_key="user.id",
|
foreign_key="user.id",
|
||||||
index=True,
|
index=True,
|
||||||
ondelete="CASCADE",
|
ondelete="CASCADE"
|
||||||
)
|
)
|
||||||
"""所属用户UUID"""
|
"""所属用户UUID"""
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
user: "User" = Relationship(back_populates="webdavs")
|
user: "User" = Relationship(back_populates="webdavs")
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
|
||||||
|
|
||||||
class WebDAVCreateRequest(SQLModelBase):
|
|
||||||
"""创建 WebDAV 账户请求"""
|
|
||||||
|
|
||||||
name: Str255
|
|
||||||
"""账户名称"""
|
|
||||||
|
|
||||||
password: Str255 = Field(min_length=1)
|
|
||||||
"""账户密码(明文,服务端哈希后存储)"""
|
|
||||||
|
|
||||||
root: str = "/"
|
|
||||||
"""挂载根目录路径"""
|
|
||||||
|
|
||||||
readonly: bool = False
|
|
||||||
"""是否只读"""
|
|
||||||
|
|
||||||
use_proxy: bool = False
|
|
||||||
"""是否使用代理下载"""
|
|
||||||
|
|
||||||
|
|
||||||
class WebDAVUpdateRequest(SQLModelBase):
|
|
||||||
"""更新 WebDAV 账户请求"""
|
|
||||||
|
|
||||||
password: Str255 | None = Field(default=None, min_length=1)
|
|
||||||
"""新密码(为 None 时不修改)"""
|
|
||||||
|
|
||||||
root: str | None = None
|
|
||||||
"""新挂载根目录路径(为 None 时不修改)"""
|
|
||||||
|
|
||||||
readonly: bool | None = None
|
|
||||||
"""是否只读(为 None 时不修改)"""
|
|
||||||
|
|
||||||
use_proxy: bool | None = None
|
|
||||||
"""是否使用代理下载(为 None 时不修改)"""
|
|
||||||
|
|
||||||
|
|
||||||
class WebDAVAccountResponse(SQLModelBase):
|
|
||||||
"""WebDAV 账户响应"""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
"""账户ID"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""账户名称"""
|
|
||||||
|
|
||||||
root: str
|
|
||||||
"""挂载根目录路径"""
|
|
||||||
|
|
||||||
readonly: bool
|
|
||||||
"""是否只读"""
|
|
||||||
|
|
||||||
use_proxy: bool
|
|
||||||
"""是否使用代理下载"""
|
|
||||||
|
|
||||||
created_at: str
|
|
||||||
"""创建时间"""
|
|
||||||
|
|
||||||
updated_at: str
|
|
||||||
"""更新时间"""
|
|
||||||
|
|||||||
2
tests/fixtures/objects.py
vendored
2
tests/fixtures/objects.py
vendored
@@ -92,9 +92,9 @@ class ObjectFactory:
|
|||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
policy_id=policy_id,
|
policy_id=policy_id,
|
||||||
size=size,
|
size=size,
|
||||||
mime_type=kwargs.get("mime_type"),
|
|
||||||
source_name=kwargs.get("source_name", name),
|
source_name=kwargs.get("source_name", name),
|
||||||
upload_session_id=kwargs.get("upload_session_id"),
|
upload_session_id=kwargs.get("upload_session_id"),
|
||||||
|
file_metadata=kwargs.get("file_metadata"),
|
||||||
password=kwargs.get("password"),
|
password=kwargs.get("password"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
8
tests/fixtures/users.py
vendored
8
tests/fixtures/users.py
vendored
@@ -71,7 +71,7 @@ class UserFactory:
|
|||||||
is_verified=True,
|
is_verified=True,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ class UserFactory:
|
|||||||
is_verified=True,
|
is_verified=True,
|
||||||
user_id=admin.id,
|
user_id=admin.id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
return admin
|
return admin
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ class UserFactory:
|
|||||||
is_verified=True,
|
is_verified=True,
|
||||||
user_id=banned_user.id,
|
user_id=banned_user.id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
return banned_user
|
return banned_user
|
||||||
|
|
||||||
@@ -219,6 +219,6 @@ class UserFactory:
|
|||||||
is_verified=True,
|
is_verified=True,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
)
|
)
|
||||||
identity = await identity.save(session)
|
await identity.save(session)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -1,219 +0,0 @@
|
|||||||
"""
|
|
||||||
自定义属性定义端点集成测试
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from httpx import AsyncClient
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取属性定义列表测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_custom_properties_requires_auth(async_client: AsyncClient):
|
|
||||||
"""测试获取属性定义需要认证"""
|
|
||||||
response = await async_client.get("/api/v1/object/custom_property")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_custom_properties_empty(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试获取空的属性定义列表"""
|
|
||||||
response = await async_client.get(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data == []
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 创建属性定义测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_custom_property(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试创建自定义属性"""
|
|
||||||
response = await async_client.post(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"name": "评分",
|
|
||||||
"type": "rating",
|
|
||||||
"icon": "mdi:star",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
# 验证已创建
|
|
||||||
list_response = await async_client.get(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
data = list_response.json()
|
|
||||||
assert len(data) == 1
|
|
||||||
assert data[0]["name"] == "评分"
|
|
||||||
assert data[0]["type"] == "rating"
|
|
||||||
assert data[0]["icon"] == "mdi:star"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_custom_property_with_options(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试创建带选项的自定义属性"""
|
|
||||||
response = await async_client.post(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"name": "分类",
|
|
||||||
"type": "select",
|
|
||||||
"options": ["工作", "个人", "归档"],
|
|
||||||
"default_value": "个人",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
list_response = await async_client.get(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
data = list_response.json()
|
|
||||||
prop = next(p for p in data if p["name"] == "分类")
|
|
||||||
assert prop["type"] == "select"
|
|
||||||
assert prop["options"] == ["工作", "个人", "归档"]
|
|
||||||
assert prop["default_value"] == "个人"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_custom_property_duplicate_name(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试创建同名属性返回 409"""
|
|
||||||
# 先创建
|
|
||||||
await async_client.post(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "标签", "type": "text"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 再创建同名
|
|
||||||
response = await async_client.post(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "标签", "type": "text"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 409
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 更新属性定义测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_custom_property(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试更新自定义属性"""
|
|
||||||
# 先创建
|
|
||||||
await async_client.post(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "备注", "type": "text"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取 ID
|
|
||||||
list_response = await async_client.get(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "备注")
|
|
||||||
|
|
||||||
# 更新
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"/api/v1/object/custom_property/{prop_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "详细备注", "icon": "mdi:note"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
# 验证已更新
|
|
||||||
list_response = await async_client.get(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
prop = next(p for p in list_response.json() if p["id"] == prop_id)
|
|
||||||
assert prop["name"] == "详细备注"
|
|
||||||
assert prop["icon"] == "mdi:note"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_custom_property_not_found(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试更新不存在的属性返回 404"""
|
|
||||||
fake_id = str(uuid4())
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"/api/v1/object/custom_property/{fake_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "不存在"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 删除属性定义测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_custom_property(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试删除自定义属性"""
|
|
||||||
# 先创建
|
|
||||||
await async_client.post(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "待删除", "type": "text"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取 ID
|
|
||||||
list_response = await async_client.get(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
prop_id = next(p["id"] for p in list_response.json() if p["name"] == "待删除")
|
|
||||||
|
|
||||||
# 删除
|
|
||||||
response = await async_client.delete(
|
|
||||||
f"/api/v1/object/custom_property/{prop_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
# 验证已删除
|
|
||||||
list_response = await async_client.get(
|
|
||||||
"/api/v1/object/custom_property",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
prop_names = [p["name"] for p in list_response.json()]
|
|
||||||
assert "待删除" not in prop_names
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_custom_property_not_found(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试删除不存在的属性返回 404"""
|
|
||||||
fake_id = str(uuid4())
|
|
||||||
response = await async_client.delete(
|
|
||||||
f"/api/v1/object/custom_property/{fake_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 404
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
"""
|
|
||||||
对象元数据端点集成测试
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from httpx import AsyncClient
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from sqlmodels import ObjectMetadata
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取元数据测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_metadata_requires_auth(async_client: AsyncClient):
|
|
||||||
"""测试获取元数据需要认证"""
|
|
||||||
fake_id = str(uuid4())
|
|
||||||
response = await async_client.get(f"/api/v1/object/{fake_id}/metadata")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_metadata_empty(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
):
|
|
||||||
"""测试获取无元数据的对象"""
|
|
||||||
file_id = test_directory_structure["file_id"]
|
|
||||||
response = await async_client.get(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["metadatas"] == {}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_metadata_with_entries(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
initialized_db,
|
|
||||||
):
|
|
||||||
"""测试获取有元数据的对象"""
|
|
||||||
file_id = test_directory_structure["file_id"]
|
|
||||||
|
|
||||||
# 直接写入元数据
|
|
||||||
entries = [
|
|
||||||
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
|
|
||||||
ObjectMetadata(object_id=file_id, name="exif:height", value="1080", is_public=True),
|
|
||||||
ObjectMetadata(object_id=file_id, name="sys:extract_status", value="done", is_public=False),
|
|
||||||
]
|
|
||||||
for entry in entries:
|
|
||||||
initialized_db.add(entry)
|
|
||||||
await initialized_db.commit()
|
|
||||||
|
|
||||||
response = await async_client.get(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
# sys: 命名空间应被过滤
|
|
||||||
assert "exif:width" in data["metadatas"]
|
|
||||||
assert "exif:height" in data["metadatas"]
|
|
||||||
assert "sys:extract_status" not in data["metadatas"]
|
|
||||||
assert data["metadatas"]["exif:width"] == "1920"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_metadata_ns_filter(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
initialized_db,
|
|
||||||
):
|
|
||||||
"""测试按命名空间过滤元数据"""
|
|
||||||
file_id = test_directory_structure["file_id"]
|
|
||||||
|
|
||||||
entries = [
|
|
||||||
ObjectMetadata(object_id=file_id, name="exif:width", value="1920", is_public=True),
|
|
||||||
ObjectMetadata(object_id=file_id, name="music:title", value="Test Song", is_public=True),
|
|
||||||
]
|
|
||||||
for entry in entries:
|
|
||||||
initialized_db.add(entry)
|
|
||||||
await initialized_db.commit()
|
|
||||||
|
|
||||||
# 只获取 exif 命名空间
|
|
||||||
response = await async_client.get(
|
|
||||||
f"/api/v1/object/{file_id}/metadata?ns=exif",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "exif:width" in data["metadatas"]
|
|
||||||
assert "music:title" not in data["metadatas"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_metadata_nonexistent_object(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试获取不存在对象的元数据"""
|
|
||||||
fake_id = str(uuid4())
|
|
||||||
response = await async_client.get(
|
|
||||||
f"/api/v1/object/{fake_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 更新元数据测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_patch_metadata_requires_auth(async_client: AsyncClient):
|
|
||||||
"""测试更新元数据需要认证"""
|
|
||||||
fake_id = str(uuid4())
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"/api/v1/object/{fake_id}/metadata",
|
|
||||||
json={"patches": [{"key": "custom:tag", "value": "test"}]},
|
|
||||||
)
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_patch_metadata_set_custom(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
):
|
|
||||||
"""测试设置自定义元数据"""
|
|
||||||
file_id = test_directory_structure["file_id"]
|
|
||||||
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"patches": [
|
|
||||||
{"key": "custom:tag1", "value": "旅游"},
|
|
||||||
{"key": "custom:tag2", "value": "风景"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
# 验证已写入
|
|
||||||
get_response = await async_client.get(
|
|
||||||
f"/api/v1/object/{file_id}/metadata?ns=custom",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert get_response.status_code == 200
|
|
||||||
data = get_response.json()
|
|
||||||
assert data["metadatas"]["custom:tag1"] == "旅游"
|
|
||||||
assert data["metadatas"]["custom:tag2"] == "风景"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_patch_metadata_update_existing(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
):
|
|
||||||
"""测试更新已有的元数据"""
|
|
||||||
file_id = test_directory_structure["file_id"]
|
|
||||||
|
|
||||||
# 先创建
|
|
||||||
await async_client.patch(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"patches": [{"key": "custom:note", "value": "旧值"}]},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 再更新
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"patches": [{"key": "custom:note", "value": "新值"}]},
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
# 验证已更新
|
|
||||||
get_response = await async_client.get(
|
|
||||||
f"/api/v1/object/{file_id}/metadata?ns=custom",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
data = get_response.json()
|
|
||||||
assert data["metadatas"]["custom:note"] == "新值"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_patch_metadata_delete(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
):
|
|
||||||
"""测试删除元数据条目"""
|
|
||||||
file_id = test_directory_structure["file_id"]
|
|
||||||
|
|
||||||
# 先创建
|
|
||||||
await async_client.patch(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"patches": [{"key": "custom:to_delete", "value": "temp"}]},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 删除(value 为 null)
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"patches": [{"key": "custom:to_delete", "value": None}]},
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
# 验证已删除
|
|
||||||
get_response = await async_client.get(
|
|
||||||
f"/api/v1/object/{file_id}/metadata?ns=custom",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
data = get_response.json()
|
|
||||||
assert "custom:to_delete" not in data["metadatas"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_patch_metadata_reject_non_custom_namespace(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
):
|
|
||||||
"""测试拒绝修改非 custom: 命名空间"""
|
|
||||||
file_id = test_directory_structure["file_id"]
|
|
||||||
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"/api/v1/object/{file_id}/metadata",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"patches": [{"key": "exif:width", "value": "1920"}]},
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
@@ -1,591 +0,0 @@
|
|||||||
"""
|
|
||||||
WebDAV 账户管理端点集成测试
|
|
||||||
"""
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
from httpx import AsyncClient
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, User
|
|
||||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
|
||||||
from sqlmodels.user import UserStatus
|
|
||||||
from utils import Password
|
|
||||||
from utils.JWT import create_access_token
|
|
||||||
|
|
||||||
API_PREFIX = "/api/v1/webdav"
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Fixtures ====================
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def no_webdav_headers(initialized_db: AsyncSession) -> dict[str, str]:
|
|
||||||
"""创建一个 WebDAV 被禁用的用户,返回其认证头"""
|
|
||||||
group = Group(
|
|
||||||
id=uuid4(),
|
|
||||||
name="无WebDAV用户组",
|
|
||||||
max_storage=1024 * 1024 * 1024,
|
|
||||||
share_enabled=True,
|
|
||||||
web_dav_enabled=False,
|
|
||||||
admin=False,
|
|
||||||
speed_limit=0,
|
|
||||||
)
|
|
||||||
initialized_db.add(group)
|
|
||||||
await initialized_db.commit()
|
|
||||||
await initialized_db.refresh(group)
|
|
||||||
|
|
||||||
group_options = GroupOptions(
|
|
||||||
group_id=group.id,
|
|
||||||
share_download=True,
|
|
||||||
share_free=False,
|
|
||||||
relocate=False,
|
|
||||||
source_batch=0,
|
|
||||||
select_node=False,
|
|
||||||
advance_delete=False,
|
|
||||||
)
|
|
||||||
initialized_db.add(group_options)
|
|
||||||
await initialized_db.commit()
|
|
||||||
await initialized_db.refresh(group_options)
|
|
||||||
|
|
||||||
user = User(
|
|
||||||
id=uuid4(),
|
|
||||||
email="nowebdav@test.local",
|
|
||||||
nickname="无WebDAV用户",
|
|
||||||
status=UserStatus.ACTIVE,
|
|
||||||
storage=0,
|
|
||||||
score=0,
|
|
||||||
group_id=group.id,
|
|
||||||
avatar="default",
|
|
||||||
)
|
|
||||||
initialized_db.add(user)
|
|
||||||
await initialized_db.commit()
|
|
||||||
await initialized_db.refresh(user)
|
|
||||||
|
|
||||||
identity = AuthIdentity(
|
|
||||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
|
||||||
identifier="nowebdav@test.local",
|
|
||||||
credential=Password.hash("nowebdav123"),
|
|
||||||
is_primary=True,
|
|
||||||
is_verified=True,
|
|
||||||
user_id=user.id,
|
|
||||||
)
|
|
||||||
initialized_db.add(identity)
|
|
||||||
|
|
||||||
from sqlmodels import Policy
|
|
||||||
policy = await Policy.get(initialized_db, Policy.name == "本地存储")
|
|
||||||
|
|
||||||
root = Object(
|
|
||||||
id=uuid4(),
|
|
||||||
name="/",
|
|
||||||
type=ObjectType.FOLDER,
|
|
||||||
owner_id=user.id,
|
|
||||||
parent_id=None,
|
|
||||||
policy_id=policy.id,
|
|
||||||
size=0,
|
|
||||||
)
|
|
||||||
initialized_db.add(root)
|
|
||||||
await initialized_db.commit()
|
|
||||||
|
|
||||||
group.options = group_options
|
|
||||||
group_claims = GroupClaims.from_group(group)
|
|
||||||
result = create_access_token(
|
|
||||||
sub=user.id,
|
|
||||||
jti=uuid4(),
|
|
||||||
status=user.status.value,
|
|
||||||
group=group_claims,
|
|
||||||
)
|
|
||||||
return {"Authorization": f"Bearer {result.access_token}"}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 认证测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_accounts_requires_auth(async_client: AsyncClient):
|
|
||||||
"""测试获取账户列表需要认证"""
|
|
||||||
response = await async_client.get(f"{API_PREFIX}/accounts")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_account_requires_auth(async_client: AsyncClient):
|
|
||||||
"""测试创建账户需要认证"""
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
json={"name": "test", "password": "testpass"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_account_requires_auth(async_client: AsyncClient):
|
|
||||||
"""测试更新账户需要认证"""
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"{API_PREFIX}/accounts/1",
|
|
||||||
json={"readonly": True},
|
|
||||||
)
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_account_requires_auth(async_client: AsyncClient):
|
|
||||||
"""测试删除账户需要认证"""
|
|
||||||
response = await async_client.delete(f"{API_PREFIX}/accounts/1")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== WebDAV 禁用测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_accounts_webdav_disabled(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
no_webdav_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试 WebDAV 被禁用时返回 403"""
|
|
||||||
response = await async_client.get(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=no_webdav_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_account_webdav_disabled(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
no_webdav_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试 WebDAV 被禁用时创建账户返回 403"""
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=no_webdav_headers,
|
|
||||||
json={"name": "test", "password": "testpass"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取账户列表测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_accounts_empty(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试初始状态账户列表为空"""
|
|
||||||
response = await async_client.get(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == []
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 创建账户测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_account_success(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试成功创建 WebDAV 账户"""
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "my-nas", "password": "secretpass"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["name"] == "my-nas"
|
|
||||||
assert data["root"] == "/"
|
|
||||||
assert data["readonly"] is False
|
|
||||||
assert data["use_proxy"] is False
|
|
||||||
assert "id" in data
|
|
||||||
assert "created_at" in data
|
|
||||||
assert "updated_at" in data
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_account_with_options(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试创建带选项的 WebDAV 账户"""
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"name": "readonly-nas",
|
|
||||||
"password": "secretpass",
|
|
||||||
"readonly": True,
|
|
||||||
"use_proxy": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["name"] == "readonly-nas"
|
|
||||||
assert data["readonly"] is True
|
|
||||||
assert data["use_proxy"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_account_duplicate_name(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试重名账户返回 409"""
|
|
||||||
# 先创建一个
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "dup-test", "password": "pass1"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
|
|
||||||
# 再创建同名的
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "dup-test", "password": "pass2"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 409
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_account_invalid_root(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试无效根目录路径返回 400"""
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"name": "bad-root",
|
|
||||||
"password": "secretpass",
|
|
||||||
"root": "/nonexistent/path",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_account_with_valid_subdir(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
):
|
|
||||||
"""测试使用有效的子目录作为根路径"""
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"name": "docs-only",
|
|
||||||
"password": "secretpass",
|
|
||||||
"root": "/docs",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
assert response.json()["root"] == "/docs"
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 列表包含已创建账户测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_accounts_after_create(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试创建后列表中包含该账户"""
|
|
||||||
# 创建
|
|
||||||
await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "list-test", "password": "pass"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 列表
|
|
||||||
response = await async_client.get(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
accounts = response.json()
|
|
||||||
assert len(accounts) == 1
|
|
||||||
assert accounts[0]["name"] == "list-test"
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 更新账户测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_account_success(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试成功更新 WebDAV 账户"""
|
|
||||||
# 创建
|
|
||||||
create_resp = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "update-test", "password": "oldpass"},
|
|
||||||
)
|
|
||||||
account_id = create_resp.json()["id"]
|
|
||||||
|
|
||||||
# 更新
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"{API_PREFIX}/accounts/{account_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"readonly": True},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["readonly"] is True
|
|
||||||
assert data["name"] == "update-test"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_account_password(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试更新密码"""
|
|
||||||
# 创建
|
|
||||||
create_resp = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "pwd-test", "password": "oldpass"},
|
|
||||||
)
|
|
||||||
account_id = create_resp.json()["id"]
|
|
||||||
|
|
||||||
# 更新密码
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"{API_PREFIX}/accounts/{account_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"password": "newpass123"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_account_root(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
test_directory_structure: dict[str, UUID],
|
|
||||||
):
|
|
||||||
"""测试更新根目录路径"""
|
|
||||||
# 创建
|
|
||||||
create_resp = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "root-update", "password": "pass"},
|
|
||||||
)
|
|
||||||
account_id = create_resp.json()["id"]
|
|
||||||
|
|
||||||
# 更新 root 到有效子目录
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"{API_PREFIX}/accounts/{account_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"root": "/docs"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["root"] == "/docs"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_account_invalid_root(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试更新为无效根目录返回 400"""
|
|
||||||
# 创建
|
|
||||||
create_resp = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "bad-root-update", "password": "pass"},
|
|
||||||
)
|
|
||||||
account_id = create_resp.json()["id"]
|
|
||||||
|
|
||||||
# 更新到无效路径
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"{API_PREFIX}/accounts/{account_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"root": "/nonexistent"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_account_not_found(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试更新不存在的账户返回 404"""
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"{API_PREFIX}/accounts/99999",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"readonly": True},
|
|
||||||
)
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_other_user_account(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
admin_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试更新其他用户的账户返回 404"""
|
|
||||||
# 管理员创建账户
|
|
||||||
create_resp = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=admin_headers,
|
|
||||||
json={"name": "admin-account", "password": "pass"},
|
|
||||||
)
|
|
||||||
account_id = create_resp.json()["id"]
|
|
||||||
|
|
||||||
# 普通用户尝试更新
|
|
||||||
response = await async_client.patch(
|
|
||||||
f"{API_PREFIX}/accounts/{account_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"readonly": True},
|
|
||||||
)
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 删除账户测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_account_success(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试成功删除 WebDAV 账户"""
|
|
||||||
# 创建
|
|
||||||
create_resp = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "delete-test", "password": "pass"},
|
|
||||||
)
|
|
||||||
account_id = create_resp.json()["id"]
|
|
||||||
|
|
||||||
# 删除
|
|
||||||
response = await async_client.delete(
|
|
||||||
f"{API_PREFIX}/accounts/{account_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 204
|
|
||||||
|
|
||||||
# 确认列表中已不存在
|
|
||||||
list_resp = await async_client.get(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert list_resp.status_code == 200
|
|
||||||
names = [a["name"] for a in list_resp.json()]
|
|
||||||
assert "delete-test" not in names
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_account_not_found(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试删除不存在的账户返回 404"""
|
|
||||||
response = await async_client.delete(
|
|
||||||
f"{API_PREFIX}/accounts/99999",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_other_user_account(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
admin_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试删除其他用户的账户返回 404"""
|
|
||||||
# 管理员创建账户
|
|
||||||
create_resp = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=admin_headers,
|
|
||||||
json={"name": "admin-del-test", "password": "pass"},
|
|
||||||
)
|
|
||||||
account_id = create_resp.json()["id"]
|
|
||||||
|
|
||||||
# 普通用户尝试删除
|
|
||||||
response = await async_client.delete(
|
|
||||||
f"{API_PREFIX}/accounts/{account_id}",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 多账户测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_accounts(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试同一用户可以创建多个账户"""
|
|
||||||
for name in ["account-1", "account-2", "account-3"]:
|
|
||||||
response = await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": name, "password": "pass"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
|
|
||||||
# 列表应有3个
|
|
||||||
response = await async_client.get(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.json()) == 3
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 用户隔离测试 ====================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_accounts_user_isolation(
|
|
||||||
async_client: AsyncClient,
|
|
||||||
auth_headers: dict[str, str],
|
|
||||||
admin_headers: dict[str, str],
|
|
||||||
):
|
|
||||||
"""测试不同用户的账户相互隔离"""
|
|
||||||
# 普通用户创建
|
|
||||||
await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"name": "user-account", "password": "pass"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 管理员创建
|
|
||||||
await async_client.post(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=admin_headers,
|
|
||||||
json={"name": "admin-account", "password": "pass"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 普通用户只看到自己的
|
|
||||||
response = await async_client.get(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=auth_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
accounts = response.json()
|
|
||||||
assert len(accounts) == 1
|
|
||||||
assert accounts[0]["name"] == "user-account"
|
|
||||||
|
|
||||||
# 管理员只看到自己的
|
|
||||||
response = await async_client.get(
|
|
||||||
f"{API_PREFIX}/accounts",
|
|
||||||
headers=admin_headers,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
accounts = response.json()
|
|
||||||
assert len(accounts) == 1
|
|
||||||
assert accounts[0]["name"] == "admin-account"
|
|
||||||
@@ -23,7 +23,6 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
|
|||||||
|
|
||||||
from main import app
|
from main import app
|
||||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||||
from sqlmodels.policy import GroupPolicyLink
|
|
||||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||||
from sqlmodels.user import UserStatus
|
from sqlmodels.user import UserStatus
|
||||||
from utils import Password
|
from utils import Password
|
||||||
@@ -109,12 +108,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
|||||||
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
|
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
|
||||||
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
|
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
|
||||||
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
|
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
|
||||||
Setting(type=SettingsType.AVATAR, name="gravatar_server", value="https://www.gravatar.com/"),
|
|
||||||
Setting(type=SettingsType.AVATAR, name="avatar_size", value="2097152"),
|
|
||||||
Setting(type=SettingsType.AVATAR, name="avatar_size_l", value="200"),
|
|
||||||
Setting(type=SettingsType.AVATAR, name="avatar_size_m", value="130"),
|
|
||||||
Setting(type=SettingsType.AVATAR, name="avatar_size_s", value="50"),
|
|
||||||
Setting(type=SettingsType.PATH, name="avatar_path", value="avatar"),
|
|
||||||
]
|
]
|
||||||
for setting in settings:
|
for setting in settings:
|
||||||
test_session.add(setting)
|
test_session.add(setting)
|
||||||
@@ -163,11 +156,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
|||||||
await test_session.refresh(admin_group)
|
await test_session.refresh(admin_group)
|
||||||
await test_session.refresh(default_policy)
|
await test_session.refresh(default_policy)
|
||||||
|
|
||||||
# 4. 关联用户组与存储策略
|
# 4. 创建用户组选项
|
||||||
test_session.add(GroupPolicyLink(group_id=default_group.id, policy_id=default_policy.id))
|
|
||||||
test_session.add(GroupPolicyLink(group_id=admin_group.id, policy_id=default_policy.id))
|
|
||||||
|
|
||||||
# 5. 创建用户组选项
|
|
||||||
default_group_options = GroupOptions(
|
default_group_options = GroupOptions(
|
||||||
group_id=default_group.id,
|
group_id=default_group.id,
|
||||||
share_download=True,
|
share_download=True,
|
||||||
|
|||||||
@@ -37,12 +37,6 @@ async def load_secret_key() -> None:
|
|||||||
if setting:
|
if setting:
|
||||||
SECRET_KEY = setting.value
|
SECRET_KEY = setting.value
|
||||||
|
|
||||||
if not SECRET_KEY:
|
|
||||||
raise RuntimeError(
|
|
||||||
"JWT SECRET_KEY 未配置,拒绝启动。"
|
|
||||||
"请在 Setting 表中添加 type='auth', name='secret_key' 的记录。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_token_payload(
|
def build_token_payload(
|
||||||
data: dict,
|
data: dict,
|
||||||
|
|||||||
@@ -62,10 +62,6 @@ def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
|
|||||||
"""Raises an HTTP 501 Not Implemented exception."""
|
"""Raises an HTTP 501 Not Implemented exception."""
|
||||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail)
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=detail)
|
||||||
|
|
||||||
def raise_bad_gateway(detail: str | None = None) -> NoReturn:
|
|
||||||
"""Raises an HTTP 502 Bad Gateway exception."""
|
|
||||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=detail)
|
|
||||||
|
|
||||||
def raise_service_unavailable(detail: str | None = None) -> NoReturn:
|
def raise_service_unavailable(detail: str | None = None) -> NoReturn:
|
||||||
"""Raises an HTTP 503 Service Unavailable exception."""
|
"""Raises an HTTP 503 Service Unavailable exception."""
|
||||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
|
||||||
|
|||||||
Reference in New Issue
Block a user