feat: add multi-provider auth via AuthIdentity and extend site config
- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link) - Remove password field from User model, credentials now stored in AuthIdentity - Refactor unified login/register to use AuthIdentity-based provider checking - Add site config fields: footer_code, tos_url, privacy_url, auth_methods - Add auth settings defaults in migration (email_password enabled by default) - Update admin user creation to create AuthIdentity records - Update all tests to use AuthIdentity model Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')
|
||||
|
||||
from main import app
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.group import Group, GroupClaims, GroupOptions
|
||||
from sqlmodels.migration import migration
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
@@ -192,7 +193,6 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
user = User(
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password=Password.hash(password),
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=100,
|
||||
@@ -200,6 +200,17 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="testuser@test.local",
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
# 创建用户根目录
|
||||
root_folder = Object(
|
||||
name="/",
|
||||
@@ -279,7 +290,6 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
admin = User(
|
||||
email="admin@disknext.local",
|
||||
nickname="管理员",
|
||||
password=Password.hash(password),
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=9999,
|
||||
@@ -287,6 +297,17 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
admin = await admin.save(db_session)
|
||||
|
||||
# 创建管理员邮箱密码认证身份
|
||||
admin_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="admin@disknext.local",
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=admin.id,
|
||||
)
|
||||
await admin_identity.save(db_session)
|
||||
|
||||
# 创建管理员根目录
|
||||
root_folder = Object(
|
||||
name="/",
|
||||
|
||||
75
tests/fixtures/users.py
vendored
75
tests/fixtures/users.py
vendored
@@ -2,12 +2,14 @@
|
||||
用户测试数据工厂
|
||||
|
||||
提供创建测试用户的便捷方法。
|
||||
用户密码凭证通过 AuthIdentity 管理,不再存储在 User 表中。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -20,7 +22,7 @@ class UserFactory:
|
||||
group_id: UUID,
|
||||
email: str | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> User:
|
||||
"""
|
||||
创建普通用户
|
||||
@@ -29,7 +31,7 @@ class UserFactory:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
email: 用户邮箱(默认: test_user_{随机}@test.local)
|
||||
password: 明文密码(默认: password123)
|
||||
password: 明文密码(默认: password123),若提供则同时创建 AuthIdentity
|
||||
**kwargs: 其他用户字段
|
||||
|
||||
返回:
|
||||
@@ -46,12 +48,10 @@ class UserFactory:
|
||||
user = User(
|
||||
email=email,
|
||||
nickname=kwargs.get("nickname", email),
|
||||
password=Password.hash(password),
|
||||
status=kwargs.get("status", True),
|
||||
status=kwargs.get("status", UserStatus.ACTIVE),
|
||||
storage=kwargs.get("storage", 0),
|
||||
score=kwargs.get("score", 100),
|
||||
group_id=group_id,
|
||||
two_factor=kwargs.get("two_factor"),
|
||||
avatar=kwargs.get("avatar", "default"),
|
||||
group_expires=kwargs.get("group_expires"),
|
||||
theme=kwargs.get("theme", "system"),
|
||||
@@ -61,6 +61,18 @@ class UserFactory:
|
||||
)
|
||||
|
||||
user = await user.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@@ -68,7 +80,7 @@ class UserFactory:
|
||||
session: AsyncSession,
|
||||
admin_group_id: UUID,
|
||||
email: str | None = None,
|
||||
password: str | None = None
|
||||
password: str | None = None,
|
||||
) -> User:
|
||||
"""
|
||||
创建管理员用户
|
||||
@@ -93,8 +105,7 @@ class UserFactory:
|
||||
admin = User(
|
||||
email=email,
|
||||
nickname=f"管理员 {email}",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=9999,
|
||||
group_id=admin_group_id,
|
||||
@@ -102,13 +113,25 @@ class UserFactory:
|
||||
)
|
||||
|
||||
admin = await admin.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=admin.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return admin
|
||||
|
||||
@staticmethod
|
||||
async def create_banned(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
email: str | None = None
|
||||
email: str | None = None,
|
||||
) -> User:
|
||||
"""
|
||||
创建被封禁用户
|
||||
@@ -129,8 +152,7 @@ class UserFactory:
|
||||
banned_user = User(
|
||||
email=email,
|
||||
nickname=f"封禁用户 {email}",
|
||||
password=Password.hash("banned_password"),
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=group_id,
|
||||
@@ -138,6 +160,18 @@ class UserFactory:
|
||||
)
|
||||
|
||||
banned_user = await banned_user.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash("banned_password"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=banned_user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return banned_user
|
||||
|
||||
@staticmethod
|
||||
@@ -145,7 +179,7 @@ class UserFactory:
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
storage_bytes: int,
|
||||
email: str | None = None
|
||||
email: str | None = None,
|
||||
) -> User:
|
||||
"""
|
||||
创建已使用指定存储空间的用户
|
||||
@@ -167,8 +201,7 @@ class UserFactory:
|
||||
user = User(
|
||||
email=email,
|
||||
nickname=email,
|
||||
password=Password.hash("password123"),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=storage_bytes,
|
||||
score=100,
|
||||
group_id=group_id,
|
||||
@@ -176,4 +209,16 @@ class UserFactory:
|
||||
)
|
||||
|
||||
user = await user.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash("password123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
@@ -83,6 +83,24 @@ async def test_site_config_captcha_settings(async_client: AsyncClient):
|
||||
assert "forgetCaptcha" in config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_site_config_auth_methods(async_client: AsyncClient):
|
||||
"""测试配置包含认证方式列表"""
|
||||
response = await async_client.get("/api/site/config")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
config = data["data"]
|
||||
assert "authMethods" in config
|
||||
assert isinstance(config["authMethods"], list)
|
||||
assert len(config["authMethods"]) > 0
|
||||
|
||||
# 每个认证方式应包含 provider 和 isEnabled
|
||||
for method in config["authMethods"]:
|
||||
assert "provider" in method
|
||||
assert "isEnabled" in method
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_site_captcha_endpoint_exists(async_client: AsyncClient):
|
||||
"""测试验证码端点存在(即使未实现也应返回有效响应)"""
|
||||
|
||||
@@ -15,9 +15,10 @@ async def test_user_login_success(
|
||||
"""测试成功登录"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["email"],
|
||||
"password": test_user_info["password"],
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": test_user_info["email"],
|
||||
"credential": test_user_info["password"],
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -37,9 +38,10 @@ async def test_user_login_wrong_password(
|
||||
"""测试密码错误返回 401"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["email"],
|
||||
"password": "wrongpassword",
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": test_user_info["email"],
|
||||
"credential": "wrongpassword",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -50,9 +52,10 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
|
||||
"""测试不存在的用户返回 401"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": "nonexistent@test.local",
|
||||
"password": "anypassword",
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": "nonexistent@test.local",
|
||||
"credential": "anypassword",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -66,9 +69,10 @@ async def test_user_login_user_banned(
|
||||
"""测试封禁用户返回 403"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": banned_user_info["email"],
|
||||
"password": banned_user_info["password"],
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": banned_user_info["email"],
|
||||
"credential": banned_user_info["password"],
|
||||
}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -82,8 +86,9 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"email": "newuser@test.local",
|
||||
"password": "newpass123",
|
||||
"provider": "email_password",
|
||||
"identifier": "newuser@test.local",
|
||||
"credential": "newpass123",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -104,8 +109,9 @@ async def test_user_register_duplicate_email(
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"email": test_user_info["email"],
|
||||
"password": "anypassword",
|
||||
"provider": "email_password",
|
||||
"identifier": test_user_info["email"],
|
||||
"credential": "anypassword",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
@@ -23,6 +23,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
|
||||
|
||||
from main import app
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
@@ -98,6 +99,15 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
|
||||
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="secret_key", value="test_secret_key_for_jwt_token_generation"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_password_required", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_phone_binding_required", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
|
||||
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
|
||||
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
|
||||
]
|
||||
for setting in settings:
|
||||
test_session.add(setting)
|
||||
@@ -183,7 +193,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
test_user = User(
|
||||
id=uuid4(),
|
||||
email="testuser@test.local",
|
||||
password=Password.hash("testpass123"),
|
||||
nickname="测试用户",
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
@@ -196,7 +205,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
admin_user = User(
|
||||
id=uuid4(),
|
||||
email="admin@disknext.local",
|
||||
password=Password.hash("adminpass123"),
|
||||
nickname="管理员",
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
@@ -209,7 +217,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
banned_user = User(
|
||||
id=uuid4(),
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("banned123"),
|
||||
nickname="封禁用户",
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
storage=0,
|
||||
@@ -226,7 +233,40 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
await test_session.refresh(admin_user)
|
||||
await test_session.refresh(banned_user)
|
||||
|
||||
# 7. 创建用户根目录
|
||||
# 7. 创建认证身份
|
||||
test_user_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="testuser@test.local",
|
||||
credential=Password.hash("testpass123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
test_session.add(test_user_identity)
|
||||
|
||||
admin_user_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="admin@disknext.local",
|
||||
credential=Password.hash("adminpass123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
test_session.add(admin_user_identity)
|
||||
|
||||
banned_user_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="banneduser@test.local",
|
||||
credential=Password.hash("banned123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=banned_user.id,
|
||||
)
|
||||
test_session.add(banned_user_identity)
|
||||
|
||||
await test_session.commit()
|
||||
|
||||
# 8. 创建用户根目录
|
||||
test_user_root = Object(
|
||||
id=uuid4(),
|
||||
name="/",
|
||||
@@ -251,7 +291,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
await test_session.commit()
|
||||
|
||||
# 8. 设置JWT密钥(从数据库加载)
|
||||
# 9. 设置JWT密钥(从数据库加载)
|
||||
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
|
||||
|
||||
# 刷新 group options
|
||||
|
||||
@@ -18,7 +18,6 @@ async def test_user_curd():
|
||||
|
||||
test_user = User(
|
||||
email='test_user@test.local',
|
||||
password='test_password',
|
||||
group_id=created_group.id
|
||||
)
|
||||
|
||||
@@ -28,7 +27,6 @@ async def test_user_curd():
|
||||
# 验证用户是否存在
|
||||
assert created_user.id is not None
|
||||
assert created_user.email == 'test_user@test.local'
|
||||
assert created_user.password == 'test_password'
|
||||
assert created_user.group_id == created_group.id
|
||||
|
||||
# 测试查 Read
|
||||
@@ -36,18 +34,16 @@ async def test_user_curd():
|
||||
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.email == 'test_user@test.local'
|
||||
assert fetched_user.password == 'test_password'
|
||||
assert fetched_user.group_id == created_group.id
|
||||
|
||||
# 测试改 Update
|
||||
updated_user = await fetched_user.update(
|
||||
session,
|
||||
{"email": "updated_user@test.local", "password": "updated_password"}
|
||||
{"email": "updated_user@test.local"}
|
||||
)
|
||||
|
||||
assert updated_user is not None
|
||||
assert updated_user.email == 'updated_user@test.local'
|
||||
assert updated_user.password == 'updated_password'
|
||||
|
||||
# 测试删除 Delete
|
||||
await updated_user.delete(session)
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_object_create_folder(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -53,7 +53,7 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -98,7 +98,7 @@ async def test_object_is_file_property(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -125,7 +125,7 @@ async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -151,7 +151,7 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="rootuser", password="password", group_id=group.id)
|
||||
user = User(email="rootuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -183,7 +183,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = User(email="pathuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -214,7 +214,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="nesteduser", password="password", group_id=group.id)
|
||||
user = User(email="nesteduser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -277,7 +277,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="notfounduser", password="password", group_id=group.id)
|
||||
user = User(email="notfounduser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -311,7 +311,7 @@ async def test_object_get_children(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="childrenuser", password="password", group_id=group.id)
|
||||
user = User(email="childrenuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -363,7 +363,7 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="reluser", password="password", group_id=group.id)
|
||||
user = User(email="reluser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -408,7 +408,7 @@ async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="uniqueuser", password="password", group_id=group.id)
|
||||
user = User(email="uniqueuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -456,7 +456,7 @@ async def test_object_get_full_path(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = User(email="pathuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
|
||||
@@ -20,7 +20,6 @@ async def test_user_create(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password="hashed_password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -43,7 +42,6 @@ async def test_user_unique_email(db_session: AsyncSession):
|
||||
# 创建第一个用户
|
||||
user1 = User(
|
||||
email="duplicate@test.local",
|
||||
password="password1",
|
||||
group_id=group.id
|
||||
)
|
||||
await user1.save(db_session)
|
||||
@@ -51,7 +49,6 @@ async def test_user_unique_email(db_session: AsyncSession):
|
||||
# 尝试创建同名用户
|
||||
user2 = User(
|
||||
email="duplicate@test.local",
|
||||
password="password2",
|
||||
group_id=group.id
|
||||
)
|
||||
|
||||
@@ -70,7 +67,6 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="publicuser@test.local",
|
||||
nickname="公开用户",
|
||||
password="secret_password",
|
||||
storage=1024,
|
||||
avatar="avatar.jpg",
|
||||
group_id=group.id
|
||||
@@ -88,8 +84,6 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
|
||||
assert public_user.nick is None # 实际行为
|
||||
assert public_user.storage == 1024
|
||||
# 密码不应该在公开数据中
|
||||
assert not hasattr(public_user, 'password')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -102,7 +96,6 @@ async def test_user_group_relationship(db_session: AsyncSession):
|
||||
# 创建用户
|
||||
user = User(
|
||||
email="vipuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -126,7 +119,6 @@ async def test_user_status_default(db_session: AsyncSession):
|
||||
|
||||
user = User(
|
||||
email="defaultuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -142,7 +134,6 @@ async def test_user_storage_default(db_session: AsyncSession):
|
||||
|
||||
user = User(
|
||||
email="storageuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -159,7 +150,6 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
# 测试默认值
|
||||
user1 = User(
|
||||
email="user1@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user1 = await user1.save(db_session)
|
||||
@@ -168,7 +158,6 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
# 测试设置为 LIGHT
|
||||
user2 = User(
|
||||
email="user2@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.LIGHT,
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -178,9 +167,40 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
# 测试设置为 DARK
|
||||
user3 = User(
|
||||
email="user3@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.DARK,
|
||||
group_id=group.id
|
||||
)
|
||||
user3 = await user3.save(db_session)
|
||||
assert user3.theme == ThemeType.DARK
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_email_optional(db_session: AsyncSession):
|
||||
"""测试 email 可以为空(支持社交登录用户)"""
|
||||
group = Group(name="默认组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
nickname="社交用户",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.email is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_phone_field(db_session: AsyncSession):
|
||||
"""测试 phone 字段"""
|
||||
group = Group(name="默认组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
email="phoneuser@test.local",
|
||||
phone="13800138000",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.phone == "13800138000"
|
||||
|
||||
@@ -1,78 +1,154 @@
|
||||
"""
|
||||
Login 服务的单元测试
|
||||
|
||||
测试 unified_login() 各 provider 路径。
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
from service.user.login import login
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
from sqlmodels.user import User, UnifiedLoginRequest, TokenResponse, UserStatus
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
from service.user.login import unified_login
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_user(db_session: AsyncSession):
|
||||
"""创建测试用户"""
|
||||
async def setup_auth_settings(db_session: AsyncSession):
|
||||
"""创建认证相关的 Setting 配置"""
|
||||
settings = [
|
||||
Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"),
|
||||
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
|
||||
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
|
||||
]
|
||||
for s in settings:
|
||||
await s.save(db_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_user(db_session: AsyncSession, setup_auth_settings):
|
||||
"""创建测试用户和邮箱密码认证身份"""
|
||||
# 创建用户组
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 创建用户组选项
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
)
|
||||
await group_options.save(db_session)
|
||||
|
||||
# 创建正常用户
|
||||
plain_password = "secure_password_123"
|
||||
user = User(
|
||||
email="loginuser@test.local",
|
||||
password=Password.hash(plain_password),
|
||||
status=UserStatus.ACTIVE,
|
||||
group_id=group.id
|
||||
group_id=group.id,
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
credential=Password.hash(plain_password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"password": plain_password,
|
||||
"group_id": group.id
|
||||
"group_id": group.id,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_banned_user(db_session: AsyncSession):
|
||||
async def setup_banned_user(db_session: AsyncSession, setup_auth_settings):
|
||||
"""创建被封禁的用户"""
|
||||
group = Group(name="测试组2")
|
||||
group = await group.save(db_session)
|
||||
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
)
|
||||
await group_options.save(db_session)
|
||||
|
||||
user = User(
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=UserStatus.ADMIN_BANNED, # 封禁状态
|
||||
group_id=group.id
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
group_id=group.id,
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="banneduser@test.local",
|
||||
credential=Password.hash("password"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_2fa_user(db_session: AsyncSession):
|
||||
async def setup_2fa_user(db_session: AsyncSession, setup_auth_settings):
|
||||
"""创建启用了两步验证的用户"""
|
||||
import pyotp
|
||||
|
||||
group = Group(name="测试组3")
|
||||
group = await group.save(db_session)
|
||||
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
)
|
||||
await group_options.save(db_session)
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
user = User(
|
||||
email="2fauser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=UserStatus.ACTIVE,
|
||||
two_factor=secret,
|
||||
group_id=group.id
|
||||
group_id=group.id,
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
# 创建带 2FA secret 的邮箱密码认证身份
|
||||
import orjson
|
||||
extra_data = orjson.dumps({"two_factor": secret}).decode('utf-8')
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=Password.hash("password"),
|
||||
extra_data=extra_data,
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"secret": secret,
|
||||
"password": "password"
|
||||
"password": "password",
|
||||
}
|
||||
|
||||
|
||||
@@ -81,12 +157,13 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
"""测试正常登录"""
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
credential=user_data["password"],
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
result = await unified_login(db_session, request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
assert result.access_token is not None
|
||||
@@ -96,42 +173,48 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_not_found(db_session: AsyncSession):
|
||||
async def test_login_user_not_found(db_session: AsyncSession, setup_user):
|
||||
"""测试用户不存在"""
|
||||
login_request = LoginRequest(
|
||||
email="nonexistent@test.local",
|
||||
password="any_password"
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="nonexistent@test.local",
|
||||
credential="any_password",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result is None
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
"""测试密码错误"""
|
||||
login_request = LoginRequest(
|
||||
email="loginuser@test.local",
|
||||
password="wrong_password"
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
credential="wrong_password",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result is None
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
|
||||
"""测试用户被封禁"""
|
||||
login_request = LoginRequest(
|
||||
email="banneduser@test.local",
|
||||
password="password"
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="banneduser@test.local",
|
||||
credential="password",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result is False
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -139,15 +222,17 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
|
||||
"""测试需要 2FA"""
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"]
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=user_data["password"],
|
||||
# 未提供 two_fa_code
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result == "2fa_required"
|
||||
assert exc_info.value.status_code == 428
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -155,15 +240,17 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
|
||||
"""测试 2FA 错误"""
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code="000000" # 错误的验证码
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=user_data["password"],
|
||||
two_fa_code="000000",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result == "2fa_invalid"
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -178,56 +265,44 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
|
||||
totp = pyotp.TOTP(secret)
|
||||
valid_code = totp.now()
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code=valid_code
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=user_data["password"],
|
||||
two_fa_code=valid_code,
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
result = await unified_login(db_session, request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
assert result.access_token is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
"""测试返回的令牌可以被解码"""
|
||||
import jwt as pyjwt
|
||||
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
async def test_login_provider_disabled(db_session: AsyncSession, setup_user):
|
||||
"""测试未启用的 provider"""
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.PHONE_SMS,
|
||||
identifier="13800138000",
|
||||
credential="123456",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
|
||||
# 注意: 实际项目中需要使用正确的 SECRET_KEY
|
||||
# 这里假设测试环境已经设置了 SECRET_KEY
|
||||
# decoded = pyjwt.decode(
|
||||
# result.access_token,
|
||||
# SECRET_KEY,
|
||||
# algorithms=["HS256"]
|
||||
# )
|
||||
# assert decoded["sub"] == "loginuser"
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user):
|
||||
"""测试邮箱大小写敏感"""
|
||||
user_data = setup_user
|
||||
|
||||
# 使用大写邮箱登录
|
||||
login_request = LoginRequest(
|
||||
email="LOGINUSER@TEST.LOCAL",
|
||||
password=user_data["password"]
|
||||
async def test_login_missing_password(db_session: AsyncSession, setup_user):
|
||||
"""测试邮箱密码登录缺少密码"""
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
# 未提供 credential
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
# 应该失败,因为邮箱大小写不匹配
|
||||
assert result is None
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
Reference in New Issue
Block a user