Files
disknext/tests/conftest.py
于小丘 729773cae3 feat: add multi-provider auth via AuthIdentity and extend site config
- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link)
- Remove password field from User model, credentials now stored in AuthIdentity
- Refactor unified login/register to use AuthIdentity-based provider checking
- Add site config fields: footer_code, tos_url, privacy_url, auth_methods
- Add auth settings defaults in migration (email_password enabled by default)
- Update admin user creation to create AuthIdentity records
- Update all tests to use AuthIdentity model

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 22:49:12 +08:00

455 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Pytest 配置文件
提供测试所需的 fixtures包括数据库会话、认证用户、测试客户端等。
"""
import asyncio
import os
import sys
from typing import AsyncGenerator
from uuid import UUID
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from httpx import AsyncClient, ASGITransport
from loguru import logger as l
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import sessionmaker
# 添加项目根目录到Python路径确保可以导入项目模块
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from main import app
from sqlmodels.database import get_session
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.group import Group, GroupClaims, GroupOptions
from sqlmodels.migration import migration
from sqlmodels.object import Object, ObjectType
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.user import User, UserStatus
from utils.JWT import create_access_token
from utils.password.pwd import Password
# ==================== 事件循环 ====================
@pytest.fixture(scope="session")
def event_loop():
"""
创建 session 级别的事件循环
注意pytest-asyncio 在不同版本中对事件循环的管理有所不同。
此 fixture 确保整个测试会话使用同一个事件循环。
"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
# ==================== 数据库 ====================
@pytest_asyncio.fixture(scope="function")
async def test_engine() -> AsyncGenerator[AsyncEngine, None]:
"""
创建 SQLite 内存数据库引擎function scope
每个测试函数都会获得一个全新的数据库,确保测试隔离。
"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
connect_args={"check_same_thread": False},
future=True,
)
# 创建所有表
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield engine
# 清理
await engine.dispose()
@pytest_asyncio.fixture(scope="function")
async def db_session(test_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
"""
创建异步数据库会话function scope
使用内存数据库引擎创建会话,每个测试函数独立。
"""
async_session_factory = sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session_factory() as session:
yield session
@pytest_asyncio.fixture(scope="function")
async def initialized_db(db_session: AsyncSession) -> AsyncSession:
"""
已初始化的数据库(运行 migration
执行数据库迁移逻辑,创建默认数据(如管理员用户组、默认策略等)。
"""
# 注意migration 函数需要适配以支持传入 session
# 如果 migration 不支持传入 session需要修改其实现
try:
# 这里假设 migration 可以在测试环境中运行
# 实际项目中可能需要单独实现测试数据初始化逻辑
pass
except Exception as e:
l.warning(f"Migration 在测试环境中跳过: {e}")
return db_session
# ==================== HTTP 客户端 ====================
@pytest.fixture(scope="function")
def client() -> TestClient:
"""
同步 TestClientfunction scope
用于测试 FastAPI 端点的同步客户端。
"""
return TestClient(app)
@pytest_asyncio.fixture(scope="function")
async def async_client() -> AsyncGenerator[AsyncClient, None]:
"""
异步 httpx.AsyncClientfunction scope
用于测试异步端点,支持 WebSocket 等异步操作。
"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
# ==================== 覆盖依赖 ====================
def override_get_session(db_session: AsyncSession):
"""
覆盖 FastAPI 的数据库会话依赖
将应用的数据库会话替换为测试会话。
"""
async def _override():
yield db_session
app.dependency_overrides[get_session] = _override
# ==================== 测试用户 ====================
@pytest_asyncio.fixture(scope="function")
async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
"""
创建测试用户并返回 {id, email, password, token}
创建一个普通用户,包含用户组、存储策略和根目录。
"""
# 创建默认用户组
group = Group(
name="测试用户组",
max_storage=1024 * 1024 * 1024 * 10, # 10GB
share_enabled=True,
web_dav_enabled=True,
admin=False,
speed_limit=0,
)
group = await group.save(db_session)
# 创建用户组选项
group_options = GroupOptions(
group_id=group.id,
share_download=True,
share_free=False,
relocate=True,
)
await group_options.save(db_session)
# 创建默认存储策略
policy = Policy(
name="测试本地策略",
type=PolicyType.LOCAL,
server="/tmp/disknext_test",
is_private=True,
max_size=1024 * 1024 * 100, # 100MB
)
policy = await policy.save(db_session)
# 创建测试用户
password = "test_password_123"
user = User(
email="testuser@test.local",
nickname="测试用户",
status=UserStatus.ACTIVE,
storage=0,
score=100,
group_id=group.id,
)
user = await user.save(db_session)
# 创建邮箱密码认证身份
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="testuser@test.local",
credential=Password.hash(password),
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(db_session)
# 创建用户根目录
root_folder = Object(
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
policy_id=policy.id,
size=0,
)
await root_folder.save(db_session)
# 构建权限快照
group.options = group_options
group_claims = GroupClaims.from_group(group)
# 生成访问令牌
from uuid import uuid4
access_token_obj = create_access_token(
sub=user.id,
jti=uuid4(),
status=user.status.value,
group=group_claims,
)
return {
"id": user.id,
"email": user.email,
"password": password,
"token": access_token_obj.access_token,
"group_id": group.id,
"policy_id": policy.id,
}
@pytest_asyncio.fixture(scope="function")
async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
"""
获取管理员用户 {id, email, token}
创建具有管理员权限的用户。
"""
# 创建管理员用户组
admin_group = Group(
name="管理员组",
max_storage=0, # 无限制
share_enabled=True,
web_dav_enabled=True,
admin=True,
speed_limit=0,
)
admin_group = await admin_group.save(db_session)
# 创建管理员组选项
admin_group_options = GroupOptions(
group_id=admin_group.id,
share_download=True,
share_free=True,
relocate=True,
source_batch=100,
select_node=True,
advance_delete=True,
)
await admin_group_options.save(db_session)
# 创建默认存储策略
policy = Policy(
name="管理员本地策略",
type=PolicyType.LOCAL,
server="/tmp/disknext_admin",
is_private=True,
max_size=0, # 无限制
)
policy = await policy.save(db_session)
# 创建管理员用户
password = "admin_password_456"
admin = User(
email="admin@disknext.local",
nickname="管理员",
status=UserStatus.ACTIVE,
storage=0,
score=9999,
group_id=admin_group.id,
)
admin = await admin.save(db_session)
# 创建管理员邮箱密码认证身份
admin_identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="admin@disknext.local",
credential=Password.hash(password),
is_primary=True,
is_verified=True,
user_id=admin.id,
)
await admin_identity.save(db_session)
# 创建管理员根目录
root_folder = Object(
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=admin.id,
policy_id=policy.id,
size=0,
)
await root_folder.save(db_session)
# 构建权限快照
admin_group.options = admin_group_options
admin_group_claims = GroupClaims.from_group(admin_group)
# 生成访问令牌
from uuid import uuid4
access_token_obj = create_access_token(
sub=admin.id,
jti=uuid4(),
status=admin.status.value,
group=admin_group_claims,
)
return {
"id": admin.id,
"email": admin.email,
"password": password,
"token": access_token_obj.access_token,
"group_id": admin_group.id,
"policy_id": policy.id,
}
# ==================== 认证请求头 ====================
@pytest.fixture(scope="function")
def auth_headers(test_user: dict[str, str | UUID]) -> dict[str, str]:
"""
返回认证请求头 {"Authorization": "Bearer ..."}
使用测试用户的令牌。
"""
return {"Authorization": f"Bearer {test_user['token']}"}
@pytest.fixture(scope="function")
def admin_headers(admin_user: dict[str, str | UUID]) -> dict[str, str]:
"""
返回管理员认证请求头
使用管理员用户的令牌。
"""
return {"Authorization": f"Bearer {admin_user['token']}"}
# ==================== 测试数据 ====================
@pytest_asyncio.fixture(scope="function")
async def test_directory(
db_session: AsyncSession,
test_user: dict[str, str | UUID]
) -> dict[str, UUID]:
"""
为测试用户创建目录结构
创建以下目录结构:
/testuser (root)
├── documents
│ ├── work
│ └── personal
├── images
└── videos
返回: {"root": UUID, "documents": UUID, "work": UUID, ...}
"""
user_id: UUID = test_user["id"]
policy_id: UUID = test_user["policy_id"]
# 获取根目录
root = await Object.get_root(db_session, user_id)
if not root:
raise ValueError("测试用户的根目录不存在")
# 创建顶级目录
documents = Object(
name="documents",
type=ObjectType.FOLDER,
parent_id=root.id,
owner_id=user_id,
policy_id=policy_id,
size=0,
)
documents = await documents.save(db_session)
images = Object(
name="images",
type=ObjectType.FOLDER,
parent_id=root.id,
owner_id=user_id,
policy_id=policy_id,
size=0,
)
images = await images.save(db_session)
videos = Object(
name="videos",
type=ObjectType.FOLDER,
parent_id=root.id,
owner_id=user_id,
policy_id=policy_id,
size=0,
)
videos = await videos.save(db_session)
# 创建子目录
work = Object(
name="work",
type=ObjectType.FOLDER,
parent_id=documents.id,
owner_id=user_id,
policy_id=policy_id,
size=0,
)
work = await work.save(db_session)
personal = Object(
name="personal",
type=ObjectType.FOLDER,
parent_id=documents.id,
owner_id=user_id,
policy_id=policy_id,
size=0,
)
personal = await personal.save(db_session)
return {
"root": root.id,
"documents": documents.id,
"images": images.id,
"videos": videos.id,
"work": work.id,
"personal": personal.id,
}