feat: migrate ORM base to sqlmodel-ext, add file viewers and WOPI integration

- Migrate SQLModel base classes, mixins, and database management to
  external sqlmodel-ext package; remove sqlmodels/base/, sqlmodels/mixin/,
  and sqlmodels/database.py
- Add file viewer/editor system with WOPI protocol support for
  collaborative editing (OnlyOffice, Collabora)
- Add enterprise edition license verification module (ee/)
- Add Dockerfile multi-stage build with Cython compilation support
- Add new dependencies: sqlmodel-ext, cryptography, whatthepatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 14:23:17 +08:00
parent 53b757de7a
commit eac0766e79
74 changed files with 4819 additions and 4837 deletions

View File

@@ -0,0 +1,253 @@
"""
管理员文件应用管理集成测试
测试管理员 CRUD、扩展名更新、用户组权限更新和权限校验。
"""
from uuid import UUID, uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.file_app import FileApp, FileAppExtension, FileAppType
from sqlmodels.group import Group
from sqlmodels.user import User
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def setup_admin_app(
initialized_db: AsyncSession,
) -> dict[str, UUID]:
"""创建测试用管理员文件应用"""
app = FileApp(
name="管理员测试应用",
app_key="admin_test_app",
type=FileAppType.BUILTIN,
is_enabled=True,
)
app = await app.save(initialized_db)
ext = FileAppExtension(app_id=app.id, extension="test", priority=0)
await ext.save(initialized_db)
return {"app_id": app.id}
# ==================== Admin CRUD ====================
class TestAdminFileAppCRUD:
"""管理员文件应用 CRUD 测试"""
@pytest.mark.asyncio
async def test_create_file_app(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
) -> None:
"""管理员创建文件应用"""
response = await async_client.post(
"/api/v1/admin/file-app/",
headers=admin_headers,
json={
"name": "新建应用",
"app_key": "new_app",
"type": "builtin",
"description": "测试新建",
"extensions": ["pdf", "txt"],
"allowed_group_ids": [],
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "新建应用"
assert data["app_key"] == "new_app"
assert "pdf" in data["extensions"]
assert "txt" in data["extensions"]
@pytest.mark.asyncio
async def test_create_duplicate_app_key(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""创建重复 app_key 返回 409"""
response = await async_client.post(
"/api/v1/admin/file-app/",
headers=admin_headers,
json={
"name": "重复应用",
"app_key": "admin_test_app",
"type": "builtin",
},
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_list_file_apps(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""管理员列出文件应用"""
response = await async_client.get(
"/api/v1/admin/file-app/list",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert "apps" in data
assert data["total"] >= 1
@pytest.mark.asyncio
async def test_get_file_app_detail(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""管理员获取应用详情"""
app_id = setup_admin_app["app_id"]
response = await async_client.get(
f"/api/v1/admin/file-app/{app_id}",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["app_key"] == "admin_test_app"
assert "test" in data["extensions"]
@pytest.mark.asyncio
async def test_get_nonexistent_app(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
) -> None:
"""获取不存在的应用返回 404"""
response = await async_client.get(
f"/api/v1/admin/file-app/{uuid4()}",
headers=admin_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_file_app(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""管理员更新应用"""
app_id = setup_admin_app["app_id"]
response = await async_client.patch(
f"/api/v1/admin/file-app/{app_id}",
headers=admin_headers,
json={
"name": "更新后的名称",
"is_enabled": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "更新后的名称"
assert data["is_enabled"] is False
@pytest.mark.asyncio
async def test_delete_file_app(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
admin_headers: dict[str, str],
) -> None:
"""管理员删除应用"""
# 先创建一个应用
app = FileApp(
name="待删除应用", app_key="to_delete_admin", type=FileAppType.BUILTIN
)
app = await app.save(initialized_db)
app_id = app.id
response = await async_client.delete(
f"/api/v1/admin/file-app/{app_id}",
headers=admin_headers,
)
assert response.status_code == 204
# 确认已删除
found = await FileApp.get(initialized_db, FileApp.id == app_id)
assert found is None
# ==================== Extensions Management ====================
class TestAdminExtensionManagement:
"""管理员扩展名管理测试"""
@pytest.mark.asyncio
async def test_update_extensions(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""全量替换扩展名列表"""
app_id = setup_admin_app["app_id"]
response = await async_client.put(
f"/api/v1/admin/file-app/{app_id}/extensions",
headers=admin_headers,
json={"extensions": ["doc", "docx", "odt"]},
)
assert response.status_code == 200
data = response.json()
assert sorted(data["extensions"]) == ["doc", "docx", "odt"]
# ==================== Group Access Management ====================
class TestAdminGroupAccessManagement:
"""管理员用户组权限管理测试"""
@pytest.mark.asyncio
async def test_update_group_access(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""全量替换用户组权限"""
app_id = setup_admin_app["app_id"]
admin_user = await User.get(initialized_db, User.email == "admin@disknext.local")
group_id = admin_user.group_id
response = await async_client.put(
f"/api/v1/admin/file-app/{app_id}/groups",
headers=admin_headers,
json={"group_ids": [str(group_id)]},
)
assert response.status_code == 200
data = response.json()
assert str(group_id) in data["allowed_group_ids"]
# ==================== Permission Tests ====================
class TestAdminPermission:
"""权限校验测试"""
@pytest.mark.asyncio
async def test_non_admin_forbidden(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""普通用户访问管理端点返回 403"""
response = await async_client.get(
"/api/v1/admin/file-app/list",
headers=auth_headers,
)
assert response.status_code == 403

View File

@@ -0,0 +1,466 @@
"""
文本文件内容 GET/PATCH 集成测试
测试 GET /file/content/{file_id} 和 PATCH /file/content/{file_id} 端点。
"""
import hashlib
from pathlib import Path
from uuid import uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels import Object, ObjectType, PhysicalFile, Policy, User
@pytest.fixture(autouse=True)
def _register_sqlite_greatest():
"""注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法"""
def _on_connect(dbapi_connection, connection_record):
if hasattr(dbapi_connection, 'create_function'):
dbapi_connection.create_function("greatest", 2, max)
event.listen(Engine, "connect", _on_connect)
yield
event.remove(Engine, "connect", _on_connect)
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def local_policy(
initialized_db: AsyncSession,
tmp_path: Path,
) -> Policy:
"""创建指向临时目录的本地存储策略"""
from sqlmodels import PolicyType
policy = Policy(
id=uuid4(),
name="测试本地存储",
type=PolicyType.LOCAL,
server=str(tmp_path),
)
initialized_db.add(policy)
await initialized_db.commit()
await initialized_db.refresh(policy)
return policy
@pytest_asyncio.fixture
async def text_file(
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> dict[str, str | int]:
"""创建包含 UTF-8 文本内容的测试文件"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
content = "line1\nline2\nline3\n"
content_bytes = content.encode('utf-8')
content_hash = hashlib.sha256(content_bytes).hexdigest()
file_path = tmp_path / "test.txt"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="test.txt",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"content": content,
"hash": content_hash,
"size": len(content_bytes),
"path": str(file_path),
}
@pytest_asyncio.fixture
async def binary_file(
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> dict[str, str | int]:
"""创建非 UTF-8 的二进制测试文件"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
# 包含无效 UTF-8 字节序列
content_bytes = b'\x80\x81\x82\xff\xfe\xfd'
file_path = tmp_path / "binary.dat"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="binary.dat",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"path": str(file_path),
}
# ==================== GET /file/content/{file_id} ====================
class TestGetFileContent:
"""GET /file/content/{file_id} 端点测试"""
@pytest.mark.asyncio
async def test_get_content_success(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""成功获取文本文件内容和哈希"""
response = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["content"] == text_file["content"]
assert data["hash"] == text_file["hash"]
assert data["size"] == text_file["size"]
@pytest.mark.asyncio
async def test_get_content_non_utf8_returns_400(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
binary_file: dict[str, str | int],
) -> None:
"""非 UTF-8 文件返回 400"""
response = await async_client.get(
f"/api/v1/file/content/{binary_file['id']}",
headers=auth_headers,
)
assert response.status_code == 400
assert "UTF-8" in response.json()["detail"]
@pytest.mark.asyncio
async def test_get_content_not_found(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""文件不存在返回 404"""
fake_id = uuid4()
response = await async_client.get(
f"/api/v1/file/content/{fake_id}",
headers=auth_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_get_content_unauthenticated(
self,
async_client: AsyncClient,
text_file: dict[str, str | int],
) -> None:
"""未认证返回 401"""
response = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_content_normalizes_crlf(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> None:
"""CRLF 换行符被规范化为 LF"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
crlf_content = b"line1\r\nline2\r\n"
file_path = tmp_path / "crlf.txt"
file_path.write_bytes(crlf_content)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(crlf_content),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="crlf.txt",
type=ObjectType.FILE,
size=len(crlf_content),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/file/content/{file_obj.id}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
# 内容应该被规范化为 LF
assert data["content"] == "line1\nline2\n"
# 哈希基于规范化后的内容
expected_hash = hashlib.sha256("line1\nline2\n".encode('utf-8')).hexdigest()
assert data["hash"] == expected_hash
# ==================== PATCH /file/content/{file_id} ====================
class TestPatchFileContent:
"""PATCH /file/content/{file_id} 端点测试"""
@pytest.mark.asyncio
async def test_patch_content_success(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""正常增量保存"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+LINE2_MODIFIED\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert response.status_code == 200
data = response.json()
assert "new_hash" in data
assert "new_size" in data
assert data["new_hash"] != text_file["hash"]
# 验证文件实际被修改
file_path = Path(text_file["path"])
new_content = file_path.read_text(encoding='utf-8')
assert "LINE2_MODIFIED" in new_content
assert "line2" not in new_content
@pytest.mark.asyncio
async def test_patch_content_hash_mismatch_returns_409(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""base_hash 不匹配返回 409"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+changed\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": "0" * 64, # 错误的哈希
},
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_patch_content_invalid_patch_returns_422(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""无效的 patch 格式返回 422"""
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": "this is not a valid patch",
"base_hash": text_file["hash"],
},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_patch_content_context_mismatch_returns_422(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""patch 上下文行不匹配返回 422"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" WRONG_CONTEXT_LINE\n"
"-line2\n"
"+replaced\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_patch_content_unauthenticated(
self,
async_client: AsyncClient,
text_file: dict[str, str | int],
) -> None:
"""未认证返回 401"""
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
json={
"patch": "--- a\n+++ b\n",
"base_hash": text_file["hash"],
},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_patch_content_not_found(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""文件不存在返回 404"""
fake_id = uuid4()
response = await async_client.patch(
f"/api/v1/file/content/{fake_id}",
headers=auth_headers,
json={
"patch": "--- a\n+++ b\n",
"base_hash": "0" * 64,
},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_patch_then_get_consistency(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""PATCH 后 GET 返回一致的内容和哈希"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+PATCHED\n"
" line3\n"
)
# PATCH
patch_resp = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert patch_resp.status_code == 200
patch_data = patch_resp.json()
# GET
get_resp = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
)
assert get_resp.status_code == 200
get_data = get_resp.json()
# 一致性验证
assert get_data["hash"] == patch_data["new_hash"]
assert get_data["size"] == patch_data["new_size"]
assert "PATCHED" in get_data["content"]

View File

@@ -0,0 +1,305 @@
"""
文件查看器集成测试
测试查看器查询、用户默认设置、用户组过滤等端点。
"""
from uuid import UUID
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.file_app import (
FileApp,
FileAppExtension,
FileAppGroupLink,
FileAppType,
UserFileAppDefault,
)
from sqlmodels.user import User
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def setup_file_apps(
initialized_db: AsyncSession,
) -> dict[str, UUID]:
"""创建测试用文件查看器应用"""
# PDF 阅读器(不限制用户组)
pdf_app = FileApp(
name="PDF 阅读器",
app_key="pdfjs",
type=FileAppType.BUILTIN,
is_enabled=True,
is_restricted=False,
)
pdf_app = await pdf_app.save(initialized_db)
# Monaco 编辑器(不限制用户组)
monaco_app = FileApp(
name="代码编辑器",
app_key="monaco",
type=FileAppType.BUILTIN,
is_enabled=True,
is_restricted=False,
)
monaco_app = await monaco_app.save(initialized_db)
# Collabora限制用户组
collabora_app = FileApp(
name="Collabora",
app_key="collabora",
type=FileAppType.WOPI,
is_enabled=True,
is_restricted=True,
)
collabora_app = await collabora_app.save(initialized_db)
# 已禁用的应用
disabled_app = FileApp(
name="禁用的应用",
app_key="disabled_app",
type=FileAppType.BUILTIN,
is_enabled=False,
is_restricted=False,
)
disabled_app = await disabled_app.save(initialized_db)
# 创建扩展名
for ext in ["pdf"]:
await FileAppExtension(app_id=pdf_app.id, extension=ext, priority=0).save(initialized_db)
for ext in ["txt", "md", "json"]:
await FileAppExtension(app_id=monaco_app.id, extension=ext, priority=0).save(initialized_db)
for ext in ["docx", "xlsx", "pptx"]:
await FileAppExtension(app_id=collabora_app.id, extension=ext, priority=0).save(initialized_db)
for ext in ["pdf"]:
await FileAppExtension(app_id=disabled_app.id, extension=ext, priority=10).save(initialized_db)
return {
"pdf_app_id": pdf_app.id,
"monaco_app_id": monaco_app.id,
"collabora_app_id": collabora_app.id,
"disabled_app_id": disabled_app.id,
}
# ==================== GET /file/viewers ====================
class TestGetViewers:
"""查询可用查看器测试"""
@pytest.mark.asyncio
async def test_get_viewers_for_pdf(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""查询 PDF 查看器:返回已启用的,排除已禁用的"""
response = await async_client.get(
"/api/v1/file/viewers?ext=pdf",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "viewers" in data
viewer_keys = [v["app_key"] for v in data["viewers"]]
# pdfjs 应该在列表中
assert "pdfjs" in viewer_keys
# 禁用的应用不应出现
assert "disabled_app" not in viewer_keys
# 默认值应为 None
assert data["default_viewer_id"] is None
@pytest.mark.asyncio
async def test_get_viewers_normalizes_extension(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""扩展名规范化:.PDF → pdf"""
response = await async_client.get(
"/api/v1/file/viewers?ext=.PDF",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert len(data["viewers"]) >= 1
@pytest.mark.asyncio
async def test_get_viewers_empty_for_unknown_ext(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""未知扩展名返回空列表"""
response = await async_client.get(
"/api/v1/file/viewers?ext=xyz_unknown",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["viewers"] == []
@pytest.mark.asyncio
async def test_group_restriction_filters_app(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""用户组限制collabora 限制了用户组,用户不在白名单内则不可见"""
# collabora 是受限的,用户组不在白名单中
response = await async_client.get(
"/api/v1/file/viewers?ext=docx",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
viewer_keys = [v["app_key"] for v in data["viewers"]]
assert "collabora" not in viewer_keys
# 将用户组加入白名单
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
link = FileAppGroupLink(
app_id=setup_file_apps["collabora_app_id"],
group_id=test_user.group_id,
)
initialized_db.add(link)
await initialized_db.commit()
# 再次查询
response = await async_client.get(
"/api/v1/file/viewers?ext=docx",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
viewer_keys = [v["app_key"] for v in data["viewers"]]
assert "collabora" in viewer_keys
@pytest.mark.asyncio
async def test_unauthorized_without_token(
self,
async_client: AsyncClient,
) -> None:
"""未认证请求返回 401"""
response = await async_client.get("/api/v1/file/viewers?ext=pdf")
assert response.status_code in (401, 403)
# ==================== User File Viewer Defaults ====================
class TestUserFileViewerDefaults:
"""用户默认查看器设置测试"""
@pytest.mark.asyncio
async def test_set_default_viewer(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""设置默认查看器"""
response = await async_client.put(
"/api/v1/user/settings/file-viewers/default",
headers=auth_headers,
json={
"extension": "pdf",
"app_id": str(setup_file_apps["pdf_app_id"]),
},
)
assert response.status_code == 200
data = response.json()
assert data["extension"] == "pdf"
assert data["app"]["app_key"] == "pdfjs"
@pytest.mark.asyncio
async def test_list_default_viewers(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""列出默认查看器"""
# 先创建一个默认
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
await UserFileAppDefault(
user_id=test_user.id,
extension="pdf",
app_id=setup_file_apps["pdf_app_id"],
).save(initialized_db)
response = await async_client.get(
"/api/v1/user/settings/file-viewers/defaults",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
@pytest.mark.asyncio
async def test_delete_default_viewer(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""撤销默认查看器"""
# 创建一个默认
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
default = await UserFileAppDefault(
user_id=test_user.id,
extension="txt",
app_id=setup_file_apps["monaco_app_id"],
).save(initialized_db)
response = await async_client.delete(
f"/api/v1/user/settings/file-viewers/default/{default.id}",
headers=auth_headers,
)
assert response.status_code == 204
# 验证已删除
found = await UserFileAppDefault.get(
initialized_db, UserFileAppDefault.id == default.id
)
assert found is None
@pytest.mark.asyncio
async def test_get_viewers_includes_default(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""查看器查询应包含用户默认选择"""
# 设置默认
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
await UserFileAppDefault(
user_id=test_user.id,
extension="pdf",
app_id=setup_file_apps["pdf_app_id"],
).save(initialized_db)
response = await async_client.get(
"/api/v1/file/viewers?ext=pdf",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["default_viewer_id"] == str(setup_file_apps["pdf_app_id"])

View File

@@ -0,0 +1,386 @@
"""
FileApp 模型单元测试
测试 FileApp、FileAppExtension、UserFileAppDefault 的 CRUD 和约束。
"""
from uuid import UUID
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.file_app import (
FileApp,
FileAppExtension,
FileAppGroupLink,
FileAppType,
UserFileAppDefault,
)
from sqlmodels.group import Group
from sqlmodels.user import User, UserStatus
from sqlmodels.policy import Policy, PolicyType
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def sample_group(db_session: AsyncSession) -> Group:
"""创建测试用户组"""
group = Group(name="测试组", max_storage=0, admin=False)
return await group.save(db_session)
@pytest_asyncio.fixture
async def sample_user(db_session: AsyncSession, sample_group: Group) -> User:
"""创建测试用户"""
user = User(
email="fileapp_test@test.local",
nickname="文件应用测试用户",
status=UserStatus.ACTIVE,
group_id=sample_group.id,
)
return await user.save(db_session)
@pytest_asyncio.fixture
async def sample_app(db_session: AsyncSession) -> FileApp:
"""创建测试文件应用"""
app = FileApp(
name="测试PDF阅读器",
app_key="test_pdfjs",
type=FileAppType.BUILTIN,
icon="file-pdf",
description="测试用 PDF 阅读器",
is_enabled=True,
is_restricted=False,
)
return await app.save(db_session)
@pytest_asyncio.fixture
async def sample_app_with_extensions(db_session: AsyncSession, sample_app: FileApp) -> FileApp:
"""创建带扩展名的文件应用"""
ext1 = FileAppExtension(app_id=sample_app.id, extension="pdf", priority=0)
ext2 = FileAppExtension(app_id=sample_app.id, extension="djvu", priority=1)
await ext1.save(db_session)
await ext2.save(db_session)
return sample_app
# ==================== FileApp CRUD ====================
class TestFileAppCRUD:
"""FileApp 基础 CRUD 测试"""
async def test_create_file_app(self, db_session: AsyncSession) -> None:
"""测试创建文件应用"""
app = FileApp(
name="Monaco 编辑器",
app_key="monaco",
type=FileAppType.BUILTIN,
description="代码编辑器",
is_enabled=True,
)
app = await app.save(db_session)
assert app.id is not None
assert app.name == "Monaco 编辑器"
assert app.app_key == "monaco"
assert app.type == FileAppType.BUILTIN
assert app.is_enabled is True
assert app.is_restricted is False
async def test_get_file_app_by_key(self, db_session: AsyncSession, sample_app: FileApp) -> None:
"""测试按 app_key 查询"""
found = await FileApp.get(db_session, FileApp.app_key == "test_pdfjs")
assert found is not None
assert found.id == sample_app.id
async def test_unique_app_key(self, db_session: AsyncSession, sample_app: FileApp) -> None:
"""测试 app_key 唯一约束"""
dup = FileApp(
name="重复应用",
app_key="test_pdfjs",
type=FileAppType.BUILTIN,
)
with pytest.raises(IntegrityError):
await dup.save(db_session)
async def test_update_file_app(self, db_session: AsyncSession, sample_app: FileApp) -> None:
"""测试更新文件应用"""
sample_app.name = "更新后的名称"
sample_app.is_enabled = False
sample_app = await sample_app.save(db_session)
found = await FileApp.get(db_session, FileApp.id == sample_app.id)
assert found.name == "更新后的名称"
assert found.is_enabled is False
async def test_delete_file_app(self, db_session: AsyncSession) -> None:
"""测试删除文件应用"""
app = FileApp(
name="待删除应用",
app_key="to_delete",
type=FileAppType.IFRAME,
)
app = await app.save(db_session)
app_id = app.id
await FileApp.delete(db_session, app)
found = await FileApp.get(db_session, FileApp.id == app_id)
assert found is None
async def test_create_wopi_app(self, db_session: AsyncSession) -> None:
"""测试创建 WOPI 类型应用"""
app = FileApp(
name="Collabora",
app_key="collabora",
type=FileAppType.WOPI,
wopi_discovery_url="http://collabora:9980/hosting/discovery",
wopi_editor_url_template="http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc={wopi_src}&access_token={access_token}",
is_enabled=True,
)
app = await app.save(db_session)
assert app.type == FileAppType.WOPI
assert app.wopi_discovery_url is not None
assert app.wopi_editor_url_template is not None
async def test_create_iframe_app(self, db_session: AsyncSession) -> None:
"""测试创建 iframe 类型应用"""
app = FileApp(
name="Office 在线预览",
app_key="office_viewer",
type=FileAppType.IFRAME,
iframe_url_template="https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
is_enabled=False,
)
app = await app.save(db_session)
assert app.type == FileAppType.IFRAME
assert "{file_url}" in app.iframe_url_template
async def test_to_summary(self, db_session: AsyncSession, sample_app: FileApp) -> None:
"""测试转换为摘要 DTO"""
summary = sample_app.to_summary()
assert summary.id == sample_app.id
assert summary.name == sample_app.name
assert summary.app_key == sample_app.app_key
assert summary.type == sample_app.type
# ==================== FileAppExtension ====================
class TestFileAppExtension:
"""FileAppExtension 测试"""
async def test_create_extension(self, db_session: AsyncSession, sample_app: FileApp) -> None:
"""测试创建扩展名关联"""
ext = FileAppExtension(
app_id=sample_app.id,
extension="pdf",
priority=0,
)
ext = await ext.save(db_session)
assert ext.id is not None
assert ext.extension == "pdf"
assert ext.priority == 0
async def test_query_by_extension(
self, db_session: AsyncSession, sample_app_with_extensions: FileApp
) -> None:
"""测试按扩展名查询"""
results: list[FileAppExtension] = await FileAppExtension.get(
db_session,
FileAppExtension.extension == "pdf",
fetch_mode="all",
)
assert len(results) >= 1
assert any(r.app_id == sample_app_with_extensions.id for r in results)
async def test_unique_app_extension(self, db_session: AsyncSession, sample_app: FileApp) -> None:
"""测试 (app_id, extension) 唯一约束"""
ext1 = FileAppExtension(app_id=sample_app.id, extension="txt", priority=0)
await ext1.save(db_session)
ext2 = FileAppExtension(app_id=sample_app.id, extension="txt", priority=1)
with pytest.raises(IntegrityError):
await ext2.save(db_session)
async def test_cascade_delete(
self, db_session: AsyncSession, sample_app_with_extensions: FileApp
) -> None:
"""测试级联删除:删除应用时扩展名也被删除"""
app_id = sample_app_with_extensions.id
# 确认扩展名存在
exts = await FileAppExtension.get(
db_session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
assert len(exts) == 2
# 删除应用
await FileApp.delete(db_session, sample_app_with_extensions)
# 确认扩展名也被删除
exts = await FileAppExtension.get(
db_session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
assert len(exts) == 0
# ==================== FileAppGroupLink ====================
class TestFileAppGroupLink:
"""FileAppGroupLink 用户组访问控制测试"""
async def test_create_group_link(
self, db_session: AsyncSession, sample_app: FileApp, sample_group: Group
) -> None:
"""测试创建用户组关联"""
link = FileAppGroupLink(app_id=sample_app.id, group_id=sample_group.id)
db_session.add(link)
await db_session.commit()
result = await db_session.exec(
select(FileAppGroupLink).where(
FileAppGroupLink.app_id == sample_app.id,
FileAppGroupLink.group_id == sample_group.id,
)
)
found = result.first()
assert found is not None
async def test_multiple_groups(self, db_session: AsyncSession, sample_app: FileApp) -> None:
"""测试一个应用关联多个用户组"""
group1 = Group(name="组A", admin=False)
group1 = await group1.save(db_session)
group2 = Group(name="组B", admin=False)
group2 = await group2.save(db_session)
db_session.add(FileAppGroupLink(app_id=sample_app.id, group_id=group1.id))
db_session.add(FileAppGroupLink(app_id=sample_app.id, group_id=group2.id))
await db_session.commit()
result = await db_session.exec(
select(FileAppGroupLink).where(FileAppGroupLink.app_id == sample_app.id)
)
links = result.all()
assert len(links) == 2
# ==================== UserFileAppDefault ====================
class TestUserFileAppDefault:
"""UserFileAppDefault 用户偏好测试"""
async def test_create_default(
self, db_session: AsyncSession, sample_app: FileApp, sample_user: User
) -> None:
"""测试创建用户默认偏好"""
default = UserFileAppDefault(
user_id=sample_user.id,
extension="pdf",
app_id=sample_app.id,
)
default = await default.save(db_session)
assert default.id is not None
assert default.extension == "pdf"
async def test_unique_user_extension(
self, db_session: AsyncSession, sample_app: FileApp, sample_user: User
) -> None:
"""测试 (user_id, extension) 唯一约束"""
default1 = UserFileAppDefault(
user_id=sample_user.id, extension="pdf", app_id=sample_app.id
)
await default1.save(db_session)
# 创建另一个应用
app2 = FileApp(
name="另一个阅读器",
app_key="pdf_alt",
type=FileAppType.BUILTIN,
)
app2 = await app2.save(db_session)
default2 = UserFileAppDefault(
user_id=sample_user.id, extension="pdf", app_id=app2.id
)
with pytest.raises(IntegrityError):
await default2.save(db_session)
async def test_cascade_delete_on_app(
self, db_session: AsyncSession, sample_user: User
) -> None:
"""测试级联删除:删除应用时用户偏好也被删除"""
app = FileApp(
name="待删除应用2",
app_key="to_delete_2",
type=FileAppType.BUILTIN,
)
app = await app.save(db_session)
app_id = app.id
default = UserFileAppDefault(
user_id=sample_user.id, extension="xyz", app_id=app_id
)
await default.save(db_session)
# 确认存在
found = await UserFileAppDefault.get(
db_session, UserFileAppDefault.app_id == app_id
)
assert found is not None
# 删除应用
await FileApp.delete(db_session, app)
# 确认用户偏好也被删除
found = await UserFileAppDefault.get(
db_session, UserFileAppDefault.app_id == app_id
)
assert found is None
# ==================== DTO ====================
class TestFileAppDTO:
"""DTO 模型测试"""
async def test_file_app_response_from_app(
self, db_session: AsyncSession, sample_app_with_extensions: FileApp, sample_group: Group
) -> None:
"""测试 FileAppResponse.from_app()"""
from sqlmodels.file_app import FileAppResponse
extensions = await FileAppExtension.get(
db_session,
FileAppExtension.app_id == sample_app_with_extensions.id,
fetch_mode="all",
)
# 直接构造 link 对象用于 DTO 测试,无需持久化
link = FileAppGroupLink(
app_id=sample_app_with_extensions.id,
group_id=sample_group.id,
)
response = FileAppResponse.from_app(
sample_app_with_extensions, extensions, [link]
)
assert response.id == sample_app_with_extensions.id
assert response.app_key == "test_pdfjs"
assert "pdf" in response.extensions
assert "djvu" in response.extensions
assert sample_group.id in response.allowed_group_ids

View File

@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
setting = await setting.save(db_session)
# 更新值
from sqlmodels.base import SQLModelBase
from sqlmodel_ext import SQLModelBase
class SettingUpdate(SQLModelBase):
value: str | None = None

View File

@@ -0,0 +1,178 @@
"""
文本文件 patch 逻辑单元测试
测试 whatthepatch 库的 patch 解析与应用,
以及换行符规范化和 SHA-256 哈希计算。
"""
import hashlib
import pytest
import whatthepatch
from whatthepatch.exceptions import HunkApplyException
class TestPatchApply:
"""测试 patch 解析与应用"""
def test_normal_patch(self) -> None:
"""正常 patch 应用"""
original = "line1\nline2\nline3"
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+LINE2\n"
" line3\n"
)
diffs = list(whatthepatch.parse_patch(patch_text))
assert len(diffs) == 1
result = whatthepatch.apply_diff(diffs[0], original)
new_text = '\n'.join(result)
assert "LINE2" in new_text
assert "line2" not in new_text
def test_add_lines_patch(self) -> None:
"""添加行的 patch"""
original = "line1\nline2"
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,2 +1,3 @@\n"
" line1\n"
" line2\n"
"+line3\n"
)
diffs = list(whatthepatch.parse_patch(patch_text))
result = whatthepatch.apply_diff(diffs[0], original)
new_text = '\n'.join(result)
assert "line3" in new_text
def test_delete_lines_patch(self) -> None:
"""删除行的 patch"""
original = "line1\nline2\nline3"
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,2 @@\n"
" line1\n"
"-line2\n"
" line3\n"
)
diffs = list(whatthepatch.parse_patch(patch_text))
result = whatthepatch.apply_diff(diffs[0], original)
new_text = '\n'.join(result)
assert "line2" not in new_text
assert "line1" in new_text
assert "line3" in new_text
def test_invalid_patch_format(self) -> None:
"""无效的 patch 格式返回空列表"""
diffs = list(whatthepatch.parse_patch("this is not a patch"))
assert len(diffs) == 0
def test_patch_context_mismatch(self) -> None:
"""patch 上下文不匹配时抛出异常"""
original = "line1\nline2\nline3\n"
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-WRONG\n"
"+REPLACED\n"
" line3\n"
)
diffs = list(whatthepatch.parse_patch(patch_text))
with pytest.raises(HunkApplyException):
whatthepatch.apply_diff(diffs[0], original)
def test_empty_file_patch(self) -> None:
"""空文件应用 patch"""
original = ""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -0,0 +1,2 @@\n"
"+line1\n"
"+line2\n"
)
diffs = list(whatthepatch.parse_patch(patch_text))
result = whatthepatch.apply_diff(diffs[0], original)
new_text = '\n'.join(result)
assert "line1" in new_text
assert "line2" in new_text
class TestHashComputation:
"""测试 SHA-256 哈希计算"""
def test_hash_consistency(self) -> None:
"""相同内容产生相同哈希"""
content = "hello world\n"
content_bytes = content.encode('utf-8')
hash1 = hashlib.sha256(content_bytes).hexdigest()
hash2 = hashlib.sha256(content_bytes).hexdigest()
assert hash1 == hash2
assert len(hash1) == 64
def test_hash_differs_for_different_content(self) -> None:
"""不同内容产生不同哈希"""
hash1 = hashlib.sha256(b"content A").hexdigest()
hash2 = hashlib.sha256(b"content B").hexdigest()
assert hash1 != hash2
def test_hash_after_normalization(self) -> None:
"""换行符规范化后的哈希一致性"""
content_crlf = "line1\r\nline2\r\n"
content_lf = "line1\nline2\n"
# 规范化后应相同
normalized = content_crlf.replace('\r\n', '\n').replace('\r', '\n')
assert normalized == content_lf
hash_normalized = hashlib.sha256(normalized.encode('utf-8')).hexdigest()
hash_lf = hashlib.sha256(content_lf.encode('utf-8')).hexdigest()
assert hash_normalized == hash_lf
class TestLineEndingNormalization:
"""测试换行符规范化"""
def test_crlf_to_lf(self) -> None:
"""CRLF 转换为 LF"""
content = "line1\r\nline2\r\n"
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
assert normalized == "line1\nline2\n"
def test_cr_to_lf(self) -> None:
"""CR 转换为 LF"""
content = "line1\rline2\r"
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
assert normalized == "line1\nline2\n"
def test_lf_unchanged(self) -> None:
"""LF 保持不变"""
content = "line1\nline2\n"
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
assert normalized == content
def test_mixed_line_endings(self) -> None:
"""混合换行符统一为 LF"""
content = "line1\r\nline2\rline3\n"
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
assert normalized == "line1\nline2\nline3\n"

View File

@@ -0,0 +1,77 @@
"""
WOPI Token 单元测试
测试 WOPI 访问令牌的生成和验证。
"""
from uuid import uuid4
import pytest
import utils.JWT as JWT
from utils.JWT.wopi_token import create_wopi_token, verify_wopi_token
# 确保测试 secret key
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
class TestWopiToken:
"""WOPI Token 测试"""
def test_create_and_verify_token(self) -> None:
"""创建和验证令牌"""
file_id = uuid4()
user_id = uuid4()
token, ttl = create_wopi_token(file_id, user_id, can_write=True)
assert isinstance(token, str)
assert isinstance(ttl, int)
assert ttl > 0
payload = verify_wopi_token(token)
assert payload is not None
assert payload.file_id == file_id
assert payload.user_id == user_id
assert payload.can_write is True
def test_verify_read_only_token(self) -> None:
"""验证只读令牌"""
file_id = uuid4()
user_id = uuid4()
token, ttl = create_wopi_token(file_id, user_id, can_write=False)
payload = verify_wopi_token(token)
assert payload is not None
assert payload.can_write is False
def test_verify_invalid_token(self) -> None:
"""验证无效令牌返回 None"""
payload = verify_wopi_token("invalid_token_string")
assert payload is None
def test_verify_non_wopi_token(self) -> None:
"""验证非 WOPI 类型令牌返回 None"""
import jwt as pyjwt
# 创建一个不含 type=wopi 的令牌
token = pyjwt.encode(
{"file_id": str(uuid4()), "user_id": str(uuid4()), "type": "download"},
JWT.SECRET_KEY,
algorithm="HS256",
)
payload = verify_wopi_token(token)
assert payload is None
def test_ttl_is_future_milliseconds(self) -> None:
"""TTL 应为未来的毫秒时间戳"""
import time
file_id = uuid4()
user_id = uuid4()
token, ttl = create_wopi_token(file_id, user_id)
current_ms = int(time.time() * 1000)
# TTL 应大于当前时间
assert ttl > current_ms
# TTL 不应超过 11 小时后10h + 余量)
assert ttl < current_ms + 11 * 3600 * 1000