- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link) - Remove password field from User model, credentials now stored in AuthIdentity - Refactor unified login/register to use AuthIdentity-based provider checking - Add site config fields: footer_code, tos_url, privacy_url, auth_methods - Add auth settings defaults in migration (email_password enabled by default) - Update admin user creation to create AuthIdentity records - Update all tests to use AuthIdentity model Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
455 lines
12 KiB
Python
455 lines
12 KiB
Python
"""
|
||
Pytest 配置文件
|
||
|
||
提供测试所需的 fixtures,包括数据库会话、认证用户、测试客户端等。
|
||
"""
|
||
import asyncio
|
||
import os
|
||
import sys
|
||
from typing import AsyncGenerator
|
||
from uuid import UUID
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from fastapi.testclient import TestClient
|
||
from httpx import AsyncClient, ASGITransport
|
||
from loguru import logger as l
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||
from sqlmodel import SQLModel
|
||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||
from sqlalchemy.orm import sessionmaker
|
||
|
||
# 添加项目根目录到Python路径,确保可以导入项目模块
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
|
||
from main import app
|
||
from sqlmodels.database import get_session
|
||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||
from sqlmodels.group import Group, GroupClaims, GroupOptions
|
||
from sqlmodels.migration import migration
|
||
from sqlmodels.object import Object, ObjectType
|
||
from sqlmodels.policy import Policy, PolicyType
|
||
from sqlmodels.user import User, UserStatus
|
||
from utils.JWT import create_access_token
|
||
from utils.password.pwd import Password
|
||
|
||
|
||
# ==================== 事件循环 ====================
|
||
|
||
@pytest.fixture(scope="session")
|
||
def event_loop():
|
||
"""
|
||
创建 session 级别的事件循环
|
||
|
||
注意:pytest-asyncio 在不同版本中对事件循环的管理有所不同。
|
||
此 fixture 确保整个测试会话使用同一个事件循环。
|
||
"""
|
||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||
yield loop
|
||
loop.close()
|
||
|
||
|
||
# ==================== 数据库 ====================
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def test_engine() -> AsyncGenerator[AsyncEngine, None]:
|
||
"""
|
||
创建 SQLite 内存数据库引擎(function scope)
|
||
|
||
每个测试函数都会获得一个全新的数据库,确保测试隔离。
|
||
"""
|
||
engine = create_async_engine(
|
||
"sqlite+aiosqlite:///:memory:",
|
||
echo=False,
|
||
connect_args={"check_same_thread": False},
|
||
future=True,
|
||
)
|
||
|
||
# 创建所有表
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(SQLModel.metadata.create_all)
|
||
|
||
yield engine
|
||
|
||
# 清理
|
||
await engine.dispose()
|
||
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def db_session(test_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
||
"""
|
||
创建异步数据库会话(function scope)
|
||
|
||
使用内存数据库引擎创建会话,每个测试函数独立。
|
||
"""
|
||
async_session_factory = sessionmaker(
|
||
test_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
)
|
||
|
||
async with async_session_factory() as session:
|
||
yield session
|
||
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def initialized_db(db_session: AsyncSession) -> AsyncSession:
|
||
"""
|
||
已初始化的数据库(运行 migration)
|
||
|
||
执行数据库迁移逻辑,创建默认数据(如管理员用户组、默认策略等)。
|
||
"""
|
||
# 注意:migration 函数需要适配以支持传入 session
|
||
# 如果 migration 不支持传入 session,需要修改其实现
|
||
try:
|
||
# 这里假设 migration 可以在测试环境中运行
|
||
# 实际项目中可能需要单独实现测试数据初始化逻辑
|
||
pass
|
||
except Exception as e:
|
||
l.warning(f"Migration 在测试环境中跳过: {e}")
|
||
|
||
return db_session
|
||
|
||
|
||
# ==================== HTTP 客户端 ====================
|
||
|
||
@pytest.fixture(scope="function")
|
||
def client() -> TestClient:
|
||
"""
|
||
同步 TestClient(function scope)
|
||
|
||
用于测试 FastAPI 端点的同步客户端。
|
||
"""
|
||
return TestClient(app)
|
||
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def async_client() -> AsyncGenerator[AsyncClient, None]:
|
||
"""
|
||
异步 httpx.AsyncClient(function scope)
|
||
|
||
用于测试异步端点,支持 WebSocket 等异步操作。
|
||
"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
yield client
|
||
|
||
|
||
# ==================== 覆盖依赖 ====================
|
||
|
||
def override_get_session(db_session: AsyncSession):
|
||
"""
|
||
覆盖 FastAPI 的数据库会话依赖
|
||
|
||
将应用的数据库会话替换为测试会话。
|
||
"""
|
||
async def _override():
|
||
yield db_session
|
||
|
||
app.dependency_overrides[get_session] = _override
|
||
|
||
|
||
# ==================== 测试用户 ====================
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||
"""
|
||
创建测试用户并返回 {id, email, password, token}
|
||
|
||
创建一个普通用户,包含用户组、存储策略和根目录。
|
||
"""
|
||
# 创建默认用户组
|
||
group = Group(
|
||
name="测试用户组",
|
||
max_storage=1024 * 1024 * 1024 * 10, # 10GB
|
||
share_enabled=True,
|
||
web_dav_enabled=True,
|
||
admin=False,
|
||
speed_limit=0,
|
||
)
|
||
group = await group.save(db_session)
|
||
|
||
# 创建用户组选项
|
||
group_options = GroupOptions(
|
||
group_id=group.id,
|
||
share_download=True,
|
||
share_free=False,
|
||
relocate=True,
|
||
)
|
||
await group_options.save(db_session)
|
||
|
||
# 创建默认存储策略
|
||
policy = Policy(
|
||
name="测试本地策略",
|
||
type=PolicyType.LOCAL,
|
||
server="/tmp/disknext_test",
|
||
is_private=True,
|
||
max_size=1024 * 1024 * 100, # 100MB
|
||
)
|
||
policy = await policy.save(db_session)
|
||
|
||
# 创建测试用户
|
||
password = "test_password_123"
|
||
user = User(
|
||
email="testuser@test.local",
|
||
nickname="测试用户",
|
||
status=UserStatus.ACTIVE,
|
||
storage=0,
|
||
score=100,
|
||
group_id=group.id,
|
||
)
|
||
user = await user.save(db_session)
|
||
|
||
# 创建邮箱密码认证身份
|
||
identity = AuthIdentity(
|
||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||
identifier="testuser@test.local",
|
||
credential=Password.hash(password),
|
||
is_primary=True,
|
||
is_verified=True,
|
||
user_id=user.id,
|
||
)
|
||
await identity.save(db_session)
|
||
|
||
# 创建用户根目录
|
||
root_folder = Object(
|
||
name="/",
|
||
type=ObjectType.FOLDER,
|
||
parent_id=None,
|
||
owner_id=user.id,
|
||
policy_id=policy.id,
|
||
size=0,
|
||
)
|
||
await root_folder.save(db_session)
|
||
|
||
# 构建权限快照
|
||
group.options = group_options
|
||
group_claims = GroupClaims.from_group(group)
|
||
|
||
# 生成访问令牌
|
||
from uuid import uuid4
|
||
access_token_obj = create_access_token(
|
||
sub=user.id,
|
||
jti=uuid4(),
|
||
status=user.status.value,
|
||
group=group_claims,
|
||
)
|
||
|
||
return {
|
||
"id": user.id,
|
||
"email": user.email,
|
||
"password": password,
|
||
"token": access_token_obj.access_token,
|
||
"group_id": group.id,
|
||
"policy_id": policy.id,
|
||
}
|
||
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||
"""
|
||
获取管理员用户 {id, email, token}
|
||
|
||
创建具有管理员权限的用户。
|
||
"""
|
||
# 创建管理员用户组
|
||
admin_group = Group(
|
||
name="管理员组",
|
||
max_storage=0, # 无限制
|
||
share_enabled=True,
|
||
web_dav_enabled=True,
|
||
admin=True,
|
||
speed_limit=0,
|
||
)
|
||
admin_group = await admin_group.save(db_session)
|
||
|
||
# 创建管理员组选项
|
||
admin_group_options = GroupOptions(
|
||
group_id=admin_group.id,
|
||
share_download=True,
|
||
share_free=True,
|
||
relocate=True,
|
||
source_batch=100,
|
||
select_node=True,
|
||
advance_delete=True,
|
||
)
|
||
await admin_group_options.save(db_session)
|
||
|
||
# 创建默认存储策略
|
||
policy = Policy(
|
||
name="管理员本地策略",
|
||
type=PolicyType.LOCAL,
|
||
server="/tmp/disknext_admin",
|
||
is_private=True,
|
||
max_size=0, # 无限制
|
||
)
|
||
policy = await policy.save(db_session)
|
||
|
||
# 创建管理员用户
|
||
password = "admin_password_456"
|
||
admin = User(
|
||
email="admin@disknext.local",
|
||
nickname="管理员",
|
||
status=UserStatus.ACTIVE,
|
||
storage=0,
|
||
score=9999,
|
||
group_id=admin_group.id,
|
||
)
|
||
admin = await admin.save(db_session)
|
||
|
||
# 创建管理员邮箱密码认证身份
|
||
admin_identity = AuthIdentity(
|
||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||
identifier="admin@disknext.local",
|
||
credential=Password.hash(password),
|
||
is_primary=True,
|
||
is_verified=True,
|
||
user_id=admin.id,
|
||
)
|
||
await admin_identity.save(db_session)
|
||
|
||
# 创建管理员根目录
|
||
root_folder = Object(
|
||
name="/",
|
||
type=ObjectType.FOLDER,
|
||
parent_id=None,
|
||
owner_id=admin.id,
|
||
policy_id=policy.id,
|
||
size=0,
|
||
)
|
||
await root_folder.save(db_session)
|
||
|
||
# 构建权限快照
|
||
admin_group.options = admin_group_options
|
||
admin_group_claims = GroupClaims.from_group(admin_group)
|
||
|
||
# 生成访问令牌
|
||
from uuid import uuid4
|
||
access_token_obj = create_access_token(
|
||
sub=admin.id,
|
||
jti=uuid4(),
|
||
status=admin.status.value,
|
||
group=admin_group_claims,
|
||
)
|
||
|
||
return {
|
||
"id": admin.id,
|
||
"email": admin.email,
|
||
"password": password,
|
||
"token": access_token_obj.access_token,
|
||
"group_id": admin_group.id,
|
||
"policy_id": policy.id,
|
||
}
|
||
|
||
|
||
# ==================== 认证请求头 ====================
|
||
|
||
@pytest.fixture(scope="function")
|
||
def auth_headers(test_user: dict[str, str | UUID]) -> dict[str, str]:
|
||
"""
|
||
返回认证请求头 {"Authorization": "Bearer ..."}
|
||
|
||
使用测试用户的令牌。
|
||
"""
|
||
return {"Authorization": f"Bearer {test_user['token']}"}
|
||
|
||
|
||
@pytest.fixture(scope="function")
|
||
def admin_headers(admin_user: dict[str, str | UUID]) -> dict[str, str]:
|
||
"""
|
||
返回管理员认证请求头
|
||
|
||
使用管理员用户的令牌。
|
||
"""
|
||
return {"Authorization": f"Bearer {admin_user['token']}"}
|
||
|
||
|
||
# ==================== 测试数据 ====================
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def test_directory(
|
||
db_session: AsyncSession,
|
||
test_user: dict[str, str | UUID]
|
||
) -> dict[str, UUID]:
|
||
"""
|
||
为测试用户创建目录结构
|
||
|
||
创建以下目录结构:
|
||
/testuser (root)
|
||
├── documents
|
||
│ ├── work
|
||
│ └── personal
|
||
├── images
|
||
└── videos
|
||
|
||
返回: {"root": UUID, "documents": UUID, "work": UUID, ...}
|
||
"""
|
||
user_id: UUID = test_user["id"]
|
||
policy_id: UUID = test_user["policy_id"]
|
||
|
||
# 获取根目录
|
||
root = await Object.get_root(db_session, user_id)
|
||
if not root:
|
||
raise ValueError("测试用户的根目录不存在")
|
||
|
||
# 创建顶级目录
|
||
documents = Object(
|
||
name="documents",
|
||
type=ObjectType.FOLDER,
|
||
parent_id=root.id,
|
||
owner_id=user_id,
|
||
policy_id=policy_id,
|
||
size=0,
|
||
)
|
||
documents = await documents.save(db_session)
|
||
|
||
images = Object(
|
||
name="images",
|
||
type=ObjectType.FOLDER,
|
||
parent_id=root.id,
|
||
owner_id=user_id,
|
||
policy_id=policy_id,
|
||
size=0,
|
||
)
|
||
images = await images.save(db_session)
|
||
|
||
videos = Object(
|
||
name="videos",
|
||
type=ObjectType.FOLDER,
|
||
parent_id=root.id,
|
||
owner_id=user_id,
|
||
policy_id=policy_id,
|
||
size=0,
|
||
)
|
||
videos = await videos.save(db_session)
|
||
|
||
# 创建子目录
|
||
work = Object(
|
||
name="work",
|
||
type=ObjectType.FOLDER,
|
||
parent_id=documents.id,
|
||
owner_id=user_id,
|
||
policy_id=policy_id,
|
||
size=0,
|
||
)
|
||
work = await work.save(db_session)
|
||
|
||
personal = Object(
|
||
name="personal",
|
||
type=ObjectType.FOLDER,
|
||
parent_id=documents.id,
|
||
owner_id=user_id,
|
||
policy_id=policy_id,
|
||
size=0,
|
||
)
|
||
personal = await personal.save(db_session)
|
||
|
||
return {
|
||
"root": root.id,
|
||
"documents": documents.id,
|
||
"images": images.id,
|
||
"videos": videos.id,
|
||
"work": work.id,
|
||
"personal": personal.id,
|
||
}
|