Files
disknext/tests/integration/api/test_file_source_link.py
于小丘 b5d09009e3 feat: implement source link endpoints and enforce policy rules
- Add POST/GET source link endpoints for file sharing via permanent URLs
- Enforce max_size check in PATCH /file/content to prevent size limit bypass
- Support is_private (proxy) vs public (302 redirect) storage modes
- Replace all ResponseBase(data=...) with proper DTOs or 204 responses
- Add 18 integration tests for source link and policy rule enforcement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:07:20 +08:00

555 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
文件外链与 Policy 规则集成测试
测试端点:
- POST /file/source/{file_id} 创建/获取文件外链
- GET /file/get/{file_id}/{name} 外链直接输出
- GET /file/source/{file_id}/{name} 外链重定向/输出
测试 Policy 规则:
- max_size 在 PATCH /file/content 中的检查
- is_origin_link_enable 控制外链创建与访问
- is_private + base_url 控制 302 重定向 vs 应用代理
"""
import hashlib
from pathlib import Path
from uuid import UUID, uuid4
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels import Object, ObjectType, PhysicalFile, Policy, PolicyType, SourceLink, User
@pytest.fixture(autouse=True)
def _register_sqlite_greatest():
"""注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法"""
def _on_connect(dbapi_connection, connection_record):
if hasattr(dbapi_connection, 'create_function'):
dbapi_connection.create_function("greatest", 2, max)
event.listen(Engine, "connect", _on_connect)
yield
event.remove(Engine, "connect", _on_connect)
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def source_policy(
initialized_db: AsyncSession,
tmp_path: Path,
) -> Policy:
"""创建启用外链的本地存储策略"""
policy = Policy(
id=uuid4(),
name="测试外链存储",
type=PolicyType.LOCAL,
server=str(tmp_path),
is_origin_link_enable=True,
is_private=True,
max_size=0,
)
initialized_db.add(policy)
await initialized_db.commit()
await initialized_db.refresh(policy)
return policy
@pytest_asyncio.fixture
async def source_file(
initialized_db: AsyncSession,
tmp_path: Path,
source_policy: Policy,
) -> dict[str, str | int]:
"""创建一个文本测试文件,关联到启用外链的存储策略"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
content = "A" * 50
content_bytes = content.encode('utf-8')
content_hash = hashlib.sha256(content_bytes).hexdigest()
file_path = tmp_path / "source_test.txt"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=source_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="source_test.txt",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=source_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"name": "source_test.txt",
"content": content,
"hash": content_hash,
"size": len(content_bytes),
"path": str(file_path),
}
@pytest_asyncio.fixture
async def source_file_with_link(
initialized_db: AsyncSession,
source_file: dict[str, str | int],
) -> dict[str, str | int]:
"""创建已有 SourceLink 的测试文件"""
link = SourceLink(
name=source_file["name"],
object_id=UUID(source_file["id"]),
downloads=5,
)
initialized_db.add(link)
await initialized_db.commit()
await initialized_db.refresh(link)
return {**source_file, "link_id": link.id, "link_downloads": 5}
# ==================== POST /file/source/{file_id} ====================
class TestCreateSourceLink:
"""POST /file/source/{file_id} 端点测试"""
@pytest.mark.asyncio
async def test_create_source_link_success(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
source_file: dict[str, str | int],
) -> None:
"""成功创建外链"""
response = await async_client.post(
f"/api/v1/file/source/{source_file['id']}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "/api/v1/file/source/" in data["url"]
assert source_file["name"] in data["url"]
assert data["downloads"] == 0
@pytest.mark.asyncio
async def test_create_source_link_idempotent(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
source_file_with_link: dict[str, str | int],
) -> None:
"""已有外链时返回现有外链(幂等)"""
response = await async_client.post(
f"/api/v1/file/source/{source_file_with_link['id']}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["downloads"] == source_file_with_link["link_downloads"]
@pytest.mark.asyncio
async def test_create_source_link_disabled_returns_403(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
source_file: dict[str, str | int],
source_policy: Policy,
initialized_db: AsyncSession,
) -> None:
"""存储策略未启用外链时返回 403"""
source_policy.is_origin_link_enable = False
initialized_db.add(source_policy)
await initialized_db.commit()
response = await async_client.post(
f"/api/v1/file/source/{source_file['id']}",
headers=auth_headers,
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_create_source_link_file_not_found(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
) -> None:
"""文件不存在返回 404"""
response = await async_client.post(
f"/api/v1/file/source/{uuid4()}",
headers=auth_headers,
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_create_source_link_unauthenticated(
self,
async_client: AsyncClient,
source_file: dict[str, str | int],
) -> None:
"""未认证返回 401"""
response = await async_client.post(
f"/api/v1/file/source/{source_file['id']}",
)
assert response.status_code == 401
# ==================== GET /file/get/{file_id}/{name} ====================
class TestFileGetDirect:
"""GET /file/get/{file_id}/{name} 端点测试"""
@pytest.mark.asyncio
async def test_get_direct_success(
self,
async_client: AsyncClient,
source_file_with_link: dict[str, str | int],
) -> None:
"""成功通过外链直接获取文件(无需认证)"""
response = await async_client.get(
f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}",
)
assert response.status_code == 200
assert source_file_with_link["content"] in response.text
@pytest.mark.asyncio
async def test_get_direct_increments_download_count(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
source_file_with_link: dict[str, str | int],
initialized_db: AsyncSession,
) -> None:
"""下载后递增计数"""
link_before = await SourceLink.get(
initialized_db,
SourceLink.object_id == UUID(source_file_with_link["id"]),
)
downloads_before = link_before.downloads
await async_client.get(
f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}",
)
await initialized_db.refresh(link_before)
assert link_before.downloads == downloads_before + 1
@pytest.mark.asyncio
async def test_get_direct_no_link_returns_404(
self,
async_client: AsyncClient,
source_file: dict[str, str | int],
) -> None:
"""未创建外链的文件返回 404"""
response = await async_client.get(
f"/api/v1/file/get/{source_file['id']}/{source_file['name']}",
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_get_direct_nonexistent_file_returns_404(
self,
async_client: AsyncClient,
) -> None:
"""文件不存在返回 404"""
response = await async_client.get(
f"/api/v1/file/get/{uuid4()}/fake.txt",
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_get_direct_disabled_policy_returns_403(
self,
async_client: AsyncClient,
source_file_with_link: dict[str, str | int],
source_policy: Policy,
initialized_db: AsyncSession,
) -> None:
"""存储策略禁用外链时返回 403"""
source_policy.is_origin_link_enable = False
initialized_db.add(source_policy)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/file/get/{source_file_with_link['id']}/{source_file_with_link['name']}",
)
assert response.status_code == 403
# ==================== GET /file/source/{file_id}/{name} ====================
class TestFileSourceRedirect:
"""GET /file/source/{file_id}/{name} 端点测试"""
@pytest.mark.asyncio
async def test_source_private_returns_file_content(
self,
async_client: AsyncClient,
source_file_with_link: dict[str, str | int],
) -> None:
"""is_private=True 时直接返回文件内容"""
response = await async_client.get(
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
follow_redirects=False,
)
assert response.status_code == 200
assert source_file_with_link["content"] in response.text
@pytest.mark.asyncio
async def test_source_public_redirects_302(
self,
async_client: AsyncClient,
source_file_with_link: dict[str, str | int],
source_policy: Policy,
initialized_db: AsyncSession,
) -> None:
"""is_private=False + base_url 时 302 重定向"""
source_policy.is_private = False
source_policy.base_url = "http://cdn.example.com/storage"
initialized_db.add(source_policy)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
follow_redirects=False,
)
assert response.status_code == 302
location = response.headers["location"]
assert "cdn.example.com/storage" in location
@pytest.mark.asyncio
async def test_source_public_no_base_url_fallback(
self,
async_client: AsyncClient,
source_file_with_link: dict[str, str | int],
source_policy: Policy,
initialized_db: AsyncSession,
) -> None:
"""is_private=False 但 base_url 为空时降级为直接输出"""
source_policy.is_private = False
source_policy.base_url = None
initialized_db.add(source_policy)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
follow_redirects=False,
)
assert response.status_code == 200
assert source_file_with_link["content"] in response.text
@pytest.mark.asyncio
async def test_source_increments_download_count(
self,
async_client: AsyncClient,
source_file_with_link: dict[str, str | int],
initialized_db: AsyncSession,
) -> None:
"""访问外链递增下载计数"""
link_before = await SourceLink.get(
initialized_db,
SourceLink.object_id == UUID(source_file_with_link["id"]),
)
downloads_before = link_before.downloads
await async_client.get(
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
)
await initialized_db.refresh(link_before)
assert link_before.downloads == downloads_before + 1
@pytest.mark.asyncio
async def test_source_no_link_returns_404(
self,
async_client: AsyncClient,
source_file: dict[str, str | int],
) -> None:
"""未创建外链的文件返回 404"""
response = await async_client.get(
f"/api/v1/file/source/{source_file['id']}/{source_file['name']}",
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_source_disabled_policy_returns_403(
self,
async_client: AsyncClient,
source_file_with_link: dict[str, str | int],
source_policy: Policy,
initialized_db: AsyncSession,
) -> None:
"""存储策略禁用外链时返回 403"""
source_policy.is_origin_link_enable = False
initialized_db.add(source_policy)
await initialized_db.commit()
response = await async_client.get(
f"/api/v1/file/source/{source_file_with_link['id']}/{source_file_with_link['name']}",
)
assert response.status_code == 403
# ==================== max_size 在 PATCH 中的检查 ====================
class TestPatchMaxSizePolicy:
"""PATCH /file/content/{file_id} 的 max_size 策略检查"""
@pytest_asyncio.fixture
async def size_limited_policy(
self,
initialized_db: AsyncSession,
tmp_path: Path,
) -> Policy:
"""创建有大小限制的存储策略100 bytes"""
policy = Policy(
id=uuid4(),
name="限制大小存储",
type=PolicyType.LOCAL,
server=str(tmp_path),
max_size=100,
)
initialized_db.add(policy)
await initialized_db.commit()
await initialized_db.refresh(policy)
return policy
@pytest_asyncio.fixture
async def small_file(
self,
initialized_db: AsyncSession,
tmp_path: Path,
size_limited_policy: Policy,
) -> dict[str, str | int]:
"""创建一个 50 字节的文本文件(策略限制 100 字节)"""
user = await User.get(initialized_db, User.email == "testuser@test.local")
root = await Object.get_root(initialized_db, user.id)
content = "A" * 50
content_bytes = content.encode('utf-8')
content_hash = hashlib.sha256(content_bytes).hexdigest()
file_path = tmp_path / "small.txt"
file_path.write_bytes(content_bytes)
physical_file = PhysicalFile(
id=uuid4(),
storage_path=str(file_path),
size=len(content_bytes),
policy_id=size_limited_policy.id,
reference_count=1,
)
initialized_db.add(physical_file)
file_obj = Object(
id=uuid4(),
name="small.txt",
type=ObjectType.FILE,
size=len(content_bytes),
physical_file_id=physical_file.id,
parent_id=root.id,
owner_id=user.id,
policy_id=size_limited_policy.id,
)
initialized_db.add(file_obj)
await initialized_db.commit()
return {
"id": str(file_obj.id),
"content": content,
"hash": content_hash,
"size": len(content_bytes),
"path": str(file_path),
}
@pytest.mark.asyncio
async def test_patch_exceeds_max_size_returns_413(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
small_file: dict[str, str | int],
) -> None:
"""PATCH 后文件超过 max_size 返回 413"""
big_content = "B" * 200
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1 +1 @@\n"
f"-{'A' * 50}\n"
f"+{big_content}\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{small_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": small_file["hash"],
},
)
assert response.status_code == 413
@pytest.mark.asyncio
async def test_patch_within_max_size_succeeds(
self,
async_client: AsyncClient,
auth_headers: dict[str, str],
small_file: dict[str, str | int],
) -> None:
"""PATCH 后文件未超过 max_size 正常保存"""
new_content = "C" * 80 # 80 bytes < 100 bytes limit
patch_text = (
"--- a\n"
"+++ b\n"
"@@ -1 +1 @@\n"
f"-{'A' * 50}\n"
f"+{new_content}\n"
)
response = await async_client.patch(
f"/api/v1/file/content/{small_file['id']}",
headers=auth_headers,
json={
"patch": patch_text,
"base_hash": small_file["hash"],
},
)
assert response.status_code == 200
data = response.json()
assert data["new_size"] == 80