All checks were successful
Test / test (push) Successful in 2m14s
Add complete WebDAV support: management REST API (CRUD accounts at /api/v1/webdav/accounts) and DAV protocol endpoint (/dav) using WsgiDAV + a2wsgi bridge for client access via HTTP Basic Auth. Includes Redis+TTLCache auth caching and integration tests (24 cases). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
595 lines
19 KiB
Python
595 lines
19 KiB
Python
"""
|
||
DiskNext WebDAV 存储 Provider
|
||
|
||
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
|
||
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
|
||
"""
|
||
import asyncio
|
||
import io
|
||
import mimetypes
|
||
from pathlib import Path
|
||
from typing import ClassVar
|
||
from uuid import UUID
|
||
|
||
from loguru import logger as l
|
||
from wsgidav.dav_error import (
|
||
DAVError,
|
||
HTTP_FORBIDDEN,
|
||
HTTP_INSUFFICIENT_STORAGE,
|
||
HTTP_NOT_FOUND,
|
||
)
|
||
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
|
||
|
||
from service.storage import LocalStorageService, adjust_user_storage
|
||
from sqlmodels.database_connection import DatabaseManager
|
||
from sqlmodels.object import Object, ObjectType
|
||
from sqlmodels.physical_file import PhysicalFile
|
||
from sqlmodels.policy import Policy
|
||
from sqlmodels.user import User
|
||
from sqlmodels.webdav import WebDAV
|
||
|
||
|
||
class EventLoopRef:
|
||
"""持有主线程事件循环引用,供 WSGI 线程使用"""
|
||
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
|
||
|
||
@classmethod
|
||
async def capture(cls) -> None:
|
||
"""在 async 上下文中调用,捕获当前事件循环"""
|
||
cls._loop = asyncio.get_running_loop()
|
||
|
||
@classmethod
|
||
def get(cls) -> asyncio.AbstractEventLoop:
|
||
if cls._loop is None:
|
||
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
|
||
return cls._loop
|
||
|
||
|
||
def _run_async(coro): # type: ignore[no-untyped-def]
|
||
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
|
||
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
|
||
return future.result()
|
||
|
||
|
||
def _get_session(): # type: ignore[no-untyped-def]
|
||
"""获取数据库会话上下文管理器"""
|
||
return DatabaseManager._async_session_factory()
|
||
|
||
|
||
# ==================== 异步辅助函数 ====================
|
||
|
||
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
|
||
"""获取 WebDAV 账户"""
|
||
async with _get_session() as session:
|
||
return await WebDAV.get(session, WebDAV.id == webdav_id)
|
||
|
||
|
||
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
|
||
"""根据路径获取对象"""
|
||
async with _get_session() as session:
|
||
return await Object.get_by_path(session, user_id, path)
|
||
|
||
|
||
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
|
||
"""获取目录子对象"""
|
||
async with _get_session() as session:
|
||
return await Object.get_children(session, user_id, parent_id)
|
||
|
||
|
||
async def _get_object_by_id(object_id: UUID) -> Object | None:
|
||
"""根据ID获取对象"""
|
||
async with _get_session() as session:
|
||
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
|
||
|
||
|
||
async def _get_user(user_id: UUID) -> User | None:
|
||
"""获取用户(含 group 关系)"""
|
||
async with _get_session() as session:
|
||
return await User.get(session, User.id == user_id, load=User.group)
|
||
|
||
|
||
async def _get_policy(policy_id: UUID) -> Policy | None:
|
||
"""获取存储策略"""
|
||
async with _get_session() as session:
|
||
return await Policy.get(session, Policy.id == policy_id)
|
||
|
||
|
||
async def _create_folder(
|
||
name: str,
|
||
parent_id: UUID,
|
||
owner_id: UUID,
|
||
policy_id: UUID,
|
||
) -> Object:
|
||
"""创建目录对象"""
|
||
async with _get_session() as session:
|
||
obj = Object(
|
||
name=name,
|
||
type=ObjectType.FOLDER,
|
||
size=0,
|
||
parent_id=parent_id,
|
||
owner_id=owner_id,
|
||
policy_id=policy_id,
|
||
)
|
||
obj = await obj.save(session)
|
||
return obj
|
||
|
||
|
||
async def _create_file(
|
||
name: str,
|
||
parent_id: UUID,
|
||
owner_id: UUID,
|
||
policy_id: UUID,
|
||
) -> Object:
|
||
"""创建空文件对象"""
|
||
async with _get_session() as session:
|
||
obj = Object(
|
||
name=name,
|
||
type=ObjectType.FILE,
|
||
size=0,
|
||
parent_id=parent_id,
|
||
owner_id=owner_id,
|
||
policy_id=policy_id,
|
||
)
|
||
obj = await obj.save(session)
|
||
return obj
|
||
|
||
|
||
async def _soft_delete_object(object_id: UUID) -> None:
|
||
"""软删除对象(移入回收站)"""
|
||
from service.storage import soft_delete_objects
|
||
|
||
async with _get_session() as session:
|
||
obj = await Object.get(session, Object.id == object_id)
|
||
if obj:
|
||
await soft_delete_objects(session, [obj])
|
||
|
||
|
||
async def _finalize_upload(
|
||
object_id: UUID,
|
||
physical_path: str,
|
||
size: int,
|
||
owner_id: UUID,
|
||
policy_id: UUID,
|
||
) -> None:
|
||
"""上传完成后更新对象元数据和物理文件记录"""
|
||
async with _get_session() as session:
|
||
# 获取存储路径(相对路径)
|
||
policy = await Policy.get(session, Policy.id == policy_id)
|
||
if not policy or not policy.server:
|
||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||
|
||
base_path = Path(policy.server).resolve()
|
||
full_path = Path(physical_path).resolve()
|
||
storage_path = str(full_path.relative_to(base_path))
|
||
|
||
# 创建 PhysicalFile 记录
|
||
pf = PhysicalFile(
|
||
storage_path=storage_path,
|
||
size=size,
|
||
policy_id=policy_id,
|
||
reference_count=1,
|
||
)
|
||
pf = await pf.save(session)
|
||
|
||
# 更新 Object
|
||
obj = await Object.get(session, Object.id == object_id)
|
||
if obj:
|
||
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
|
||
session.add(obj)
|
||
await session.commit()
|
||
|
||
# 更新用户存储用量
|
||
if size > 0:
|
||
await adjust_user_storage(session, owner_id, size)
|
||
|
||
|
||
async def _move_object(
|
||
object_id: UUID,
|
||
new_parent_id: UUID,
|
||
new_name: str,
|
||
) -> None:
|
||
"""移动/重命名对象"""
|
||
async with _get_session() as session:
|
||
obj = await Object.get(session, Object.id == object_id)
|
||
if obj:
|
||
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
|
||
session.add(obj)
|
||
await session.commit()
|
||
|
||
|
||
async def _copy_object_recursive(
|
||
src_id: UUID,
|
||
dst_parent_id: UUID,
|
||
dst_name: str,
|
||
owner_id: UUID,
|
||
) -> None:
|
||
"""递归复制对象"""
|
||
from service.storage import copy_object_recursive
|
||
|
||
async with _get_session() as session:
|
||
src = await Object.get(session, Object.id == src_id)
|
||
if not src:
|
||
return
|
||
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
|
||
|
||
|
||
# ==================== 辅助工具 ====================
|
||
|
||
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
|
||
"""从 environ 中提取认证信息"""
|
||
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
|
||
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
|
||
return user_id, webdav_id
|
||
|
||
|
||
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
|
||
"""
|
||
将 DAV 相对路径映射到 DiskNext 绝对路径。
|
||
|
||
:param account_root: 账户挂载根路径,如 "/" 或 "/docs"
|
||
:param dav_path: DAV 请求路径,如 "/" 或 "/photos/cat.jpg"
|
||
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
|
||
"""
|
||
# 规范化根路径
|
||
root = account_root.rstrip("/")
|
||
if not root:
|
||
root = ""
|
||
|
||
# 规范化 DAV 路径
|
||
if not dav_path or dav_path == "/":
|
||
return root + "/" if root else "/"
|
||
|
||
if not dav_path.startswith("/"):
|
||
dav_path = "/" + dav_path
|
||
|
||
full = root + dav_path
|
||
return full if full else "/"
|
||
|
||
|
||
def _check_readonly(environ: dict[str, object]) -> None:
|
||
"""检查账户是否只读,只读则抛出 403"""
|
||
account = environ.get("disknext.webdav_account")
|
||
if account and getattr(account, 'readonly', False):
|
||
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
|
||
|
||
|
||
def _check_storage_quota(user: User, additional_bytes: int) -> None:
|
||
"""检查存储配额"""
|
||
max_storage = user.group.max_storage
|
||
if max_storage > 0 and user.storage + additional_bytes > max_storage:
|
||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||
|
||
|
||
# ==================== Provider ====================
|
||
|
||
class DiskNextDAVProvider(DAVProvider):
|
||
"""DiskNext WebDAV 存储 Provider"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
def get_resource_inst(
|
||
self,
|
||
path: str,
|
||
environ: dict[str, object],
|
||
) -> 'DiskNextCollection | DiskNextFile | None':
|
||
"""
|
||
将 WebDAV 路径映射到资源对象。
|
||
|
||
首次调用时加载 WebDAV 账户信息并缓存到 environ。
|
||
"""
|
||
user_id, webdav_id = _get_environ_info(environ)
|
||
|
||
# 首次请求时加载账户信息
|
||
if "disknext.webdav_account" not in environ:
|
||
account = _run_async(_get_webdav_account(webdav_id))
|
||
if not account:
|
||
return None
|
||
environ["disknext.webdav_account"] = account
|
||
|
||
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
|
||
disknext_path = _resolve_dav_path(account.root, path)
|
||
|
||
obj = _run_async(_get_object_by_path(user_id, disknext_path))
|
||
if not obj:
|
||
return None
|
||
|
||
if obj.is_folder:
|
||
return DiskNextCollection(path, environ, obj, user_id, account)
|
||
else:
|
||
return DiskNextFile(path, environ, obj, user_id, account)
|
||
|
||
def is_readonly(self) -> bool:
|
||
"""只读由账户级别控制,不在 provider 级别限制"""
|
||
return False
|
||
|
||
|
||
# ==================== Collection(目录) ====================
|
||
|
||
class DiskNextCollection(DAVCollection):
|
||
"""DiskNext 目录资源"""
|
||
|
||
def __init__(
|
||
self,
|
||
path: str,
|
||
environ: dict[str, object],
|
||
obj: Object,
|
||
user_id: UUID,
|
||
account: WebDAV,
|
||
) -> None:
|
||
super().__init__(path, environ)
|
||
self._obj = obj
|
||
self._user_id = user_id
|
||
self._account = account
|
||
|
||
def get_display_info(self) -> dict[str, str]:
|
||
return {"type": "Directory"}
|
||
|
||
def get_member_names(self) -> list[str]:
|
||
"""获取子对象名称列表"""
|
||
children = _run_async(_get_children(self._user_id, self._obj.id))
|
||
return [c.name for c in children]
|
||
|
||
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
|
||
"""获取指定名称的子资源"""
|
||
member_path = self.path.rstrip("/") + "/" + name
|
||
account_root = self._account.root
|
||
disknext_path = _resolve_dav_path(account_root, member_path)
|
||
|
||
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
|
||
if not obj:
|
||
return None
|
||
|
||
if obj.is_folder:
|
||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||
else:
|
||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||
|
||
def get_creation_date(self) -> float | None:
|
||
if self._obj.created_at:
|
||
return self._obj.created_at.timestamp()
|
||
return None
|
||
|
||
def get_last_modified(self) -> float | None:
|
||
if self._obj.updated_at:
|
||
return self._obj.updated_at.timestamp()
|
||
return None
|
||
|
||
def create_empty_resource(self, name: str) -> 'DiskNextFile':
|
||
"""创建空文件(PUT 操作的第一步)"""
|
||
_check_readonly(self.environ)
|
||
|
||
obj = _run_async(_create_file(
|
||
name=name,
|
||
parent_id=self._obj.id,
|
||
owner_id=self._user_id,
|
||
policy_id=self._obj.policy_id,
|
||
))
|
||
|
||
member_path = self.path.rstrip("/") + "/" + name
|
||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||
|
||
def create_collection(self, name: str) -> 'DiskNextCollection':
|
||
"""创建子目录(MKCOL)"""
|
||
_check_readonly(self.environ)
|
||
|
||
obj = _run_async(_create_folder(
|
||
name=name,
|
||
parent_id=self._obj.id,
|
||
owner_id=self._user_id,
|
||
policy_id=self._obj.policy_id,
|
||
))
|
||
|
||
member_path = self.path.rstrip("/") + "/" + name
|
||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||
|
||
def delete(self) -> None:
|
||
"""软删除目录"""
|
||
_check_readonly(self.environ)
|
||
_run_async(_soft_delete_object(self._obj.id))
|
||
|
||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||
"""复制或移动目录"""
|
||
_check_readonly(self.environ)
|
||
|
||
account_root = self._account.root
|
||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||
|
||
# 解析目标父路径和新名称
|
||
if "/" in dest_disknext.rstrip("/"):
|
||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||
else:
|
||
parent_path = "/"
|
||
new_name = dest_disknext.lstrip("/")
|
||
|
||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||
if not dest_parent:
|
||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||
|
||
if is_move:
|
||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||
else:
|
||
_run_async(_copy_object_recursive(
|
||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||
))
|
||
|
||
return True
|
||
|
||
def support_recursive_delete(self) -> bool:
|
||
return True
|
||
|
||
def support_recursive_move(self, dest_path: str) -> bool:
|
||
return True
|
||
|
||
|
||
# ==================== NonCollection(文件) ====================
|
||
|
||
class DiskNextFile(DAVNonCollection):
|
||
"""DiskNext 文件资源"""
|
||
|
||
def __init__(
|
||
self,
|
||
path: str,
|
||
environ: dict[str, object],
|
||
obj: Object,
|
||
user_id: UUID,
|
||
account: WebDAV,
|
||
) -> None:
|
||
super().__init__(path, environ)
|
||
self._obj = obj
|
||
self._user_id = user_id
|
||
self._account = account
|
||
self._write_path: str | None = None
|
||
self._write_stream: io.BufferedWriter | None = None
|
||
|
||
def get_content_length(self) -> int | None:
|
||
return self._obj.size if self._obj.size else 0
|
||
|
||
def get_content_type(self) -> str | None:
|
||
# 尝试从文件名推断 MIME 类型
|
||
mime, _ = mimetypes.guess_type(self._obj.name)
|
||
return mime or "application/octet-stream"
|
||
|
||
def get_creation_date(self) -> float | None:
|
||
if self._obj.created_at:
|
||
return self._obj.created_at.timestamp()
|
||
return None
|
||
|
||
def get_last_modified(self) -> float | None:
|
||
if self._obj.updated_at:
|
||
return self._obj.updated_at.timestamp()
|
||
return None
|
||
|
||
def get_display_info(self) -> dict[str, str]:
|
||
return {"type": "File"}
|
||
|
||
def get_content(self) -> io.BufferedReader | None:
|
||
"""
|
||
返回文件内容的可读流。
|
||
|
||
WsgiDAV 在线程中运行,可安全使用同步 open()。
|
||
"""
|
||
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
|
||
if not obj_with_file or not obj_with_file.physical_file:
|
||
return None
|
||
|
||
pf = obj_with_file.physical_file
|
||
policy = _run_async(_get_policy(obj_with_file.policy_id))
|
||
if not policy or not policy.server:
|
||
return None
|
||
|
||
full_path = Path(policy.server).resolve() / pf.storage_path
|
||
if not full_path.is_file():
|
||
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
|
||
return None
|
||
|
||
return open(full_path, "rb") # noqa: SIM115
|
||
|
||
def begin_write(self, *, content_type: str | None = None) -> io.BufferedWriter:
|
||
"""
|
||
开始写入文件(PUT 操作)。
|
||
|
||
返回一个可写的文件流,WsgiDAV 将向其中写入请求体数据。
|
||
"""
|
||
_check_readonly(self.environ)
|
||
|
||
# 检查配额
|
||
user = _run_async(_get_user(self._user_id))
|
||
if user:
|
||
content_length = self.environ.get("CONTENT_LENGTH")
|
||
if content_length:
|
||
_check_storage_quota(user, int(content_length))
|
||
|
||
# 获取策略以确定存储路径
|
||
policy = _run_async(_get_policy(self._obj.policy_id))
|
||
if not policy or not policy.server:
|
||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||
|
||
storage_service = LocalStorageService(policy)
|
||
dir_path, storage_name, full_path = _run_async(
|
||
storage_service.generate_file_path(
|
||
user_id=self._user_id,
|
||
original_filename=self._obj.name,
|
||
)
|
||
)
|
||
|
||
self._write_path = full_path
|
||
self._write_stream = open(full_path, "wb") # noqa: SIM115
|
||
return self._write_stream
|
||
|
||
def end_write(self, *, with_errors: bool) -> None:
|
||
"""写入完成后的收尾工作"""
|
||
if self._write_stream:
|
||
self._write_stream.close()
|
||
self._write_stream = None
|
||
|
||
if with_errors or not self._write_path:
|
||
return
|
||
|
||
# 获取文件大小
|
||
file_path = Path(self._write_path)
|
||
if not file_path.exists():
|
||
return
|
||
|
||
size = file_path.stat().st_size
|
||
|
||
# 更新数据库记录
|
||
_run_async(_finalize_upload(
|
||
object_id=self._obj.id,
|
||
physical_path=self._write_path,
|
||
size=size,
|
||
owner_id=self._user_id,
|
||
policy_id=self._obj.policy_id,
|
||
))
|
||
|
||
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
|
||
|
||
def delete(self) -> None:
|
||
"""软删除文件"""
|
||
_check_readonly(self.environ)
|
||
_run_async(_soft_delete_object(self._obj.id))
|
||
|
||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||
"""复制或移动文件"""
|
||
_check_readonly(self.environ)
|
||
|
||
account_root = self._account.root
|
||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||
|
||
# 解析目标父路径和新名称
|
||
if "/" in dest_disknext.rstrip("/"):
|
||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||
else:
|
||
parent_path = "/"
|
||
new_name = dest_disknext.lstrip("/")
|
||
|
||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||
if not dest_parent:
|
||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||
|
||
if is_move:
|
||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||
else:
|
||
_run_async(_copy_object_recursive(
|
||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||
))
|
||
|
||
return True
|
||
|
||
def support_content_length(self) -> bool:
|
||
return True
|
||
|
||
def get_etag(self) -> str | None:
|
||
"""返回 ETag(基于ID和更新时间),WsgiDAV 会自动加双引号"""
|
||
if self._obj.updated_at:
|
||
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
|
||
return None
|
||
|
||
def support_etag(self) -> bool:
|
||
return True
|
||
|
||
def support_ranges(self) -> bool:
|
||
return True
|