feat: add multi-provider auth via AuthIdentity and extend site config

- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link)
- Remove password field from User model, credentials now stored in AuthIdentity
- Refactor unified login/register to use AuthIdentity-based provider checking
- Add site config fields: footer_code, tos_url, privacy_url, auth_methods
- Add auth settings defaults in migration (email_password enabled by default)
- Update admin user creation to create AuthIdentity records
- Update all tests to use AuthIdentity model

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 22:49:12 +08:00
parent d831c9c0d6
commit 729773cae3
20 changed files with 1447 additions and 412 deletions

View File

@@ -12,6 +12,7 @@ from sqlmodels import (
Group, Object, ObjectType, Setting, SettingsType, Group, Object, ObjectType, Setting, SettingsType,
BatchDeleteRequest, BatchDeleteRequest,
) )
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import ( from sqlmodels.user import (
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus, UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
) )
@@ -83,13 +84,26 @@ async def router_admin_create_user(
""" """
创建一个新的用户,设置邮箱、密码、用户组等信息。 创建一个新的用户,设置邮箱、密码、用户组等信息。
管理员创建用户时,若提供了 email + password
会同时创建 AuthIdentity(provider=email_password)。
:param session: 数据库会话 :param session: 数据库会话
:param request: 创建用户请求 DTO :param request: 创建用户请求 DTO
:return: 创建结果 :return: 创建结果
""" """
existing_user = await User.get(session, User.email == request.email) # 如果提供了邮箱检查唯一性User 表和 AuthIdentity 表)
if existing_user: if request.email:
raise HTTPException(status_code=409, detail="该邮箱已被注册") existing_user = await User.get(session, User.email == request.email)
if existing_user:
raise HTTPException(status_code=409, detail="该邮箱已被注册")
existing_identity = await AuthIdentity.get(
session,
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
& (AuthIdentity.identifier == request.email),
)
if existing_identity:
raise HTTPException(status_code=409, detail="该邮箱已被绑定")
# 验证用户组存在 # 验证用户组存在
group = await Group.get(session, Group.id == request.group_id) group = await Group.get(session, Group.id == request.group_id)
@@ -98,12 +112,24 @@ async def router_admin_create_user(
user = User( user = User(
email=request.email, email=request.email,
password=Password.hash(request.password),
nickname=request.nickname, nickname=request.nickname,
group_id=request.group_id, group_id=request.group_id,
status=request.status, status=request.status,
) )
user = await user.save(session) user = await user.save(session)
# 如果提供了邮箱和密码,创建邮箱密码认证身份
if request.email and request.password:
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier=request.email,
credential=Password.hash(request.password),
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(session)
return user.to_public() return user.to_public()
@@ -148,17 +174,7 @@ async def router_admin_update_user(
if not group: if not group:
raise HTTPException(status_code=400, detail="目标用户组不存在") raise HTTPException(status_code=400, detail="目标用户组不存在")
# 如果更新密码,需要加密
update_data = request.model_dump(exclude_unset=True) update_data = request.model_dump(exclude_unset=True)
if 'password' in update_data and update_data['password']:
update_data['password'] = Password.hash(update_data['password'])
elif 'password' in update_data:
del update_data['password'] # 空密码不更新
# 验证两步验证密钥格式(如果提供了值且不为 None长度必须为 32
if 'two_factor' in update_data and update_data['two_factor'] is not None:
if len(update_data['two_factor']) != 32:
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
# 记录旧 status 以便检测变更 # 记录旧 status 以便检测变更
old_status = user.status old_status = user.status
@@ -175,7 +191,7 @@ async def router_admin_update_user(
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE: elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
await UserBanStore.unban(str(user_id)) await UserBanStore.unban(str(user_id))
l.info(f"管理员更新了用户: {request.email}") l.info(f"管理员更新了用户: {user.email}")
@admin_user_router.delete( @admin_user_router.delete(

View File

@@ -4,7 +4,9 @@ from middleware.dependencies import SessionDep
from sqlmodels import ( from sqlmodels import (
ResponseBase, Setting, SettingsType, SiteConfigResponse, ResponseBase, Setting, SettingsType, SiteConfigResponse,
ThemePreset, ThemePresetResponse, ThemePresetListResponse, ThemePreset, ThemePresetResponse, ThemePresetListResponse,
AuthMethodConfig,
) )
from sqlmodels.auth_identity import AuthProviderType
from sqlmodels.setting import CaptchaType from sqlmodels.setting import CaptchaType
from utils import http_exceptions from utils import http_exceptions
@@ -70,7 +72,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
获取站点全局配置 获取站点全局配置
无需认证。前端在初始化时调用此端点获取验证码类型、 无需认证。前端在初始化时调用此端点获取验证码类型、
登录/注册/找回密码是否需要验证码等配置。 登录/注册/找回密码是否需要验证码、可用的认证方式等配置。
""" """
# 批量查询所需设置 # 批量查询所需设置
settings: list[Setting] = await Setting.get( settings: list[Setting] = await Setting.get(
@@ -78,7 +80,9 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
(Setting.type == SettingsType.BASIC) | (Setting.type == SettingsType.BASIC) |
(Setting.type == SettingsType.LOGIN) | (Setting.type == SettingsType.LOGIN) |
(Setting.type == SettingsType.REGISTER) | (Setting.type == SettingsType.REGISTER) |
(Setting.type == SettingsType.CAPTCHA), (Setting.type == SettingsType.CAPTCHA) |
(Setting.type == SettingsType.AUTH) |
(Setting.type == SettingsType.OAUTH),
fetch_mode="all", fetch_mode="all",
) )
@@ -94,6 +98,16 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE: elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
captcha_key = s.get("captcha_CloudflareKey") or None captcha_key = s.get("captcha_CloudflareKey") or None
# 构建认证方式列表
auth_methods: list[AuthMethodConfig] = [
AuthMethodConfig(provider=AuthProviderType.EMAIL_PASSWORD, is_enabled=s.get("auth_email_password_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.PHONE_SMS, is_enabled=s.get("auth_phone_sms_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.GITHUB, is_enabled=s.get("github_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.QQ, is_enabled=s.get("qq_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.PASSKEY, is_enabled=s.get("auth_passkey_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.MAGIC_LINK, is_enabled=s.get("auth_magic_link_enabled") == "1"),
]
return SiteConfigResponse( return SiteConfigResponse(
title=s.get("siteName") or "DiskNext", title=s.get("siteName") or "DiskNext",
register_enabled=s.get("register_enabled") == "1", register_enabled=s.get("register_enabled") == "1",
@@ -102,4 +116,11 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
forget_captcha=s.get("forget_captcha") == "1", forget_captcha=s.get("forget_captcha") == "1",
captcha_type=captcha_type, captcha_type=captcha_type,
captcha_key=captcha_key, captcha_key=captcha_key,
auth_methods=auth_methods,
password_required=s.get("auth_password_required") == "1",
phone_binding_required=s.get("auth_phone_binding_required") == "1",
email_binding_required=s.get("auth_email_binding_required") == "1",
footer_code=s.get("footer_code"),
tos_url=s.get("tos_url"),
privacy_url=s.get("privacy_url"),
) )

View File

@@ -2,7 +2,8 @@ from typing import Annotated, Literal
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import jwt import jwt
from fastapi import APIRouter, Depends, Form, HTTPException from fastapi import APIRouter, Depends, HTTPException
from itsdangerous import URLSafeTimedSerializer
from loguru import logger from loguru import logger
from webauthn import generate_registration_options from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict from webauthn.helpers import options_to_json_dict
@@ -12,6 +13,7 @@ import sqlmodels
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep, require_captcha from middleware.dependencies import SessionDep, require_captcha
from service.captcha import CaptchaScene from service.captcha import CaptchaScene
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import UserStatus from sqlmodels.user import UserStatus
from utils import JWT, Password, http_exceptions from utils import JWT, Password, http_exceptions
from .settings import user_settings_router from .settings import user_settings_router
@@ -23,59 +25,36 @@ user_router = APIRouter(
user_router.include_router(user_settings_router) user_router.include_router(user_settings_router)
class OAuth2PasswordWithExtrasForm:
"""
扩展 OAuth2 密码表单。
在标准 username/password 基础上添加 otp_code 字段。
captcha_code 由 require_captcha 依赖注入单独处理。
"""
def __init__(
self,
*,
username: Annotated[str, Form()],
password: Annotated[str, Form()],
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
):
self.username = username
self.password = password
self.otp_code = otp_code
@user_router.post( @user_router.post(
path='/session', path='/session',
summary='用户登录', summary='用户登录(统一入口)',
description='用户登录端点,支持验证码校验和两步验证', description='统一登录端点,支持多种认证方式',
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
) )
async def router_user_session( async def router_user_session(
session: SessionDep, session: SessionDep,
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()], request: sqlmodels.UnifiedLoginRequest,
) -> sqlmodels.TokenResponse: ) -> sqlmodels.TokenResponse:
""" """
用户登录端点 统一登录端点
表单字段 请求体
- username: 用户邮箱 - provider: 登录方式email_password / github / qq / passkey / magic_link
- password: 用户密码 - identifier: 标识符(邮箱 / OAuth code / credential_id / magic link token
- captcha_code: 验证码 token可选由 require_captcha 依赖校验 - credential: 凭证(密码 / WebAuthn assertion 等
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要 - two_fa_code: 两步验证码(可选)
- redirect_uri: OAuth 回调地址(可选)
- captcha: 验证码(可选)
错误处理: 错误处理:
- 400: 需要验证码但未提供 - 400: 登录方式未启用 / 参数错误
- 401: 邮箱/密码错误,或 2FA 验证码错误 - 401: 凭证错误
- 403: 账户已禁用 / 验证码验证失败 - 403: 账户已禁用
- 428: 需要两步验证但未提供 otp_code - 428: 需要两步验证
- 501: 暂未实现的登录方式
""" """
return await service.user.login( return await service.user.unified_login(session, request)
session,
sqlmodels.LoginRequest(
email=form_data.username,
password=form_data.password,
two_fa_code=form_data.otp_code,
),
)
@user_router.post( @user_router.post(
path='/session/refresh', path='/session/refresh',
@@ -150,41 +129,82 @@ async def router_user_session_refresh(
@user_router.post( @user_router.post(
path='/', path='/',
summary='用户注册', summary='用户注册(统一入口)',
description='User registration endpoint.', description='User registration endpoint.',
status_code=204, status_code=204,
) )
async def router_user_register( async def router_user_register(
session: SessionDep, session: SessionDep,
request: sqlmodels.RegisterRequest, request: sqlmodels.UnifiedRegisterRequest,
) -> None: ) -> None:
""" """
用户注册端点 统一注册端点
流程: 流程:
1. 验证用户名唯一性 1. 检查注册开关
2. 获取默认用户组 2. 检查 provider 启用
3. 创建用户记录 3. 验证 identifier 唯一性AuthIdentity 表)
4. 创建用户根目录name="/" 4. 创建 User + AuthIdentity + 根目录
:param session: 数据库会话 请求体:
:param request: 注册请求 - provider: 注册方式email_password / phone_sms
:return: 注册结果 - identifier: 标识符(邮箱 / 手机号)
:raises HTTPException 400: 用户名已存在 - credential: 凭证(密码 / 短信验证码)
:raises HTTPException 500: 默认用户组或存储策略不存在 - nickname: 昵称(可选)
- captcha: 验证码(可选)
错误处理:
- 400: 注册未开放 / 参数错误
- 409: 邮箱或手机号已存在
- 501: 暂未实现的注册方式
""" """
# 1. 验证邮箱唯一性 # 1. 检查注册开关
register_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
& (sqlmodels.Setting.name == "register_enabled"),
)
if not register_setting or register_setting.value != "1":
http_exceptions.raise_bad_request("注册功能未开放")
# 2. 目前只支持 email_password 注册
if request.provider == AuthProviderType.PHONE_SMS:
http_exceptions.raise_not_implemented("短信注册暂未开放")
elif request.provider != AuthProviderType.EMAIL_PASSWORD:
http_exceptions.raise_bad_request("不支持的注册方式")
# 3. 检查密码是否必填
password_required_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_password_required"),
)
is_password_required = not password_required_setting or password_required_setting.value != "0"
if is_password_required and not request.credential:
http_exceptions.raise_bad_request("密码不能为空")
# 4. 验证 identifier 唯一性AuthIdentity 表)
existing_identity = await AuthIdentity.get(
session,
(AuthIdentity.provider == request.provider)
& (AuthIdentity.identifier == request.identifier),
)
if existing_identity:
raise HTTPException(status_code=409, detail="该邮箱已被注册")
# 同时检查 User.email 唯一性(防止旧数据冲突)
existing_user = await sqlmodels.User.get( existing_user = await sqlmodels.User.get(
session, session,
sqlmodels.User.email == request.email sqlmodels.User.email == request.identifier,
) )
if existing_user: if existing_user:
raise HTTPException(status_code=400, detail="邮箱已存在") raise HTTPException(status_code=409, detail="邮箱已被注册")
# 2. 获取默认用户组(从设置中读取 UUID # 5. 获取默认用户组
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get( default_group_setting = await sqlmodels.Setting.get(
session, session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group") (sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
& (sqlmodels.Setting.name == "default_group"),
) )
if default_group_setting is None or not default_group_setting.value: if default_group_setting is None or not default_group_setting.value:
logger.error("默认用户组不存在") logger.error("默认用户组不存在")
@@ -196,17 +216,28 @@ async def router_user_register(
logger.error("默认用户组不存在") logger.error("默认用户组不存在")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
# 3. 创建用户 # 6. 创建用户
hashed_password = Password.hash(request.password)
new_user = sqlmodels.User( new_user = sqlmodels.User(
email=request.email, email=request.identifier,
password=hashed_password, nickname=request.nickname,
group_id=default_group.id, group_id=default_group.id,
) )
new_user_id = new_user.id new_user_id = new_user.id
await new_user.save(session) await new_user.save(session)
# 4. 创建用户根目录 # 7. 创建 AuthIdentity
hashed_password = Password.hash(request.credential) if request.credential else None
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier=request.identifier,
credential=hashed_password,
is_primary=True,
is_verified=False,
user_id=new_user_id,
)
await identity.save(session)
# 8. 创建用户根目录
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储") default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
if not default_policy: if not default_policy:
logger.error("默认存储策略不存在") logger.error("默认存储策略不存在")
@@ -220,6 +251,66 @@ async def router_user_register(
policy_id=default_policy.id, policy_id=default_policy.id,
).save(session) ).save(session)
@user_router.post(
path='/magic-link',
summary='发送 Magic Link 邮件',
description='生成 Magic Link token 并发送到指定邮箱。',
status_code=204,
)
async def router_user_magic_link(
session: SessionDep,
request: sqlmodels.MagicLinkRequest,
) -> None:
"""
发送 Magic Link 邮件
流程:
1. 验证邮箱对应的 AuthIdentity 存在
2. 生成签名 token
3. 发送邮件(包含带 token 的链接)
错误处理:
- 400: Magic Link 未启用
- 404: 邮箱未注册
"""
# 检查 magic_link 是否启用
magic_link_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_magic_link_enabled"),
)
if not magic_link_setting or magic_link_setting.value != "1":
http_exceptions.raise_bad_request("Magic Link 登录未启用")
# 验证邮箱存在
identity = await AuthIdentity.get(
session,
(AuthIdentity.identifier == request.email)
& (
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
),
)
if not identity:
http_exceptions.raise_not_found("该邮箱未注册")
# 生成签名 token
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
token = serializer.dumps(request.email, salt="magic-link-salt")
# 获取站点 URL
site_url_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.BASIC)
& (sqlmodels.Setting.name == "siteURL"),
)
site_url = site_url_setting.value if site_url_setting else "http://localhost"
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token}
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)")
@user_router.post( @user_router.post(
path='/code', path='/code',
summary='发送验证码邮件', summary='发送验证码邮件',
@@ -236,46 +327,6 @@ def router_user_email_code(
""" """
http_exceptions.raise_not_implemented() http_exceptions.raise_not_implemented()
@user_router.get(
path='/qq',
summary='初始化QQ登录',
description='Initialize QQ login for a user.',
)
def router_user_qq() -> sqlmodels.ResponseBase:
"""
Initialize QQ login for a user.
Returns:
dict: A dictionary containing QQ login initialization information.
"""
http_exceptions.raise_not_implemented()
@user_router.get(
path='authn/{username}',
summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.',
)
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
http_exceptions.raise_not_implemented()
@user_router.post(
path='authn/finish/{username}',
summary='WebAuthn登录',
description='Finish WebAuthn login for a user.',
)
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
"""
Finish WebAuthn login for a user.
Args:
username (str): The username of the user.
Returns:
dict: A dictionary containing WebAuthn login information.
"""
http_exceptions.raise_not_implemented()
@user_router.get( @user_router.get(
path='/profile/{id}', path='/profile/{id}',
summary='获取用户主页展示用分享', summary='获取用户主页展示用分享',
@@ -348,8 +399,6 @@ async def router_user_me(
return sqlmodels.UserResponse( return sqlmodels.UserResponse(
id=user.id, id=user.id,
email=user.email, email=user.email,
status=user.status,
score=user.score,
nickname=user.nickname, nickname=user.nickname,
avatar=user.avatar, avatar=user.avatar,
created_at=user.created_at, created_at=user.created_at,
@@ -389,8 +438,8 @@ async def router_user_storage(
@user_router.put( @user_router.put(
path='/authn/start', path='/authn/start',
summary='WebAuthn登录初始化', summary='注册 Passkey 凭证(初始化',
description='Initialize WebAuthn login for a user.', description='Initialize Passkey registration for a user.',
dependencies=[Depends(auth_required)], dependencies=[Depends(auth_required)],
) )
async def router_user_authn_start( async def router_user_authn_start(
@@ -398,18 +447,19 @@ async def router_user_authn_start(
user: Annotated[sqlmodels.user.User, Depends(auth_required)], user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.ResponseBase: ) -> sqlmodels.ResponseBase:
""" """
Initialize WebAuthn login for a user. Passkey 注册初始化(需要登录)
Returns: 返回 WebAuthn registration options前端使用 navigator.credentials.create() 处理。
dict: A dictionary containing WebAuthn initialization information.
错误处理:
- 400: Passkey 未启用
""" """
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
authn_setting = await sqlmodels.Setting.get( authn_setting = await sqlmodels.Setting.get(
session, session,
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled") (sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
) )
if not authn_setting or authn_setting.value != "1": if not authn_setting or authn_setting.value != "1":
raise HTTPException(status_code=400, detail="WebAuthn is not enabled") raise HTTPException(status_code=400, detail="Passkey 未启用")
site_url_setting = await sqlmodels.Setting.get( site_url_setting = await sqlmodels.Setting.get(
session, session,
@@ -423,23 +473,26 @@ async def router_user_authn_start(
options = generate_registration_options( options = generate_registration_options(
rp_id=site_url_setting.value if site_url_setting else "", rp_id=site_url_setting.value if site_url_setting else "",
rp_name=site_title_setting.value if site_title_setting else "", rp_name=site_title_setting.value if site_title_setting else "",
user_name=user.email, user_name=user.email or str(user.id),
user_display_name=user.nickname or user.email, user_display_name=user.nickname or user.email or str(user.id),
) )
return sqlmodels.ResponseBase(data=options_to_json_dict(options)) return sqlmodels.ResponseBase(data=options_to_json_dict(options))
@user_router.put( @user_router.put(
path='/authn/finish', path='/authn/finish',
summary='WebAuthn登录', summary='注册 Passkey 凭证(完成)',
description='Finish WebAuthn login for a user.', description='Finish Passkey registration for a user.',
dependencies=[Depends(auth_required)], dependencies=[Depends(auth_required)],
) )
def router_user_authn_finish() -> sqlmodels.ResponseBase: def router_user_authn_finish() -> sqlmodels.ResponseBase:
""" """
Finish WebAuthn login for a user. Passkey 注册完成(需要登录)
接收前端 navigator.credentials.create() 返回的凭证数据,
创建 UserAuthn 行 + AuthIdentity(provider=passkey)。
Returns: Returns:
dict: A dictionary containing WebAuthn login information. dict: A dictionary containing Passkey registration information.
""" """
http_exceptions.raise_not_implemented() http_exceptions.raise_not_implemented()

View File

@@ -1,4 +1,5 @@
from typing import Annotated from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
@@ -9,6 +10,7 @@ from middleware.dependencies import SessionDep
from sqlmodels import ( from sqlmodels import (
BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest, BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest,
SettingOption, UserSettingUpdateRequest, SettingOption, UserSettingUpdateRequest,
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
) )
from sqlmodels.color import ThemeColorsBase from sqlmodels.color import ThemeColorsBase
from utils import JWT, Password, http_exceptions from utils import JWT, Password, http_exceptions
@@ -117,16 +119,29 @@ async def router_user_settings(
else: else:
theme_colors = BUILTIN_DEFAULT_COLORS theme_colors = BUILTIN_DEFAULT_COLORS
# 检查是否启用了两步验证(从 email_password AuthIdentity 的 extra_data 中读取)
has_two_factor = False
email_identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.user_id == user.id)
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
)
if email_identity and email_identity.extra_data:
import orjson
extra: dict = orjson.loads(email_identity.extra_data)
has_two_factor = bool(extra.get("two_factor"))
return sqlmodels.UserSettingResponse( return sqlmodels.UserSettingResponse(
id=user.id, id=user.id,
email=user.email, email=user.email,
phone=user.phone,
nickname=user.nickname, nickname=user.nickname,
created_at=user.created_at, created_at=user.created_at,
group_name=user.group.name, group_name=user.group.name,
language=user.language, language=user.language,
timezone=user.timezone, timezone=user.timezone,
group_expires=user.group_expires, group_expires=user.group_expires,
two_factor=user.two_factor is not None, two_factor=has_two_factor,
theme_preset_id=user.theme_preset_id, theme_preset_id=user.theme_preset_id,
theme_colors=theme_colors, theme_colors=theme_colors,
) )
@@ -255,7 +270,7 @@ async def router_user_settings_2fa(
返回 setup_token用于后续验证请求和 uri用于生成二维码 返回 setup_token用于后续验证请求和 uri用于生成二维码
""" """
return await Password.generate_totp(name=user.email) return await Password.generate_totp(name=user.email or str(user.id))
@user_settings_router.post( @user_settings_router.post(
@@ -273,7 +288,7 @@ async def router_user_settings_2fa_enable(
""" """
启用两步验证 启用两步验证
请求体包含 setup_tokenGET /2fa 返回的令牌)和 code6 位验证码) 将 2FA secret 存储到 email_password AuthIdentity 的 extra_data 中
""" """
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY) serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
@@ -287,6 +302,150 @@ async def router_user_settings_2fa_enable(
if Password.verify_totp(secret, request.code) != PasswordStatus.VALID: if Password.verify_totp(secret, request.code) != PasswordStatus.VALID:
raise HTTPException(status_code=400, detail="Invalid OTP code") raise HTTPException(status_code=400, detail="Invalid OTP code")
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA # 将 secret 存储到 AuthIdentity.extra_data 中
user.two_factor = secret email_identity: AuthIdentity | None = await AuthIdentity.get(
user = await user.save(session) session,
(AuthIdentity.user_id == user.id)
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
)
if not email_identity:
raise HTTPException(status_code=400, detail="未找到邮箱密码认证身份")
import orjson
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
extra["two_factor"] = secret
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
await email_identity.save(session)
# ==================== 认证身份管理 ====================
@user_settings_router.get(
path='/identities',
summary='列出已绑定的认证身份',
)
async def router_user_settings_identities(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[AuthIdentityResponse]:
"""
列出当前用户已绑定的所有认证身份
返回:
- 认证身份列表,包含 provider、identifier、display_name 等
"""
identities: list[AuthIdentity] = await AuthIdentity.get(
session,
AuthIdentity.user_id == user.id,
fetch_mode="all",
)
return [identity.to_response() for identity in identities]
@user_settings_router.post(
path='/identity',
summary='绑定新的认证身份',
status_code=status.HTTP_201_CREATED,
)
async def router_user_settings_bind_identity(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
request: BindIdentityRequest,
) -> AuthIdentityResponse:
"""
绑定新的登录方式
请求体:
- provider: 提供者类型
- identifier: 标识符(邮箱 / 手机号 / OAuth code
- credential: 凭证(密码、验证码等)
- redirect_uri: OAuth 回调地址(可选)
错误处理:
- 400: provider 未启用
- 409: 该身份已被其他用户绑定
"""
# 检查是否已被绑定
existing = await AuthIdentity.get(
session,
(AuthIdentity.provider == request.provider)
& (AuthIdentity.identifier == request.identifier),
)
if existing:
raise HTTPException(status_code=409, detail="该身份已被绑定")
# 处理密码类型的凭证
credential: str | None = None
if request.provider == AuthProviderType.EMAIL_PASSWORD and request.credential:
credential = Password.hash(request.credential)
identity = AuthIdentity(
provider=request.provider,
identifier=request.identifier,
credential=credential,
is_primary=False,
is_verified=False,
user_id=user.id,
)
identity = await identity.save(session)
return identity.to_response()
@user_settings_router.delete(
path='/identity/{identity_id}',
summary='解绑认证身份',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_user_settings_unbind_identity(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
identity_id: UUID,
) -> None:
"""
解绑一个认证身份
约束:
- 不能解绑最后一个身份
- 站长配置强制绑定邮箱/手机号时,不能解绑对应身份
错误处理:
- 404: 身份不存在或不属于当前用户
- 400: 不能解绑最后一个身份 / 不能解绑强制绑定的身份
"""
# 查找目标身份
identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.id == identity_id) & (AuthIdentity.user_id == user.id),
)
if not identity:
http_exceptions.raise_not_found("认证身份不存在")
# 检查是否为最后一个身份
all_identities: list[AuthIdentity] = await AuthIdentity.get(
session,
AuthIdentity.user_id == user.id,
fetch_mode="all",
)
if len(all_identities) <= 1:
http_exceptions.raise_bad_request("不能解绑最后一个认证身份")
# 检查强制绑定约束
if identity.provider == AuthProviderType.EMAIL_PASSWORD:
email_required_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_email_binding_required"),
)
if email_required_setting and email_required_setting.value == "1":
http_exceptions.raise_bad_request("站长要求必须绑定邮箱,不能解绑")
if identity.provider == AuthProviderType.PHONE_SMS:
phone_required_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_phone_binding_required"),
)
if phone_required_setting and phone_required_setting.value == "1":
http_exceptions.raise_bad_request("站长要求必须绑定手机号,不能解绑")
await AuthIdentity.delete(session, identity)

View File

@@ -1 +1 @@
from .login import login from .login import unified_login

View File

@@ -1,83 +1,417 @@
from uuid import uuid4 """
统一登录服务
from loguru import logger 支持多种认证方式邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信预留
"""
from uuid import UUID, uuid4
from middleware.dependencies import SessionDep from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from sqlmodels import LoginRequest, TokenResponse, User from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.group import GroupClaims, GroupOptions from sqlmodels.group import GroupClaims, GroupOptions
from sqlmodels.user import UserStatus from sqlmodels.object import Object, ObjectType
from utils import http_exceptions from sqlmodels.policy import Policy
from utils.JWT import create_access_token, create_refresh_token from sqlmodels.setting import Setting, SettingsType
from sqlmodels.user import TokenResponse, UnifiedLoginRequest, User, UserStatus
from utils import JWT, http_exceptions
from utils.password.pwd import Password, PasswordStatus from utils.password.pwd import Password, PasswordStatus
async def login( async def unified_login(
session: SessionDep, session: AsyncSession,
login_request: LoginRequest, request: UnifiedLoginRequest,
) -> TokenResponse: ) -> TokenResponse:
""" """
根据账号密码进行登录 统一登录入口,根据 provider 分发到不同的登录逻辑
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
:param session: 数据库会话 :param session: 数据库会话
:param login_request: 登录请求 :param request: 统一登录请求
:return: TokenResponse
:return: TokenResponse 对象或状态码或 None
""" """
# 获取用户信息(预加载 group 关系) await _check_provider_enabled(session, request.provider)
current_user: User = await User.get(
match request.provider:
case AuthProviderType.EMAIL_PASSWORD:
user = await _login_email_password(session, request)
case AuthProviderType.GITHUB:
user = await _login_oauth(session, request, AuthProviderType.GITHUB)
case AuthProviderType.QQ:
user = await _login_oauth(session, request, AuthProviderType.QQ)
case AuthProviderType.PASSKEY:
user = await _login_passkey(session, request)
case AuthProviderType.MAGIC_LINK:
user = await _login_magic_link(session, request)
case AuthProviderType.PHONE_SMS:
http_exceptions.raise_not_implemented("短信登录暂未开放")
case _:
http_exceptions.raise_bad_request(f"不支持的登录方式: {request.provider}")
return await _issue_tokens(session, user)
async def _check_provider_enabled(session: AsyncSession, provider: AuthProviderType) -> None:
"""检查认证方式是否已被站长启用"""
# OAuth 类型从 OAUTH 设置中查询
if provider in (AuthProviderType.GITHUB, AuthProviderType.QQ):
setting_name = f"{provider.value}_enabled"
setting = await Setting.get(
session,
(Setting.type == SettingsType.OAUTH) & (Setting.name == setting_name),
)
if not setting or setting.value != "1":
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
return
# 其他类型从 AUTH 设置中查询
setting_name = f"auth_{provider.value}_enabled"
setting = await Setting.get(
session, session,
User.email == login_request.email, (Setting.type == SettingsType.AUTH) & (Setting.name == setting_name),
fetch_mode="first", )
load=User.group, if not setting or setting.value != "1":
) #type: ignore http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
# 验证用户是否存在
if not current_user:
logger.debug(f"Cannot find user with email: {login_request.email}")
http_exceptions.raise_unauthorized("Invalid email or password")
# 验证密码是否正确 async def _login_email_password(
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID: session: AsyncSession,
logger.debug(f"Password verification failed for user: {login_request.email}") request: UnifiedLoginRequest,
http_exceptions.raise_unauthorized("Invalid email or password") ) -> User:
"""邮箱+密码登录"""
if not request.credential:
http_exceptions.raise_bad_request("密码不能为空")
# 验证用户是否可登录修复显式枚举比较StrEnum 永远 truthy # 查找 AuthIdentity
if current_user.status != UserStatus.ACTIVE: identity: AuthIdentity | None = await AuthIdentity.get(
http_exceptions.raise_forbidden("Your account is disabled") session,
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
& (AuthIdentity.identifier == request.identifier),
)
if not identity:
l.debug(f"未找到邮箱密码身份: {request.identifier}")
http_exceptions.raise_unauthorized("邮箱或密码错误")
# 检查两步验证 # 验证密码
if current_user.two_factor: if not identity.credential:
# 用户已启用两步验证 http_exceptions.raise_unauthorized("邮箱或密码错误")
if not login_request.two_fa_code:
logger.debug(f"2FA required for user: {login_request.email}")
http_exceptions.raise_precondition_required("2FA required")
# 验证 OTP 码 if Password.verify(identity.credential, request.credential) != PasswordStatus.VALID:
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID: l.debug(f"密码验证失败: {request.identifier}")
logger.debug(f"Invalid 2FA code for user: {login_request.email}") http_exceptions.raise_unauthorized("邮箱或密码错误")
http_exceptions.raise_unauthorized("Invalid 2FA code")
# 加载用户
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
# 验证用户状态
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
# 检查两步验证(从 AuthIdentity.extra_data 中读取 2FA secret
if identity.extra_data:
import orjson
extra: dict = orjson.loads(identity.extra_data)
two_factor_secret: str | None = extra.get("two_factor")
if two_factor_secret:
if not request.two_fa_code:
l.debug(f"需要两步验证: {request.identifier}")
http_exceptions.raise_precondition_required("需要两步验证")
if Password.verify_totp(two_factor_secret, request.two_fa_code) != PasswordStatus.VALID:
l.debug(f"两步验证失败: {request.identifier}")
http_exceptions.raise_unauthorized("两步验证码错误")
return user
async def _login_oauth(
session: AsyncSession,
request: UnifiedLoginRequest,
provider: AuthProviderType,
) -> User:
"""
OAuth 登录GitHub / QQ
identifier 为 OAuth authorization code后端换取 access_token 再获取用户信息。
"""
# 读取 OAuth 配置
client_id_setting = await Setting.get(
session,
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_id"),
)
client_secret_setting = await Setting.get(
session,
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_secret"),
)
if not client_id_setting or not client_secret_setting:
http_exceptions.raise_bad_request(f"{provider.value} OAuth 未配置")
client_id = client_id_setting.value or ""
client_secret = client_secret_setting.value or ""
# 根据 provider 创建对应的 OAuth 客户端
if provider == AuthProviderType.GITHUB:
from service.oauth import GithubOAuth
oauth_client = GithubOAuth(client_id, client_secret)
token_resp = await oauth_client.get_access_token(code=request.identifier)
user_info_resp = await oauth_client.get_user_info(token_resp)
openid = str(user_info_resp.user_data.id)
nickname = user_info_resp.user_data.name or user_info_resp.user_data.login
avatar_url = user_info_resp.user_data.avatar_url
email = user_info_resp.user_data.email
elif provider == AuthProviderType.QQ:
from service.oauth import QQOAuth
oauth_client = QQOAuth(client_id, client_secret)
token_resp = await oauth_client.get_access_token(
code=request.identifier,
redirect_uri=request.redirect_uri or "",
)
openid_resp = await oauth_client.get_openid(token_resp.access_token)
user_info_resp = await oauth_client.get_user_info(
token_resp,
app_id=client_id,
openid=openid_resp.openid,
)
openid = openid_resp.openid
nickname = user_info_resp.user_data.nickname
avatar_url = user_info_resp.user_data.figureurl_qq_2 or user_info_resp.user_data.figureurl_2
email = None
else:
http_exceptions.raise_bad_request(f"不支持的 OAuth 提供者: {provider.value}")
# 查找已有 AuthIdentity
identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.provider == provider) & (AuthIdentity.identifier == openid),
)
if identity:
# 已绑定 → 更新 OAuth 信息并返回关联用户
identity.display_name = nickname
identity.avatar_url = avatar_url
await identity.save(session)
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
return user
# 未绑定 → 自动注册
user = await _auto_register_oauth_user(
session,
provider=provider,
openid=openid,
nickname=nickname,
avatar_url=avatar_url,
email=email,
)
return user
async def _auto_register_oauth_user(
session: AsyncSession,
*,
provider: AuthProviderType,
openid: str,
nickname: str | None,
avatar_url: str | None,
email: str | None,
) -> User:
"""OAuth 自动注册用户"""
# 获取默认用户组
default_group_setting = await Setting.get(
session,
(Setting.type == SettingsType.REGISTER) & (Setting.name == "default_group"),
)
if not default_group_setting or not default_group_setting.value:
l.error("默认用户组未配置")
http_exceptions.raise_internal_error()
default_group_id = UUID(default_group_setting.value)
# 创建用户
new_user = User(
email=email,
nickname=nickname,
avatar=avatar_url or "default",
group_id=default_group_id,
)
new_user_id = new_user.id
new_user = await new_user.save(session)
# 创建 AuthIdentity
identity = AuthIdentity(
provider=provider,
identifier=openid,
display_name=nickname,
avatar_url=avatar_url,
is_primary=True,
is_verified=True,
user_id=new_user_id,
)
await identity.save(session)
# 创建用户根目录
default_policy = await Policy.get(session, Policy.name == "本地存储")
if default_policy:
await Object(
name="/",
type=ObjectType.FOLDER,
owner_id=new_user_id,
parent_id=None,
policy_id=default_policy.id,
).save(session)
# 重新加载用户(含 group 关系)
user: User = await User.get(session, User.id == new_user_id, load=User.group)
l.info(f"OAuth 自动注册用户: provider={provider.value}, openid={openid}")
return user
async def _login_passkey(
session: AsyncSession,
request: UnifiedLoginRequest,
) -> User:
"""
Passkey/WebAuthn 登录
identifier 为 credential_idcredential 为 JSON 格式的 authenticator assertion response。
"""
from webauthn import verify_authentication_response
from webauthn.helpers.structs import AuthenticationCredential
if not request.credential:
http_exceptions.raise_bad_request("WebAuthn assertion response 不能为空")
# 查找 AuthIdentity
identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.provider == AuthProviderType.PASSKEY)
& (AuthIdentity.identifier == request.identifier),
)
if not identity:
http_exceptions.raise_unauthorized("Passkey 凭证未注册")
# 加载对应的 UserAuthn 记录
from sqlmodels.user_authn import UserAuthn
authn: UserAuthn | None = await UserAuthn.get(
session,
UserAuthn.credential_id == request.identifier,
)
if not authn:
http_exceptions.raise_unauthorized("Passkey 凭证数据不存在")
# 获取 RP ID
site_url_setting = await Setting.get(
session,
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
)
rp_id = site_url_setting.value if site_url_setting else "localhost"
# 验证 WebAuthn assertion
import orjson
credential = AuthenticationCredential.model_validate(orjson.loads(request.credential))
try:
verification = verify_authentication_response(
credential=credential,
expected_rp_id=rp_id,
expected_origin=f"https://{rp_id}",
expected_challenge=b"", # TODO: 从 session/cache 中获取 challenge
credential_public_key=bytes.fromhex(authn.credential_public_key),
credential_current_sign_count=authn.sign_count,
)
except Exception as e:
l.warning(f"WebAuthn 验证失败: {e}")
http_exceptions.raise_unauthorized("Passkey 验证失败")
# 更新签名计数
authn.sign_count = verification.new_sign_count
await authn.save(session)
# 加载用户
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
return user
async def _login_magic_link(
session: AsyncSession,
request: UnifiedLoginRequest,
) -> User:
"""
Magic Link 登录
identifier 为签名 token由 itsdangerous 生成。
"""
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
try:
email = serializer.loads(request.identifier, salt="magic-link-salt", max_age=600)
except SignatureExpired:
http_exceptions.raise_unauthorized("Magic Link 已过期")
except BadSignature:
http_exceptions.raise_unauthorized("Magic Link 无效")
# 查找绑定了该邮箱的 AuthIdentityemail_password 或 magic_link
identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.identifier == email)
& (
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
),
)
if not identity:
http_exceptions.raise_unauthorized("该邮箱未注册")
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
# 标记邮箱已验证
if not identity.is_verified:
identity.is_verified = True
await identity.save(session)
return user
async def _issue_tokens(session: AsyncSession, user: User) -> TokenResponse:
"""
签发 JWT 双令牌access + refresh
提取自原 login.py 的签发逻辑,供所有 provider 共用。
"""
# 加载 GroupOptions # 加载 GroupOptions
group_options: GroupOptions | None = await GroupOptions.get( group_options: GroupOptions | None = await GroupOptions.get(
session, session,
GroupOptions.group_id == current_user.group_id, GroupOptions.group_id == user.group_id,
) )
# 构建权限快照 # 构建权限快照
current_user.group.options = group_options user.group.options = group_options
group_claims = GroupClaims.from_group(current_user.group) group_claims = GroupClaims.from_group(user.group)
# 创建令牌 # 创建令牌
access_token = create_access_token( access_token = JWT.create_access_token(
sub=current_user.id, sub=user.id,
jti=uuid4(), jti=uuid4(),
status=current_user.status.value, status=user.status.value,
group=group_claims, group=group_claims,
) )
refresh_token = create_refresh_token( refresh_token = JWT.create_refresh_token(
sub=current_user.id, sub=user.id,
jti=uuid4() jti=uuid4(),
) )
return TokenResponse( return TokenResponse(

View File

@@ -1,9 +1,16 @@
from .auth_identity import (
AuthIdentity,
AuthIdentityResponse,
AuthProviderType,
BindIdentityRequest,
)
from .user import ( from .user import (
BatchDeleteRequest, BatchDeleteRequest,
JWTPayload, JWTPayload,
LoginRequest, MagicLinkRequest,
UnifiedLoginRequest,
UnifiedRegisterRequest,
RefreshTokenRequest, RefreshTokenRequest,
RegisterRequest,
AccessTokenBase, AccessTokenBase,
RefreshTokenBase, RefreshTokenBase,
TokenResponse, TokenResponse,
@@ -89,7 +96,7 @@ from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, Policy
from .redeem import Redeem, RedeemType from .redeem import Redeem, RedeemType
from .report import Report, ReportReason from .report import Report, ReportReason
from .setting import ( from .setting import (
Setting, SettingsType, SiteConfigResponse, Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
# 管理员DTO # 管理员DTO
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse, SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
) )

139
sqlmodels/auth_identity.py Normal file
View File

@@ -0,0 +1,139 @@
"""
认证身份模块
一个用户可拥有多种登录方式邮箱密码、OAuth、Passkey、Magic Link 等)。
AuthIdentity 表存储每种认证方式的凭证信息。
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
class AuthProviderType(StrEnum):
"""认证提供者类型"""
EMAIL_PASSWORD = "email_password"
"""邮箱+密码"""
PHONE_SMS = "phone_sms"
"""手机号+短信验证码(预留)"""
GITHUB = "github"
"""GitHub OAuth"""
QQ = "qq"
"""QQ OAuth"""
PASSKEY = "passkey"
"""Passkey/WebAuthn"""
MAGIC_LINK = "magic_link"
"""邮箱 Magic Link"""
# ==================== DTO 模型 ====================
class AuthIdentityResponse(SQLModelBase):
"""认证身份响应 DTO列表展示用"""
id: UUID
"""身份UUID"""
provider: AuthProviderType
"""提供者类型"""
identifier: str
"""标识符(邮箱/手机号/OAuth openid"""
display_name: str | None = None
"""显示名称OAuth 昵称等)"""
avatar_url: str | None = None
"""头像 URL"""
is_primary: bool = False
"""是否主要身份"""
is_verified: bool = False
"""是否已验证"""
class BindIdentityRequest(SQLModelBase):
"""绑定认证身份请求 DTO"""
provider: AuthProviderType
"""提供者类型"""
identifier: str
"""标识符(邮箱/手机号/OAuth code"""
credential: str | None = None
"""凭证(密码、验证码等)"""
redirect_uri: str | None = None
"""OAuth 回调地址"""
# ==================== 数据库模型 ====================
class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
"""用户认证身份 — 一个用户可以有多种登录方式"""
__table_args__ = (
UniqueConstraint("provider", "identifier", name="uq_auth_identity_provider_identifier"),
)
provider: AuthProviderType = Field(index=True)
"""提供者类型"""
identifier: str = Field(max_length=255, index=True)
"""标识符(邮箱/手机号/OAuth openid"""
credential: str | None = Field(default=None, max_length=1024)
"""凭证Argon2 哈希密码 / null"""
display_name: str | None = Field(default=None, max_length=100)
"""OAuth 昵称"""
avatar_url: str | None = Field(default=None, max_length=512)
"""OAuth 头像 URL"""
extra_data: str | None = None
"""JSON 附加数据2FA secret、OAuth refresh_token 等)"""
is_primary: bool = False
"""是否主要身份"""
is_verified: bool = False
"""是否已验证"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE",
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="auth_identities")
def to_response(self) -> AuthIdentityResponse:
"""转换为响应 DTO"""
return AuthIdentityResponse(
id=self.id,
provider=self.provider,
identifier=self.identifier,
display_name=self.display_name,
avatar_url=self.avatar_url,
is_primary=self.is_primary,
is_verified=self.is_verified,
)

File diff suppressed because one or more lines are too long

View File

@@ -2,6 +2,7 @@ from enum import StrEnum
from sqlmodel import UniqueConstraint from sqlmodel import UniqueConstraint
from .auth_identity import AuthProviderType
from .base import SQLModelBase from .base import SQLModelBase
from .mixin import TableBaseMixin from .mixin import TableBaseMixin
from .user import UserResponse from .user import UserResponse
@@ -12,6 +13,19 @@ class CaptchaType(StrEnum):
GCAPTCHA = "gcaptcha" GCAPTCHA = "gcaptcha"
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile" CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
# ==================== Auth 配置 DTO ====================
class AuthMethodConfig(SQLModelBase):
"""认证方式配置 DTO"""
provider: AuthProviderType
"""提供者类型"""
is_enabled: bool
"""是否启用"""
# ==================== DTO 模型 ==================== # ==================== DTO 模型 ====================
class SiteConfigResponse(SQLModelBase): class SiteConfigResponse(SQLModelBase):
@@ -50,6 +64,27 @@ class SiteConfigResponse(SQLModelBase):
captcha_key: str | None = None captcha_key: str | None = None
"""验证码 public keyDEFAULT 类型时为 None""" """验证码 public keyDEFAULT 类型时为 None"""
auth_methods: list[AuthMethodConfig] = []
"""可用的登录方式列表"""
password_required: bool = True
"""注册时是否必须设置密码"""
phone_binding_required: bool = False
"""是否强制绑定手机号"""
email_binding_required: bool = True
"""是否强制绑定邮箱"""
footer_code: str | None = None
"""自定义页脚代码"""
tos_url: str | None = None
"""服务条款 URL"""
privacy_url: str | None = None
"""隐私政策 URL"""
# ==================== 管理员设置 DTO ==================== # ==================== 管理员设置 DTO ====================

View File

@@ -9,6 +9,7 @@ from sqlmodel import Field, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo from sqlmodel.main import RelationshipInfo
from .auth_identity import AuthProviderType
from .base import SQLModelBase from .base import SQLModelBase
from .color import ChromaticColor, NeutralColor, ThemeColorsBase from .color import ChromaticColor, NeutralColor, ThemeColorsBase
from .model_base import ResponseBase from .model_base import ResponseBase
@@ -17,6 +18,7 @@ from .mixin import UUIDTableBaseMixin, TableViewRequest, ListResponse
T = TypeVar("T", bound="User") T = TypeVar("T", bound="User")
if TYPE_CHECKING: if TYPE_CHECKING:
from .auth_identity import AuthIdentity
from .group import Group from .group import Group
from .download import Download from .download import Download
from .object import Object from .object import Object
@@ -69,8 +71,8 @@ class UserFilterParams(SQLModelBase):
class UserBase(SQLModelBase): class UserBase(SQLModelBase):
"""用户基础字段,供数据库模型和 DTO 共享""" """用户基础字段,供数据库模型和 DTO 共享"""
email: str email: str | None = None
"""用户邮箱""" """用户邮箱(社交登录用户可能没有邮箱)"""
status: UserStatus = UserStatus.ACTIVE status: UserStatus = UserStatus.ACTIVE
"""用户状态""" """用户状态"""
@@ -81,30 +83,42 @@ class UserBase(SQLModelBase):
# ==================== DTO 模型 ==================== # ==================== DTO 模型 ====================
class LoginRequest(SQLModelBase): class UnifiedLoginRequest(SQLModelBase):
"""登录请求 DTO""" """统一登录请求 DTO"""
email: str provider: AuthProviderType
"""用户邮箱""" """登录方式"""
password: str identifier: str
"""用户密码""" """标识符(邮箱 / OAuth code / Magic Link token"""
captcha: str | None = None credential: str | None = None
"""验证码""" """凭证密码provider=email_password 时必填)"""
two_fa_code: str | None = Field(default=None, min_length=6, max_length=6) two_fa_code: str | None = Field(default=None, min_length=6, max_length=6)
"""两步验证代码""" """两步验证代码"""
redirect_uri: str | None = None
"""OAuth 回调地址"""
class RegisterRequest(SQLModelBase): captcha: str | None = None
"""注册请求 DTO""" """验证码"""
email: str
"""用户邮箱,唯一"""
password: str class UnifiedRegisterRequest(SQLModelBase):
"""用户密码""" """统一注册请求 DTO"""
provider: AuthProviderType
"""注册方式email_password / phone_sms"""
identifier: str
"""标识符(邮箱 / 手机号)"""
credential: str | None = None
"""凭证(密码 / 短信验证码)"""
nickname: str | None = Field(default=None, max_length=50)
"""昵称"""
captcha: str | None = None captcha: str | None = None
"""验证码""" """验证码"""
@@ -190,7 +204,7 @@ class UserResponse(ResponseBase):
id: UUID id: UUID
"""用户UUID""" """用户UUID"""
email: str email: str | None = None
"""用户邮箱""" """用户邮箱"""
nickname: str | None = None nickname: str | None = None
@@ -248,9 +262,6 @@ class UserPublic(UserBase):
group_name: str | None = None group_name: str | None = None
"""用户组名称""" """用户组名称"""
two_factor: str | None = None
"""两步验证密钥32位字符串null 表示未启用)"""
created_at: datetime | None = None created_at: datetime | None = None
"""创建时间""" """创建时间"""
@@ -264,9 +275,12 @@ class UserSettingResponse(SQLModelBase):
id: UUID id: UUID
"""用户UUID""" """用户UUID"""
email: str email: str | None = None
"""用户邮箱""" """用户邮箱"""
phone: str | None = None
"""手机号"""
nickname: str | None = None nickname: str | None = None
"""昵称""" """昵称"""
@@ -341,16 +355,26 @@ class UserTwoFactorResponse(SQLModelBase):
"""两步验证密钥""" """两步验证密钥"""
class MagicLinkRequest(SQLModelBase):
"""Magic Link 请求 DTO"""
email: str
"""接收 Magic Link 的邮箱"""
captcha: str | None = None
"""验证码"""
# ==================== 管理员用户管理 DTO ==================== # ==================== 管理员用户管理 DTO ====================
class UserAdminCreateRequest(SQLModelBase): class UserAdminCreateRequest(SQLModelBase):
"""管理员创建用户请求 DTO""" """管理员创建用户请求 DTO"""
email: str = Field(max_length=50) email: str | None = Field(default=None, max_length=50)
"""用户邮箱""" """用户邮箱"""
password: str password: str | None = None
"""用户密码(明文,由服务端加密)""" """用户密码(明文,由服务端加密;为空则不创建邮箱密码身份"""
nickname: str | None = Field(default=None, max_length=50) nickname: str | None = Field(default=None, max_length=50)
"""昵称""" """昵称"""
@@ -365,14 +389,14 @@ class UserAdminCreateRequest(SQLModelBase):
class UserAdminUpdateRequest(SQLModelBase): class UserAdminUpdateRequest(SQLModelBase):
"""管理员更新用户请求 DTO""" """管理员更新用户请求 DTO"""
email: str = Field(max_length=50) email: str | None = Field(default=None, max_length=50)
"""邮箱""" """邮箱"""
nickname: str | None = Field(default=None, max_length=50) nickname: str | None = Field(default=None, max_length=50)
"""昵称""" """昵称"""
password: str | None = None phone: str | None = None
"""新密码(为空则不修改)""" """手机号"""
group_id: UUID | None = None group_id: UUID | None = None
"""用户组UUID""" """用户组UUID"""
@@ -389,9 +413,6 @@ class UserAdminUpdateRequest(SQLModelBase):
group_expires: datetime | None = None group_expires: datetime | None = None
"""用户组过期时间""" """用户组过期时间"""
two_factor: str | None = None
"""两步验证密钥32位字符串传 null 可清除,不传则不修改)"""
class UserCalibrateResponse(SQLModelBase): class UserCalibrateResponse(SQLModelBase):
"""用户存储校准响应 DTO""" """用户存储校准响应 DTO"""
@@ -415,9 +436,6 @@ class UserCalibrateResponse(SQLModelBase):
class UserAdminDetailResponse(UserPublic): class UserAdminDetailResponse(UserPublic):
"""管理员用户详情响应 DTO""" """管理员用户详情响应 DTO"""
two_factor_enabled: bool = False
"""是否启用两步验证"""
file_count: int = 0 file_count: int = 0
"""文件数量""" """文件数量"""
@@ -443,14 +461,14 @@ UserSettingResponse.model_rebuild()
class User(UserBase, UUIDTableBaseMixin): class User(UserBase, UUIDTableBaseMixin):
"""用户模型""" """用户模型"""
email: str = Field(max_length=50, unique=True, index=True) email: str | None = Field(default=None, max_length=50, unique=True, index=True)
"""用户邮箱,唯一""" """用户邮箱(社交登录用户可能没有邮箱)"""
nickname: str | None = Field(default=None, max_length=50) nickname: str | None = Field(default=None, max_length=50)
"""用于公开展示的名字,可使用真实姓名或昵称""" """用于公开展示的名字,可使用真实姓名或昵称"""
password: str = Field(max_length=255) phone: str | None = Field(default=None, max_length=20, unique=True, index=True)
"""用户密码(加密后""" """手机号(预留"""
status: UserStatus = UserStatus.ACTIVE status: UserStatus = UserStatus.ACTIVE
"""用户状态""" """用户状态"""
@@ -458,9 +476,6 @@ class User(UserBase, UUIDTableBaseMixin):
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0) storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
"""已用存储空间(字节)""" """已用存储空间(字节)"""
two_factor: str | None = Field(default=None, min_length=32, max_length=32)
"""两步验证密钥"""
avatar: str = Field(default="default", max_length=255) avatar: str = Field(default="default", max_length=255)
"""头像地址""" """头像地址"""
@@ -533,6 +548,12 @@ class User(UserBase, UUIDTableBaseMixin):
} }
) )
auth_identities: list["AuthIdentity"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
"""用户的认证身份列表"""
downloads: list["Download"] = Relationship( downloads: list["Download"] = Relationship(
back_populates="user", back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"} sa_relationship_kwargs={"cascade": "all, delete-orphan"}
@@ -634,4 +655,3 @@ class User(UserBase, UUIDTableBaseMixin):
filter=filter, filter=filter,
table_view=table_view, table_view=table_view,
) )

View File

@@ -24,6 +24,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')
from main import app from main import app
from sqlmodels.database import get_session from sqlmodels.database import get_session
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.group import Group, GroupClaims, GroupOptions from sqlmodels.group import Group, GroupClaims, GroupOptions
from sqlmodels.migration import migration from sqlmodels.migration import migration
from sqlmodels.object import Object, ObjectType from sqlmodels.object import Object, ObjectType
@@ -192,7 +193,6 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
user = User( user = User(
email="testuser@test.local", email="testuser@test.local",
nickname="测试用户", nickname="测试用户",
password=Password.hash(password),
status=UserStatus.ACTIVE, status=UserStatus.ACTIVE,
storage=0, storage=0,
score=100, score=100,
@@ -200,6 +200,17 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
) )
user = await user.save(db_session) user = await user.save(db_session)
# 创建邮箱密码认证身份
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="testuser@test.local",
credential=Password.hash(password),
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(db_session)
# 创建用户根目录 # 创建用户根目录
root_folder = Object( root_folder = Object(
name="/", name="/",
@@ -279,7 +290,6 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
admin = User( admin = User(
email="admin@disknext.local", email="admin@disknext.local",
nickname="管理员", nickname="管理员",
password=Password.hash(password),
status=UserStatus.ACTIVE, status=UserStatus.ACTIVE,
storage=0, storage=0,
score=9999, score=9999,
@@ -287,6 +297,17 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
) )
admin = await admin.save(db_session) admin = await admin.save(db_session)
# 创建管理员邮箱密码认证身份
admin_identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="admin@disknext.local",
credential=Password.hash(password),
is_primary=True,
is_verified=True,
user_id=admin.id,
)
await admin_identity.save(db_session)
# 创建管理员根目录 # 创建管理员根目录
root_folder = Object( root_folder = Object(
name="/", name="/",

View File

@@ -2,12 +2,14 @@
用户测试数据工厂 用户测试数据工厂
提供创建测试用户的便捷方法。 提供创建测试用户的便捷方法。
用户密码凭证通过 AuthIdentity 管理,不再存储在 User 表中。
""" """
from uuid import UUID from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import User, UserStatus
from utils.password.pwd import Password from utils.password.pwd import Password
@@ -20,7 +22,7 @@ class UserFactory:
group_id: UUID, group_id: UUID,
email: str | None = None, email: str | None = None,
password: str | None = None, password: str | None = None,
**kwargs **kwargs,
) -> User: ) -> User:
""" """
创建普通用户 创建普通用户
@@ -29,7 +31,7 @@ class UserFactory:
session: 数据库会话 session: 数据库会话
group_id: 用户组UUID group_id: 用户组UUID
email: 用户邮箱(默认: test_user_{随机}@test.local email: 用户邮箱(默认: test_user_{随机}@test.local
password: 明文密码(默认: password123 password: 明文密码(默认: password123,若提供则同时创建 AuthIdentity
**kwargs: 其他用户字段 **kwargs: 其他用户字段
返回: 返回:
@@ -46,12 +48,10 @@ class UserFactory:
user = User( user = User(
email=email, email=email,
nickname=kwargs.get("nickname", email), nickname=kwargs.get("nickname", email),
password=Password.hash(password), status=kwargs.get("status", UserStatus.ACTIVE),
status=kwargs.get("status", True),
storage=kwargs.get("storage", 0), storage=kwargs.get("storage", 0),
score=kwargs.get("score", 100), score=kwargs.get("score", 100),
group_id=group_id, group_id=group_id,
two_factor=kwargs.get("two_factor"),
avatar=kwargs.get("avatar", "default"), avatar=kwargs.get("avatar", "default"),
group_expires=kwargs.get("group_expires"), group_expires=kwargs.get("group_expires"),
theme=kwargs.get("theme", "system"), theme=kwargs.get("theme", "system"),
@@ -61,6 +61,18 @@ class UserFactory:
) )
user = await user.save(session) user = await user.save(session)
# 创建邮箱密码认证身份
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier=email,
credential=Password.hash(password),
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(session)
return user return user
@staticmethod @staticmethod
@@ -68,7 +80,7 @@ class UserFactory:
session: AsyncSession, session: AsyncSession,
admin_group_id: UUID, admin_group_id: UUID,
email: str | None = None, email: str | None = None,
password: str | None = None password: str | None = None,
) -> User: ) -> User:
""" """
创建管理员用户 创建管理员用户
@@ -93,8 +105,7 @@ class UserFactory:
admin = User( admin = User(
email=email, email=email,
nickname=f"管理员 {email}", nickname=f"管理员 {email}",
password=Password.hash(password), status=UserStatus.ACTIVE,
status=True,
storage=0, storage=0,
score=9999, score=9999,
group_id=admin_group_id, group_id=admin_group_id,
@@ -102,13 +113,25 @@ class UserFactory:
) )
admin = await admin.save(session) admin = await admin.save(session)
# 创建邮箱密码认证身份
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier=email,
credential=Password.hash(password),
is_primary=True,
is_verified=True,
user_id=admin.id,
)
await identity.save(session)
return admin return admin
@staticmethod @staticmethod
async def create_banned( async def create_banned(
session: AsyncSession, session: AsyncSession,
group_id: UUID, group_id: UUID,
email: str | None = None email: str | None = None,
) -> User: ) -> User:
""" """
创建被封禁用户 创建被封禁用户
@@ -129,8 +152,7 @@ class UserFactory:
banned_user = User( banned_user = User(
email=email, email=email,
nickname=f"封禁用户 {email}", nickname=f"封禁用户 {email}",
password=Password.hash("banned_password"), status=UserStatus.ADMIN_BANNED,
status=False, # 封禁状态
storage=0, storage=0,
score=0, score=0,
group_id=group_id, group_id=group_id,
@@ -138,6 +160,18 @@ class UserFactory:
) )
banned_user = await banned_user.save(session) banned_user = await banned_user.save(session)
# 创建邮箱密码认证身份
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier=email,
credential=Password.hash("banned_password"),
is_primary=True,
is_verified=True,
user_id=banned_user.id,
)
await identity.save(session)
return banned_user return banned_user
@staticmethod @staticmethod
@@ -145,7 +179,7 @@ class UserFactory:
session: AsyncSession, session: AsyncSession,
group_id: UUID, group_id: UUID,
storage_bytes: int, storage_bytes: int,
email: str | None = None email: str | None = None,
) -> User: ) -> User:
""" """
创建已使用指定存储空间的用户 创建已使用指定存储空间的用户
@@ -167,8 +201,7 @@ class UserFactory:
user = User( user = User(
email=email, email=email,
nickname=email, nickname=email,
password=Password.hash("password123"), status=UserStatus.ACTIVE,
status=True,
storage=storage_bytes, storage=storage_bytes,
score=100, score=100,
group_id=group_id, group_id=group_id,
@@ -176,4 +209,16 @@ class UserFactory:
) )
user = await user.save(session) user = await user.save(session)
# 创建邮箱密码认证身份
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier=email,
credential=Password.hash("password123"),
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(session)
return user return user

View File

@@ -83,6 +83,24 @@ async def test_site_config_captcha_settings(async_client: AsyncClient):
assert "forgetCaptcha" in config assert "forgetCaptcha" in config
@pytest.mark.asyncio
async def test_site_config_auth_methods(async_client: AsyncClient):
"""测试配置包含认证方式列表"""
response = await async_client.get("/api/site/config")
assert response.status_code == 200
data = response.json()
config = data["data"]
assert "authMethods" in config
assert isinstance(config["authMethods"], list)
assert len(config["authMethods"]) > 0
# 每个认证方式应包含 provider 和 isEnabled
for method in config["authMethods"]:
assert "provider" in method
assert "isEnabled" in method
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_site_captcha_endpoint_exists(async_client: AsyncClient): async def test_site_captcha_endpoint_exists(async_client: AsyncClient):
"""测试验证码端点存在(即使未实现也应返回有效响应)""" """测试验证码端点存在(即使未实现也应返回有效响应)"""

View File

@@ -15,9 +15,10 @@ async def test_user_login_success(
"""测试成功登录""" """测试成功登录"""
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ json={
"username": test_user_info["email"], "provider": "email_password",
"password": test_user_info["password"], "identifier": test_user_info["email"],
"credential": test_user_info["password"],
} }
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -37,9 +38,10 @@ async def test_user_login_wrong_password(
"""测试密码错误返回 401""" """测试密码错误返回 401"""
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ json={
"username": test_user_info["email"], "provider": "email_password",
"password": "wrongpassword", "identifier": test_user_info["email"],
"credential": "wrongpassword",
} }
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -50,9 +52,10 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
"""测试不存在的用户返回 401""" """测试不存在的用户返回 401"""
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ json={
"username": "nonexistent@test.local", "provider": "email_password",
"password": "anypassword", "identifier": "nonexistent@test.local",
"credential": "anypassword",
} }
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -66,9 +69,10 @@ async def test_user_login_user_banned(
"""测试封禁用户返回 403""" """测试封禁用户返回 403"""
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ json={
"username": banned_user_info["email"], "provider": "email_password",
"password": banned_user_info["password"], "identifier": banned_user_info["email"],
"credential": banned_user_info["password"],
} }
) )
assert response.status_code == 403 assert response.status_code == 403
@@ -82,8 +86,9 @@ async def test_user_register_success(async_client: AsyncClient):
response = await async_client.post( response = await async_client.post(
"/api/user/", "/api/user/",
json={ json={
"email": "newuser@test.local", "provider": "email_password",
"password": "newpass123", "identifier": "newuser@test.local",
"credential": "newpass123",
} }
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -104,8 +109,9 @@ async def test_user_register_duplicate_email(
response = await async_client.post( response = await async_client.post(
"/api/user/", "/api/user/",
json={ json={
"email": test_user_info["email"], "provider": "email_password",
"password": "anypassword", "identifier": test_user_info["email"],
"credential": "anypassword",
} }
) )
assert response.status_code == 400 assert response.status_code == 400

View File

@@ -23,6 +23,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
from main import app from main import app
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import UserStatus from sqlmodels.user import UserStatus
from utils import Password from utils import Password
from utils.JWT import create_access_token from utils.JWT import create_access_token
@@ -98,6 +99,15 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""), Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"), Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
Setting(type=SettingsType.AUTH, name="secret_key", value="test_secret_key_for_jwt_token_generation"), Setting(type=SettingsType.AUTH, name="secret_key", value="test_secret_key_for_jwt_token_generation"),
Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"),
Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"),
Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"),
Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"),
Setting(type=SettingsType.AUTH, name="auth_password_required", value="1"),
Setting(type=SettingsType.AUTH, name="auth_phone_binding_required", value="0"),
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
] ]
for setting in settings: for setting in settings:
test_session.add(setting) test_session.add(setting)
@@ -183,7 +193,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
test_user = User( test_user = User(
id=uuid4(), id=uuid4(),
email="testuser@test.local", email="testuser@test.local",
password=Password.hash("testpass123"),
nickname="测试用户", nickname="测试用户",
status=UserStatus.ACTIVE, status=UserStatus.ACTIVE,
storage=0, storage=0,
@@ -196,7 +205,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
admin_user = User( admin_user = User(
id=uuid4(), id=uuid4(),
email="admin@disknext.local", email="admin@disknext.local",
password=Password.hash("adminpass123"),
nickname="管理员", nickname="管理员",
status=UserStatus.ACTIVE, status=UserStatus.ACTIVE,
storage=0, storage=0,
@@ -209,7 +217,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
banned_user = User( banned_user = User(
id=uuid4(), id=uuid4(),
email="banneduser@test.local", email="banneduser@test.local",
password=Password.hash("banned123"),
nickname="封禁用户", nickname="封禁用户",
status=UserStatus.ADMIN_BANNED, status=UserStatus.ADMIN_BANNED,
storage=0, storage=0,
@@ -226,7 +233,40 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
await test_session.refresh(admin_user) await test_session.refresh(admin_user)
await test_session.refresh(banned_user) await test_session.refresh(banned_user)
# 7. 创建用户根目录 # 7. 创建认证身份
test_user_identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="testuser@test.local",
credential=Password.hash("testpass123"),
is_primary=True,
is_verified=True,
user_id=test_user.id,
)
test_session.add(test_user_identity)
admin_user_identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="admin@disknext.local",
credential=Password.hash("adminpass123"),
is_primary=True,
is_verified=True,
user_id=admin_user.id,
)
test_session.add(admin_user_identity)
banned_user_identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="banneduser@test.local",
credential=Password.hash("banned123"),
is_primary=True,
is_verified=True,
user_id=banned_user.id,
)
test_session.add(banned_user_identity)
await test_session.commit()
# 8. 创建用户根目录
test_user_root = Object( test_user_root = Object(
id=uuid4(), id=uuid4(),
name="/", name="/",
@@ -251,7 +291,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
await test_session.commit() await test_session.commit()
# 8. 设置JWT密钥从数据库加载 # 9. 设置JWT密钥从数据库加载
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation" JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
# 刷新 group options # 刷新 group options

View File

@@ -18,7 +18,6 @@ async def test_user_curd():
test_user = User( test_user = User(
email='test_user@test.local', email='test_user@test.local',
password='test_password',
group_id=created_group.id group_id=created_group.id
) )
@@ -28,7 +27,6 @@ async def test_user_curd():
# 验证用户是否存在 # 验证用户是否存在
assert created_user.id is not None assert created_user.id is not None
assert created_user.email == 'test_user@test.local' assert created_user.email == 'test_user@test.local'
assert created_user.password == 'test_password'
assert created_user.group_id == created_group.id assert created_user.group_id == created_group.id
# 测试查 Read # 测试查 Read
@@ -36,18 +34,16 @@ async def test_user_curd():
assert fetched_user is not None assert fetched_user is not None
assert fetched_user.email == 'test_user@test.local' assert fetched_user.email == 'test_user@test.local'
assert fetched_user.password == 'test_password'
assert fetched_user.group_id == created_group.id assert fetched_user.group_id == created_group.id
# 测试改 Update # 测试改 Update
updated_user = await fetched_user.update( updated_user = await fetched_user.update(
session, session,
{"email": "updated_user@test.local", "password": "updated_password"} {"email": "updated_user@test.local"}
) )
assert updated_user is not None assert updated_user is not None
assert updated_user.email == 'updated_user@test.local' assert updated_user.email == 'updated_user@test.local'
assert updated_user.password == 'updated_password'
# 测试删除 Delete # 测试删除 Delete
await updated_user.delete(session) await updated_user.delete(session)

View File

@@ -19,7 +19,7 @@ async def test_object_create_folder(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="testuser", password="password", group_id=group.id) user = User(email="testuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy( policy = Policy(
@@ -53,7 +53,7 @@ async def test_object_create_file(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="testuser", password="password", group_id=group.id) user = User(email="testuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy( policy = Policy(
@@ -98,7 +98,7 @@ async def test_object_is_file_property(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="testuser", password="password", group_id=group.id) user = User(email="testuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -125,7 +125,7 @@ async def test_object_is_folder_property(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="testuser", password="password", group_id=group.id) user = User(email="testuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -151,7 +151,7 @@ async def test_object_get_root(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="rootuser", password="password", group_id=group.id) user = User(email="rootuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -183,7 +183,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="pathuser", password="password", group_id=group.id) user = User(email="pathuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -214,7 +214,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="nesteduser", password="password", group_id=group.id) user = User(email="nesteduser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -277,7 +277,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="notfounduser", password="password", group_id=group.id) user = User(email="notfounduser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -311,7 +311,7 @@ async def test_object_get_children(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="childrenuser", password="password", group_id=group.id) user = User(email="childrenuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -363,7 +363,7 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="reluser", password="password", group_id=group.id) user = User(email="reluser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -408,7 +408,7 @@ async def test_object_unique_constraint(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="uniqueuser", password="password", group_id=group.id) user = User(email="uniqueuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -456,7 +456,7 @@ async def test_object_get_full_path(db_session: AsyncSession):
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(email="pathuser", password="password", group_id=group.id) user = User(email="pathuser", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")

View File

@@ -20,7 +20,6 @@ async def test_user_create(db_session: AsyncSession):
user = User( user = User(
email="testuser@test.local", email="testuser@test.local",
nickname="测试用户", nickname="测试用户",
password="hashed_password",
group_id=group.id group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
@@ -43,7 +42,6 @@ async def test_user_unique_email(db_session: AsyncSession):
# 创建第一个用户 # 创建第一个用户
user1 = User( user1 = User(
email="duplicate@test.local", email="duplicate@test.local",
password="password1",
group_id=group.id group_id=group.id
) )
await user1.save(db_session) await user1.save(db_session)
@@ -51,7 +49,6 @@ async def test_user_unique_email(db_session: AsyncSession):
# 尝试创建同名用户 # 尝试创建同名用户
user2 = User( user2 = User(
email="duplicate@test.local", email="duplicate@test.local",
password="password2",
group_id=group.id group_id=group.id
) )
@@ -70,7 +67,6 @@ async def test_user_to_public(db_session: AsyncSession):
user = User( user = User(
email="publicuser@test.local", email="publicuser@test.local",
nickname="公开用户", nickname="公开用户",
password="secret_password",
storage=1024, storage=1024,
avatar="avatar.jpg", avatar="avatar.jpg",
group_id=group.id group_id=group.id
@@ -88,8 +84,6 @@ async def test_user_to_public(db_session: AsyncSession):
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段 # 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
assert public_user.nick is None # 实际行为 assert public_user.nick is None # 实际行为
assert public_user.storage == 1024 assert public_user.storage == 1024
# 密码不应该在公开数据中
assert not hasattr(public_user, 'password')
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -102,7 +96,6 @@ async def test_user_group_relationship(db_session: AsyncSession):
# 创建用户 # 创建用户
user = User( user = User(
email="vipuser@test.local", email="vipuser@test.local",
password="password",
group_id=group.id group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
@@ -126,7 +119,6 @@ async def test_user_status_default(db_session: AsyncSession):
user = User( user = User(
email="defaultuser@test.local", email="defaultuser@test.local",
password="password",
group_id=group.id group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
@@ -142,7 +134,6 @@ async def test_user_storage_default(db_session: AsyncSession):
user = User( user = User(
email="storageuser@test.local", email="storageuser@test.local",
password="password",
group_id=group.id group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
@@ -159,7 +150,6 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试默认值 # 测试默认值
user1 = User( user1 = User(
email="user1@test.local", email="user1@test.local",
password="password",
group_id=group.id group_id=group.id
) )
user1 = await user1.save(db_session) user1 = await user1.save(db_session)
@@ -168,7 +158,6 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试设置为 LIGHT # 测试设置为 LIGHT
user2 = User( user2 = User(
email="user2@test.local", email="user2@test.local",
password="password",
theme=ThemeType.LIGHT, theme=ThemeType.LIGHT,
group_id=group.id group_id=group.id
) )
@@ -178,9 +167,40 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试设置为 DARK # 测试设置为 DARK
user3 = User( user3 = User(
email="user3@test.local", email="user3@test.local",
password="password",
theme=ThemeType.DARK, theme=ThemeType.DARK,
group_id=group.id group_id=group.id
) )
user3 = await user3.save(db_session) user3 = await user3.save(db_session)
assert user3.theme == ThemeType.DARK assert user3.theme == ThemeType.DARK
@pytest.mark.asyncio
async def test_user_email_optional(db_session: AsyncSession):
"""测试 email 可以为空(支持社交登录用户)"""
group = Group(name="默认组")
group = await group.save(db_session)
user = User(
nickname="社交用户",
group_id=group.id
)
user = await user.save(db_session)
assert user.id is not None
assert user.email is None
@pytest.mark.asyncio
async def test_user_phone_field(db_session: AsyncSession):
"""测试 phone 字段"""
group = Group(name="默认组")
group = await group.save(db_session)
user = User(
email="phoneuser@test.local",
phone="13800138000",
group_id=group.id
)
user = await user.save(db_session)
assert user.phone == "13800138000"

View File

@@ -1,78 +1,154 @@
""" """
Login 服务的单元测试 Login 服务的单元测试
测试 unified_login() 各 provider 路径。
""" """
import pytest import pytest
from fastapi import HTTPException
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.group import Group from sqlmodels.setting import Setting, SettingsType
from service.user.login import login from sqlmodels.user import User, UnifiedLoginRequest, TokenResponse, UserStatus
from sqlmodels.group import Group, GroupOptions
from service.user.login import unified_login
from utils.password.pwd import Password from utils.password.pwd import Password
@pytest.fixture @pytest.fixture
async def setup_user(db_session: AsyncSession): async def setup_auth_settings(db_session: AsyncSession):
"""创建测试用户""" """创建认证相关的 Setting 配置"""
settings = [
Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"),
Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"),
Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"),
Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"),
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
]
for s in settings:
await s.save(db_session)
@pytest.fixture
async def setup_user(db_session: AsyncSession, setup_auth_settings):
"""创建测试用户和邮箱密码认证身份"""
# 创建用户组 # 创建用户组
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
# 创建用户组选项
group_options = GroupOptions(
group_id=group.id,
share_download=True,
share_free=False,
relocate=False,
)
await group_options.save(db_session)
# 创建正常用户 # 创建正常用户
plain_password = "secure_password_123" plain_password = "secure_password_123"
user = User( user = User(
email="loginuser@test.local", email="loginuser@test.local",
password=Password.hash(plain_password),
status=UserStatus.ACTIVE, status=UserStatus.ACTIVE,
group_id=group.id group_id=group.id,
) )
user = await user.save(db_session) user = await user.save(db_session)
# 创建邮箱密码认证身份
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="loginuser@test.local",
credential=Password.hash(plain_password),
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(db_session)
return { return {
"user": user, "user": user,
"password": plain_password, "password": plain_password,
"group_id": group.id "group_id": group.id,
} }
@pytest.fixture @pytest.fixture
async def setup_banned_user(db_session: AsyncSession): async def setup_banned_user(db_session: AsyncSession, setup_auth_settings):
"""创建被封禁的用户""" """创建被封禁的用户"""
group = Group(name="测试组2") group = Group(name="测试组2")
group = await group.save(db_session) group = await group.save(db_session)
group_options = GroupOptions(
group_id=group.id,
share_download=True,
share_free=False,
relocate=False,
)
await group_options.save(db_session)
user = User( user = User(
email="banneduser@test.local", email="banneduser@test.local",
password=Password.hash("password"), status=UserStatus.ADMIN_BANNED,
status=UserStatus.ADMIN_BANNED, # 封禁状态 group_id=group.id,
group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="banneduser@test.local",
credential=Password.hash("password"),
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(db_session)
return user return user
@pytest.fixture @pytest.fixture
async def setup_2fa_user(db_session: AsyncSession): async def setup_2fa_user(db_session: AsyncSession, setup_auth_settings):
"""创建启用了两步验证的用户""" """创建启用了两步验证的用户"""
import pyotp import pyotp
group = Group(name="测试组3") group = Group(name="测试组3")
group = await group.save(db_session) group = await group.save(db_session)
group_options = GroupOptions(
group_id=group.id,
share_download=True,
share_free=False,
relocate=False,
)
await group_options.save(db_session)
secret = pyotp.random_base32() secret = pyotp.random_base32()
user = User( user = User(
email="2fauser@test.local", email="2fauser@test.local",
password=Password.hash("password"),
status=UserStatus.ACTIVE, status=UserStatus.ACTIVE,
two_factor=secret, group_id=group.id,
group_id=group.id
) )
user = await user.save(db_session) user = await user.save(db_session)
# 创建带 2FA secret 的邮箱密码认证身份
import orjson
extra_data = orjson.dumps({"two_factor": secret}).decode('utf-8')
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier="2fauser@test.local",
credential=Password.hash("password"),
extra_data=extra_data,
is_primary=True,
is_verified=True,
user_id=user.id,
)
await identity.save(db_session)
return { return {
"user": user, "user": user,
"secret": secret, "secret": secret,
"password": "password" "password": "password",
} }
@@ -81,12 +157,13 @@ async def test_login_success(db_session: AsyncSession, setup_user):
"""测试正常登录""" """测试正常登录"""
user_data = setup_user user_data = setup_user
login_request = LoginRequest( request = UnifiedLoginRequest(
email="loginuser@test.local", provider=AuthProviderType.EMAIL_PASSWORD,
password=user_data["password"] identifier="loginuser@test.local",
credential=user_data["password"],
) )
result = await login(db_session, login_request) result = await unified_login(db_session, request)
assert isinstance(result, TokenResponse) assert isinstance(result, TokenResponse)
assert result.access_token is not None assert result.access_token is not None
@@ -96,42 +173,48 @@ async def test_login_success(db_session: AsyncSession, setup_user):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_user_not_found(db_session: AsyncSession): async def test_login_user_not_found(db_session: AsyncSession, setup_user):
"""测试用户不存在""" """测试用户不存在"""
login_request = LoginRequest( request = UnifiedLoginRequest(
email="nonexistent@test.local", provider=AuthProviderType.EMAIL_PASSWORD,
password="any_password" identifier="nonexistent@test.local",
credential="any_password",
) )
result = await login(db_session, login_request) with pytest.raises(HTTPException) as exc_info:
await unified_login(db_session, request)
assert result is None assert exc_info.value.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_wrong_password(db_session: AsyncSession, setup_user): async def test_login_wrong_password(db_session: AsyncSession, setup_user):
"""测试密码错误""" """测试密码错误"""
login_request = LoginRequest( request = UnifiedLoginRequest(
email="loginuser@test.local", provider=AuthProviderType.EMAIL_PASSWORD,
password="wrong_password" identifier="loginuser@test.local",
credential="wrong_password",
) )
result = await login(db_session, login_request) with pytest.raises(HTTPException) as exc_info:
await unified_login(db_session, request)
assert result is None assert exc_info.value.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user): async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
"""测试用户被封禁""" """测试用户被封禁"""
login_request = LoginRequest( request = UnifiedLoginRequest(
email="banneduser@test.local", provider=AuthProviderType.EMAIL_PASSWORD,
password="password" identifier="banneduser@test.local",
credential="password",
) )
result = await login(db_session, login_request) with pytest.raises(HTTPException) as exc_info:
await unified_login(db_session, request)
assert result is False assert exc_info.value.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -139,15 +222,17 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
"""测试需要 2FA""" """测试需要 2FA"""
user_data = setup_2fa_user user_data = setup_2fa_user
login_request = LoginRequest( request = UnifiedLoginRequest(
email="2fauser@test.local", provider=AuthProviderType.EMAIL_PASSWORD,
password=user_data["password"] identifier="2fauser@test.local",
credential=user_data["password"],
# 未提供 two_fa_code # 未提供 two_fa_code
) )
result = await login(db_session, login_request) with pytest.raises(HTTPException) as exc_info:
await unified_login(db_session, request)
assert result == "2fa_required" assert exc_info.value.status_code == 428
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -155,15 +240,17 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
"""测试 2FA 错误""" """测试 2FA 错误"""
user_data = setup_2fa_user user_data = setup_2fa_user
login_request = LoginRequest( request = UnifiedLoginRequest(
email="2fauser@test.local", provider=AuthProviderType.EMAIL_PASSWORD,
password=user_data["password"], identifier="2fauser@test.local",
two_fa_code="000000" # 错误的验证码 credential=user_data["password"],
two_fa_code="000000",
) )
result = await login(db_session, login_request) with pytest.raises(HTTPException) as exc_info:
await unified_login(db_session, request)
assert result == "2fa_invalid" assert exc_info.value.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -178,56 +265,44 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
totp = pyotp.TOTP(secret) totp = pyotp.TOTP(secret)
valid_code = totp.now() valid_code = totp.now()
login_request = LoginRequest( request = UnifiedLoginRequest(
email="2fauser@test.local", provider=AuthProviderType.EMAIL_PASSWORD,
password=user_data["password"], identifier="2fauser@test.local",
two_fa_code=valid_code credential=user_data["password"],
two_fa_code=valid_code,
) )
result = await login(db_session, login_request) result = await unified_login(db_session, request)
assert isinstance(result, TokenResponse) assert isinstance(result, TokenResponse)
assert result.access_token is not None assert result.access_token is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user): async def test_login_provider_disabled(db_session: AsyncSession, setup_user):
"""测试返回的令牌可以被解码""" """测试未启用的 provider"""
import jwt as pyjwt request = UnifiedLoginRequest(
provider=AuthProviderType.PHONE_SMS,
user_data = setup_user identifier="13800138000",
credential="123456",
login_request = LoginRequest(
email="loginuser@test.local",
password=user_data["password"]
) )
result = await login(db_session, login_request) with pytest.raises(HTTPException) as exc_info:
await unified_login(db_session, request)
assert isinstance(result, TokenResponse) assert exc_info.value.status_code == 400
# 注意: 实际项目中需要使用正确的 SECRET_KEY
# 这里假设测试环境已经设置了 SECRET_KEY
# decoded = pyjwt.decode(
# result.access_token,
# SECRET_KEY,
# algorithms=["HS256"]
# )
# assert decoded["sub"] == "loginuser"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user): async def test_login_missing_password(db_session: AsyncSession, setup_user):
"""测试邮箱大小写敏感""" """测试邮箱密码登录缺少密码"""
user_data = setup_user request = UnifiedLoginRequest(
provider=AuthProviderType.EMAIL_PASSWORD,
# 使用大写邮箱登录 identifier="loginuser@test.local",
login_request = LoginRequest( # 未提供 credential
email="LOGINUSER@TEST.LOCAL",
password=user_data["password"]
) )
result = await login(db_session, login_request) with pytest.raises(HTTPException) as exc_info:
await unified_login(db_session, request)
# 应该失败,因为邮箱大小写不匹配 assert exc_info.value.status_code == 400
assert result is None