feat: migrate ORM base to sqlmodel-ext, add file viewers and WOPI integration
All checks were successful
Test / test (push) Successful in 1m45s

- Migrate SQLModel base classes, mixins, and database management to
  external sqlmodel-ext package; remove sqlmodels/base/, sqlmodels/mixin/,
  and sqlmodels/database.py
- Add file viewer/editor system with WOPI protocol support for
  collaborative editing (OnlyOffice, Collabora)
- Add enterprise edition license verification module (ee/)
- Add Dockerfile multi-stage build with Cython compilation support
- Add new dependencies: sqlmodel-ext, cryptography, whatthepatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 14:23:17 +08:00
parent 53b757de7a
commit ccadfe57cd
81 changed files with 5106 additions and 4837 deletions

View File

@@ -0,0 +1,466 @@
"""
文本文件内容 GET/PATCH 集成测试
测试 GET /file/content/{file_id} 和 PATCH /file/content/{file_id} 端点。
"""
import hashlib
from pathlib import Path
from uuid import uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels import Object, ObjectType, PhysicalFile, Policy, User
@pytest.fixture(autouse=True)
def _register_sqlite_greatest():
"""注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法"""
def _on_connect(dbapi_connection, connection_record):
if hasattr(dbapi_connection, 'create_function'):
dbapi_connection.create_function("greatest", 2, max)
event.listen(Engine, "connect", _on_connect)
yield
event.remove(Engine, "connect", _on_connect)
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def local_policy(
initialized_db: AsyncSession,
tmp_path: Path,
) -> Policy:
"""创建指向临时目录的本地存储策略"""
from sqlmodels import PolicyType
policy = Policy(
id=uuid4(),
name="测试本地存储",
type=PolicyType.LOCAL,
server=str(tmp_path),
)
initialized_db.add(policy)
await initialized_db.commit()
await initialized_db.refresh(policy)
return policy
@pytest_asyncio.fixture
async def text_file(
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> dict[str, str | int]:
"""创建包含 UTF-8 文本内容的测试文件"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
content = "line1\nline2\nline3\n"
content_bytes = content.encode('utf-8')
content_hash = hashlib.sha256(content_bytes).hexdigest()
file_path = tmp_path / "test.txt"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="test.txt",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"content": content,
"hash": content_hash,
"size": len(content_bytes),
"path": str(file_path),
}
@pytest_asyncio.fixture
async def binary_file(
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> dict[str, str | int]:
"""创建非 UTF-8 的二进制测试文件"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
# 包含无效 UTF-8 字节序列
content_bytes = b'\x80\x81\x82\xff\xfe\xfd'
file_path = tmp_path / "binary.dat"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="binary.dat",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"path": str(file_path),
}
# ==================== GET /file/content/{file_id} ====================
class TestGetFileContent:
"""GET /file/content/{file_id} 端点测试"""
@pytest.mark.asyncio
async def test_get_content_success(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""成功获取文本文件内容和哈希"""
response = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["content"] == text_file["content"]
assert data["hash"] == text_file["hash"]
assert data["size"] == text_file["size"]
@pytest.mark.asyncio
async def test_get_content_non_utf8_returns_400(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
binary_file: dict[str, str | int],
) -> None:
"""非 UTF-8 文件返回 400"""
response = await async_client.get(
f"/api/v1/file/content/{binary_file['id']}",
headers=auth_headers,
)
assert response.status_code == 400
assert "UTF-8" in response.json()["detail"]
@pytest.mark.asyncio
async def test_get_content_not_found(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""文件不存在返回 404"""
fake_id = uuid4()
response = await async_client.get(
f"/api/v1/file/content/{fake_id}",
headers=auth_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_get_content_unauthenticated(
self,
async_client: AsyncClient,
text_file: dict[str, str | int],
) -> None:
"""未认证返回 401"""
response = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_content_normalizes_crlf(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
initialized_db: AsyncSession,
tmp_path: Path,
local_policy: Policy,
) -> None:
"""CRLF 换行符被规范化为 LF"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
crlf_content = b"line1\r\nline2\r\n"
file_path = tmp_path / "crlf.txt"
file_path.write_bytes(crlf_content)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(crlf_content),
policy_id=local_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="crlf.txt",
type=ObjectType.FILE,
size=len(crlf_content),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=local_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/file/content/{file_obj.id}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
# 内容应该被规范化为 LF
assert data["content"] == "line1\nline2\n"
# 哈希基于规范化后的内容
expected_hash = hashlib.sha256("line1\nline2\n".encode('utf-8')).hexdigest()
assert data["hash"] == expected_hash
# ==================== PATCH /file/content/{file_id} ====================
class TestPatchFileContent:
"""PATCH /file/content/{file_id} 端点测试"""
@pytest.mark.asyncio
async def test_patch_content_success(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""正常增量保存"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+LINE2_MODIFIED\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert response.status_code == 200
data = response.json()
assert "new_hash" in data
assert "new_size" in data
assert data["new_hash"] != text_file["hash"]
# 验证文件实际被修改
file_path = Path(text_file["path"])
new_content = file_path.read_text(encoding='utf-8')
assert "LINE2_MODIFIED" in new_content
assert "line2" not in new_content
@pytest.mark.asyncio
async def test_patch_content_hash_mismatch_returns_409(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""base_hash 不匹配返回 409"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+changed\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": "0" * 64, # 错误的哈希
},
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_patch_content_invalid_patch_returns_422(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""无效的 patch 格式返回 422"""
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": "this is not a valid patch",
"base_hash": text_file["hash"],
},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_patch_content_context_mismatch_returns_422(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""patch 上下文行不匹配返回 422"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" WRONG_CONTEXT_LINE\n"
"-line2\n"
"+replaced\n"
" line3\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_patch_content_unauthenticated(
self,
async_client: AsyncClient,
text_file: dict[str, str | int],
) -> None:
"""未认证返回 401"""
response = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
json={
"patch": "--- a\n+++ b\n",
"base_hash": text_file["hash"],
},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_patch_content_not_found(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""文件不存在返回 404"""
fake_id = uuid4()
response = await async_client.patch(
f"/api/v1/file/content/{fake_id}",
headers=auth_headers,
json={
"patch": "--- a\n+++ b\n",
"base_hash": "0" * 64,
},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_patch_then_get_consistency(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
text_file: dict[str, str | int],
) -> None:
"""PATCH 后 GET 返回一致的内容和哈希"""
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1,3 +1,3 @@\n"
" line1\n"
"-line2\n"
"+PATCHED\n"
" line3\n"
)
# PATCH
patch_resp = await async_client.patch(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": text_file["hash"],
},
)
assert patch_resp.status_code == 200
patch_data = patch_resp.json()
# GET
get_resp = await async_client.get(
f"/api/v1/file/content/{text_file['id']}",
headers=auth_headers,
)
assert get_resp.status_code == 200
get_data = get_resp.json()
# 一致性验证
assert get_data["hash"] == patch_data["new_hash"]
assert get_data["size"] == patch_data["new_size"]
assert "PATCHED" in get_data["content"]