feat: implement WebDAV protocol support with WsgiDAV + account management API
All checks were successful
Test / test (push) Successful in 2m14s
All checks were successful
Test / test (push) Successful in 2m14s
Add complete WebDAV support: management REST API (CRUD accounts at /api/v1/webdav/accounts) and DAV protocol endpoint (/dav) using WsgiDAV + a2wsgi bridge for client access via HTTP Basic Auth. Includes Redis+TTLCache auth caching and integration tests (24 cases). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
8
main.py
8
main.py
@@ -5,6 +5,8 @@ from fastapi import FastAPI, Request
|
||||
from loguru import logger as l
|
||||
|
||||
from routers import router
|
||||
from routers.dav import dav_app
|
||||
from routers.dav.provider import EventLoopRef
|
||||
from service.redis import RedisManager
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.migration import migration
|
||||
@@ -40,6 +42,9 @@ async def _init_db() -> None:
|
||||
"""初始化数据库连接引擎"""
|
||||
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
||||
|
||||
# 捕获事件循环引用(供 WSGI 线程桥接使用)
|
||||
lifespan.add_startup(EventLoopRef.capture)
|
||||
|
||||
# 添加初始化数据库启动项
|
||||
lifespan.add_startup(_init_db)
|
||||
lifespan.add_startup(migration)
|
||||
@@ -88,6 +93,9 @@ async def handle_unexpected_exceptions(
|
||||
# 挂载路由
|
||||
app.include_router(router)
|
||||
|
||||
# 挂载 WebDAV 协议端点(优先于 SPA catch-all)
|
||||
app.mount("/dav", dav_app)
|
||||
|
||||
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
|
||||
if STATICS_DIR.is_dir():
|
||||
from starlette.staticfiles import StaticFiles
|
||||
|
||||
@@ -33,6 +33,8 @@ dependencies = [
|
||||
"uvicorn>=0.38.0",
|
||||
"webauthn>=2.7.0",
|
||||
"whatthepatch>=1.0.6",
|
||||
"wsgidav>=4.3.0",
|
||||
"a2wsgi>=1.10.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -1,110 +1,207 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from sqlmodels import ResponseBase
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
Object,
|
||||
User,
|
||||
WebDAV,
|
||||
WebDAVAccountResponse,
|
||||
WebDAVCreateRequest,
|
||||
WebDAVUpdateRequest,
|
||||
)
|
||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password
|
||||
|
||||
# WebDAV 管理路由
|
||||
webdav_router = APIRouter(
|
||||
prefix='/webdav',
|
||||
tags=["webdav"],
|
||||
)
|
||||
|
||||
|
||||
def _check_webdav_enabled(user: User) -> None:
|
||||
"""检查用户组是否启用了 WebDAV 功能"""
|
||||
if not user.group.web_dav_enabled:
|
||||
http_exceptions.raise_forbidden("WebDAV 功能未启用")
|
||||
|
||||
|
||||
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
|
||||
"""将 WebDAV 数据库模型转换为响应 DTO"""
|
||||
return WebDAVAccountResponse(
|
||||
id=account.id,
|
||||
name=account.name,
|
||||
root=account.root,
|
||||
readonly=account.readonly,
|
||||
use_proxy=account.use_proxy,
|
||||
created_at=str(account.created_at),
|
||||
updated_at=str(account.updated_at),
|
||||
)
|
||||
|
||||
|
||||
@webdav_router.get(
|
||||
path='/accounts',
|
||||
summary='获取账号信息',
|
||||
description='Get account information for WebDAV.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
summary='获取账号列表',
|
||||
)
|
||||
def router_webdav_accounts() -> ResponseBase:
|
||||
async def list_accounts(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[WebDAVAccountResponse]:
|
||||
"""
|
||||
Get account information for WebDAV.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the account information.
|
||||
列出当前用户所有 WebDAV 账户
|
||||
|
||||
认证:JWT Bearer Token
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
accounts: list[WebDAV] = await WebDAV.get(
|
||||
session,
|
||||
WebDAV.user_id == user_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
return [_to_response(a) for a in accounts]
|
||||
|
||||
|
||||
@webdav_router.post(
|
||||
path='/accounts',
|
||||
summary='新建账号',
|
||||
description='Create a new WebDAV account.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
summary='创建账号',
|
||||
status_code=201,
|
||||
)
|
||||
def router_webdav_create_account() -> ResponseBase:
|
||||
async def create_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: WebDAVCreateRequest,
|
||||
) -> WebDAVAccountResponse:
|
||||
"""
|
||||
Create a new WebDAV account.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created account.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
创建 WebDAV 账户
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/accounts/{id}',
|
||||
summary='删除账号',
|
||||
description='Delete a WebDAV account by its ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_account(id: str) -> ResponseBase:
|
||||
"""
|
||||
Delete a WebDAV account by its ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the account to be deleted.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
认证:JWT Bearer Token
|
||||
|
||||
@webdav_router.post(
|
||||
path='/mount',
|
||||
summary='新建目录挂载',
|
||||
description='Create a new WebDAV mount point.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_create_mount() -> ResponseBase:
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 400: 根目录路径不存在或不是目录
|
||||
- 409: 账户名已存在
|
||||
"""
|
||||
Create a new WebDAV mount point.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created mount point.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
# 验证账户名唯一
|
||||
existing = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.name == request.name) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict("账户名已存在")
|
||||
|
||||
# 验证 root 路径存在且为目录
|
||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||
if not root_obj or not root_obj.is_folder:
|
||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||
|
||||
# 创建账户
|
||||
account = WebDAV(
|
||||
name=request.name,
|
||||
password=Password.hash(request.password),
|
||||
root=request.root,
|
||||
readonly=request.readonly,
|
||||
use_proxy=request.use_proxy,
|
||||
user_id=user_id,
|
||||
)
|
||||
account = await account.save(session)
|
||||
|
||||
l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
|
||||
return _to_response(account)
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/mount/{id}',
|
||||
summary='删除目录挂载',
|
||||
description='Delete a WebDAV mount point by its ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_mount(id: str) -> ResponseBase:
|
||||
"""
|
||||
Delete a WebDAV mount point by its ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the mount point to be deleted.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.patch(
|
||||
path='accounts/{id}',
|
||||
summary='更新账号信息',
|
||||
description='Update WebDAV account information by ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
path='/accounts/{account_id}',
|
||||
summary='更新账号',
|
||||
)
|
||||
def router_webdav_update_account(id: str) -> ResponseBase:
|
||||
async def update_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
account_id: int,
|
||||
request: WebDAVUpdateRequest,
|
||||
) -> WebDAVAccountResponse:
|
||||
"""
|
||||
Update WebDAV account information by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the account to be updated.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the updated account.
|
||||
更新 WebDAV 账户
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 404: 账户不存在
|
||||
- 400: 根目录路径不存在或不是目录
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if not account:
|
||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||
|
||||
# 验证 root 路径
|
||||
if request.root is not None:
|
||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||
if not root_obj or not root_obj.is_folder:
|
||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||
|
||||
# 密码哈希后原地替换,update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
|
||||
is_password_changed = request.password is not None
|
||||
if is_password_changed:
|
||||
request.password = Password.hash(request.password)
|
||||
|
||||
account = await account.update(session, request)
|
||||
|
||||
# 密码变更时清除认证缓存
|
||||
if is_password_changed:
|
||||
await WebDAVAuthCache.invalidate_account(user_id, account.name)
|
||||
|
||||
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
|
||||
return _to_response(account)
|
||||
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/accounts/{account_id}',
|
||||
summary='删除账号',
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
account_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
删除 WebDAV 账户
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 404: 账户不存在
|
||||
"""
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if not account:
|
||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||
|
||||
account_name = account.name
|
||||
await WebDAV.delete(session, account)
|
||||
|
||||
# 清除认证缓存
|
||||
await WebDAVAuthCache.invalidate_account(user_id, account_name)
|
||||
|
||||
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# WebDAV 操作路由
|
||||
35
routers/dav/__init__.py
Normal file
35
routers/dav/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
WebDAV 协议入口
|
||||
|
||||
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
|
||||
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
|
||||
"""
|
||||
from a2wsgi import WSGIMiddleware
|
||||
from wsgidav.wsgidav_app import WsgiDAVApp
|
||||
|
||||
from .domain_controller import DiskNextDomainController
|
||||
from .provider import DiskNextDAVProvider
|
||||
|
||||
_wsgidav_config: dict[str, object] = {
|
||||
"provider_mapping": {
|
||||
"/": DiskNextDAVProvider(),
|
||||
},
|
||||
"http_authenticator": {
|
||||
"domain_controller": DiskNextDomainController,
|
||||
"accept_basic": True,
|
||||
"accept_digest": False,
|
||||
"default_to_digest": False,
|
||||
},
|
||||
"verbose": 1,
|
||||
# 使用 WsgiDAV 内置的内存锁管理器
|
||||
"lock_storage": True,
|
||||
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
|
||||
"dir_browser": {
|
||||
"enable": False,
|
||||
},
|
||||
}
|
||||
|
||||
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
|
||||
|
||||
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
|
||||
"""ASGI 应用,挂载到 /dav 路径"""
|
||||
148
routers/dav/domain_controller.py
Normal file
148
routers/dav/domain_controller.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
WebDAV 认证控制器
|
||||
|
||||
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
|
||||
通过 DiskNext 的 WebDAV 账户模型进行认证。
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
"""
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from wsgidav.dc.base_dc import BaseDomainController
|
||||
|
||||
from routers.dav.provider import EventLoopRef, _get_session
|
||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from sqlmodels.webdav import WebDAV
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
async def _authenticate(
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
) -> tuple[UUID, int] | None:
|
||||
"""
|
||||
异步认证 WebDAV 用户。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 明文密码
|
||||
:return: (user_id, webdav_id) 或 None
|
||||
"""
|
||||
# 1. 查缓存
|
||||
cached = await WebDAVAuthCache.get(email, account_name, password)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. 缓存未命中,查库验证
|
||||
async with _get_session() as session:
|
||||
user = await User.get(session, User.email == email, load=User.group)
|
||||
if not user:
|
||||
return None
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
return None
|
||||
if not user.group.web_dav_enabled:
|
||||
return None
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
|
||||
)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
status = Password.verify(account.password, password)
|
||||
if status == PasswordStatus.INVALID:
|
||||
return None
|
||||
|
||||
user_id: UUID = user.id
|
||||
webdav_id: int = account.id
|
||||
|
||||
# 3. 写入缓存
|
||||
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
|
||||
|
||||
return user_id, webdav_id
|
||||
|
||||
|
||||
class DiskNextDomainController(BaseDomainController):
|
||||
"""
|
||||
DiskNext WebDAV 认证控制器
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
密码: WebDAV 账户密码(创建账户时设置)
|
||||
"""
|
||||
|
||||
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
|
||||
super().__init__(wsgidav_app, config)
|
||||
|
||||
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
|
||||
"""返回 realm 名称"""
|
||||
return "DiskNext WebDAV"
|
||||
|
||||
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
|
||||
"""所有请求都需要认证"""
|
||||
return True
|
||||
|
||||
def is_share_anonymous(self, path_info: str) -> bool:
|
||||
"""不支持匿名访问"""
|
||||
return False
|
||||
|
||||
def supports_http_digest_auth(self) -> bool:
|
||||
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
|
||||
return False
|
||||
|
||||
def basic_auth_user(
|
||||
self,
|
||||
realm: str,
|
||||
user_name: str,
|
||||
password: str,
|
||||
environ: dict[str, object],
|
||||
) -> bool:
|
||||
"""
|
||||
HTTP Basic Auth 认证。
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
|
||||
"""
|
||||
# 解析用户名
|
||||
if "/" not in user_name:
|
||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||
return False
|
||||
|
||||
email, account_name = user_name.split("/", 1)
|
||||
if not email or not account_name:
|
||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||
return False
|
||||
|
||||
# 在 WSGI 线程中调用异步认证
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_authenticate(email, account_name, password),
|
||||
EventLoopRef.get(),
|
||||
)
|
||||
result = future.result()
|
||||
|
||||
if result is None:
|
||||
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
|
||||
return False
|
||||
|
||||
user_id, webdav_id = result
|
||||
|
||||
# 将认证信息存入 environ,供 Provider 使用
|
||||
environ["disknext.user_id"] = user_id
|
||||
environ["disknext.webdav_id"] = webdav_id
|
||||
environ["disknext.email"] = email
|
||||
environ["disknext.account_name"] = account_name
|
||||
|
||||
return True
|
||||
|
||||
def digest_auth_user(
|
||||
self,
|
||||
realm: str,
|
||||
user_name: str,
|
||||
environ: dict[str, object],
|
||||
) -> bool:
|
||||
"""不支持 Digest 认证"""
|
||||
return False
|
||||
594
routers/dav/provider.py
Normal file
594
routers/dav/provider.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
DiskNext WebDAV 存储 Provider
|
||||
|
||||
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
|
||||
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from wsgidav.dav_error import (
|
||||
DAVError,
|
||||
HTTP_FORBIDDEN,
|
||||
HTTP_INSUFFICIENT_STORAGE,
|
||||
HTTP_NOT_FOUND,
|
||||
)
|
||||
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
|
||||
|
||||
from service.storage import LocalStorageService, adjust_user_storage
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.physical_file import PhysicalFile
|
||||
from sqlmodels.policy import Policy
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.webdav import WebDAV
|
||||
|
||||
|
||||
class EventLoopRef:
|
||||
"""持有主线程事件循环引用,供 WSGI 线程使用"""
|
||||
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
|
||||
|
||||
@classmethod
|
||||
async def capture(cls) -> None:
|
||||
"""在 async 上下文中调用,捕获当前事件循环"""
|
||||
cls._loop = asyncio.get_running_loop()
|
||||
|
||||
@classmethod
|
||||
def get(cls) -> asyncio.AbstractEventLoop:
|
||||
if cls._loop is None:
|
||||
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
|
||||
return cls._loop
|
||||
|
||||
|
||||
def _run_async(coro): # type: ignore[no-untyped-def]
|
||||
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
|
||||
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
|
||||
return future.result()
|
||||
|
||||
|
||||
def _get_session(): # type: ignore[no-untyped-def]
|
||||
"""获取数据库会话上下文管理器"""
|
||||
return DatabaseManager._async_session_factory()
|
||||
|
||||
|
||||
# ==================== 异步辅助函数 ====================
|
||||
|
||||
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
|
||||
"""获取 WebDAV 账户"""
|
||||
async with _get_session() as session:
|
||||
return await WebDAV.get(session, WebDAV.id == webdav_id)
|
||||
|
||||
|
||||
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
|
||||
"""根据路径获取对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get_by_path(session, user_id, path)
|
||||
|
||||
|
||||
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
|
||||
"""获取目录子对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get_children(session, user_id, parent_id)
|
||||
|
||||
|
||||
async def _get_object_by_id(object_id: UUID) -> Object | None:
|
||||
"""根据ID获取对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
|
||||
|
||||
|
||||
async def _get_user(user_id: UUID) -> User | None:
|
||||
"""获取用户(含 group 关系)"""
|
||||
async with _get_session() as session:
|
||||
return await User.get(session, User.id == user_id, load=User.group)
|
||||
|
||||
|
||||
async def _get_policy(policy_id: UUID) -> Policy | None:
|
||||
"""获取存储策略"""
|
||||
async with _get_session() as session:
|
||||
return await Policy.get(session, Policy.id == policy_id)
|
||||
|
||||
|
||||
async def _create_folder(
|
||||
name: str,
|
||||
parent_id: UUID,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> Object:
|
||||
"""创建目录对象"""
|
||||
async with _get_session() as session:
|
||||
obj = Object(
|
||||
name=name,
|
||||
type=ObjectType.FOLDER,
|
||||
size=0,
|
||||
parent_id=parent_id,
|
||||
owner_id=owner_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
obj = await obj.save(session)
|
||||
return obj
|
||||
|
||||
|
||||
async def _create_file(
|
||||
name: str,
|
||||
parent_id: UUID,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> Object:
|
||||
"""创建空文件对象"""
|
||||
async with _get_session() as session:
|
||||
obj = Object(
|
||||
name=name,
|
||||
type=ObjectType.FILE,
|
||||
size=0,
|
||||
parent_id=parent_id,
|
||||
owner_id=owner_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
obj = await obj.save(session)
|
||||
return obj
|
||||
|
||||
|
||||
async def _soft_delete_object(object_id: UUID) -> None:
|
||||
"""软删除对象(移入回收站)"""
|
||||
from service.storage import soft_delete_objects
|
||||
|
||||
async with _get_session() as session:
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
await soft_delete_objects(session, [obj])
|
||||
|
||||
|
||||
async def _finalize_upload(
|
||||
object_id: UUID,
|
||||
physical_path: str,
|
||||
size: int,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> None:
|
||||
"""上传完成后更新对象元数据和物理文件记录"""
|
||||
async with _get_session() as session:
|
||||
# 获取存储路径(相对路径)
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy or not policy.server:
|
||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||
|
||||
base_path = Path(policy.server).resolve()
|
||||
full_path = Path(physical_path).resolve()
|
||||
storage_path = str(full_path.relative_to(base_path))
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
pf = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
size=size,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
pf = await pf.save(session)
|
||||
|
||||
# 更新 Object
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
|
||||
session.add(obj)
|
||||
await session.commit()
|
||||
|
||||
# 更新用户存储用量
|
||||
if size > 0:
|
||||
await adjust_user_storage(session, owner_id, size)
|
||||
|
||||
|
||||
async def _move_object(
|
||||
object_id: UUID,
|
||||
new_parent_id: UUID,
|
||||
new_name: str,
|
||||
) -> None:
|
||||
"""移动/重命名对象"""
|
||||
async with _get_session() as session:
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
|
||||
session.add(obj)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _copy_object_recursive(
|
||||
src_id: UUID,
|
||||
dst_parent_id: UUID,
|
||||
dst_name: str,
|
||||
owner_id: UUID,
|
||||
) -> None:
|
||||
"""递归复制对象"""
|
||||
from service.storage import copy_object_recursive
|
||||
|
||||
async with _get_session() as session:
|
||||
src = await Object.get(session, Object.id == src_id)
|
||||
if not src:
|
||||
return
|
||||
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
|
||||
|
||||
|
||||
# ==================== 辅助工具 ====================
|
||||
|
||||
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
|
||||
"""从 environ 中提取认证信息"""
|
||||
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
|
||||
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
|
||||
return user_id, webdav_id
|
||||
|
||||
|
||||
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
|
||||
"""
|
||||
将 DAV 相对路径映射到 DiskNext 绝对路径。
|
||||
|
||||
:param account_root: 账户挂载根路径,如 "/" 或 "/docs"
|
||||
:param dav_path: DAV 请求路径,如 "/" 或 "/photos/cat.jpg"
|
||||
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
|
||||
"""
|
||||
# 规范化根路径
|
||||
root = account_root.rstrip("/")
|
||||
if not root:
|
||||
root = ""
|
||||
|
||||
# 规范化 DAV 路径
|
||||
if not dav_path or dav_path == "/":
|
||||
return root + "/" if root else "/"
|
||||
|
||||
if not dav_path.startswith("/"):
|
||||
dav_path = "/" + dav_path
|
||||
|
||||
full = root + dav_path
|
||||
return full if full else "/"
|
||||
|
||||
|
||||
def _check_readonly(environ: dict[str, object]) -> None:
|
||||
"""检查账户是否只读,只读则抛出 403"""
|
||||
account = environ.get("disknext.webdav_account")
|
||||
if account and getattr(account, 'readonly', False):
|
||||
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
|
||||
|
||||
|
||||
def _check_storage_quota(user: User, additional_bytes: int) -> None:
|
||||
"""检查存储配额"""
|
||||
max_storage = user.group.max_storage
|
||||
if max_storage > 0 and user.storage + additional_bytes > max_storage:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
|
||||
|
||||
# ==================== Provider ====================
|
||||
|
||||
class DiskNextDAVProvider(DAVProvider):
|
||||
"""DiskNext WebDAV 存储 Provider"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_resource_inst(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
) -> 'DiskNextCollection | DiskNextFile | None':
|
||||
"""
|
||||
将 WebDAV 路径映射到资源对象。
|
||||
|
||||
首次调用时加载 WebDAV 账户信息并缓存到 environ。
|
||||
"""
|
||||
user_id, webdav_id = _get_environ_info(environ)
|
||||
|
||||
# 首次请求时加载账户信息
|
||||
if "disknext.webdav_account" not in environ:
|
||||
account = _run_async(_get_webdav_account(webdav_id))
|
||||
if not account:
|
||||
return None
|
||||
environ["disknext.webdav_account"] = account
|
||||
|
||||
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
|
||||
disknext_path = _resolve_dav_path(account.root, path)
|
||||
|
||||
obj = _run_async(_get_object_by_path(user_id, disknext_path))
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if obj.is_folder:
|
||||
return DiskNextCollection(path, environ, obj, user_id, account)
|
||||
else:
|
||||
return DiskNextFile(path, environ, obj, user_id, account)
|
||||
|
||||
def is_readonly(self) -> bool:
|
||||
"""只读由账户级别控制,不在 provider 级别限制"""
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Collection(目录) ====================
|
||||
|
||||
class DiskNextCollection(DAVCollection):
|
||||
"""DiskNext 目录资源"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
account: WebDAV,
|
||||
) -> None:
|
||||
super().__init__(path, environ)
|
||||
self._obj = obj
|
||||
self._user_id = user_id
|
||||
self._account = account
|
||||
|
||||
def get_display_info(self) -> dict[str, str]:
|
||||
return {"type": "Directory"}
|
||||
|
||||
def get_member_names(self) -> list[str]:
|
||||
"""获取子对象名称列表"""
|
||||
children = _run_async(_get_children(self._user_id, self._obj.id))
|
||||
return [c.name for c in children]
|
||||
|
||||
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
|
||||
"""获取指定名称的子资源"""
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
account_root = self._account.root
|
||||
disknext_path = _resolve_dav_path(account_root, member_path)
|
||||
|
||||
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if obj.is_folder:
|
||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||
else:
|
||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def get_creation_date(self) -> float | None:
|
||||
if self._obj.created_at:
|
||||
return self._obj.created_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_last_modified(self) -> float | None:
|
||||
if self._obj.updated_at:
|
||||
return self._obj.updated_at.timestamp()
|
||||
return None
|
||||
|
||||
def create_empty_resource(self, name: str) -> 'DiskNextFile':
|
||||
"""创建空文件(PUT 操作的第一步)"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
obj = _run_async(_create_file(
|
||||
name=name,
|
||||
parent_id=self._obj.id,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def create_collection(self, name: str) -> 'DiskNextCollection':
|
||||
"""创建子目录(MKCOL)"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
obj = _run_async(_create_folder(
|
||||
name=name,
|
||||
parent_id=self._obj.id,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""软删除目录"""
|
||||
_check_readonly(self.environ)
|
||||
_run_async(_soft_delete_object(self._obj.id))
|
||||
|
||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||
"""复制或移动目录"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
account_root = self._account.root
|
||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||
|
||||
# 解析目标父路径和新名称
|
||||
if "/" in dest_disknext.rstrip("/"):
|
||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||
else:
|
||||
parent_path = "/"
|
||||
new_name = dest_disknext.lstrip("/")
|
||||
|
||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||
if not dest_parent:
|
||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||
|
||||
if is_move:
|
||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||
else:
|
||||
_run_async(_copy_object_recursive(
|
||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||
))
|
||||
|
||||
return True
|
||||
|
||||
def support_recursive_delete(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_recursive_move(self, dest_path: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# ==================== NonCollection(文件) ====================
|
||||
|
||||
class DiskNextFile(DAVNonCollection):
|
||||
"""DiskNext 文件资源"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
account: WebDAV,
|
||||
) -> None:
|
||||
super().__init__(path, environ)
|
||||
self._obj = obj
|
||||
self._user_id = user_id
|
||||
self._account = account
|
||||
self._write_path: str | None = None
|
||||
self._write_stream: io.BufferedWriter | None = None
|
||||
|
||||
def get_content_length(self) -> int | None:
|
||||
return self._obj.size if self._obj.size else 0
|
||||
|
||||
def get_content_type(self) -> str | None:
|
||||
# 尝试从文件名推断 MIME 类型
|
||||
mime, _ = mimetypes.guess_type(self._obj.name)
|
||||
return mime or "application/octet-stream"
|
||||
|
||||
def get_creation_date(self) -> float | None:
|
||||
if self._obj.created_at:
|
||||
return self._obj.created_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_last_modified(self) -> float | None:
|
||||
if self._obj.updated_at:
|
||||
return self._obj.updated_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_display_info(self) -> dict[str, str]:
|
||||
return {"type": "File"}
|
||||
|
||||
def get_content(self) -> io.BufferedReader | None:
|
||||
"""
|
||||
返回文件内容的可读流。
|
||||
|
||||
WsgiDAV 在线程中运行,可安全使用同步 open()。
|
||||
"""
|
||||
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
|
||||
if not obj_with_file or not obj_with_file.physical_file:
|
||||
return None
|
||||
|
||||
pf = obj_with_file.physical_file
|
||||
policy = _run_async(_get_policy(obj_with_file.policy_id))
|
||||
if not policy or not policy.server:
|
||||
return None
|
||||
|
||||
full_path = Path(policy.server).resolve() / pf.storage_path
|
||||
if not full_path.is_file():
|
||||
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
|
||||
return None
|
||||
|
||||
return open(full_path, "rb") # noqa: SIM115
|
||||
|
||||
def begin_write(self, *, content_type: str | None = None) -> io.BufferedWriter:
|
||||
"""
|
||||
开始写入文件(PUT 操作)。
|
||||
|
||||
返回一个可写的文件流,WsgiDAV 将向其中写入请求体数据。
|
||||
"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
# 检查配额
|
||||
user = _run_async(_get_user(self._user_id))
|
||||
if user:
|
||||
content_length = self.environ.get("CONTENT_LENGTH")
|
||||
if content_length:
|
||||
_check_storage_quota(user, int(content_length))
|
||||
|
||||
# 获取策略以确定存储路径
|
||||
policy = _run_async(_get_policy(self._obj.policy_id))
|
||||
if not policy or not policy.server:
|
||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = _run_async(
|
||||
storage_service.generate_file_path(
|
||||
user_id=self._user_id,
|
||||
original_filename=self._obj.name,
|
||||
)
|
||||
)
|
||||
|
||||
self._write_path = full_path
|
||||
self._write_stream = open(full_path, "wb") # noqa: SIM115
|
||||
return self._write_stream
|
||||
|
||||
def end_write(self, *, with_errors: bool) -> None:
|
||||
"""写入完成后的收尾工作"""
|
||||
if self._write_stream:
|
||||
self._write_stream.close()
|
||||
self._write_stream = None
|
||||
|
||||
if with_errors or not self._write_path:
|
||||
return
|
||||
|
||||
# 获取文件大小
|
||||
file_path = Path(self._write_path)
|
||||
if not file_path.exists():
|
||||
return
|
||||
|
||||
size = file_path.stat().st_size
|
||||
|
||||
# 更新数据库记录
|
||||
_run_async(_finalize_upload(
|
||||
object_id=self._obj.id,
|
||||
physical_path=self._write_path,
|
||||
size=size,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
|
||||
|
||||
def delete(self) -> None:
|
||||
"""软删除文件"""
|
||||
_check_readonly(self.environ)
|
||||
_run_async(_soft_delete_object(self._obj.id))
|
||||
|
||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||
"""复制或移动文件"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
account_root = self._account.root
|
||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||
|
||||
# 解析目标父路径和新名称
|
||||
if "/" in dest_disknext.rstrip("/"):
|
||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||
else:
|
||||
parent_path = "/"
|
||||
new_name = dest_disknext.lstrip("/")
|
||||
|
||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||
if not dest_parent:
|
||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||
|
||||
if is_move:
|
||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||
else:
|
||||
_run_async(_copy_object_recursive(
|
||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||
))
|
||||
|
||||
return True
|
||||
|
||||
def support_content_length(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_etag(self) -> str | None:
|
||||
"""返回 ETag(基于ID和更新时间),WsgiDAV 会自动加双引号"""
|
||||
if self._obj.updated_at:
|
||||
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
|
||||
return None
|
||||
|
||||
def support_etag(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_ranges(self) -> bool:
|
||||
return True
|
||||
128
service/redis/webdav_auth_cache.py
Normal file
128
service/redis/webdav_auth_cache.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
WebDAV 认证缓存
|
||||
|
||||
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
|
||||
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
||||
"""
|
||||
import hashlib
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from cachetools import TTLCache
|
||||
from loguru import logger as l
|
||||
|
||||
from . import RedisManager
|
||||
|
||||
_AUTH_TTL: int = 300
|
||||
"""认证缓存 TTL(秒),5 分钟"""
|
||||
|
||||
|
||||
class WebDAVAuthCache:
|
||||
"""
|
||||
WebDAV 认证结果缓存
|
||||
|
||||
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
|
||||
缓存值格式: {user_id}:{webdav_id}
|
||||
|
||||
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
|
||||
"""
|
||||
|
||||
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
|
||||
"""内存缓存降级方案"""
|
||||
|
||||
@classmethod
|
||||
def _build_key(cls, email: str, account_name: str, password: str) -> str:
|
||||
"""构建缓存键"""
|
||||
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
|
||||
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls,
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
) -> tuple[UUID, int] | None:
|
||||
"""
|
||||
查询缓存中的认证结果。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 用户提供的明文密码
|
||||
:return: (user_id, webdav_id) 或 None(缓存未命中)
|
||||
"""
|
||||
key = cls._build_key(email, account_name, password)
|
||||
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
value = await client.get(key)
|
||||
if value is not None:
|
||||
raw = value.decode() if isinstance(value, bytes) else value
|
||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||
return UUID(user_id_str), int(webdav_id_str)
|
||||
else:
|
||||
raw = cls._memory_cache.get(key)
|
||||
if raw is not None:
|
||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||
return UUID(user_id_str), int(webdav_id_str)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def set(
|
||||
cls,
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
user_id: UUID,
|
||||
webdav_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
写入认证结果到缓存。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 用户提供的明文密码
|
||||
:param user_id: 用户UUID
|
||||
:param webdav_id: WebDAV 账户ID
|
||||
"""
|
||||
key = cls._build_key(email, account_name, password)
|
||||
value = f"{user_id}:{webdav_id}"
|
||||
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
await client.set(key, value, ex=_AUTH_TTL)
|
||||
else:
|
||||
cls._memory_cache[key] = value
|
||||
|
||||
@classmethod
|
||||
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
|
||||
"""
|
||||
失效指定账户的所有缓存。
|
||||
|
||||
由于缓存键包含 password hash,无法精确删除,
|
||||
Redis 端使用 pattern scan 删除,内存端清空全部。
|
||||
|
||||
:param user_id: 用户UUID
|
||||
:param account_name: WebDAV 账户名
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
pattern = f"webdav_auth:*/{account_name}:*"
|
||||
cursor: int = 0
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await client.delete(*keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
else:
|
||||
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
|
||||
keys_to_delete = [
|
||||
k for k in cls._memory_cache
|
||||
if f"/{account_name}:" in k
|
||||
]
|
||||
for k in keys_to_delete:
|
||||
cls._memory_cache.pop(k, None)
|
||||
|
||||
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")
|
||||
@@ -75,7 +75,9 @@ from .object import (
|
||||
ObjectBase,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
ObjectFileFinalize,
|
||||
ObjectMoveRequest,
|
||||
ObjectMoveUpdate,
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
@@ -115,7 +117,10 @@ from .source_link import SourceLink
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||
from .webdav import WebDAV
|
||||
from .webdav import (
|
||||
WebDAV, WebDAVBase,
|
||||
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
|
||||
)
|
||||
from .file_app import (
|
||||
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
|
||||
# DTO
|
||||
|
||||
@@ -78,6 +78,26 @@ class ObjectBase(SQLModelBase):
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class ObjectFileFinalize(SQLModelBase):
|
||||
"""文件上传完成后更新 Object 的 DTO"""
|
||||
|
||||
size: int
|
||||
"""文件大小(字节)"""
|
||||
|
||||
physical_file_id: UUID
|
||||
"""关联的物理文件UUID"""
|
||||
|
||||
|
||||
class ObjectMoveUpdate(SQLModelBase):
|
||||
"""移动/重命名 Object 的 DTO"""
|
||||
|
||||
parent_id: UUID
|
||||
"""新的父目录UUID"""
|
||||
|
||||
name: str
|
||||
"""新名称"""
|
||||
|
||||
|
||||
class DirectoryCreateRequest(SQLModelBase):
|
||||
"""创建目录请求 DTO"""
|
||||
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""
|
||||
WebDAV 账户模型
|
||||
|
||||
管理用户的 WebDAV 连接账户,每个账户对应一个挂载根路径。
|
||||
通过 HTTP Basic Auth 认证访问 DAV 协议端点。
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
@@ -9,24 +14,104 @@ from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class WebDAV(SQLModelBase, TableBaseMixin):
|
||||
"""WebDAV账户模型"""
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class WebDAVBase(SQLModelBase):
|
||||
"""WebDAV 账户基础字段"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""账户名称(同一用户下唯一)"""
|
||||
|
||||
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"})
|
||||
"""挂载根目录路径"""
|
||||
|
||||
readonly: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
|
||||
"""是否只读"""
|
||||
|
||||
use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": "false"})
|
||||
"""是否使用代理下载"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class WebDAV(WebDAVBase, TableBaseMixin):
|
||||
"""WebDAV 账户模型"""
|
||||
|
||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
|
||||
|
||||
name: str = Field(max_length=255, description="WebDAV账户名")
|
||||
password: str = Field(max_length=255, description="WebDAV密码")
|
||||
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
|
||||
readonly: bool = Field(default=False, description="是否只读")
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理下载")
|
||||
|
||||
password: str = Field(max_length=255)
|
||||
"""密码(Argon2 哈希)"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="webdavs")
|
||||
user: "User" = Relationship(back_populates="webdavs")
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class WebDAVCreateRequest(SQLModelBase):
|
||||
"""创建 WebDAV 账户请求"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""账户名称"""
|
||||
|
||||
password: str = Field(min_length=1, max_length=255)
|
||||
"""账户密码(明文,服务端哈希后存储)"""
|
||||
|
||||
root: str = "/"
|
||||
"""挂载根目录路径"""
|
||||
|
||||
readonly: bool = False
|
||||
"""是否只读"""
|
||||
|
||||
use_proxy: bool = False
|
||||
"""是否使用代理下载"""
|
||||
|
||||
|
||||
class WebDAVUpdateRequest(SQLModelBase):
|
||||
"""更新 WebDAV 账户请求"""
|
||||
|
||||
password: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
"""新密码(为 None 时不修改)"""
|
||||
|
||||
root: str | None = None
|
||||
"""新挂载根目录路径(为 None 时不修改)"""
|
||||
|
||||
readonly: bool | None = None
|
||||
"""是否只读(为 None 时不修改)"""
|
||||
|
||||
use_proxy: bool | None = None
|
||||
"""是否使用代理下载(为 None 时不修改)"""
|
||||
|
||||
|
||||
class WebDAVAccountResponse(SQLModelBase):
|
||||
"""WebDAV 账户响应"""
|
||||
|
||||
id: int
|
||||
"""账户ID"""
|
||||
|
||||
name: str
|
||||
"""账户名称"""
|
||||
|
||||
root: str
|
||||
"""挂载根目录路径"""
|
||||
|
||||
readonly: bool
|
||||
"""是否只读"""
|
||||
|
||||
use_proxy: bool
|
||||
"""是否使用代理下载"""
|
||||
|
||||
created_at: str
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: str
|
||||
"""更新时间"""
|
||||
|
||||
591
tests/integration/api/test_webdav.py
Normal file
591
tests/integration/api/test_webdav.py
Normal file
@@ -0,0 +1,591 @@
|
||||
"""
|
||||
WebDAV 账户管理端点集成测试
|
||||
"""
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, User
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
|
||||
API_PREFIX = "/api/v1/webdav"
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def no_webdav_headers(initialized_db: AsyncSession) -> dict[str, str]:
|
||||
"""创建一个 WebDAV 被禁用的用户,返回其认证头"""
|
||||
group = Group(
|
||||
id=uuid4(),
|
||||
name="无WebDAV用户组",
|
||||
max_storage=1024 * 1024 * 1024,
|
||||
share_enabled=True,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
initialized_db.add(group)
|
||||
await initialized_db.commit()
|
||||
await initialized_db.refresh(group)
|
||||
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
source_batch=0,
|
||||
select_node=False,
|
||||
advance_delete=False,
|
||||
)
|
||||
initialized_db.add(group_options)
|
||||
await initialized_db.commit()
|
||||
await initialized_db.refresh(group_options)
|
||||
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
email="nowebdav@test.local",
|
||||
nickname="无WebDAV用户",
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=group.id,
|
||||
avatar="default",
|
||||
)
|
||||
initialized_db.add(user)
|
||||
await initialized_db.commit()
|
||||
await initialized_db.refresh(user)
|
||||
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="nowebdav@test.local",
|
||||
credential=Password.hash("nowebdav123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
initialized_db.add(identity)
|
||||
|
||||
from sqlmodels import Policy
|
||||
policy = await Policy.get(initialized_db, Policy.name == "本地存储")
|
||||
|
||||
root = Object(
|
||||
id=uuid4(),
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=user.id,
|
||||
parent_id=None,
|
||||
policy_id=policy.id,
|
||||
size=0,
|
||||
)
|
||||
initialized_db.add(root)
|
||||
await initialized_db.commit()
|
||||
|
||||
group.options = group_options
|
||||
group_claims = GroupClaims.from_group(group)
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
return {"Authorization": f"Bearer {result.access_token}"}
|
||||
|
||||
|
||||
# ==================== 认证测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_requires_auth(async_client: AsyncClient):
|
||||
"""测试获取账户列表需要认证"""
|
||||
response = await async_client.get(f"{API_PREFIX}/accounts")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_requires_auth(async_client: AsyncClient):
|
||||
"""测试创建账户需要认证"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
json={"name": "test", "password": "testpass"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_requires_auth(async_client: AsyncClient):
|
||||
"""测试更新账户需要认证"""
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/1",
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account_requires_auth(async_client: AsyncClient):
|
||||
"""测试删除账户需要认证"""
|
||||
response = await async_client.delete(f"{API_PREFIX}/accounts/1")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ==================== WebDAV 禁用测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_webdav_disabled(
|
||||
async_client: AsyncClient,
|
||||
no_webdav_headers: dict[str, str],
|
||||
):
|
||||
"""测试 WebDAV 被禁用时返回 403"""
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=no_webdav_headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_webdav_disabled(
|
||||
async_client: AsyncClient,
|
||||
no_webdav_headers: dict[str, str],
|
||||
):
|
||||
"""测试 WebDAV 被禁用时创建账户返回 403"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=no_webdav_headers,
|
||||
json={"name": "test", "password": "testpass"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ==================== 获取账户列表测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_empty(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试初始状态账户列表为空"""
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
# ==================== 创建账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_success(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试成功创建 WebDAV 账户"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "my-nas", "password": "secretpass"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
data = response.json()
|
||||
assert data["name"] == "my-nas"
|
||||
assert data["root"] == "/"
|
||||
assert data["readonly"] is False
|
||||
assert data["use_proxy"] is False
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert "updated_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_with_options(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试创建带选项的 WebDAV 账户"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "readonly-nas",
|
||||
"password": "secretpass",
|
||||
"readonly": True,
|
||||
"use_proxy": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
data = response.json()
|
||||
assert data["name"] == "readonly-nas"
|
||||
assert data["readonly"] is True
|
||||
assert data["use_proxy"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_duplicate_name(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试重名账户返回 409"""
|
||||
# 先创建一个
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "dup-test", "password": "pass1"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 再创建同名的
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "dup-test", "password": "pass2"},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_invalid_root(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试无效根目录路径返回 400"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "bad-root",
|
||||
"password": "secretpass",
|
||||
"root": "/nonexistent/path",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_with_valid_subdir(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试使用有效的子目录作为根路径"""
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "docs-only",
|
||||
"password": "secretpass",
|
||||
"root": "/docs",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["root"] == "/docs"
|
||||
|
||||
|
||||
# ==================== 列表包含已创建账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_after_create(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试创建后列表中包含该账户"""
|
||||
# 创建
|
||||
await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "list-test", "password": "pass"},
|
||||
)
|
||||
|
||||
# 列表
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
accounts = response.json()
|
||||
assert len(accounts) == 1
|
||||
assert accounts[0]["name"] == "list-test"
|
||||
|
||||
|
||||
# ==================== 更新账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_success(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试成功更新 WebDAV 账户"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "update-test", "password": "oldpass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["readonly"] is True
|
||||
assert data["name"] == "update-test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_password(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新密码"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "pwd-test", "password": "oldpass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新密码
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"password": "newpass123"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_root(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
test_directory_structure: dict[str, UUID],
|
||||
):
|
||||
"""测试更新根目录路径"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "root-update", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新 root 到有效子目录
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"root": "/docs"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["root"] == "/docs"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_invalid_root(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新为无效根目录返回 400"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "bad-root-update", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 更新到无效路径
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"root": "/nonexistent"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_not_found(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新不存在的账户返回 404"""
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/99999",
|
||||
headers=auth_headers,
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_other_user_account(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
admin_headers: dict[str, str],
|
||||
):
|
||||
"""测试更新其他用户的账户返回 404"""
|
||||
# 管理员创建账户
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
json={"name": "admin-account", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 普通用户尝试更新
|
||||
response = await async_client.patch(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
json={"readonly": True},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ==================== 删除账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account_success(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试成功删除 WebDAV 账户"""
|
||||
# 创建
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "delete-test", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 删除
|
||||
response = await async_client.delete(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 确认列表中已不存在
|
||||
list_resp = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert list_resp.status_code == 200
|
||||
names = [a["name"] for a in list_resp.json()]
|
||||
assert "delete-test" not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account_not_found(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试删除不存在的账户返回 404"""
|
||||
response = await async_client.delete(
|
||||
f"{API_PREFIX}/accounts/99999",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_other_user_account(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
admin_headers: dict[str, str],
|
||||
):
|
||||
"""测试删除其他用户的账户返回 404"""
|
||||
# 管理员创建账户
|
||||
create_resp = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
json={"name": "admin-del-test", "password": "pass"},
|
||||
)
|
||||
account_id = create_resp.json()["id"]
|
||||
|
||||
# 普通用户尝试删除
|
||||
response = await async_client.delete(
|
||||
f"{API_PREFIX}/accounts/{account_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ==================== 多账户测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_accounts(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
):
|
||||
"""测试同一用户可以创建多个账户"""
|
||||
for name in ["account-1", "account-2", "account-3"]:
|
||||
response = await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": name, "password": "pass"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 列表应有3个
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 3
|
||||
|
||||
|
||||
# ==================== 用户隔离测试 ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accounts_user_isolation(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
admin_headers: dict[str, str],
|
||||
):
|
||||
"""测试不同用户的账户相互隔离"""
|
||||
# 普通用户创建
|
||||
await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
json={"name": "user-account", "password": "pass"},
|
||||
)
|
||||
|
||||
# 管理员创建
|
||||
await async_client.post(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
json={"name": "admin-account", "password": "pass"},
|
||||
)
|
||||
|
||||
# 普通用户只看到自己的
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
accounts = response.json()
|
||||
assert len(accounts) == 1
|
||||
assert accounts[0]["name"] == "user-account"
|
||||
|
||||
# 管理员只看到自己的
|
||||
response = await async_client.get(
|
||||
f"{API_PREFIX}/accounts",
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
accounts = response.json()
|
||||
assert len(accounts) == 1
|
||||
assert accounts[0]["name"] == "admin-account"
|
||||
46
uv.lock
generated
46
uv.lock
generated
@@ -6,6 +6,15 @@ resolution-markers = [
|
||||
"python_full_version < '3.14'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "a2wsgi"
|
||||
version = "1.10.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9a/cb/822c56fbea97e9eee201a2e434a80437f6750ebcb1ed307ee3a0a7505b14/a2wsgi-1.10.10.tar.gz", hash = "sha256:a5bcffb52081ba39df0d5e9a884fc6f819d92e3a42389343ba77cbf809fe1f45", size = 18799, upload-time = "2025-06-18T09:00:10.843Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/d5/349aba3dc421e73cbd4958c0ce0a4f1aa3a738bc0d7de75d2f40ed43a535/a2wsgi-1.10.10-py3-none-any.whl", hash = "sha256:d2b21379479718539dc15fce53b876251a0efe7615352dfe49f6ad1bc507848d", size = 17389, upload-time = "2025-06-18T09:00:09.676Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
version = "25.1.0"
|
||||
@@ -500,11 +509,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/fa/d3c15189f7c52aaefbaea76fb012119b04b9013f4bf446cb4eb4c26c4e6b/cython-3.2.4-py3-none-any.whl", hash = "sha256:732fc93bc33ae4b14f6afaca663b916c2fdd5dcbfad7114e17fb2434eeaea45c", size = 1257078, upload-time = "2026-01-04T14:14:12.373Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "disknext-server"
|
||||
version = "0.0.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "a2wsgi" },
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite" },
|
||||
@@ -533,6 +552,7 @@ dependencies = [
|
||||
{ name = "uvicorn" },
|
||||
{ name = "webauthn" },
|
||||
{ name = "whatthepatch" },
|
||||
{ name = "wsgidav" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -543,6 +563,7 @@ build = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "a2wsgi", specifier = ">=1.10.0" },
|
||||
{ name = "aiofiles", specifier = ">=25.1.0" },
|
||||
{ name = "aiohttp", specifier = ">=3.13.2" },
|
||||
{ name = "aiosqlite", specifier = "==0.22.1" },
|
||||
@@ -573,6 +594,7 @@ requires-dist = [
|
||||
{ name = "uvicorn", specifier = ">=0.38.0" },
|
||||
{ name = "webauthn", specifier = ">=2.7.0" },
|
||||
{ name = "whatthepatch", specifier = ">=1.0.6" },
|
||||
{ name = "wsgidav", specifier = ">=4.3.0" },
|
||||
]
|
||||
provides-extras = ["build"]
|
||||
|
||||
@@ -975,6 +997,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json5"
|
||||
version = "0.13.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/77/e8/a3f261a66e4663f22700bc8a17c08cb83e91fbf086726e7a228398968981/json5-0.13.0.tar.gz", hash = "sha256:b1edf8d487721c0bf64d83c28e91280781f6e21f4a797d3261c7c828d4c165bf", size = 52441, upload-time = "2026-01-01T19:42:14.99Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/9e/038522f50ceb7e74f1f991bf1b699f24b0c2bbe7c390dd36ad69f4582258/json5-0.13.0-py3-none-any.whl", hash = "sha256:9a08e1dd65f6a4d4c6fa82d216cf2477349ec2346a38fd70cc11d2557499fbcc", size = 36163, upload-time = "2026-01-01T19:42:13.962Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.7.3"
|
||||
@@ -2049,6 +2080,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083, upload-time = "2024-12-07T15:28:26.465Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wsgidav"
|
||||
version = "4.3.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "defusedxml" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "json5" },
|
||||
{ name = "pyyaml" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/f4/9c89e3e41dc7762cbb005d1baf23381718c7b13607236eacda23b855a288/wsgidav-4.3.3.tar.gz", hash = "sha256:5f0ad71bea72def3018b6ba52da3bcb83f61e0873c27225344582805d6e52b9e", size = 168118, upload-time = "2024-05-04T18:28:01.199Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/8e/04fb92513f4deab0f9bf4bdeeebc74f12d4de75ff00ad213c69983fc6563/WsgiDAV-4.3.3-py3-none-any.whl", hash = "sha256:8d96b0f05ad7f280572e99d1c605962a853d715f8e934298555d0c47ef275e88", size = 164954, upload-time = "2024-05-04T18:27:57.718Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.22.0"
|
||||
|
||||
Reference in New Issue
Block a user