feat: migrate ORM base to sqlmodel-ext, add file viewers and WOPI integration

- Migrate SQLModel base classes, mixins, and database management to
  external sqlmodel-ext package; remove sqlmodels/base/, sqlmodels/mixin/,
  and sqlmodels/database.py
- Add file viewer/editor system with WOPI protocol support for
  collaborative editing (OnlyOffice, Collabora)
- Add enterprise edition license verification module (ee/)
- Add Dockerfile multi-stage build with Cython compilation support
- Add new dependencies: sqlmodel-ext, cryptography, whatthepatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 14:23:17 +08:00
parent 53b757de7a
commit eac0766e79
74 changed files with 4819 additions and 4837 deletions

View File

@@ -0,0 +1,253 @@
"""
管理员文件应用管理集成测试
测试管理员 CRUD、扩展名更新、用户组权限更新和权限校验。
"""
from uuid import UUID, uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.file_app import FileApp, FileAppExtension, FileAppType
from sqlmodels.group import Group
from sqlmodels.user import User
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def setup_admin_app(
initialized_db: AsyncSession,
) -> dict[str, UUID]:
"""创建测试用管理员文件应用"""
app = FileApp(
name="管理员测试应用",
app_key="admin_test_app",
type=FileAppType.BUILTIN,
is_enabled=True,
)
app = await app.save(initialized_db)
ext = FileAppExtension(app_id=app.id, extension="test", priority=0)
await ext.save(initialized_db)
return {"app_id": app.id}
# ==================== Admin CRUD ====================
class TestAdminFileAppCRUD:
"""管理员文件应用 CRUD 测试"""
@pytest.mark.asyncio
async def test_create_file_app(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
) -> None:
"""管理员创建文件应用"""
response = await async_client.post(
"/api/v1/admin/file-app/",
headers=admin_headers,
json={
"name": "新建应用",
"app_key": "new_app",
"type": "builtin",
"description": "测试新建",
"extensions": ["pdf", "txt"],
"allowed_group_ids": [],
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "新建应用"
assert data["app_key"] == "new_app"
assert "pdf" in data["extensions"]
assert "txt" in data["extensions"]
@pytest.mark.asyncio
async def test_create_duplicate_app_key(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""创建重复 app_key 返回 409"""
response = await async_client.post(
"/api/v1/admin/file-app/",
headers=admin_headers,
json={
"name": "重复应用",
"app_key": "admin_test_app",
"type": "builtin",
},
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_list_file_apps(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""管理员列出文件应用"""
response = await async_client.get(
"/api/v1/admin/file-app/list",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert "apps" in data
assert data["total"] >= 1
@pytest.mark.asyncio
async def test_get_file_app_detail(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""管理员获取应用详情"""
app_id = setup_admin_app["app_id"]
response = await async_client.get(
f"/api/v1/admin/file-app/{app_id}",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["app_key"] == "admin_test_app"
assert "test" in data["extensions"]
@pytest.mark.asyncio
async def test_get_nonexistent_app(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
) -> None:
"""获取不存在的应用返回 404"""
response = await async_client.get(
f"/api/v1/admin/file-app/{uuid4()}",
headers=admin_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_file_app(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""管理员更新应用"""
app_id = setup_admin_app["app_id"]
response = await async_client.patch(
f"/api/v1/admin/file-app/{app_id}",
headers=admin_headers,
json={
"name": "更新后的名称",
"is_enabled": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "更新后的名称"
assert data["is_enabled"] is False
@pytest.mark.asyncio
async def test_delete_file_app(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
admin_headers: dict[str, str],
) -> None:
"""管理员删除应用"""
# 先创建一个应用
app = FileApp(
name="待删除应用", app_key="to_delete_admin", type=FileAppType.BUILTIN
)
app = await app.save(initialized_db)
app_id = app.id
response = await async_client.delete(
f"/api/v1/admin/file-app/{app_id}",
headers=admin_headers,
)
assert response.status_code == 204
# 确认已删除
found = await FileApp.get(initialized_db, FileApp.id == app_id)
assert found is None
# ==================== Extensions Management ====================
class TestAdminExtensionManagement:
"""管理员扩展名管理测试"""
@pytest.mark.asyncio
async def test_update_extensions(
self,
async_client: AsyncClient,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""全量替换扩展名列表"""
app_id = setup_admin_app["app_id"]
response = await async_client.put(
f"/api/v1/admin/file-app/{app_id}/extensions",
headers=admin_headers,
json={"extensions": ["doc", "docx", "odt"]},
)
assert response.status_code == 200
data = response.json()
assert sorted(data["extensions"]) == ["doc", "docx", "odt"]
# ==================== Group Access Management ====================
class TestAdminGroupAccessManagement:
"""管理员用户组权限管理测试"""
@pytest.mark.asyncio
async def test_update_group_access(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
admin_headers: dict[str, str],
setup_admin_app: dict[str, UUID],
) -> None:
"""全量替换用户组权限"""
app_id = setup_admin_app["app_id"]
admin_user = await User.get(initialized_db, User.email == "admin@disknext.local")
group_id = admin_user.group_id
response = await async_client.put(
f"/api/v1/admin/file-app/{app_id}/groups",
headers=admin_headers,
json={"group_ids": [str(group_id)]},
)
assert response.status_code == 200
data = response.json()
assert str(group_id) in data["allowed_group_ids"]
# ==================== Permission Tests ====================
class TestAdminPermission:
"""权限校验测试"""
@pytest.mark.asyncio
async def test_non_admin_forbidden(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""普通用户访问管理端点返回 403"""
response = await async_client.get(
"/api/v1/admin/file-app/list",
headers=auth_headers,
)
assert response.status_code == 403

View File

@@ -0,0 +1,466 @@
"""
文本文件内容 GET/PATCH 集成测试
测试 GET /file/content/{file_id} 和 PATCH /file/content/{file_id} 端点。
"""
import hashlib
from pathlib import Path
from uuid import uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels import Object, ObjectType, PhysicalFile, Policy, User
@pytest.fixture(autouse=True)
def _register_sqlite_greatest():
"""注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法"""
def _on_connect(dbapi_connection, connection_record):
if hasattr(dbapi_connection, 'create_function'):
dbapi_connection.create_function("greatest", 2, max)
event.listen(Engine, "connect", _on_connect)
yield
event.remove(Engine, "connect", _on_connect)
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def local_policy(
initialized_db: AsyncSession,
tmp_path: Path,
) -> Policy:
"""创建指向临时目录的本地存储策略"""
from sqlmodels import PolicyType
policy = Policy(
id=uuid4(),
name="测试本地存储",
type=PolicyType.LOCAL,
server=str(tmp_path),
)
initialized_db.add(policy)
await initialized_db.commit()
await initialized_db.refresh(policy)
return policy
@pytest_asyncio.fixture
async def text_file(
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> dict[str, str | int]:
"""创建包含 UTF-8 文本内容的测试文件"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
content = "line1\nline2\nline3\n"
content_bytes = content.encode('utf-8')
content_hash = hashlib.sha256(content_bytes).hexdigest()
file_path = tmp_path / "test.txt"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="test.txt",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"content": content,
"hash": content_hash,
"size": len(content_bytes),
"path": str(file_path),
}
@pytest_asyncio.fixture
async def binary_file(
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> dict[str, str | int]:
"""创建非 UTF-8 的二进制测试文件"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
# 包含无效 UTF-8 字节序列
content_bytes = b'\x80\x81\x82\xff\xfe\xfd'
file_path = tmp_path / "binary.dat"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="binary.dat",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"path": str(file_path),
}
# ==================== GET /file/content/{file_id} ====================
class TestGetFileContent:
"""GET /file/content/{file_id} 端点测试"""
@pytest.mark.asyncio
async def test_get_content_success(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""成功获取文本文件内容和哈希"""
response = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["content"] == text_file["content"]
assert data["hash"] == text_file["hash"]
assert data["size"] == text_file["size"]
@pytest.mark.asyncio
async def test_get_content_non_utf8_returns_400(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
binary_file: dict[str, str | int],
) -> None:
"""非 UTF-8 文件返回 400"""
response = await async_client.get(
f"/api/v1/file/content/{binary_file['id']}",
headers=auth_headers,
)
assert response.status_code == 400
assert "UTF-8" in response.json()["detail"]
@pytest.mark.asyncio
async def test_get_content_not_found(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""文件不存在返回 404"""
fake_id = uuid4()
response = await async_client.get(
f"/api/v1/file/content/{fake_id}",
headers=auth_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_get_content_unauthenticated(
self,
async_client: AsyncClient,
text_file: dict[str, str | int],
) -> None:
"""未认证返回 401"""
response = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_content_normalizes_crlf(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> None:
"""CRLF 换行符被规范化为 LF"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
crlf_content = b"line1\r\nline2\r\n"
file_path = tmp_path / "crlf.txt"
file_path.write_bytes(crlf_content)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(crlf_content),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="crlf.txt",
type=ObjectType.FILE,
size=len(crlf_content),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/file/content/{file_obj.id}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
# 内容应该被规范化为 LF
assert data["content"] == "line1\nline2\n"
# 哈希基于规范化后的内容
expected_hash = hashlib.sha256("line1\nline2\n".encode('utf-8')).hexdigest()
assert data["hash"] == expected_hash
# ==================== PATCH /file/content/{file_id} ====================
class TestPatchFileContent:
"""PATCH /file/content/{file_id} 端点测试"""
@pytest.mark.asyncio
async def test_patch_content_success(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""正常增量保存"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+LINE2_MODIFIED\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert response.status_code == 200
data = response.json()
assert "new_hash" in data
assert "new_size" in data
assert data["new_hash"] != text_file["hash"]
# 验证文件实际被修改
file_path = Path(text_file["path"])
new_content = file_path.read_text(encoding='utf-8')
assert "LINE2_MODIFIED" in new_content
assert "line2" not in new_content
@pytest.mark.asyncio
async def test_patch_content_hash_mismatch_returns_409(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""base_hash 不匹配返回 409"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+changed\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": "0" * 64, # 错误的哈希
},
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_patch_content_invalid_patch_returns_422(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""无效的 patch 格式返回 422"""
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": "this is not a valid patch",
"base_hash": text_file["hash"],
},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_patch_content_context_mismatch_returns_422(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""patch 上下文行不匹配返回 422"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" WRONG_CONTEXT_LINE\n"
"-line2\n"
"+replaced\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_patch_content_unauthenticated(
self,
async_client: AsyncClient,
text_file: dict[str, str | int],
) -> None:
"""未认证返回 401"""
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
json={
"patch": "--- a\n+++ b\n",
"base_hash": text_file["hash"],
},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_patch_content_not_found(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""文件不存在返回 404"""
fake_id = uuid4()
response = await async_client.patch(
f"/api/v1/file/content/{fake_id}",
headers=auth_headers,
json={
"patch": "--- a\n+++ b\n",
"base_hash": "0" * 64,
},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_patch_then_get_consistency(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""PATCH 后 GET 返回一致的内容和哈希"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+PATCHED\n"
" line3\n"
)
# PATCH
patch_resp = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert patch_resp.status_code == 200
patch_data = patch_resp.json()
# GET
get_resp = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
)
assert get_resp.status_code == 200
get_data = get_resp.json()
# 一致性验证
assert get_data["hash"] == patch_data["new_hash"]
assert get_data["size"] == patch_data["new_size"]
assert "PATCHED" in get_data["content"]

View File

@@ -0,0 +1,305 @@
"""
文件查看器集成测试
测试查看器查询、用户默认设置、用户组过滤等端点。
"""
from uuid import UUID
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.file_app import (
FileApp,
FileAppExtension,
FileAppGroupLink,
FileAppType,
UserFileAppDefault,
)
from sqlmodels.user import User
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def setup_file_apps(
initialized_db: AsyncSession,
) -> dict[str, UUID]:
"""创建测试用文件查看器应用"""
# PDF 阅读器(不限制用户组)
pdf_app = FileApp(
name="PDF 阅读器",
app_key="pdfjs",
type=FileAppType.BUILTIN,
is_enabled=True,
is_restricted=False,
)
pdf_app = await pdf_app.save(initialized_db)
# Monaco 编辑器(不限制用户组)
monaco_app = FileApp(
name="代码编辑器",
app_key="monaco",
type=FileAppType.BUILTIN,
is_enabled=True,
is_restricted=False,
)
monaco_app = await monaco_app.save(initialized_db)
# Collabora限制用户组
collabora_app = FileApp(
name="Collabora",
app_key="collabora",
type=FileAppType.WOPI,
is_enabled=True,
is_restricted=True,
)
collabora_app = await collabora_app.save(initialized_db)
# 已禁用的应用
disabled_app = FileApp(
name="禁用的应用",
app_key="disabled_app",
type=FileAppType.BUILTIN,
is_enabled=False,
is_restricted=False,
)
disabled_app = await disabled_app.save(initialized_db)
# 创建扩展名
for ext in ["pdf"]:
await FileAppExtension(app_id=pdf_app.id, extension=ext, priority=0).save(initialized_db)
for ext in ["txt", "md", "json"]:
await FileAppExtension(app_id=monaco_app.id, extension=ext, priority=0).save(initialized_db)
for ext in ["docx", "xlsx", "pptx"]:
await FileAppExtension(app_id=collabora_app.id, extension=ext, priority=0).save(initialized_db)
for ext in ["pdf"]:
await FileAppExtension(app_id=disabled_app.id, extension=ext, priority=10).save(initialized_db)
return {
"pdf_app_id": pdf_app.id,
"monaco_app_id": monaco_app.id,
"collabora_app_id": collabora_app.id,
"disabled_app_id": disabled_app.id,
}
# ==================== GET /file/viewers ====================
class TestGetViewers:
"""查询可用查看器测试"""
@pytest.mark.asyncio
async def test_get_viewers_for_pdf(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""查询 PDF 查看器:返回已启用的,排除已禁用的"""
response = await async_client.get(
"/api/v1/file/viewers?ext=pdf",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "viewers" in data
viewer_keys = [v["app_key"] for v in data["viewers"]]
# pdfjs 应该在列表中
assert "pdfjs" in viewer_keys
# 禁用的应用不应出现
assert "disabled_app" not in viewer_keys
# 默认值应为 None
assert data["default_viewer_id"] is None
@pytest.mark.asyncio
async def test_get_viewers_normalizes_extension(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""扩展名规范化:.PDF → pdf"""
response = await async_client.get(
"/api/v1/file/viewers?ext=.PDF",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert len(data["viewers"]) >= 1
@pytest.mark.asyncio
async def test_get_viewers_empty_for_unknown_ext(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""未知扩展名返回空列表"""
response = await async_client.get(
"/api/v1/file/viewers?ext=xyz_unknown",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["viewers"] == []
@pytest.mark.asyncio
async def test_group_restriction_filters_app(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""用户组限制collabora 限制了用户组,用户不在白名单内则不可见"""
# collabora 是受限的,用户组不在白名单中
response = await async_client.get(
"/api/v1/file/viewers?ext=docx",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
viewer_keys = [v["app_key"] for v in data["viewers"]]
assert "collabora" not in viewer_keys
# 将用户组加入白名单
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
link = FileAppGroupLink(
app_id=setup_file_apps["collabora_app_id"],
group_id=test_user.group_id,
)
initialized_db.add(link)
await initialized_db.commit()
# 再次查询
response = await async_client.get(
"/api/v1/file/viewers?ext=docx",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
viewer_keys = [v["app_key"] for v in data["viewers"]]
assert "collabora" in viewer_keys
@pytest.mark.asyncio
async def test_unauthorized_without_token(
self,
async_client: AsyncClient,
) -> None:
"""未认证请求返回 401"""
response = await async_client.get("/api/v1/file/viewers?ext=pdf")
assert response.status_code in (401, 403)
# ==================== User File Viewer Defaults ====================
class TestUserFileViewerDefaults:
"""用户默认查看器设置测试"""
@pytest.mark.asyncio
async def test_set_default_viewer(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""设置默认查看器"""
response = await async_client.put(
"/api/v1/user/settings/file-viewers/default",
headers=auth_headers,
json={
"extension": "pdf",
"app_id": str(setup_file_apps["pdf_app_id"]),
},
)
assert response.status_code == 200
data = response.json()
assert data["extension"] == "pdf"
assert data["app"]["app_key"] == "pdfjs"
@pytest.mark.asyncio
async def test_list_default_viewers(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""列出默认查看器"""
# 先创建一个默认
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
await UserFileAppDefault(
user_id=test_user.id,
extension="pdf",
app_id=setup_file_apps["pdf_app_id"],
).save(initialized_db)
response = await async_client.get(
"/api/v1/user/settings/file-viewers/defaults",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
@pytest.mark.asyncio
async def test_delete_default_viewer(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""撤销默认查看器"""
# 创建一个默认
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
default = await UserFileAppDefault(
user_id=test_user.id,
extension="txt",
app_id=setup_file_apps["monaco_app_id"],
).save(initialized_db)
response = await async_client.delete(
f"/api/v1/user/settings/file-viewers/default/{default.id}",
headers=auth_headers,
)
assert response.status_code == 204
# 验证已删除
found = await UserFileAppDefault.get(
initialized_db, UserFileAppDefault.id == default.id
)
assert found is None
@pytest.mark.asyncio
async def test_get_viewers_includes_default(
self,
async_client: AsyncClient,
initialized_db: AsyncSession,
auth_headers: dict[str, str],
setup_file_apps: dict[str, UUID],
) -> None:
"""查看器查询应包含用户默认选择"""
# 设置默认
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
await UserFileAppDefault(
user_id=test_user.id,
extension="pdf",
app_id=setup_file_apps["pdf_app_id"],
).save(initialized_db)
response = await async_client.get(
"/api/v1/file/viewers?ext=pdf",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["default_viewer_id"] == str(setup_file_apps["pdf_app_id"])