From 40b6a31c98a16f9de461abd1ea2c8c23818b75d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Tue, 17 Feb 2026 15:19:29 +0800 Subject: [PATCH] feat: implement WebDAV protocol support with WsgiDAV + account management API Add complete WebDAV support: management REST API (CRUD accounts at /api/v1/webdav/accounts) and DAV protocol endpoint (/dav) using WsgiDAV + a2wsgi bridge for client access via HTTP Basic Auth. Includes Redis+TTLCache auth caching and integration tests (24 cases). Co-Authored-By: Claude Opus 4.6 --- main.py | 8 + pyproject.toml | 2 + routers/api/v1/webdav/__init__.py | 259 ++++++++---- routers/dav/README.md | 1 - routers/dav/__init__.py | 35 ++ routers/dav/domain_controller.py | 148 +++++++ routers/dav/provider.py | 594 +++++++++++++++++++++++++++ service/redis/webdav_auth_cache.py | 128 ++++++ sqlmodels/__init__.py | 7 +- sqlmodels/object.py | 20 + sqlmodels/webdav.py | 107 ++++- tests/integration/api/test_webdav.py | 591 ++++++++++++++++++++++++++ uv.lock | 46 +++ 13 files changed, 1852 insertions(+), 94 deletions(-) delete mode 100644 routers/dav/README.md create mode 100644 routers/dav/__init__.py create mode 100644 routers/dav/domain_controller.py create mode 100644 routers/dav/provider.py create mode 100644 service/redis/webdav_auth_cache.py create mode 100644 tests/integration/api/test_webdav.py diff --git a/main.py b/main.py index ef9941c..83a6a58 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,8 @@ from fastapi import FastAPI, Request from loguru import logger as l from routers import router +from routers.dav import dav_app +from routers.dav.provider import EventLoopRef from service.redis import RedisManager from sqlmodels.database_connection import DatabaseManager from sqlmodels.migration import migration @@ -40,6 +42,9 @@ async def _init_db() -> None: """初始化数据库连接引擎""" await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug) +# 捕获事件循环引用(供 WSGI 线程桥接使用) +lifespan.add_startup(EventLoopRef.capture) + # 添加初始化数据库启动项 lifespan.add_startup(_init_db) lifespan.add_startup(migration) @@ -88,6 +93,9 @@ async def handle_unexpected_exceptions( # 挂载路由 app.include_router(router) +# 挂载 WebDAV 协议端点(优先于 SPA catch-all) +app.mount("/dav", dav_app) + # 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境) if STATICS_DIR.is_dir(): from starlette.staticfiles import StaticFiles diff --git a/pyproject.toml b/pyproject.toml index 1a26fa4..c35dd86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ dependencies = [ "uvicorn>=0.38.0", "webauthn>=2.7.0", "whatthepatch>=1.0.6", + "wsgidav>=4.3.0", + "a2wsgi>=1.10.0", ] [project.optional-dependencies] diff --git a/routers/api/v1/webdav/__init__.py b/routers/api/v1/webdav/__init__.py index 270e2ce..5f9cbe5 100644 --- a/routers/api/v1/webdav/__init__.py +++ b/routers/api/v1/webdav/__init__.py @@ -1,110 +1,207 @@ +from typing import Annotated +from uuid import UUID + from fastapi import APIRouter, Depends +from loguru import logger as l from middleware.auth import auth_required -from sqlmodels import ResponseBase +from middleware.dependencies import SessionDep +from sqlmodels import ( + Object, + User, + WebDAV, + WebDAVAccountResponse, + WebDAVCreateRequest, + WebDAVUpdateRequest, +) +from service.redis.webdav_auth_cache import WebDAVAuthCache from utils import http_exceptions +from utils.password.pwd import Password -# WebDAV 管理路由 webdav_router = APIRouter( prefix='/webdav', tags=["webdav"], ) + +def _check_webdav_enabled(user: User) -> None: + """检查用户组是否启用了 WebDAV 功能""" + if not user.group.web_dav_enabled: + http_exceptions.raise_forbidden("WebDAV 功能未启用") + + +def _to_response(account: WebDAV) -> WebDAVAccountResponse: + """将 WebDAV 数据库模型转换为响应 DTO""" + return WebDAVAccountResponse( + id=account.id, + name=account.name, + root=account.root, + readonly=account.readonly, + use_proxy=account.use_proxy, + created_at=str(account.created_at), + updated_at=str(account.updated_at), + ) + + @webdav_router.get( path='/accounts', - summary='获取账号信息', - description='Get account information for WebDAV.', - dependencies=[Depends(auth_required)], + summary='获取账号列表', ) -def router_webdav_accounts() -> ResponseBase: +async def list_accounts( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], +) -> list[WebDAVAccountResponse]: """ - Get account information for WebDAV. - - Returns: - ResponseBase: A model containing the response data for the account information. + 列出当前用户所有 WebDAV 账户 + + 认证:JWT Bearer Token """ - http_exceptions.raise_not_implemented() + _check_webdav_enabled(user) + user_id: UUID = user.id + + accounts: list[WebDAV] = await WebDAV.get( + session, + WebDAV.user_id == user_id, + fetch_mode="all", + ) + return [_to_response(a) for a in accounts] + @webdav_router.post( path='/accounts', - summary='新建账号', - description='Create a new WebDAV account.', - dependencies=[Depends(auth_required)], + summary='创建账号', + status_code=201, ) -def router_webdav_create_account() -> ResponseBase: +async def create_account( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], + request: WebDAVCreateRequest, +) -> WebDAVAccountResponse: """ - Create a new WebDAV account. - - Returns: - ResponseBase: A model containing the response data for the created account. - """ - http_exceptions.raise_not_implemented() + 创建 WebDAV 账户 -@webdav_router.delete( - path='/accounts/{id}', - summary='删除账号', - description='Delete a WebDAV account by its ID.', - dependencies=[Depends(auth_required)], -) -def router_webdav_delete_account(id: str) -> ResponseBase: - """ - Delete a WebDAV account by its ID. - - Args: - id (str): The ID of the account to be deleted. - - Returns: - ResponseBase: A model containing the response data for the deletion operation. - """ - http_exceptions.raise_not_implemented() + 认证:JWT Bearer Token -@webdav_router.post( - path='/mount', - summary='新建目录挂载', - description='Create a new WebDAV mount point.', - dependencies=[Depends(auth_required)], -) -def router_webdav_create_mount() -> ResponseBase: + 错误处理: + - 403: WebDAV 功能未启用 + - 400: 根目录路径不存在或不是目录 + - 409: 账户名已存在 """ - Create a new WebDAV mount point. - - Returns: - ResponseBase: A model containing the response data for the created mount point. - """ - http_exceptions.raise_not_implemented() + _check_webdav_enabled(user) + user_id: UUID = user.id + + # 验证账户名唯一 + existing = await WebDAV.get( + session, + (WebDAV.name == request.name) & (WebDAV.user_id == user_id), + ) + if existing: + http_exceptions.raise_conflict("账户名已存在") + + # 验证 root 路径存在且为目录 + root_obj = await Object.get_by_path(session, user_id, request.root) + if not root_obj or not root_obj.is_folder: + http_exceptions.raise_bad_request("根目录路径不存在或不是目录") + + # 创建账户 + account = WebDAV( + name=request.name, + password=Password.hash(request.password), + root=request.root, + readonly=request.readonly, + use_proxy=request.use_proxy, + user_id=user_id, + ) + account = await account.save(session) + + l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}") + return _to_response(account) -@webdav_router.delete( - path='/mount/{id}', - summary='删除目录挂载', - description='Delete a WebDAV mount point by its ID.', - dependencies=[Depends(auth_required)], -) -def router_webdav_delete_mount(id: str) -> ResponseBase: - """ - Delete a WebDAV mount point by its ID. - - Args: - id (str): The ID of the mount point to be deleted. - - Returns: - ResponseBase: A model containing the response data for the deletion operation. - """ - http_exceptions.raise_not_implemented() @webdav_router.patch( - path='accounts/{id}', - summary='更新账号信息', - description='Update WebDAV account information by ID.', - dependencies=[Depends(auth_required)], + path='/accounts/{account_id}', + summary='更新账号', ) -def router_webdav_update_account(id: str) -> ResponseBase: +async def update_account( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], + account_id: int, + request: WebDAVUpdateRequest, +) -> WebDAVAccountResponse: """ - Update WebDAV account information by ID. - - Args: - id (str): The ID of the account to be updated. - - Returns: - ResponseBase: A model containing the response data for the updated account. + 更新 WebDAV 账户 + + 认证:JWT Bearer Token + + 错误处理: + - 403: WebDAV 功能未启用 + - 404: 账户不存在 + - 400: 根目录路径不存在或不是目录 """ - http_exceptions.raise_not_implemented() \ No newline at end of file + _check_webdav_enabled(user) + user_id: UUID = user.id + + account = await WebDAV.get( + session, + (WebDAV.id == account_id) & (WebDAV.user_id == user_id), + ) + if not account: + http_exceptions.raise_not_found("WebDAV 账户不存在") + + # 验证 root 路径 + if request.root is not None: + root_obj = await Object.get_by_path(session, user_id, request.root) + if not root_obj or not root_obj.is_folder: + http_exceptions.raise_bad_request("根目录路径不存在或不是目录") + + # 密码哈希后原地替换,update() 会通过 model_dump(exclude_unset=True) 只取已设置字段 + is_password_changed = request.password is not None + if is_password_changed: + request.password = Password.hash(request.password) + + account = await account.update(session, request) + + # 密码变更时清除认证缓存 + if is_password_changed: + await WebDAVAuthCache.invalidate_account(user_id, account.name) + + l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}") + return _to_response(account) + + +@webdav_router.delete( + path='/accounts/{account_id}', + summary='删除账号', + status_code=204, +) +async def delete_account( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], + account_id: int, +) -> None: + """ + 删除 WebDAV 账户 + + 认证:JWT Bearer Token + + 错误处理: + - 403: WebDAV 功能未启用 + - 404: 账户不存在 + """ + _check_webdav_enabled(user) + user_id: UUID = user.id + + account = await WebDAV.get( + session, + (WebDAV.id == account_id) & (WebDAV.user_id == user_id), + ) + if not account: + http_exceptions.raise_not_found("WebDAV 账户不存在") + + account_name = account.name + await WebDAV.delete(session, account) + + # 清除认证缓存 + await WebDAVAuthCache.invalidate_account(user_id, account_name) + + l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}") diff --git a/routers/dav/README.md b/routers/dav/README.md deleted file mode 100644 index 33142d9..0000000 --- a/routers/dav/README.md +++ /dev/null @@ -1 +0,0 @@ -# WebDAV 操作路由 \ No newline at end of file diff --git a/routers/dav/__init__.py b/routers/dav/__init__.py new file mode 100644 index 0000000..a852638 --- /dev/null +++ b/routers/dav/__init__.py @@ -0,0 +1,35 @@ +""" +WebDAV 协议入口 + +使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。 +WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。 +""" +from a2wsgi import WSGIMiddleware +from wsgidav.wsgidav_app import WsgiDAVApp + +from .domain_controller import DiskNextDomainController +from .provider import DiskNextDAVProvider + +_wsgidav_config: dict[str, object] = { + "provider_mapping": { + "/": DiskNextDAVProvider(), + }, + "http_authenticator": { + "domain_controller": DiskNextDomainController, + "accept_basic": True, + "accept_digest": False, + "default_to_digest": False, + }, + "verbose": 1, + # 使用 WsgiDAV 内置的内存锁管理器 + "lock_storage": True, + # 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议) + "dir_browser": { + "enable": False, + }, +} + +_wsgidav_app = WsgiDAVApp(_wsgidav_config) + +dav_app = WSGIMiddleware(_wsgidav_app, workers=10) +"""ASGI 应用,挂载到 /dav 路径""" diff --git a/routers/dav/domain_controller.py b/routers/dav/domain_controller.py new file mode 100644 index 0000000..2afdbab --- /dev/null +++ b/routers/dav/domain_controller.py @@ -0,0 +1,148 @@ +""" +WebDAV 认证控制器 + +实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth +通过 DiskNext 的 WebDAV 账户模型进行认证。 + +用户名格式: {email}/{webdav_account_name} +""" +import asyncio +from uuid import UUID + +from loguru import logger as l +from wsgidav.dc.base_dc import BaseDomainController + +from routers.dav.provider import EventLoopRef, _get_session +from service.redis.webdav_auth_cache import WebDAVAuthCache +from sqlmodels.user import User, UserStatus +from sqlmodels.webdav import WebDAV +from utils.password.pwd import Password, PasswordStatus + + +async def _authenticate( + email: str, + account_name: str, + password: str, +) -> tuple[UUID, int] | None: + """ + 异步认证 WebDAV 用户。 + + :param email: 用户邮箱 + :param account_name: WebDAV 账户名 + :param password: 明文密码 + :return: (user_id, webdav_id) 或 None + """ + # 1. 查缓存 + cached = await WebDAVAuthCache.get(email, account_name, password) + if cached is not None: + return cached + + # 2. 缓存未命中,查库验证 + async with _get_session() as session: + user = await User.get(session, User.email == email, load=User.group) + if not user: + return None + if user.status != UserStatus.ACTIVE: + return None + if not user.group.web_dav_enabled: + return None + + account = await WebDAV.get( + session, + (WebDAV.name == account_name) & (WebDAV.user_id == user.id), + ) + if not account: + return None + + status = Password.verify(account.password, password) + if status == PasswordStatus.INVALID: + return None + + user_id: UUID = user.id + webdav_id: int = account.id + + # 3. 写入缓存 + await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id) + + return user_id, webdav_id + + +class DiskNextDomainController(BaseDomainController): + """ + DiskNext WebDAV 认证控制器 + + 用户名格式: {email}/{webdav_account_name} + 密码: WebDAV 账户密码(创建账户时设置) + """ + + def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None: + super().__init__(wsgidav_app, config) + + def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str: + """返回 realm 名称""" + return "DiskNext WebDAV" + + def require_authentication(self, realm: str, environ: dict[str, object]) -> bool: + """所有请求都需要认证""" + return True + + def is_share_anonymous(self, path_info: str) -> bool: + """不支持匿名访问""" + return False + + def supports_http_digest_auth(self) -> bool: + """不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)""" + return False + + def basic_auth_user( + self, + realm: str, + user_name: str, + password: str, + environ: dict[str, object], + ) -> bool: + """ + HTTP Basic Auth 认证。 + + 用户名格式: {email}/{webdav_account_name} + 在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。 + """ + # 解析用户名 + if "/" not in user_name: + l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'") + return False + + email, account_name = user_name.split("/", 1) + if not email or not account_name: + l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'") + return False + + # 在 WSGI 线程中调用异步认证 + future = asyncio.run_coroutine_threadsafe( + _authenticate(email, account_name, password), + EventLoopRef.get(), + ) + result = future.result() + + if result is None: + l.debug(f"WebDAV 认证失败: {email}/{account_name}") + return False + + user_id, webdav_id = result + + # 将认证信息存入 environ,供 Provider 使用 + environ["disknext.user_id"] = user_id + environ["disknext.webdav_id"] = webdav_id + environ["disknext.email"] = email + environ["disknext.account_name"] = account_name + + return True + + def digest_auth_user( + self, + realm: str, + user_name: str, + environ: dict[str, object], + ) -> bool: + """不支持 Digest 认证""" + return False diff --git a/routers/dav/provider.py b/routers/dav/provider.py new file mode 100644 index 0000000..802db22 --- /dev/null +++ b/routers/dav/provider.py @@ -0,0 +1,594 @@ +""" +DiskNext WebDAV 存储 Provider + +将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。 +所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。 +""" +import asyncio +import io +import mimetypes +from pathlib import Path +from typing import ClassVar +from uuid import UUID + +from loguru import logger as l +from wsgidav.dav_error import ( + DAVError, + HTTP_FORBIDDEN, + HTTP_INSUFFICIENT_STORAGE, + HTTP_NOT_FOUND, +) +from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider + +from service.storage import LocalStorageService, adjust_user_storage +from sqlmodels.database_connection import DatabaseManager +from sqlmodels.object import Object, ObjectType +from sqlmodels.physical_file import PhysicalFile +from sqlmodels.policy import Policy +from sqlmodels.user import User +from sqlmodels.webdav import WebDAV + + +class EventLoopRef: + """持有主线程事件循环引用,供 WSGI 线程使用""" + _loop: ClassVar[asyncio.AbstractEventLoop | None] = None + + @classmethod + async def capture(cls) -> None: + """在 async 上下文中调用,捕获当前事件循环""" + cls._loop = asyncio.get_running_loop() + + @classmethod + def get(cls) -> asyncio.AbstractEventLoop: + if cls._loop is None: + raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()") + return cls._loop + + +def _run_async(coro): # type: ignore[no-untyped-def] + """在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程""" + future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get()) + return future.result() + + +def _get_session(): # type: ignore[no-untyped-def] + """获取数据库会话上下文管理器""" + return DatabaseManager._async_session_factory() + + +# ==================== 异步辅助函数 ==================== + +async def _get_webdav_account(webdav_id: int) -> WebDAV | None: + """获取 WebDAV 账户""" + async with _get_session() as session: + return await WebDAV.get(session, WebDAV.id == webdav_id) + + +async def _get_object_by_path(user_id: UUID, path: str) -> Object | None: + """根据路径获取对象""" + async with _get_session() as session: + return await Object.get_by_path(session, user_id, path) + + +async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]: + """获取目录子对象""" + async with _get_session() as session: + return await Object.get_children(session, user_id, parent_id) + + +async def _get_object_by_id(object_id: UUID) -> Object | None: + """根据ID获取对象""" + async with _get_session() as session: + return await Object.get(session, Object.id == object_id, load=Object.physical_file) + + +async def _get_user(user_id: UUID) -> User | None: + """获取用户(含 group 关系)""" + async with _get_session() as session: + return await User.get(session, User.id == user_id, load=User.group) + + +async def _get_policy(policy_id: UUID) -> Policy | None: + """获取存储策略""" + async with _get_session() as session: + return await Policy.get(session, Policy.id == policy_id) + + +async def _create_folder( + name: str, + parent_id: UUID, + owner_id: UUID, + policy_id: UUID, +) -> Object: + """创建目录对象""" + async with _get_session() as session: + obj = Object( + name=name, + type=ObjectType.FOLDER, + size=0, + parent_id=parent_id, + owner_id=owner_id, + policy_id=policy_id, + ) + obj = await obj.save(session) + return obj + + +async def _create_file( + name: str, + parent_id: UUID, + owner_id: UUID, + policy_id: UUID, +) -> Object: + """创建空文件对象""" + async with _get_session() as session: + obj = Object( + name=name, + type=ObjectType.FILE, + size=0, + parent_id=parent_id, + owner_id=owner_id, + policy_id=policy_id, + ) + obj = await obj.save(session) + return obj + + +async def _soft_delete_object(object_id: UUID) -> None: + """软删除对象(移入回收站)""" + from service.storage import soft_delete_objects + + async with _get_session() as session: + obj = await Object.get(session, Object.id == object_id) + if obj: + await soft_delete_objects(session, [obj]) + + +async def _finalize_upload( + object_id: UUID, + physical_path: str, + size: int, + owner_id: UUID, + policy_id: UUID, +) -> None: + """上传完成后更新对象元数据和物理文件记录""" + async with _get_session() as session: + # 获取存储路径(相对路径) + policy = await Policy.get(session, Policy.id == policy_id) + if not policy or not policy.server: + raise DAVError(HTTP_NOT_FOUND, "存储策略不存在") + + base_path = Path(policy.server).resolve() + full_path = Path(physical_path).resolve() + storage_path = str(full_path.relative_to(base_path)) + + # 创建 PhysicalFile 记录 + pf = PhysicalFile( + storage_path=storage_path, + size=size, + policy_id=policy_id, + reference_count=1, + ) + pf = await pf.save(session) + + # 更新 Object + obj = await Object.get(session, Object.id == object_id) + if obj: + obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id}) + session.add(obj) + await session.commit() + + # 更新用户存储用量 + if size > 0: + await adjust_user_storage(session, owner_id, size) + + +async def _move_object( + object_id: UUID, + new_parent_id: UUID, + new_name: str, +) -> None: + """移动/重命名对象""" + async with _get_session() as session: + obj = await Object.get(session, Object.id == object_id) + if obj: + obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name}) + session.add(obj) + await session.commit() + + +async def _copy_object_recursive( + src_id: UUID, + dst_parent_id: UUID, + dst_name: str, + owner_id: UUID, +) -> None: + """递归复制对象""" + from service.storage import copy_object_recursive + + async with _get_session() as session: + src = await Object.get(session, Object.id == src_id) + if not src: + return + await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name) + + +# ==================== 辅助工具 ==================== + +def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]: + """从 environ 中提取认证信息""" + user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment] + webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment] + return user_id, webdav_id + + +def _resolve_dav_path(account_root: str, dav_path: str) -> str: + """ + 将 DAV 相对路径映射到 DiskNext 绝对路径。 + + :param account_root: 账户挂载根路径,如 "/" 或 "/docs" + :param dav_path: DAV 请求路径,如 "/" 或 "/photos/cat.jpg" + :return: DiskNext 内部路径,如 "/docs/photos/cat.jpg" + """ + # 规范化根路径 + root = account_root.rstrip("/") + if not root: + root = "" + + # 规范化 DAV 路径 + if not dav_path or dav_path == "/": + return root + "/" if root else "/" + + if not dav_path.startswith("/"): + dav_path = "/" + dav_path + + full = root + dav_path + return full if full else "/" + + +def _check_readonly(environ: dict[str, object]) -> None: + """检查账户是否只读,只读则抛出 403""" + account = environ.get("disknext.webdav_account") + if account and getattr(account, 'readonly', False): + raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式") + + +def _check_storage_quota(user: User, additional_bytes: int) -> None: + """检查存储配额""" + max_storage = user.group.max_storage + if max_storage > 0 and user.storage + additional_bytes > max_storage: + raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足") + + +# ==================== Provider ==================== + +class DiskNextDAVProvider(DAVProvider): + """DiskNext WebDAV 存储 Provider""" + + def __init__(self) -> None: + super().__init__() + + def get_resource_inst( + self, + path: str, + environ: dict[str, object], + ) -> 'DiskNextCollection | DiskNextFile | None': + """ + 将 WebDAV 路径映射到资源对象。 + + 首次调用时加载 WebDAV 账户信息并缓存到 environ。 + """ + user_id, webdav_id = _get_environ_info(environ) + + # 首次请求时加载账户信息 + if "disknext.webdav_account" not in environ: + account = _run_async(_get_webdav_account(webdav_id)) + if not account: + return None + environ["disknext.webdav_account"] = account + + account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef] + disknext_path = _resolve_dav_path(account.root, path) + + obj = _run_async(_get_object_by_path(user_id, disknext_path)) + if not obj: + return None + + if obj.is_folder: + return DiskNextCollection(path, environ, obj, user_id, account) + else: + return DiskNextFile(path, environ, obj, user_id, account) + + def is_readonly(self) -> bool: + """只读由账户级别控制,不在 provider 级别限制""" + return False + + +# ==================== Collection(目录) ==================== + +class DiskNextCollection(DAVCollection): + """DiskNext 目录资源""" + + def __init__( + self, + path: str, + environ: dict[str, object], + obj: Object, + user_id: UUID, + account: WebDAV, + ) -> None: + super().__init__(path, environ) + self._obj = obj + self._user_id = user_id + self._account = account + + def get_display_info(self) -> dict[str, str]: + return {"type": "Directory"} + + def get_member_names(self) -> list[str]: + """获取子对象名称列表""" + children = _run_async(_get_children(self._user_id, self._obj.id)) + return [c.name for c in children] + + def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None': + """获取指定名称的子资源""" + member_path = self.path.rstrip("/") + "/" + name + account_root = self._account.root + disknext_path = _resolve_dav_path(account_root, member_path) + + obj = _run_async(_get_object_by_path(self._user_id, disknext_path)) + if not obj: + return None + + if obj.is_folder: + return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account) + else: + return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account) + + def get_creation_date(self) -> float | None: + if self._obj.created_at: + return self._obj.created_at.timestamp() + return None + + def get_last_modified(self) -> float | None: + if self._obj.updated_at: + return self._obj.updated_at.timestamp() + return None + + def create_empty_resource(self, name: str) -> 'DiskNextFile': + """创建空文件(PUT 操作的第一步)""" + _check_readonly(self.environ) + + obj = _run_async(_create_file( + name=name, + parent_id=self._obj.id, + owner_id=self._user_id, + policy_id=self._obj.policy_id, + )) + + member_path = self.path.rstrip("/") + "/" + name + return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account) + + def create_collection(self, name: str) -> 'DiskNextCollection': + """创建子目录(MKCOL)""" + _check_readonly(self.environ) + + obj = _run_async(_create_folder( + name=name, + parent_id=self._obj.id, + owner_id=self._user_id, + policy_id=self._obj.policy_id, + )) + + member_path = self.path.rstrip("/") + "/" + name + return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account) + + def delete(self) -> None: + """软删除目录""" + _check_readonly(self.environ) + _run_async(_soft_delete_object(self._obj.id)) + + def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool: + """复制或移动目录""" + _check_readonly(self.environ) + + account_root = self._account.root + dest_disknext = _resolve_dav_path(account_root, dest_path) + + # 解析目标父路径和新名称 + if "/" in dest_disknext.rstrip("/"): + parent_path = dest_disknext.rsplit("/", 1)[0] or "/" + new_name = dest_disknext.rsplit("/", 1)[1] + else: + parent_path = "/" + new_name = dest_disknext.lstrip("/") + + dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path)) + if not dest_parent: + raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在") + + if is_move: + _run_async(_move_object(self._obj.id, dest_parent.id, new_name)) + else: + _run_async(_copy_object_recursive( + self._obj.id, dest_parent.id, new_name, self._user_id, + )) + + return True + + def support_recursive_delete(self) -> bool: + return True + + def support_recursive_move(self, dest_path: str) -> bool: + return True + + +# ==================== NonCollection(文件) ==================== + +class DiskNextFile(DAVNonCollection): + """DiskNext 文件资源""" + + def __init__( + self, + path: str, + environ: dict[str, object], + obj: Object, + user_id: UUID, + account: WebDAV, + ) -> None: + super().__init__(path, environ) + self._obj = obj + self._user_id = user_id + self._account = account + self._write_path: str | None = None + self._write_stream: io.BufferedWriter | None = None + + def get_content_length(self) -> int | None: + return self._obj.size if self._obj.size else 0 + + def get_content_type(self) -> str | None: + # 尝试从文件名推断 MIME 类型 + mime, _ = mimetypes.guess_type(self._obj.name) + return mime or "application/octet-stream" + + def get_creation_date(self) -> float | None: + if self._obj.created_at: + return self._obj.created_at.timestamp() + return None + + def get_last_modified(self) -> float | None: + if self._obj.updated_at: + return self._obj.updated_at.timestamp() + return None + + def get_display_info(self) -> dict[str, str]: + return {"type": "File"} + + def get_content(self) -> io.BufferedReader | None: + """ + 返回文件内容的可读流。 + + WsgiDAV 在线程中运行,可安全使用同步 open()。 + """ + obj_with_file = _run_async(_get_object_by_id(self._obj.id)) + if not obj_with_file or not obj_with_file.physical_file: + return None + + pf = obj_with_file.physical_file + policy = _run_async(_get_policy(obj_with_file.policy_id)) + if not policy or not policy.server: + return None + + full_path = Path(policy.server).resolve() / pf.storage_path + if not full_path.is_file(): + l.warning(f"WebDAV: 物理文件不存在: {full_path}") + return None + + return open(full_path, "rb") # noqa: SIM115 + + def begin_write(self, *, content_type: str | None = None) -> io.BufferedWriter: + """ + 开始写入文件(PUT 操作)。 + + 返回一个可写的文件流,WsgiDAV 将向其中写入请求体数据。 + """ + _check_readonly(self.environ) + + # 检查配额 + user = _run_async(_get_user(self._user_id)) + if user: + content_length = self.environ.get("CONTENT_LENGTH") + if content_length: + _check_storage_quota(user, int(content_length)) + + # 获取策略以确定存储路径 + policy = _run_async(_get_policy(self._obj.policy_id)) + if not policy or not policy.server: + raise DAVError(HTTP_NOT_FOUND, "存储策略不存在") + + storage_service = LocalStorageService(policy) + dir_path, storage_name, full_path = _run_async( + storage_service.generate_file_path( + user_id=self._user_id, + original_filename=self._obj.name, + ) + ) + + self._write_path = full_path + self._write_stream = open(full_path, "wb") # noqa: SIM115 + return self._write_stream + + def end_write(self, *, with_errors: bool) -> None: + """写入完成后的收尾工作""" + if self._write_stream: + self._write_stream.close() + self._write_stream = None + + if with_errors or not self._write_path: + return + + # 获取文件大小 + file_path = Path(self._write_path) + if not file_path.exists(): + return + + size = file_path.stat().st_size + + # 更新数据库记录 + _run_async(_finalize_upload( + object_id=self._obj.id, + physical_path=self._write_path, + size=size, + owner_id=self._user_id, + policy_id=self._obj.policy_id, + )) + + l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}") + + def delete(self) -> None: + """软删除文件""" + _check_readonly(self.environ) + _run_async(_soft_delete_object(self._obj.id)) + + def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool: + """复制或移动文件""" + _check_readonly(self.environ) + + account_root = self._account.root + dest_disknext = _resolve_dav_path(account_root, dest_path) + + # 解析目标父路径和新名称 + if "/" in dest_disknext.rstrip("/"): + parent_path = dest_disknext.rsplit("/", 1)[0] or "/" + new_name = dest_disknext.rsplit("/", 1)[1] + else: + parent_path = "/" + new_name = dest_disknext.lstrip("/") + + dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path)) + if not dest_parent: + raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在") + + if is_move: + _run_async(_move_object(self._obj.id, dest_parent.id, new_name)) + else: + _run_async(_copy_object_recursive( + self._obj.id, dest_parent.id, new_name, self._user_id, + )) + + return True + + def support_content_length(self) -> bool: + return True + + def get_etag(self) -> str | None: + """返回 ETag(基于ID和更新时间),WsgiDAV 会自动加双引号""" + if self._obj.updated_at: + return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}" + return None + + def support_etag(self) -> bool: + return True + + def support_ranges(self) -> bool: + return True diff --git a/service/redis/webdav_auth_cache.py b/service/redis/webdav_auth_cache.py new file mode 100644 index 0000000..a4a99d4 --- /dev/null +++ b/service/redis/webdav_auth_cache.py @@ -0,0 +1,128 @@ +""" +WebDAV 认证缓存 + +缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。 +支持 Redis(首选)和内存缓存(降级)两种存储后端。 +""" +import hashlib +from typing import ClassVar +from uuid import UUID + +from cachetools import TTLCache +from loguru import logger as l + +from . import RedisManager + +_AUTH_TTL: int = 300 +"""认证缓存 TTL(秒),5 分钟""" + + +class WebDAVAuthCache: + """ + WebDAV 认证结果缓存 + + 缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)} + 缓存值格式: {user_id}:{webdav_id} + + 密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。 + """ + + _memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL) + """内存缓存降级方案""" + + @classmethod + def _build_key(cls, email: str, account_name: str, password: str) -> str: + """构建缓存键""" + pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16] + return f"webdav_auth:{email}/{account_name}:{pwd_hash}" + + @classmethod + async def get( + cls, + email: str, + account_name: str, + password: str, + ) -> tuple[UUID, int] | None: + """ + 查询缓存中的认证结果。 + + :param email: 用户邮箱 + :param account_name: WebDAV 账户名 + :param password: 用户提供的明文密码 + :return: (user_id, webdav_id) 或 None(缓存未命中) + """ + key = cls._build_key(email, account_name, password) + + client = RedisManager.get_client() + if client is not None: + value = await client.get(key) + if value is not None: + raw = value.decode() if isinstance(value, bytes) else value + user_id_str, webdav_id_str = raw.split(":", 1) + return UUID(user_id_str), int(webdav_id_str) + else: + raw = cls._memory_cache.get(key) + if raw is not None: + user_id_str, webdav_id_str = raw.split(":", 1) + return UUID(user_id_str), int(webdav_id_str) + + return None + + @classmethod + async def set( + cls, + email: str, + account_name: str, + password: str, + user_id: UUID, + webdav_id: int, + ) -> None: + """ + 写入认证结果到缓存。 + + :param email: 用户邮箱 + :param account_name: WebDAV 账户名 + :param password: 用户提供的明文密码 + :param user_id: 用户UUID + :param webdav_id: WebDAV 账户ID + """ + key = cls._build_key(email, account_name, password) + value = f"{user_id}:{webdav_id}" + + client = RedisManager.get_client() + if client is not None: + await client.set(key, value, ex=_AUTH_TTL) + else: + cls._memory_cache[key] = value + + @classmethod + async def invalidate_account(cls, user_id: UUID, account_name: str) -> None: + """ + 失效指定账户的所有缓存。 + + 由于缓存键包含 password hash,无法精确删除, + Redis 端使用 pattern scan 删除,内存端清空全部。 + + :param user_id: 用户UUID + :param account_name: WebDAV 账户名 + """ + client = RedisManager.get_client() + if client is not None: + pattern = f"webdav_auth:*/{account_name}:*" + cursor: int = 0 + while True: + cursor, keys = await client.scan(cursor, match=pattern, count=100) + if keys: + await client.delete(*keys) + if cursor == 0: + break + else: + # 内存缓存无法按 pattern 删除,清除所有含该账户名的条目 + keys_to_delete = [ + k for k in cls._memory_cache + if f"/{account_name}:" in k + ] + for k in keys_to_delete: + cls._memory_cache.pop(k, None) + + l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}") diff --git a/sqlmodels/__init__.py b/sqlmodels/__init__.py index 4f6124c..97898f5 100644 --- a/sqlmodels/__init__.py +++ b/sqlmodels/__init__.py @@ -75,7 +75,9 @@ from .object import ( ObjectBase, ObjectCopyRequest, ObjectDeleteRequest, + ObjectFileFinalize, ObjectMoveRequest, + ObjectMoveUpdate, ObjectPropertyDetailResponse, ObjectPropertyResponse, ObjectRenameRequest, @@ -115,7 +117,10 @@ from .source_link import SourceLink from .storage_pack import StoragePack from .tag import Tag, TagType from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary -from .webdav import WebDAV +from .webdav import ( + WebDAV, WebDAVBase, + WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse, +) from .file_app import ( FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault, # DTO diff --git a/sqlmodels/object.py b/sqlmodels/object.py index ada4f8a..e642e7c 100644 --- a/sqlmodels/object.py +++ b/sqlmodels/object.py @@ -78,6 +78,26 @@ class ObjectBase(SQLModelBase): # ==================== DTO 模型 ==================== +class ObjectFileFinalize(SQLModelBase): + """文件上传完成后更新 Object 的 DTO""" + + size: int + """文件大小(字节)""" + + physical_file_id: UUID + """关联的物理文件UUID""" + + +class ObjectMoveUpdate(SQLModelBase): + """移动/重命名 Object 的 DTO""" + + parent_id: UUID + """新的父目录UUID""" + + name: str + """新名称""" + + class DirectoryCreateRequest(SQLModelBase): """创建目录请求 DTO""" diff --git a/sqlmodels/webdav.py b/sqlmodels/webdav.py index b73bce2..7bdabcb 100644 --- a/sqlmodels/webdav.py +++ b/sqlmodels/webdav.py @@ -1,4 +1,9 @@ +""" +WebDAV 账户模型 +管理用户的 WebDAV 连接账户,每个账户对应一个挂载根路径。 +通过 HTTP Basic Auth 认证访问 DAV 协议端点。 +""" from typing import TYPE_CHECKING from uuid import UUID @@ -9,24 +14,104 @@ from sqlmodel_ext import SQLModelBase, TableBaseMixin if TYPE_CHECKING: from .user import User -class WebDAV(SQLModelBase, TableBaseMixin): - """WebDAV账户模型""" + +# ==================== Base 模型 ==================== + +class WebDAVBase(SQLModelBase): + """WebDAV 账户基础字段""" + + name: str = Field(max_length=255) + """账户名称(同一用户下唯一)""" + + root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}) + """挂载根目录路径""" + + readonly: bool = Field(default=False, sa_column_kwargs={"server_default": "false"}) + """是否只读""" + + use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": "false"}) + """是否使用代理下载""" + + +# ==================== 数据库模型 ==================== + +class WebDAV(WebDAVBase, TableBaseMixin): + """WebDAV 账户模型""" __table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),) - name: str = Field(max_length=255, description="WebDAV账户名") - password: str = Field(max_length=255, description="WebDAV密码") - root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径") - readonly: bool = Field(default=False, description="是否只读") - use_proxy: bool = Field(default=False, description="是否使用代理下载") - + password: str = Field(max_length=255) + """密码(Argon2 哈希)""" + # 外键 user_id: UUID = Field( foreign_key="user.id", index=True, - ondelete="CASCADE" + ondelete="CASCADE", ) """所属用户UUID""" - + # 关系 - user: "User" = Relationship(back_populates="webdavs") \ No newline at end of file + user: "User" = Relationship(back_populates="webdavs") + + +# ==================== DTO 模型 ==================== + +class WebDAVCreateRequest(SQLModelBase): + """创建 WebDAV 账户请求""" + + name: str = Field(max_length=255) + """账户名称""" + + password: str = Field(min_length=1, max_length=255) + """账户密码(明文,服务端哈希后存储)""" + + root: str = "/" + """挂载根目录路径""" + + readonly: bool = False + """是否只读""" + + use_proxy: bool = False + """是否使用代理下载""" + + +class WebDAVUpdateRequest(SQLModelBase): + """更新 WebDAV 账户请求""" + + password: str | None = Field(default=None, min_length=1, max_length=255) + """新密码(为 None 时不修改)""" + + root: str | None = None + """新挂载根目录路径(为 None 时不修改)""" + + readonly: bool | None = None + """是否只读(为 None 时不修改)""" + + use_proxy: bool | None = None + """是否使用代理下载(为 None 时不修改)""" + + +class WebDAVAccountResponse(SQLModelBase): + """WebDAV 账户响应""" + + id: int + """账户ID""" + + name: str + """账户名称""" + + root: str + """挂载根目录路径""" + + readonly: bool + """是否只读""" + + use_proxy: bool + """是否使用代理下载""" + + created_at: str + """创建时间""" + + updated_at: str + """更新时间""" diff --git a/tests/integration/api/test_webdav.py b/tests/integration/api/test_webdav.py new file mode 100644 index 0000000..597778e --- /dev/null +++ b/tests/integration/api/test_webdav.py @@ -0,0 +1,591 @@ +""" +WebDAV 账户管理端点集成测试 +""" +from uuid import UUID, uuid4 + +import pytest +import pytest_asyncio +from httpx import AsyncClient +from sqlmodel.ext.asyncio.session import AsyncSession + +from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, User +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType +from sqlmodels.user import UserStatus +from utils import Password +from utils.JWT import create_access_token + +API_PREFIX = "/api/v1/webdav" + + +# ==================== Fixtures ==================== + +@pytest_asyncio.fixture +async def no_webdav_headers(initialized_db: AsyncSession) -> dict[str, str]: + """创建一个 WebDAV 被禁用的用户,返回其认证头""" + group = Group( + id=uuid4(), + name="无WebDAV用户组", + max_storage=1024 * 1024 * 1024, + share_enabled=True, + web_dav_enabled=False, + admin=False, + speed_limit=0, + ) + initialized_db.add(group) + await initialized_db.commit() + await initialized_db.refresh(group) + + group_options = GroupOptions( + group_id=group.id, + share_download=True, + share_free=False, + relocate=False, + source_batch=0, + select_node=False, + advance_delete=False, + ) + initialized_db.add(group_options) + await initialized_db.commit() + await initialized_db.refresh(group_options) + + user = User( + id=uuid4(), + email="nowebdav@test.local", + nickname="无WebDAV用户", + status=UserStatus.ACTIVE, + storage=0, + score=0, + group_id=group.id, + avatar="default", + ) + initialized_db.add(user) + await initialized_db.commit() + await initialized_db.refresh(user) + + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="nowebdav@test.local", + credential=Password.hash("nowebdav123"), + is_primary=True, + is_verified=True, + user_id=user.id, + ) + initialized_db.add(identity) + + from sqlmodels import Policy + policy = await Policy.get(initialized_db, Policy.name == "本地存储") + + root = Object( + id=uuid4(), + name="/", + type=ObjectType.FOLDER, + owner_id=user.id, + parent_id=None, + policy_id=policy.id, + size=0, + ) + initialized_db.add(root) + await initialized_db.commit() + + group.options = group_options + group_claims = GroupClaims.from_group(group) + result = create_access_token( + sub=user.id, + jti=uuid4(), + status=user.status.value, + group=group_claims, + ) + return {"Authorization": f"Bearer {result.access_token}"} + + +# ==================== 认证测试 ==================== + +@pytest.mark.asyncio +async def test_list_accounts_requires_auth(async_client: AsyncClient): + """测试获取账户列表需要认证""" + response = await async_client.get(f"{API_PREFIX}/accounts") + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_create_account_requires_auth(async_client: AsyncClient): + """测试创建账户需要认证""" + response = await async_client.post( + f"{API_PREFIX}/accounts", + json={"name": "test", "password": "testpass"}, + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_update_account_requires_auth(async_client: AsyncClient): + """测试更新账户需要认证""" + response = await async_client.patch( + f"{API_PREFIX}/accounts/1", + json={"readonly": True}, + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_delete_account_requires_auth(async_client: AsyncClient): + """测试删除账户需要认证""" + response = await async_client.delete(f"{API_PREFIX}/accounts/1") + assert response.status_code == 401 + + +# ==================== WebDAV 禁用测试 ==================== + +@pytest.mark.asyncio +async def test_list_accounts_webdav_disabled( + async_client: AsyncClient, + no_webdav_headers: dict[str, str], +): + """测试 WebDAV 被禁用时返回 403""" + response = await async_client.get( + f"{API_PREFIX}/accounts", + headers=no_webdav_headers, + ) + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_create_account_webdav_disabled( + async_client: AsyncClient, + no_webdav_headers: dict[str, str], +): + """测试 WebDAV 被禁用时创建账户返回 403""" + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=no_webdav_headers, + json={"name": "test", "password": "testpass"}, + ) + assert response.status_code == 403 + + +# ==================== 获取账户列表测试 ==================== + +@pytest.mark.asyncio +async def test_list_accounts_empty( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试初始状态账户列表为空""" + response = await async_client.get( + f"{API_PREFIX}/accounts", + headers=auth_headers, + ) + assert response.status_code == 200 + assert response.json() == [] + + +# ==================== 创建账户测试 ==================== + +@pytest.mark.asyncio +async def test_create_account_success( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试成功创建 WebDAV 账户""" + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "my-nas", "password": "secretpass"}, + ) + assert response.status_code == 201 + + data = response.json() + assert data["name"] == "my-nas" + assert data["root"] == "/" + assert data["readonly"] is False + assert data["use_proxy"] is False + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + + +@pytest.mark.asyncio +async def test_create_account_with_options( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试创建带选项的 WebDAV 账户""" + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={ + "name": "readonly-nas", + "password": "secretpass", + "readonly": True, + "use_proxy": True, + }, + ) + assert response.status_code == 201 + + data = response.json() + assert data["name"] == "readonly-nas" + assert data["readonly"] is True + assert data["use_proxy"] is True + + +@pytest.mark.asyncio +async def test_create_account_duplicate_name( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试重名账户返回 409""" + # 先创建一个 + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "dup-test", "password": "pass1"}, + ) + assert response.status_code == 201 + + # 再创建同名的 + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "dup-test", "password": "pass2"}, + ) + assert response.status_code == 409 + + +@pytest.mark.asyncio +async def test_create_account_invalid_root( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试无效根目录路径返回 400""" + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={ + "name": "bad-root", + "password": "secretpass", + "root": "/nonexistent/path", + }, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_create_account_with_valid_subdir( + async_client: AsyncClient, + auth_headers: dict[str, str], + test_directory_structure: dict[str, UUID], +): + """测试使用有效的子目录作为根路径""" + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={ + "name": "docs-only", + "password": "secretpass", + "root": "/docs", + }, + ) + assert response.status_code == 201 + assert response.json()["root"] == "/docs" + + +# ==================== 列表包含已创建账户测试 ==================== + +@pytest.mark.asyncio +async def test_list_accounts_after_create( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试创建后列表中包含该账户""" + # 创建 + await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "list-test", "password": "pass"}, + ) + + # 列表 + response = await async_client.get( + f"{API_PREFIX}/accounts", + headers=auth_headers, + ) + assert response.status_code == 200 + accounts = response.json() + assert len(accounts) == 1 + assert accounts[0]["name"] == "list-test" + + +# ==================== 更新账户测试 ==================== + +@pytest.mark.asyncio +async def test_update_account_success( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试成功更新 WebDAV 账户""" + # 创建 + create_resp = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "update-test", "password": "oldpass"}, + ) + account_id = create_resp.json()["id"] + + # 更新 + response = await async_client.patch( + f"{API_PREFIX}/accounts/{account_id}", + headers=auth_headers, + json={"readonly": True}, + ) + assert response.status_code == 200 + + data = response.json() + assert data["readonly"] is True + assert data["name"] == "update-test" + + +@pytest.mark.asyncio +async def test_update_account_password( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试更新密码""" + # 创建 + create_resp = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "pwd-test", "password": "oldpass"}, + ) + account_id = create_resp.json()["id"] + + # 更新密码 + response = await async_client.patch( + f"{API_PREFIX}/accounts/{account_id}", + headers=auth_headers, + json={"password": "newpass123"}, + ) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_update_account_root( + async_client: AsyncClient, + auth_headers: dict[str, str], + test_directory_structure: dict[str, UUID], +): + """测试更新根目录路径""" + # 创建 + create_resp = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "root-update", "password": "pass"}, + ) + account_id = create_resp.json()["id"] + + # 更新 root 到有效子目录 + response = await async_client.patch( + f"{API_PREFIX}/accounts/{account_id}", + headers=auth_headers, + json={"root": "/docs"}, + ) + assert response.status_code == 200 + assert response.json()["root"] == "/docs" + + +@pytest.mark.asyncio +async def test_update_account_invalid_root( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试更新为无效根目录返回 400""" + # 创建 + create_resp = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "bad-root-update", "password": "pass"}, + ) + account_id = create_resp.json()["id"] + + # 更新到无效路径 + response = await async_client.patch( + f"{API_PREFIX}/accounts/{account_id}", + headers=auth_headers, + json={"root": "/nonexistent"}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_update_account_not_found( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试更新不存在的账户返回 404""" + response = await async_client.patch( + f"{API_PREFIX}/accounts/99999", + headers=auth_headers, + json={"readonly": True}, + ) + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_update_other_user_account( + async_client: AsyncClient, + auth_headers: dict[str, str], + admin_headers: dict[str, str], +): + """测试更新其他用户的账户返回 404""" + # 管理员创建账户 + create_resp = await async_client.post( + f"{API_PREFIX}/accounts", + headers=admin_headers, + json={"name": "admin-account", "password": "pass"}, + ) + account_id = create_resp.json()["id"] + + # 普通用户尝试更新 + response = await async_client.patch( + f"{API_PREFIX}/accounts/{account_id}", + headers=auth_headers, + json={"readonly": True}, + ) + assert response.status_code == 404 + + +# ==================== 删除账户测试 ==================== + +@pytest.mark.asyncio +async def test_delete_account_success( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试成功删除 WebDAV 账户""" + # 创建 + create_resp = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "delete-test", "password": "pass"}, + ) + account_id = create_resp.json()["id"] + + # 删除 + response = await async_client.delete( + f"{API_PREFIX}/accounts/{account_id}", + headers=auth_headers, + ) + assert response.status_code == 204 + + # 确认列表中已不存在 + list_resp = await async_client.get( + f"{API_PREFIX}/accounts", + headers=auth_headers, + ) + assert list_resp.status_code == 200 + names = [a["name"] for a in list_resp.json()] + assert "delete-test" not in names + + +@pytest.mark.asyncio +async def test_delete_account_not_found( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试删除不存在的账户返回 404""" + response = await async_client.delete( + f"{API_PREFIX}/accounts/99999", + headers=auth_headers, + ) + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_delete_other_user_account( + async_client: AsyncClient, + auth_headers: dict[str, str], + admin_headers: dict[str, str], +): + """测试删除其他用户的账户返回 404""" + # 管理员创建账户 + create_resp = await async_client.post( + f"{API_PREFIX}/accounts", + headers=admin_headers, + json={"name": "admin-del-test", "password": "pass"}, + ) + account_id = create_resp.json()["id"] + + # 普通用户尝试删除 + response = await async_client.delete( + f"{API_PREFIX}/accounts/{account_id}", + headers=auth_headers, + ) + assert response.status_code == 404 + + +# ==================== 多账户测试 ==================== + +@pytest.mark.asyncio +async def test_multiple_accounts( + async_client: AsyncClient, + auth_headers: dict[str, str], +): + """测试同一用户可以创建多个账户""" + for name in ["account-1", "account-2", "account-3"]: + response = await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": name, "password": "pass"}, + ) + assert response.status_code == 201 + + # 列表应有3个 + response = await async_client.get( + f"{API_PREFIX}/accounts", + headers=auth_headers, + ) + assert response.status_code == 200 + assert len(response.json()) == 3 + + +# ==================== 用户隔离测试 ==================== + +@pytest.mark.asyncio +async def test_accounts_user_isolation( + async_client: AsyncClient, + auth_headers: dict[str, str], + admin_headers: dict[str, str], +): + """测试不同用户的账户相互隔离""" + # 普通用户创建 + await async_client.post( + f"{API_PREFIX}/accounts", + headers=auth_headers, + json={"name": "user-account", "password": "pass"}, + ) + + # 管理员创建 + await async_client.post( + f"{API_PREFIX}/accounts", + headers=admin_headers, + json={"name": "admin-account", "password": "pass"}, + ) + + # 普通用户只看到自己的 + response = await async_client.get( + f"{API_PREFIX}/accounts", + headers=auth_headers, + ) + assert response.status_code == 200 + accounts = response.json() + assert len(accounts) == 1 + assert accounts[0]["name"] == "user-account" + + # 管理员只看到自己的 + response = await async_client.get( + f"{API_PREFIX}/accounts", + headers=admin_headers, + ) + assert response.status_code == 200 + accounts = response.json() + assert len(accounts) == 1 + assert accounts[0]["name"] == "admin-account" diff --git a/uv.lock b/uv.lock index d6d2623..ca0fcb9 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,15 @@ resolution-markers = [ "python_full_version < '3.14'", ] +[[package]] +name = "a2wsgi" +version = "1.10.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/cb/822c56fbea97e9eee201a2e434a80437f6750ebcb1ed307ee3a0a7505b14/a2wsgi-1.10.10.tar.gz", hash = "sha256:a5bcffb52081ba39df0d5e9a884fc6f819d92e3a42389343ba77cbf809fe1f45", size = 18799, upload-time = "2025-06-18T09:00:10.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/d5/349aba3dc421e73cbd4958c0ce0a4f1aa3a738bc0d7de75d2f40ed43a535/a2wsgi-1.10.10-py3-none-any.whl", hash = "sha256:d2b21379479718539dc15fce53b876251a0efe7615352dfe49f6ad1bc507848d", size = 17389, upload-time = "2025-06-18T09:00:09.676Z" }, +] + [[package]] name = "aiofiles" version = "25.1.0" @@ -500,11 +509,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/fa/d3c15189f7c52aaefbaea76fb012119b04b9013f4bf446cb4eb4c26c4e6b/cython-3.2.4-py3-none-any.whl", hash = "sha256:732fc93bc33ae4b14f6afaca663b916c2fdd5dcbfad7114e17fb2434eeaea45c", size = 1257078, upload-time = "2026-01-04T14:14:12.373Z" }, ] +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, +] + [[package]] name = "disknext-server" version = "0.0.1" source = { virtual = "." } dependencies = [ + { name = "a2wsgi" }, { name = "aiofiles" }, { name = "aiohttp" }, { name = "aiosqlite" }, @@ -533,6 +552,7 @@ dependencies = [ { name = "uvicorn" }, { name = "webauthn" }, { name = "whatthepatch" }, + { name = "wsgidav" }, ] [package.optional-dependencies] @@ -543,6 +563,7 @@ build = [ [package.metadata] requires-dist = [ + { name = "a2wsgi", specifier = ">=1.10.0" }, { name = "aiofiles", specifier = ">=25.1.0" }, { name = "aiohttp", specifier = ">=3.13.2" }, { name = "aiosqlite", specifier = "==0.22.1" }, @@ -573,6 +594,7 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.38.0" }, { name = "webauthn", specifier = ">=2.7.0" }, { name = "whatthepatch", specifier = ">=1.0.6" }, + { name = "wsgidav", specifier = ">=4.3.0" }, ] provides-extras = ["build"] @@ -975,6 +997,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "json5" +version = "0.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/77/e8/a3f261a66e4663f22700bc8a17c08cb83e91fbf086726e7a228398968981/json5-0.13.0.tar.gz", hash = "sha256:b1edf8d487721c0bf64d83c28e91280781f6e21f4a797d3261c7c828d4c165bf", size = 52441, upload-time = "2026-01-01T19:42:14.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/9e/038522f50ceb7e74f1f991bf1b699f24b0c2bbe7c390dd36ad69f4582258/json5-0.13.0-py3-none-any.whl", hash = "sha256:9a08e1dd65f6a4d4c6fa82d216cf2477349ec2346a38fd70cc11d2557499fbcc", size = 36163, upload-time = "2026-01-01T19:42:13.962Z" }, +] + [[package]] name = "loguru" version = "0.7.3" @@ -2049,6 +2080,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083, upload-time = "2024-12-07T15:28:26.465Z" }, ] +[[package]] +name = "wsgidav" +version = "4.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "defusedxml" }, + { name = "jinja2" }, + { name = "json5" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/f4/9c89e3e41dc7762cbb005d1baf23381718c7b13607236eacda23b855a288/wsgidav-4.3.3.tar.gz", hash = "sha256:5f0ad71bea72def3018b6ba52da3bcb83f61e0873c27225344582805d6e52b9e", size = 168118, upload-time = "2024-05-04T18:28:01.199Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/8e/04fb92513f4deab0f9bf4bdeeebc74f12d4de75ff00ad213c69983fc6563/WsgiDAV-4.3.3-py3-none-any.whl", hash = "sha256:8d96b0f05ad7f280572e99d1c605962a853d715f8e934298555d0c47ef275e88", size = 164954, upload-time = "2024-05-04T18:27:57.718Z" }, +] + [[package]] name = "yarl" version = "1.22.0"