feat: add multi-provider auth via AuthIdentity and extend site config
- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link) - Remove password field from User model, credentials now stored in AuthIdentity - Refactor unified login/register to use AuthIdentity-based provider checking - Add site config fields: footer_code, tos_url, privacy_url, auth_methods - Add auth settings defaults in migration (email_password enabled by default) - Update admin user creation to create AuthIdentity records - Update all tests to use AuthIdentity model Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from sqlmodels import (
|
||||
Group, Object, ObjectType, Setting, SettingsType,
|
||||
BatchDeleteRequest,
|
||||
)
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import (
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
|
||||
)
|
||||
@@ -83,13 +84,26 @@ async def router_admin_create_user(
|
||||
"""
|
||||
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
||||
|
||||
管理员创建用户时,若提供了 email + password,
|
||||
会同时创建 AuthIdentity(provider=email_password)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 创建用户请求 DTO
|
||||
:return: 创建结果
|
||||
"""
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
# 如果提供了邮箱,检查唯一性(User 表和 AuthIdentity 表)
|
||||
if request.email:
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
existing_identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
& (AuthIdentity.identifier == request.email),
|
||||
)
|
||||
if existing_identity:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被绑定")
|
||||
|
||||
# 验证用户组存在
|
||||
group = await Group.get(session, Group.id == request.group_id)
|
||||
@@ -98,12 +112,24 @@ async def router_admin_create_user(
|
||||
|
||||
user = User(
|
||||
email=request.email,
|
||||
password=Password.hash(request.password),
|
||||
nickname=request.nickname,
|
||||
group_id=request.group_id,
|
||||
status=request.status,
|
||||
)
|
||||
user = await user.save(session)
|
||||
|
||||
# 如果提供了邮箱和密码,创建邮箱密码认证身份
|
||||
if request.email and request.password:
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=request.email,
|
||||
credential=Password.hash(request.password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@@ -148,17 +174,7 @@ async def router_admin_update_user(
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||
|
||||
# 如果更新密码,需要加密
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if 'password' in update_data and update_data['password']:
|
||||
update_data['password'] = Password.hash(update_data['password'])
|
||||
elif 'password' in update_data:
|
||||
del update_data['password'] # 空密码不更新
|
||||
|
||||
# 验证两步验证密钥格式(如果提供了值且不为 None,长度必须为 32)
|
||||
if 'two_factor' in update_data and update_data['two_factor'] is not None:
|
||||
if len(update_data['two_factor']) != 32:
|
||||
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
|
||||
|
||||
# 记录旧 status 以便检测变更
|
||||
old_status = user.status
|
||||
@@ -175,7 +191,7 @@ async def router_admin_update_user(
|
||||
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
|
||||
await UserBanStore.unban(str(user_id))
|
||||
|
||||
l.info(f"管理员更新了用户: {request.email}")
|
||||
l.info(f"管理员更新了用户: {user.email}")
|
||||
|
||||
|
||||
@admin_user_router.delete(
|
||||
|
||||
@@ -4,7 +4,9 @@ from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
ResponseBase, Setting, SettingsType, SiteConfigResponse,
|
||||
ThemePreset, ThemePresetResponse, ThemePresetListResponse,
|
||||
AuthMethodConfig,
|
||||
)
|
||||
from sqlmodels.auth_identity import AuthProviderType
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
@@ -70,7 +72,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
获取站点全局配置
|
||||
|
||||
无需认证。前端在初始化时调用此端点获取验证码类型、
|
||||
登录/注册/找回密码是否需要验证码等配置。
|
||||
登录/注册/找回密码是否需要验证码、可用的认证方式等配置。
|
||||
"""
|
||||
# 批量查询所需设置
|
||||
settings: list[Setting] = await Setting.get(
|
||||
@@ -78,7 +80,9 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
(Setting.type == SettingsType.BASIC) |
|
||||
(Setting.type == SettingsType.LOGIN) |
|
||||
(Setting.type == SettingsType.REGISTER) |
|
||||
(Setting.type == SettingsType.CAPTCHA),
|
||||
(Setting.type == SettingsType.CAPTCHA) |
|
||||
(Setting.type == SettingsType.AUTH) |
|
||||
(Setting.type == SettingsType.OAUTH),
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
@@ -94,6 +98,16 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
captcha_key = s.get("captcha_CloudflareKey") or None
|
||||
|
||||
# 构建认证方式列表
|
||||
auth_methods: list[AuthMethodConfig] = [
|
||||
AuthMethodConfig(provider=AuthProviderType.EMAIL_PASSWORD, is_enabled=s.get("auth_email_password_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.PHONE_SMS, is_enabled=s.get("auth_phone_sms_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.GITHUB, is_enabled=s.get("github_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.QQ, is_enabled=s.get("qq_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.PASSKEY, is_enabled=s.get("auth_passkey_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.MAGIC_LINK, is_enabled=s.get("auth_magic_link_enabled") == "1"),
|
||||
]
|
||||
|
||||
return SiteConfigResponse(
|
||||
title=s.get("siteName") or "DiskNext",
|
||||
register_enabled=s.get("register_enabled") == "1",
|
||||
@@ -102,4 +116,11 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
forget_captcha=s.get("forget_captcha") == "1",
|
||||
captcha_type=captcha_type,
|
||||
captcha_key=captcha_key,
|
||||
auth_methods=auth_methods,
|
||||
password_required=s.get("auth_password_required") == "1",
|
||||
phone_binding_required=s.get("auth_phone_binding_required") == "1",
|
||||
email_binding_required=s.get("auth_email_binding_required") == "1",
|
||||
footer_code=s.get("footer_code"),
|
||||
tos_url=s.get("tos_url"),
|
||||
privacy_url=s.get("privacy_url"),
|
||||
)
|
||||
@@ -2,7 +2,8 @@ from typing import Annotated, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from itsdangerous import URLSafeTimedSerializer
|
||||
from loguru import logger
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
@@ -12,6 +13,7 @@ import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep, require_captcha
|
||||
from service.captcha import CaptchaScene
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from .settings import user_settings_router
|
||||
@@ -23,59 +25,36 @@ user_router = APIRouter(
|
||||
|
||||
user_router.include_router(user_settings_router)
|
||||
|
||||
class OAuth2PasswordWithExtrasForm:
|
||||
"""
|
||||
扩展 OAuth2 密码表单。
|
||||
|
||||
在标准 username/password 基础上添加 otp_code 字段。
|
||||
captcha_code 由 require_captcha 依赖注入单独处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
username: Annotated[str, Form()],
|
||||
password: Annotated[str, Form()],
|
||||
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
|
||||
):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.otp_code = otp_code
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
summary='用户登录',
|
||||
description='用户登录端点,支持验证码校验和两步验证。',
|
||||
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
|
||||
summary='用户登录(统一入口)',
|
||||
description='统一登录端点,支持多种认证方式。',
|
||||
)
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
|
||||
request: sqlmodels.UnifiedLoginRequest,
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
用户登录端点
|
||||
统一登录端点
|
||||
|
||||
表单字段:
|
||||
- username: 用户邮箱
|
||||
- password: 用户密码
|
||||
- captcha_code: 验证码 token(可选,由 require_captcha 依赖校验)
|
||||
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
|
||||
请求体:
|
||||
- provider: 登录方式(email_password / github / qq / passkey / magic_link)
|
||||
- identifier: 标识符(邮箱 / OAuth code / credential_id / magic link token)
|
||||
- credential: 凭证(密码 / WebAuthn assertion 等)
|
||||
- two_fa_code: 两步验证码(可选)
|
||||
- redirect_uri: OAuth 回调地址(可选)
|
||||
- captcha: 验证码(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: 需要验证码但未提供
|
||||
- 401: 邮箱/密码错误,或 2FA 验证码错误
|
||||
- 403: 账户已禁用 / 验证码验证失败
|
||||
- 428: 需要两步验证但未提供 otp_code
|
||||
- 400: 登录方式未启用 / 参数错误
|
||||
- 401: 凭证错误
|
||||
- 403: 账户已禁用
|
||||
- 428: 需要两步验证
|
||||
- 501: 暂未实现的登录方式
|
||||
"""
|
||||
return await service.user.login(
|
||||
session,
|
||||
sqlmodels.LoginRequest(
|
||||
email=form_data.username,
|
||||
password=form_data.password,
|
||||
two_fa_code=form_data.otp_code,
|
||||
),
|
||||
)
|
||||
return await service.user.unified_login(session, request)
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session/refresh',
|
||||
@@ -150,41 +129,82 @@ async def router_user_session_refresh(
|
||||
|
||||
@user_router.post(
|
||||
path='/',
|
||||
summary='用户注册',
|
||||
summary='用户注册(统一入口)',
|
||||
description='User registration endpoint.',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_register(
|
||||
session: SessionDep,
|
||||
request: sqlmodels.RegisterRequest,
|
||||
request: sqlmodels.UnifiedRegisterRequest,
|
||||
) -> None:
|
||||
"""
|
||||
用户注册端点
|
||||
统一注册端点
|
||||
|
||||
流程:
|
||||
1. 验证用户名唯一性
|
||||
2. 获取默认用户组
|
||||
3. 创建用户记录
|
||||
4. 创建用户根目录(name="/")
|
||||
1. 检查注册开关
|
||||
2. 检查 provider 启用
|
||||
3. 验证 identifier 唯一性(AuthIdentity 表)
|
||||
4. 创建 User + AuthIdentity + 根目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 注册请求
|
||||
:return: 注册结果
|
||||
:raises HTTPException 400: 用户名已存在
|
||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||||
请求体:
|
||||
- provider: 注册方式(email_password / phone_sms)
|
||||
- identifier: 标识符(邮箱 / 手机号)
|
||||
- credential: 凭证(密码 / 短信验证码)
|
||||
- nickname: 昵称(可选)
|
||||
- captcha: 验证码(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: 注册未开放 / 参数错误
|
||||
- 409: 邮箱或手机号已存在
|
||||
- 501: 暂未实现的注册方式
|
||||
"""
|
||||
# 1. 验证邮箱唯一性
|
||||
# 1. 检查注册开关
|
||||
register_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||
& (sqlmodels.Setting.name == "register_enabled"),
|
||||
)
|
||||
if not register_setting or register_setting.value != "1":
|
||||
http_exceptions.raise_bad_request("注册功能未开放")
|
||||
|
||||
# 2. 目前只支持 email_password 注册
|
||||
if request.provider == AuthProviderType.PHONE_SMS:
|
||||
http_exceptions.raise_not_implemented("短信注册暂未开放")
|
||||
elif request.provider != AuthProviderType.EMAIL_PASSWORD:
|
||||
http_exceptions.raise_bad_request("不支持的注册方式")
|
||||
|
||||
# 3. 检查密码是否必填
|
||||
password_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_password_required"),
|
||||
)
|
||||
is_password_required = not password_required_setting or password_required_setting.value != "0"
|
||||
if is_password_required and not request.credential:
|
||||
http_exceptions.raise_bad_request("密码不能为空")
|
||||
|
||||
# 4. 验证 identifier 唯一性(AuthIdentity 表)
|
||||
existing_identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == request.provider)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if existing_identity:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 同时检查 User.email 唯一性(防止旧数据冲突)
|
||||
existing_user = await sqlmodels.User.get(
|
||||
session,
|
||||
sqlmodels.User.email == request.email
|
||||
sqlmodels.User.email == request.identifier,
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
||||
# 5. 获取默认用户组
|
||||
default_group_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||
& (sqlmodels.Setting.name == "default_group"),
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
logger.error("默认用户组不存在")
|
||||
@@ -196,17 +216,28 @@ async def router_user_register(
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
# 6. 创建用户
|
||||
new_user = sqlmodels.User(
|
||||
email=request.email,
|
||||
password=hashed_password,
|
||||
email=request.identifier,
|
||||
nickname=request.nickname,
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id
|
||||
await new_user.save(session)
|
||||
|
||||
# 4. 创建用户根目录
|
||||
# 7. 创建 AuthIdentity
|
||||
hashed_password = Password.hash(request.credential) if request.credential else None
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=request.identifier,
|
||||
credential=hashed_password,
|
||||
is_primary=True,
|
||||
is_verified=False,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
# 8. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
@@ -220,6 +251,66 @@ async def router_user_register(
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/magic-link',
|
||||
summary='发送 Magic Link 邮件',
|
||||
description='生成 Magic Link token 并发送到指定邮箱。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_magic_link(
|
||||
session: SessionDep,
|
||||
request: sqlmodels.MagicLinkRequest,
|
||||
) -> None:
|
||||
"""
|
||||
发送 Magic Link 邮件
|
||||
|
||||
流程:
|
||||
1. 验证邮箱对应的 AuthIdentity 存在
|
||||
2. 生成签名 token
|
||||
3. 发送邮件(包含带 token 的链接)
|
||||
|
||||
错误处理:
|
||||
- 400: Magic Link 未启用
|
||||
- 404: 邮箱未注册
|
||||
"""
|
||||
# 检查 magic_link 是否启用
|
||||
magic_link_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_magic_link_enabled"),
|
||||
)
|
||||
if not magic_link_setting or magic_link_setting.value != "1":
|
||||
http_exceptions.raise_bad_request("Magic Link 登录未启用")
|
||||
|
||||
# 验证邮箱存在
|
||||
identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.identifier == request.email)
|
||||
& (
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||
),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_not_found("该邮箱未注册")
|
||||
|
||||
# 生成签名 token
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
token = serializer.dumps(request.email, salt="magic-link-salt")
|
||||
|
||||
# 获取站点 URL
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.BASIC)
|
||||
& (sqlmodels.Setting.name == "siteURL"),
|
||||
)
|
||||
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
||||
|
||||
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token})
|
||||
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)")
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
summary='发送验证码邮件',
|
||||
@@ -230,52 +321,12 @@ def router_user_email_code(
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Send a verification code email.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing information about the password reset email.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/qq',
|
||||
summary='初始化QQ登录',
|
||||
description='Initialize QQ login for a user.',
|
||||
)
|
||||
def router_user_qq() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize QQ login for a user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing QQ login initialization information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='authn/{username}',
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
)
|
||||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
||||
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.post(
|
||||
path='authn/finish/{username}',
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
)
|
||||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/profile/{id}',
|
||||
summary='获取用户主页展示用分享',
|
||||
@@ -284,10 +335,10 @@ def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user profile for display.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The user ID.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing user profile information.
|
||||
"""
|
||||
@@ -301,11 +352,11 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user avatar by ID and size.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The user ID.
|
||||
size (int): The size of the avatar image.
|
||||
|
||||
|
||||
Returns:
|
||||
str: A Base64 encoded string of the user avatar image.
|
||||
"""
|
||||
@@ -348,8 +399,6 @@ async def router_user_me(
|
||||
return sqlmodels.UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
status=user.status,
|
||||
score=user.score,
|
||||
nickname=user.nickname,
|
||||
avatar=user.avatar,
|
||||
created_at=user.created_at,
|
||||
@@ -374,9 +423,9 @@ async def router_user_storage(
|
||||
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
|
||||
# [TODO] 总空间加上用户购买的额外空间
|
||||
|
||||
|
||||
total: int = group.max_storage
|
||||
used: int = user.storage
|
||||
free: int = max(0, total - used)
|
||||
@@ -389,8 +438,8 @@ async def router_user_storage(
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/start',
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
summary='注册 Passkey 凭证(初始化)',
|
||||
description='Initialize Passkey registration for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_authn_start(
|
||||
@@ -398,18 +447,19 @@ async def router_user_authn_start(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize WebAuthn login for a user.
|
||||
Passkey 注册初始化(需要登录)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn initialization information.
|
||||
返回 WebAuthn registration options,前端使用 navigator.credentials.create() 处理。
|
||||
|
||||
错误处理:
|
||||
- 400: Passkey 未启用
|
||||
"""
|
||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||
authn_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
||||
)
|
||||
if not authn_setting or authn_setting.value != "1":
|
||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||
raise HTTPException(status_code=400, detail="Passkey 未启用")
|
||||
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
@@ -423,23 +473,26 @@ async def router_user_authn_start(
|
||||
options = generate_registration_options(
|
||||
rp_id=site_url_setting.value if site_url_setting else "",
|
||||
rp_name=site_title_setting.value if site_title_setting else "",
|
||||
user_name=user.email,
|
||||
user_display_name=user.nickname or user.email,
|
||||
user_name=user.email or str(user.id),
|
||||
user_display_name=user.nickname or user.email or str(user.id),
|
||||
)
|
||||
|
||||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/finish',
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
summary='注册 Passkey 凭证(完成)',
|
||||
description='Finish Passkey registration for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Passkey 注册完成(需要登录)
|
||||
|
||||
接收前端 navigator.credentials.create() 返回的凭证数据,
|
||||
创建 UserAuthn 行 + AuthIdentity(provider=passkey)。
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
dict: A dictionary containing Passkey registration information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
@@ -9,6 +10,7 @@ from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest,
|
||||
SettingOption, UserSettingUpdateRequest,
|
||||
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
||||
)
|
||||
from sqlmodels.color import ThemeColorsBase
|
||||
from utils import JWT, Password, http_exceptions
|
||||
@@ -117,16 +119,29 @@ async def router_user_settings(
|
||||
else:
|
||||
theme_colors = BUILTIN_DEFAULT_COLORS
|
||||
|
||||
# 检查是否启用了两步验证(从 email_password AuthIdentity 的 extra_data 中读取)
|
||||
has_two_factor = False
|
||||
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.user_id == user.id)
|
||||
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||
)
|
||||
if email_identity and email_identity.extra_data:
|
||||
import orjson
|
||||
extra: dict = orjson.loads(email_identity.extra_data)
|
||||
has_two_factor = bool(extra.get("two_factor"))
|
||||
|
||||
return sqlmodels.UserSettingResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
phone=user.phone,
|
||||
nickname=user.nickname,
|
||||
created_at=user.created_at,
|
||||
group_name=user.group.name,
|
||||
language=user.language,
|
||||
timezone=user.timezone,
|
||||
group_expires=user.group_expires,
|
||||
two_factor=user.two_factor is not None,
|
||||
two_factor=has_two_factor,
|
||||
theme_preset_id=user.theme_preset_id,
|
||||
theme_colors=theme_colors,
|
||||
)
|
||||
@@ -255,7 +270,7 @@ async def router_user_settings_2fa(
|
||||
|
||||
返回 setup_token(用于后续验证请求)和 uri(用于生成二维码)。
|
||||
"""
|
||||
return await Password.generate_totp(name=user.email)
|
||||
return await Password.generate_totp(name=user.email or str(user.id))
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
@@ -273,7 +288,7 @@ async def router_user_settings_2fa_enable(
|
||||
"""
|
||||
启用两步验证
|
||||
|
||||
请求体包含 setup_token(GET /2fa 返回的令牌)和 code(6 位验证码)。
|
||||
将 2FA secret 存储到 email_password AuthIdentity 的 extra_data 中。
|
||||
"""
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
@@ -287,6 +302,150 @@ async def router_user_settings_2fa_enable(
|
||||
if Password.verify_totp(secret, request.code) != PasswordStatus.VALID:
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
# 将 secret 存储到 AuthIdentity.extra_data 中
|
||||
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.user_id == user.id)
|
||||
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||
)
|
||||
if not email_identity:
|
||||
raise HTTPException(status_code=400, detail="未找到邮箱密码认证身份")
|
||||
|
||||
import orjson
|
||||
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
|
||||
extra["two_factor"] = secret
|
||||
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
|
||||
await email_identity.save(session)
|
||||
|
||||
|
||||
# ==================== 认证身份管理 ====================
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/identities',
|
||||
summary='列出已绑定的认证身份',
|
||||
)
|
||||
async def router_user_settings_identities(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> list[AuthIdentityResponse]:
|
||||
"""
|
||||
列出当前用户已绑定的所有认证身份
|
||||
|
||||
返回:
|
||||
- 认证身份列表,包含 provider、identifier、display_name 等
|
||||
"""
|
||||
identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||
session,
|
||||
AuthIdentity.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
return [identity.to_response() for identity in identities]
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/identity',
|
||||
summary='绑定新的认证身份',
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def router_user_settings_bind_identity(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
request: BindIdentityRequest,
|
||||
) -> AuthIdentityResponse:
|
||||
"""
|
||||
绑定新的登录方式
|
||||
|
||||
请求体:
|
||||
- provider: 提供者类型
|
||||
- identifier: 标识符(邮箱 / 手机号 / OAuth code)
|
||||
- credential: 凭证(密码、验证码等)
|
||||
- redirect_uri: OAuth 回调地址(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: provider 未启用
|
||||
- 409: 该身份已被其他用户绑定
|
||||
"""
|
||||
# 检查是否已被绑定
|
||||
existing = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == request.provider)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="该身份已被绑定")
|
||||
|
||||
# 处理密码类型的凭证
|
||||
credential: str | None = None
|
||||
if request.provider == AuthProviderType.EMAIL_PASSWORD and request.credential:
|
||||
credential = Password.hash(request.credential)
|
||||
|
||||
identity = AuthIdentity(
|
||||
provider=request.provider,
|
||||
identifier=request.identifier,
|
||||
credential=credential,
|
||||
is_primary=False,
|
||||
is_verified=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
identity = await identity.save(session)
|
||||
return identity.to_response()
|
||||
|
||||
|
||||
@user_settings_router.delete(
|
||||
path='/identity/{identity_id}',
|
||||
summary='解绑认证身份',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_user_settings_unbind_identity(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
identity_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
解绑一个认证身份
|
||||
|
||||
约束:
|
||||
- 不能解绑最后一个身份
|
||||
- 站长配置强制绑定邮箱/手机号时,不能解绑对应身份
|
||||
|
||||
错误处理:
|
||||
- 404: 身份不存在或不属于当前用户
|
||||
- 400: 不能解绑最后一个身份 / 不能解绑强制绑定的身份
|
||||
"""
|
||||
# 查找目标身份
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.id == identity_id) & (AuthIdentity.user_id == user.id),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_not_found("认证身份不存在")
|
||||
|
||||
# 检查是否为最后一个身份
|
||||
all_identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||
session,
|
||||
AuthIdentity.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
if len(all_identities) <= 1:
|
||||
http_exceptions.raise_bad_request("不能解绑最后一个认证身份")
|
||||
|
||||
# 检查强制绑定约束
|
||||
if identity.provider == AuthProviderType.EMAIL_PASSWORD:
|
||||
email_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_email_binding_required"),
|
||||
)
|
||||
if email_required_setting and email_required_setting.value == "1":
|
||||
http_exceptions.raise_bad_request("站长要求必须绑定邮箱,不能解绑")
|
||||
|
||||
if identity.provider == AuthProviderType.PHONE_SMS:
|
||||
phone_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_phone_binding_required"),
|
||||
)
|
||||
if phone_required_setting and phone_required_setting.value == "1":
|
||||
http_exceptions.raise_bad_request("站长要求必须绑定手机号,不能解绑")
|
||||
|
||||
await AuthIdentity.delete(session, identity)
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .login import login
|
||||
from .login import unified_login
|
||||
|
||||
@@ -1,83 +1,417 @@
|
||||
from uuid import uuid4
|
||||
"""
|
||||
统一登录服务
|
||||
|
||||
from loguru import logger
|
||||
支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。
|
||||
"""
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import LoginRequest, TokenResponse, User
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.group import GroupClaims, GroupOptions
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import http_exceptions
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
from sqlmodels.user import TokenResponse, UnifiedLoginRequest, User, UserStatus
|
||||
from utils import JWT, http_exceptions
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
async def login(
|
||||
session: SessionDep,
|
||||
login_request: LoginRequest,
|
||||
async def unified_login(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
根据账号密码进行登录。
|
||||
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
|
||||
统一登录入口,根据 provider 分发到不同的登录逻辑。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param login_request: 登录请求
|
||||
|
||||
:return: TokenResponse 对象或状态码或 None
|
||||
:param request: 统一登录请求
|
||||
:return: TokenResponse
|
||||
"""
|
||||
# 获取用户信息(预加载 group 关系)
|
||||
current_user: User = await User.get(
|
||||
await _check_provider_enabled(session, request.provider)
|
||||
|
||||
match request.provider:
|
||||
case AuthProviderType.EMAIL_PASSWORD:
|
||||
user = await _login_email_password(session, request)
|
||||
case AuthProviderType.GITHUB:
|
||||
user = await _login_oauth(session, request, AuthProviderType.GITHUB)
|
||||
case AuthProviderType.QQ:
|
||||
user = await _login_oauth(session, request, AuthProviderType.QQ)
|
||||
case AuthProviderType.PASSKEY:
|
||||
user = await _login_passkey(session, request)
|
||||
case AuthProviderType.MAGIC_LINK:
|
||||
user = await _login_magic_link(session, request)
|
||||
case AuthProviderType.PHONE_SMS:
|
||||
http_exceptions.raise_not_implemented("短信登录暂未开放")
|
||||
case _:
|
||||
http_exceptions.raise_bad_request(f"不支持的登录方式: {request.provider}")
|
||||
|
||||
return await _issue_tokens(session, user)
|
||||
|
||||
|
||||
async def _check_provider_enabled(session: AsyncSession, provider: AuthProviderType) -> None:
|
||||
"""检查认证方式是否已被站长启用"""
|
||||
# OAuth 类型从 OAUTH 设置中查询
|
||||
if provider in (AuthProviderType.GITHUB, AuthProviderType.QQ):
|
||||
setting_name = f"{provider.value}_enabled"
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == setting_name),
|
||||
)
|
||||
if not setting or setting.value != "1":
|
||||
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||
return
|
||||
|
||||
# 其他类型从 AUTH 设置中查询
|
||||
setting_name = f"auth_{provider.value}_enabled"
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
User.email == login_request.email,
|
||||
fetch_mode="first",
|
||||
load=User.group,
|
||||
) #type: ignore
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == setting_name),
|
||||
)
|
||||
if not setting or setting.value != "1":
|
||||
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
logger.debug(f"Cannot find user with email: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
|
||||
# 验证密码是否正确
|
||||
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
|
||||
logger.debug(f"Password verification failed for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
async def _login_email_password(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""邮箱+密码登录"""
|
||||
if not request.credential:
|
||||
http_exceptions.raise_bad_request("密码不能为空")
|
||||
|
||||
# 验证用户是否可登录(修复:显式枚举比较,StrEnum 永远 truthy)
|
||||
if current_user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("Your account is disabled")
|
||||
# 查找 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if not identity:
|
||||
l.debug(f"未找到邮箱密码身份: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 检查两步验证
|
||||
if current_user.two_factor:
|
||||
# 用户已启用两步验证
|
||||
if not login_request.two_fa_code:
|
||||
logger.debug(f"2FA required for user: {login_request.email}")
|
||||
http_exceptions.raise_precondition_required("2FA required")
|
||||
# 验证密码
|
||||
if not identity.credential:
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 验证 OTP 码
|
||||
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
||||
if Password.verify(identity.credential, request.credential) != PasswordStatus.VALID:
|
||||
l.debug(f"密码验证失败: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 加载用户
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
|
||||
# 验证用户状态
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 检查两步验证(从 AuthIdentity.extra_data 中读取 2FA secret)
|
||||
if identity.extra_data:
|
||||
import orjson
|
||||
extra: dict = orjson.loads(identity.extra_data)
|
||||
two_factor_secret: str | None = extra.get("two_factor")
|
||||
if two_factor_secret:
|
||||
if not request.two_fa_code:
|
||||
l.debug(f"需要两步验证: {request.identifier}")
|
||||
http_exceptions.raise_precondition_required("需要两步验证")
|
||||
if Password.verify_totp(two_factor_secret, request.two_fa_code) != PasswordStatus.VALID:
|
||||
l.debug(f"两步验证失败: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("两步验证码错误")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _login_oauth(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
provider: AuthProviderType,
|
||||
) -> User:
|
||||
"""
|
||||
OAuth 登录(GitHub / QQ)
|
||||
|
||||
identifier 为 OAuth authorization code,后端换取 access_token 再获取用户信息。
|
||||
"""
|
||||
# 读取 OAuth 配置
|
||||
client_id_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_id"),
|
||||
)
|
||||
client_secret_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_secret"),
|
||||
)
|
||||
if not client_id_setting or not client_secret_setting:
|
||||
http_exceptions.raise_bad_request(f"{provider.value} OAuth 未配置")
|
||||
|
||||
client_id = client_id_setting.value or ""
|
||||
client_secret = client_secret_setting.value or ""
|
||||
|
||||
# 根据 provider 创建对应的 OAuth 客户端
|
||||
if provider == AuthProviderType.GITHUB:
|
||||
from service.oauth import GithubOAuth
|
||||
oauth_client = GithubOAuth(client_id, client_secret)
|
||||
token_resp = await oauth_client.get_access_token(code=request.identifier)
|
||||
user_info_resp = await oauth_client.get_user_info(token_resp)
|
||||
openid = str(user_info_resp.user_data.id)
|
||||
nickname = user_info_resp.user_data.name or user_info_resp.user_data.login
|
||||
avatar_url = user_info_resp.user_data.avatar_url
|
||||
email = user_info_resp.user_data.email
|
||||
elif provider == AuthProviderType.QQ:
|
||||
from service.oauth import QQOAuth
|
||||
oauth_client = QQOAuth(client_id, client_secret)
|
||||
token_resp = await oauth_client.get_access_token(
|
||||
code=request.identifier,
|
||||
redirect_uri=request.redirect_uri or "",
|
||||
)
|
||||
openid_resp = await oauth_client.get_openid(token_resp.access_token)
|
||||
user_info_resp = await oauth_client.get_user_info(
|
||||
token_resp,
|
||||
app_id=client_id,
|
||||
openid=openid_resp.openid,
|
||||
)
|
||||
openid = openid_resp.openid
|
||||
nickname = user_info_resp.user_data.nickname
|
||||
avatar_url = user_info_resp.user_data.figureurl_qq_2 or user_info_resp.user_data.figureurl_2
|
||||
email = None
|
||||
else:
|
||||
http_exceptions.raise_bad_request(f"不支持的 OAuth 提供者: {provider.value}")
|
||||
|
||||
# 查找已有 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == provider) & (AuthIdentity.identifier == openid),
|
||||
)
|
||||
|
||||
if identity:
|
||||
# 已绑定 → 更新 OAuth 信息并返回关联用户
|
||||
identity.display_name = nickname
|
||||
identity.avatar_url = avatar_url
|
||||
await identity.save(session)
|
||||
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
return user
|
||||
|
||||
# 未绑定 → 自动注册
|
||||
user = await _auto_register_oauth_user(
|
||||
session,
|
||||
provider=provider,
|
||||
openid=openid,
|
||||
nickname=nickname,
|
||||
avatar_url=avatar_url,
|
||||
email=email,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def _auto_register_oauth_user(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
provider: AuthProviderType,
|
||||
openid: str,
|
||||
nickname: str | None,
|
||||
avatar_url: str | None,
|
||||
email: str | None,
|
||||
) -> User:
|
||||
"""OAuth 自动注册用户"""
|
||||
# 获取默认用户组
|
||||
default_group_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.REGISTER) & (Setting.name == "default_group"),
|
||||
)
|
||||
if not default_group_setting or not default_group_setting.value:
|
||||
l.error("默认用户组未配置")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
|
||||
# 创建用户
|
||||
new_user = User(
|
||||
email=email,
|
||||
nickname=nickname,
|
||||
avatar=avatar_url or "default",
|
||||
group_id=default_group_id,
|
||||
)
|
||||
new_user_id = new_user.id
|
||||
new_user = await new_user.save(session)
|
||||
|
||||
# 创建 AuthIdentity
|
||||
identity = AuthIdentity(
|
||||
provider=provider,
|
||||
identifier=openid,
|
||||
display_name=nickname,
|
||||
avatar_url=avatar_url,
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
# 创建用户根目录
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
if default_policy:
|
||||
await Object(
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=new_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
# 重新加载用户(含 group 关系)
|
||||
user: User = await User.get(session, User.id == new_user_id, load=User.group)
|
||||
l.info(f"OAuth 自动注册用户: provider={provider.value}, openid={openid}")
|
||||
return user
|
||||
|
||||
|
||||
async def _login_passkey(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""
|
||||
Passkey/WebAuthn 登录
|
||||
|
||||
identifier 为 credential_id,credential 为 JSON 格式的 authenticator assertion response。
|
||||
"""
|
||||
from webauthn import verify_authentication_response
|
||||
from webauthn.helpers.structs import AuthenticationCredential
|
||||
|
||||
if not request.credential:
|
||||
http_exceptions.raise_bad_request("WebAuthn assertion response 不能为空")
|
||||
|
||||
# 查找 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.PASSKEY)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_unauthorized("Passkey 凭证未注册")
|
||||
|
||||
# 加载对应的 UserAuthn 记录
|
||||
from sqlmodels.user_authn import UserAuthn
|
||||
authn: UserAuthn | None = await UserAuthn.get(
|
||||
session,
|
||||
UserAuthn.credential_id == request.identifier,
|
||||
)
|
||||
if not authn:
|
||||
http_exceptions.raise_unauthorized("Passkey 凭证数据不存在")
|
||||
|
||||
# 获取 RP ID
|
||||
site_url_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||
)
|
||||
rp_id = site_url_setting.value if site_url_setting else "localhost"
|
||||
|
||||
# 验证 WebAuthn assertion
|
||||
import orjson
|
||||
credential = AuthenticationCredential.model_validate(orjson.loads(request.credential))
|
||||
|
||||
try:
|
||||
verification = verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=f"https://{rp_id}",
|
||||
expected_challenge=b"", # TODO: 从 session/cache 中获取 challenge
|
||||
credential_public_key=bytes.fromhex(authn.credential_public_key),
|
||||
credential_current_sign_count=authn.sign_count,
|
||||
)
|
||||
except Exception as e:
|
||||
l.warning(f"WebAuthn 验证失败: {e}")
|
||||
http_exceptions.raise_unauthorized("Passkey 验证失败")
|
||||
|
||||
# 更新签名计数
|
||||
authn.sign_count = verification.new_sign_count
|
||||
await authn.save(session)
|
||||
|
||||
# 加载用户
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _login_magic_link(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""
|
||||
Magic Link 登录
|
||||
|
||||
identifier 为签名 token,由 itsdangerous 生成。
|
||||
"""
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
try:
|
||||
email = serializer.loads(request.identifier, salt="magic-link-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
http_exceptions.raise_unauthorized("Magic Link 已过期")
|
||||
except BadSignature:
|
||||
http_exceptions.raise_unauthorized("Magic Link 无效")
|
||||
|
||||
# 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link)
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.identifier == email)
|
||||
& (
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||
),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_unauthorized("该邮箱未注册")
|
||||
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 标记邮箱已验证
|
||||
if not identity.is_verified:
|
||||
identity.is_verified = True
|
||||
await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _issue_tokens(session: AsyncSession, user: User) -> TokenResponse:
|
||||
"""
|
||||
签发 JWT 双令牌(access + refresh)
|
||||
|
||||
提取自原 login.py 的签发逻辑,供所有 provider 共用。
|
||||
"""
|
||||
# 加载 GroupOptions
|
||||
group_options: GroupOptions | None = await GroupOptions.get(
|
||||
session,
|
||||
GroupOptions.group_id == current_user.group_id,
|
||||
GroupOptions.group_id == user.group_id,
|
||||
)
|
||||
|
||||
# 构建权限快照
|
||||
current_user.group.options = group_options
|
||||
group_claims = GroupClaims.from_group(current_user.group)
|
||||
user.group.options = group_options
|
||||
group_claims = GroupClaims.from_group(user.group)
|
||||
|
||||
# 创建令牌
|
||||
access_token = create_access_token(
|
||||
sub=current_user.id,
|
||||
access_token = JWT.create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=current_user.status.value,
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
sub=current_user.id,
|
||||
jti=uuid4()
|
||||
refresh_token = JWT.create_refresh_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from .auth_identity import (
|
||||
AuthIdentity,
|
||||
AuthIdentityResponse,
|
||||
AuthProviderType,
|
||||
BindIdentityRequest,
|
||||
)
|
||||
from .user import (
|
||||
BatchDeleteRequest,
|
||||
JWTPayload,
|
||||
LoginRequest,
|
||||
MagicLinkRequest,
|
||||
UnifiedLoginRequest,
|
||||
UnifiedRegisterRequest,
|
||||
RefreshTokenRequest,
|
||||
RegisterRequest,
|
||||
AccessTokenBase,
|
||||
RefreshTokenBase,
|
||||
TokenResponse,
|
||||
@@ -89,7 +96,7 @@ from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, Policy
|
||||
from .redeem import Redeem, RedeemType
|
||||
from .report import Report, ReportReason
|
||||
from .setting import (
|
||||
Setting, SettingsType, SiteConfigResponse,
|
||||
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
|
||||
# 管理员DTO
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
@@ -120,4 +127,4 @@ from .model_base import (
|
||||
)
|
||||
|
||||
# mixin 中的通用分页模型
|
||||
from .mixin import ListResponse
|
||||
from .mixin import ListResponse
|
||||
|
||||
139
sqlmodels/auth_identity.py
Normal file
139
sqlmodels/auth_identity.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
认证身份模块
|
||||
|
||||
一个用户可拥有多种登录方式(邮箱密码、OAuth、Passkey、Magic Link 等)。
|
||||
AuthIdentity 表存储每种认证方式的凭证信息。
|
||||
"""
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class AuthProviderType(StrEnum):
|
||||
"""认证提供者类型"""
|
||||
|
||||
EMAIL_PASSWORD = "email_password"
|
||||
"""邮箱+密码"""
|
||||
|
||||
PHONE_SMS = "phone_sms"
|
||||
"""手机号+短信验证码(预留)"""
|
||||
|
||||
GITHUB = "github"
|
||||
"""GitHub OAuth"""
|
||||
|
||||
QQ = "qq"
|
||||
"""QQ OAuth"""
|
||||
|
||||
PASSKEY = "passkey"
|
||||
"""Passkey/WebAuthn"""
|
||||
|
||||
MAGIC_LINK = "magic_link"
|
||||
"""邮箱 Magic Link"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class AuthIdentityResponse(SQLModelBase):
|
||||
"""认证身份响应 DTO(列表展示用)"""
|
||||
|
||||
id: UUID
|
||||
"""身份UUID"""
|
||||
|
||||
provider: AuthProviderType
|
||||
"""提供者类型"""
|
||||
|
||||
identifier: str
|
||||
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||
|
||||
display_name: str | None = None
|
||||
"""显示名称(OAuth 昵称等)"""
|
||||
|
||||
avatar_url: str | None = None
|
||||
"""头像 URL"""
|
||||
|
||||
is_primary: bool = False
|
||||
"""是否主要身份"""
|
||||
|
||||
is_verified: bool = False
|
||||
"""是否已验证"""
|
||||
|
||||
|
||||
class BindIdentityRequest(SQLModelBase):
|
||||
"""绑定认证身份请求 DTO"""
|
||||
|
||||
provider: AuthProviderType
|
||||
"""提供者类型"""
|
||||
|
||||
identifier: str
|
||||
"""标识符(邮箱/手机号/OAuth code)"""
|
||||
|
||||
credential: str | None = None
|
||||
"""凭证(密码、验证码等)"""
|
||||
|
||||
redirect_uri: str | None = None
|
||||
"""OAuth 回调地址"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
|
||||
"""用户认证身份 — 一个用户可以有多种登录方式"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("provider", "identifier", name="uq_auth_identity_provider_identifier"),
|
||||
)
|
||||
|
||||
provider: AuthProviderType = Field(index=True)
|
||||
"""提供者类型"""
|
||||
|
||||
identifier: str = Field(max_length=255, index=True)
|
||||
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||
|
||||
credential: str | None = Field(default=None, max_length=1024)
|
||||
"""凭证(Argon2 哈希密码 / null)"""
|
||||
|
||||
display_name: str | None = Field(default=None, max_length=100)
|
||||
"""OAuth 昵称"""
|
||||
|
||||
avatar_url: str | None = Field(default=None, max_length=512)
|
||||
"""OAuth 头像 URL"""
|
||||
|
||||
extra_data: str | None = None
|
||||
"""JSON 附加数据(2FA secret、OAuth refresh_token 等)"""
|
||||
|
||||
is_primary: bool = False
|
||||
"""是否主要身份"""
|
||||
|
||||
is_verified: bool = False
|
||||
"""是否已验证"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="auth_identities")
|
||||
|
||||
def to_response(self) -> AuthIdentityResponse:
|
||||
"""转换为响应 DTO"""
|
||||
return AuthIdentityResponse(
|
||||
id=self.id,
|
||||
provider=self.provider,
|
||||
identifier=self.identifier,
|
||||
display_name=self.display_name,
|
||||
avatar_url=self.avatar_url,
|
||||
is_primary=self.is_primary,
|
||||
is_verified=self.is_verified,
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@@ -2,6 +2,7 @@ from enum import StrEnum
|
||||
|
||||
from sqlmodel import UniqueConstraint
|
||||
|
||||
from .auth_identity import AuthProviderType
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from .user import UserResponse
|
||||
@@ -12,6 +13,19 @@ class CaptchaType(StrEnum):
|
||||
GCAPTCHA = "gcaptcha"
|
||||
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
|
||||
|
||||
|
||||
# ==================== Auth 配置 DTO ====================
|
||||
|
||||
class AuthMethodConfig(SQLModelBase):
|
||||
"""认证方式配置 DTO"""
|
||||
|
||||
provider: AuthProviderType
|
||||
"""提供者类型"""
|
||||
|
||||
is_enabled: bool
|
||||
"""是否启用"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class SiteConfigResponse(SQLModelBase):
|
||||
@@ -50,6 +64,27 @@ class SiteConfigResponse(SQLModelBase):
|
||||
captcha_key: str | None = None
|
||||
"""验证码 public key(DEFAULT 类型时为 None)"""
|
||||
|
||||
auth_methods: list[AuthMethodConfig] = []
|
||||
"""可用的登录方式列表"""
|
||||
|
||||
password_required: bool = True
|
||||
"""注册时是否必须设置密码"""
|
||||
|
||||
phone_binding_required: bool = False
|
||||
"""是否强制绑定手机号"""
|
||||
|
||||
email_binding_required: bool = True
|
||||
"""是否强制绑定邮箱"""
|
||||
|
||||
footer_code: str | None = None
|
||||
"""自定义页脚代码"""
|
||||
|
||||
tos_url: str | None = None
|
||||
"""服务条款 URL"""
|
||||
|
||||
privacy_url: str | None = None
|
||||
"""隐私政策 URL"""
|
||||
|
||||
|
||||
# ==================== 管理员设置 DTO ====================
|
||||
|
||||
@@ -133,4 +168,4 @@ class Setting(SettingItem, TableBaseMixin):
|
||||
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)
|
||||
|
||||
type: SettingsType
|
||||
"""设置类型/分组(覆盖基类的 str 类型为枚举类型)"""
|
||||
"""设置类型/分组(覆盖基类的 str 类型为枚举类型)"""
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlmodel import Field, Relationship
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from .auth_identity import AuthProviderType
|
||||
from .base import SQLModelBase
|
||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||
from .model_base import ResponseBase
|
||||
@@ -17,6 +18,7 @@ from .mixin import UUIDTableBaseMixin, TableViewRequest, ListResponse
|
||||
T = TypeVar("T", bound="User")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .auth_identity import AuthIdentity
|
||||
from .group import Group
|
||||
from .download import Download
|
||||
from .object import Object
|
||||
@@ -30,7 +32,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class AvatarType(StrEnum):
|
||||
"""头像类型枚举"""
|
||||
|
||||
|
||||
DEFAULT = "default"
|
||||
GRAVATAR = "gravatar"
|
||||
FILE = "file"
|
||||
@@ -69,8 +71,8 @@ class UserFilterParams(SQLModelBase):
|
||||
class UserBase(SQLModelBase):
|
||||
"""用户基础字段,供数据库模型和 DTO 共享"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
email: str | None = None
|
||||
"""用户邮箱(社交登录用户可能没有邮箱)"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
@@ -81,30 +83,42 @@ class UserBase(SQLModelBase):
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class LoginRequest(SQLModelBase):
|
||||
"""登录请求 DTO"""
|
||||
class UnifiedLoginRequest(SQLModelBase):
|
||||
"""统一登录请求 DTO"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
provider: AuthProviderType
|
||||
"""登录方式"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
identifier: str
|
||||
"""标识符(邮箱 / OAuth code / Magic Link token)"""
|
||||
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
credential: str | None = None
|
||||
"""凭证(密码,provider=email_password 时必填)"""
|
||||
|
||||
two_fa_code: str | None = Field(default=None, min_length=6, max_length=6)
|
||||
"""两步验证代码"""
|
||||
|
||||
redirect_uri: str | None = None
|
||||
"""OAuth 回调地址"""
|
||||
|
||||
class RegisterRequest(SQLModelBase):
|
||||
"""注册请求 DTO"""
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
class UnifiedRegisterRequest(SQLModelBase):
|
||||
"""统一注册请求 DTO"""
|
||||
|
||||
provider: AuthProviderType
|
||||
"""注册方式(email_password / phone_sms)"""
|
||||
|
||||
identifier: str
|
||||
"""标识符(邮箱 / 手机号)"""
|
||||
|
||||
credential: str | None = None
|
||||
"""凭证(密码 / 短信验证码)"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
@@ -190,7 +204,7 @@ class UserResponse(ResponseBase):
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
email: str | None = None
|
||||
"""用户邮箱"""
|
||||
|
||||
nickname: str | None = None
|
||||
@@ -216,10 +230,10 @@ class UserStorageResponse(SQLModelBase):
|
||||
|
||||
used: int
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
|
||||
free: int
|
||||
"""剩余存储空间(字节)"""
|
||||
|
||||
|
||||
total: int
|
||||
"""总存储空间(字节)"""
|
||||
|
||||
@@ -248,9 +262,6 @@ class UserPublic(UserBase):
|
||||
group_name: str | None = None
|
||||
"""用户组名称"""
|
||||
|
||||
two_factor: str | None = None
|
||||
"""两步验证密钥(32位字符串,null 表示未启用)"""
|
||||
|
||||
created_at: datetime | None = None
|
||||
"""创建时间"""
|
||||
|
||||
@@ -264,21 +275,24 @@ class UserSettingResponse(SQLModelBase):
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
email: str | None = None
|
||||
"""用户邮箱"""
|
||||
|
||||
phone: str | None = None
|
||||
"""手机号"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""昵称"""
|
||||
|
||||
|
||||
created_at: datetime
|
||||
"""用户注册时间"""
|
||||
|
||||
group_name: str
|
||||
"""用户所属用户组名称"""
|
||||
|
||||
|
||||
language: str
|
||||
"""语言偏好"""
|
||||
|
||||
|
||||
timezone: int
|
||||
"""时区"""
|
||||
|
||||
@@ -341,16 +355,26 @@ class UserTwoFactorResponse(SQLModelBase):
|
||||
"""两步验证密钥"""
|
||||
|
||||
|
||||
class MagicLinkRequest(SQLModelBase):
|
||||
"""Magic Link 请求 DTO"""
|
||||
|
||||
email: str
|
||||
"""接收 Magic Link 的邮箱"""
|
||||
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
|
||||
|
||||
# ==================== 管理员用户管理 DTO ====================
|
||||
|
||||
class UserAdminCreateRequest(SQLModelBase):
|
||||
"""管理员创建用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
email: str | None = Field(default=None, max_length=50)
|
||||
"""用户邮箱"""
|
||||
|
||||
password: str
|
||||
"""用户密码(明文,由服务端加密)"""
|
||||
password: str | None = None
|
||||
"""用户密码(明文,由服务端加密;为空则不创建邮箱密码身份)"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
@@ -364,15 +388,15 @@ class UserAdminCreateRequest(SQLModelBase):
|
||||
|
||||
class UserAdminUpdateRequest(SQLModelBase):
|
||||
"""管理员更新用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
|
||||
email: str | None = Field(default=None, max_length=50)
|
||||
"""邮箱"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
|
||||
password: str | None = None
|
||||
"""新密码(为空则不修改)"""
|
||||
phone: str | None = None
|
||||
"""手机号"""
|
||||
|
||||
group_id: UUID | None = None
|
||||
"""用户组UUID"""
|
||||
@@ -389,9 +413,6 @@ class UserAdminUpdateRequest(SQLModelBase):
|
||||
group_expires: datetime | None = None
|
||||
"""用户组过期时间"""
|
||||
|
||||
two_factor: str | None = None
|
||||
"""两步验证密钥(32位字符串,传 null 可清除,不传则不修改)"""
|
||||
|
||||
|
||||
class UserCalibrateResponse(SQLModelBase):
|
||||
"""用户存储校准响应 DTO"""
|
||||
@@ -415,9 +436,6 @@ class UserCalibrateResponse(SQLModelBase):
|
||||
class UserAdminDetailResponse(UserPublic):
|
||||
"""管理员用户详情响应 DTO"""
|
||||
|
||||
two_factor_enabled: bool = False
|
||||
"""是否启用两步验证"""
|
||||
|
||||
file_count: int = 0
|
||||
"""文件数量"""
|
||||
|
||||
@@ -443,14 +461,14 @@ UserSettingResponse.model_rebuild()
|
||||
class User(UserBase, UUIDTableBaseMixin):
|
||||
"""用户模型"""
|
||||
|
||||
email: str = Field(max_length=50, unique=True, index=True)
|
||||
"""用户邮箱,唯一"""
|
||||
email: str | None = Field(default=None, max_length=50, unique=True, index=True)
|
||||
"""用户邮箱(社交登录用户可能没有邮箱)"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""用于公开展示的名字,可使用真实姓名或昵称"""
|
||||
|
||||
password: str = Field(max_length=255)
|
||||
"""用户密码(加密后)"""
|
||||
phone: str | None = Field(default=None, max_length=20, unique=True, index=True)
|
||||
"""手机号(预留)"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
@@ -458,9 +476,6 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
two_factor: str | None = Field(default=None, min_length=32, max_length=32)
|
||||
"""两步验证密钥"""
|
||||
|
||||
avatar: str = Field(default="default", max_length=255)
|
||||
"""头像地址"""
|
||||
|
||||
@@ -533,6 +548,12 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
}
|
||||
)
|
||||
|
||||
auth_identities: list["AuthIdentity"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
"""用户的认证身份列表"""
|
||||
|
||||
downloads: list["Download"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
@@ -634,4 +655,3 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')
|
||||
|
||||
from main import app
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.group import Group, GroupClaims, GroupOptions
|
||||
from sqlmodels.migration import migration
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
@@ -192,7 +193,6 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
user = User(
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password=Password.hash(password),
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=100,
|
||||
@@ -200,6 +200,17 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="testuser@test.local",
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
# 创建用户根目录
|
||||
root_folder = Object(
|
||||
name="/",
|
||||
@@ -279,7 +290,6 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
admin = User(
|
||||
email="admin@disknext.local",
|
||||
nickname="管理员",
|
||||
password=Password.hash(password),
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=9999,
|
||||
@@ -287,6 +297,17 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
admin = await admin.save(db_session)
|
||||
|
||||
# 创建管理员邮箱密码认证身份
|
||||
admin_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="admin@disknext.local",
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=admin.id,
|
||||
)
|
||||
await admin_identity.save(db_session)
|
||||
|
||||
# 创建管理员根目录
|
||||
root_folder = Object(
|
||||
name="/",
|
||||
|
||||
75
tests/fixtures/users.py
vendored
75
tests/fixtures/users.py
vendored
@@ -2,12 +2,14 @@
|
||||
用户测试数据工厂
|
||||
|
||||
提供创建测试用户的便捷方法。
|
||||
用户密码凭证通过 AuthIdentity 管理,不再存储在 User 表中。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -20,7 +22,7 @@ class UserFactory:
|
||||
group_id: UUID,
|
||||
email: str | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> User:
|
||||
"""
|
||||
创建普通用户
|
||||
@@ -29,7 +31,7 @@ class UserFactory:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
email: 用户邮箱(默认: test_user_{随机}@test.local)
|
||||
password: 明文密码(默认: password123)
|
||||
password: 明文密码(默认: password123),若提供则同时创建 AuthIdentity
|
||||
**kwargs: 其他用户字段
|
||||
|
||||
返回:
|
||||
@@ -46,12 +48,10 @@ class UserFactory:
|
||||
user = User(
|
||||
email=email,
|
||||
nickname=kwargs.get("nickname", email),
|
||||
password=Password.hash(password),
|
||||
status=kwargs.get("status", True),
|
||||
status=kwargs.get("status", UserStatus.ACTIVE),
|
||||
storage=kwargs.get("storage", 0),
|
||||
score=kwargs.get("score", 100),
|
||||
group_id=group_id,
|
||||
two_factor=kwargs.get("two_factor"),
|
||||
avatar=kwargs.get("avatar", "default"),
|
||||
group_expires=kwargs.get("group_expires"),
|
||||
theme=kwargs.get("theme", "system"),
|
||||
@@ -61,6 +61,18 @@ class UserFactory:
|
||||
)
|
||||
|
||||
user = await user.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@@ -68,7 +80,7 @@ class UserFactory:
|
||||
session: AsyncSession,
|
||||
admin_group_id: UUID,
|
||||
email: str | None = None,
|
||||
password: str | None = None
|
||||
password: str | None = None,
|
||||
) -> User:
|
||||
"""
|
||||
创建管理员用户
|
||||
@@ -93,8 +105,7 @@ class UserFactory:
|
||||
admin = User(
|
||||
email=email,
|
||||
nickname=f"管理员 {email}",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=9999,
|
||||
group_id=admin_group_id,
|
||||
@@ -102,13 +113,25 @@ class UserFactory:
|
||||
)
|
||||
|
||||
admin = await admin.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash(password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=admin.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return admin
|
||||
|
||||
@staticmethod
|
||||
async def create_banned(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
email: str | None = None
|
||||
email: str | None = None,
|
||||
) -> User:
|
||||
"""
|
||||
创建被封禁用户
|
||||
@@ -129,8 +152,7 @@ class UserFactory:
|
||||
banned_user = User(
|
||||
email=email,
|
||||
nickname=f"封禁用户 {email}",
|
||||
password=Password.hash("banned_password"),
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=group_id,
|
||||
@@ -138,6 +160,18 @@ class UserFactory:
|
||||
)
|
||||
|
||||
banned_user = await banned_user.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash("banned_password"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=banned_user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return banned_user
|
||||
|
||||
@staticmethod
|
||||
@@ -145,7 +179,7 @@ class UserFactory:
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
storage_bytes: int,
|
||||
email: str | None = None
|
||||
email: str | None = None,
|
||||
) -> User:
|
||||
"""
|
||||
创建已使用指定存储空间的用户
|
||||
@@ -167,8 +201,7 @@ class UserFactory:
|
||||
user = User(
|
||||
email=email,
|
||||
nickname=email,
|
||||
password=Password.hash("password123"),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=storage_bytes,
|
||||
score=100,
|
||||
group_id=group_id,
|
||||
@@ -176,4 +209,16 @@ class UserFactory:
|
||||
)
|
||||
|
||||
user = await user.save(session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=email,
|
||||
credential=Password.hash("password123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
@@ -83,6 +83,24 @@ async def test_site_config_captcha_settings(async_client: AsyncClient):
|
||||
assert "forgetCaptcha" in config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_site_config_auth_methods(async_client: AsyncClient):
|
||||
"""测试配置包含认证方式列表"""
|
||||
response = await async_client.get("/api/site/config")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
config = data["data"]
|
||||
assert "authMethods" in config
|
||||
assert isinstance(config["authMethods"], list)
|
||||
assert len(config["authMethods"]) > 0
|
||||
|
||||
# 每个认证方式应包含 provider 和 isEnabled
|
||||
for method in config["authMethods"]:
|
||||
assert "provider" in method
|
||||
assert "isEnabled" in method
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_site_captcha_endpoint_exists(async_client: AsyncClient):
|
||||
"""测试验证码端点存在(即使未实现也应返回有效响应)"""
|
||||
|
||||
@@ -15,9 +15,10 @@ async def test_user_login_success(
|
||||
"""测试成功登录"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["email"],
|
||||
"password": test_user_info["password"],
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": test_user_info["email"],
|
||||
"credential": test_user_info["password"],
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -37,9 +38,10 @@ async def test_user_login_wrong_password(
|
||||
"""测试密码错误返回 401"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["email"],
|
||||
"password": "wrongpassword",
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": test_user_info["email"],
|
||||
"credential": "wrongpassword",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -50,9 +52,10 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
|
||||
"""测试不存在的用户返回 401"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": "nonexistent@test.local",
|
||||
"password": "anypassword",
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": "nonexistent@test.local",
|
||||
"credential": "anypassword",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -66,9 +69,10 @@ async def test_user_login_user_banned(
|
||||
"""测试封禁用户返回 403"""
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": banned_user_info["email"],
|
||||
"password": banned_user_info["password"],
|
||||
json={
|
||||
"provider": "email_password",
|
||||
"identifier": banned_user_info["email"],
|
||||
"credential": banned_user_info["password"],
|
||||
}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -82,8 +86,9 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"email": "newuser@test.local",
|
||||
"password": "newpass123",
|
||||
"provider": "email_password",
|
||||
"identifier": "newuser@test.local",
|
||||
"credential": "newpass123",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -104,8 +109,9 @@ async def test_user_register_duplicate_email(
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"email": test_user_info["email"],
|
||||
"password": "anypassword",
|
||||
"provider": "email_password",
|
||||
"identifier": test_user_info["email"],
|
||||
"credential": "anypassword",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
@@ -23,6 +23,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
|
||||
|
||||
from main import app
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
@@ -98,6 +99,15 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
|
||||
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="secret_key", value="test_secret_key_for_jwt_token_generation"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_password_required", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_phone_binding_required", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"),
|
||||
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
|
||||
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
|
||||
]
|
||||
for setting in settings:
|
||||
test_session.add(setting)
|
||||
@@ -183,7 +193,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
test_user = User(
|
||||
id=uuid4(),
|
||||
email="testuser@test.local",
|
||||
password=Password.hash("testpass123"),
|
||||
nickname="测试用户",
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
@@ -196,7 +205,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
admin_user = User(
|
||||
id=uuid4(),
|
||||
email="admin@disknext.local",
|
||||
password=Password.hash("adminpass123"),
|
||||
nickname="管理员",
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
@@ -209,7 +217,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
banned_user = User(
|
||||
id=uuid4(),
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("banned123"),
|
||||
nickname="封禁用户",
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
storage=0,
|
||||
@@ -226,7 +233,40 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
await test_session.refresh(admin_user)
|
||||
await test_session.refresh(banned_user)
|
||||
|
||||
# 7. 创建用户根目录
|
||||
# 7. 创建认证身份
|
||||
test_user_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="testuser@test.local",
|
||||
credential=Password.hash("testpass123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
test_session.add(test_user_identity)
|
||||
|
||||
admin_user_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="admin@disknext.local",
|
||||
credential=Password.hash("adminpass123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
test_session.add(admin_user_identity)
|
||||
|
||||
banned_user_identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="banneduser@test.local",
|
||||
credential=Password.hash("banned123"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=banned_user.id,
|
||||
)
|
||||
test_session.add(banned_user_identity)
|
||||
|
||||
await test_session.commit()
|
||||
|
||||
# 8. 创建用户根目录
|
||||
test_user_root = Object(
|
||||
id=uuid4(),
|
||||
name="/",
|
||||
@@ -251,7 +291,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
await test_session.commit()
|
||||
|
||||
# 8. 设置JWT密钥(从数据库加载)
|
||||
# 9. 设置JWT密钥(从数据库加载)
|
||||
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
|
||||
|
||||
# 刷新 group options
|
||||
|
||||
@@ -18,7 +18,6 @@ async def test_user_curd():
|
||||
|
||||
test_user = User(
|
||||
email='test_user@test.local',
|
||||
password='test_password',
|
||||
group_id=created_group.id
|
||||
)
|
||||
|
||||
@@ -28,7 +27,6 @@ async def test_user_curd():
|
||||
# 验证用户是否存在
|
||||
assert created_user.id is not None
|
||||
assert created_user.email == 'test_user@test.local'
|
||||
assert created_user.password == 'test_password'
|
||||
assert created_user.group_id == created_group.id
|
||||
|
||||
# 测试查 Read
|
||||
@@ -36,18 +34,16 @@ async def test_user_curd():
|
||||
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.email == 'test_user@test.local'
|
||||
assert fetched_user.password == 'test_password'
|
||||
assert fetched_user.group_id == created_group.id
|
||||
|
||||
# 测试改 Update
|
||||
updated_user = await fetched_user.update(
|
||||
session,
|
||||
{"email": "updated_user@test.local", "password": "updated_password"}
|
||||
{"email": "updated_user@test.local"}
|
||||
)
|
||||
|
||||
assert updated_user is not None
|
||||
assert updated_user.email == 'updated_user@test.local'
|
||||
assert updated_user.password == 'updated_password'
|
||||
|
||||
# 测试删除 Delete
|
||||
await updated_user.delete(session)
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_object_create_folder(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -53,7 +53,7 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -98,7 +98,7 @@ async def test_object_is_file_property(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -125,7 +125,7 @@ async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -151,7 +151,7 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="rootuser", password="password", group_id=group.id)
|
||||
user = User(email="rootuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -183,7 +183,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = User(email="pathuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -214,7 +214,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="nesteduser", password="password", group_id=group.id)
|
||||
user = User(email="nesteduser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -277,7 +277,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="notfounduser", password="password", group_id=group.id)
|
||||
user = User(email="notfounduser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -311,7 +311,7 @@ async def test_object_get_children(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="childrenuser", password="password", group_id=group.id)
|
||||
user = User(email="childrenuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -363,7 +363,7 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="reluser", password="password", group_id=group.id)
|
||||
user = User(email="reluser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -408,7 +408,7 @@ async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="uniqueuser", password="password", group_id=group.id)
|
||||
user = User(email="uniqueuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -456,7 +456,7 @@ async def test_object_get_full_path(db_session: AsyncSession):
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = User(email="pathuser", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
|
||||
@@ -20,7 +20,6 @@ async def test_user_create(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password="hashed_password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -43,7 +42,6 @@ async def test_user_unique_email(db_session: AsyncSession):
|
||||
# 创建第一个用户
|
||||
user1 = User(
|
||||
email="duplicate@test.local",
|
||||
password="password1",
|
||||
group_id=group.id
|
||||
)
|
||||
await user1.save(db_session)
|
||||
@@ -51,7 +49,6 @@ async def test_user_unique_email(db_session: AsyncSession):
|
||||
# 尝试创建同名用户
|
||||
user2 = User(
|
||||
email="duplicate@test.local",
|
||||
password="password2",
|
||||
group_id=group.id
|
||||
)
|
||||
|
||||
@@ -70,7 +67,6 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
user = User(
|
||||
email="publicuser@test.local",
|
||||
nickname="公开用户",
|
||||
password="secret_password",
|
||||
storage=1024,
|
||||
avatar="avatar.jpg",
|
||||
group_id=group.id
|
||||
@@ -88,8 +84,6 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
|
||||
assert public_user.nick is None # 实际行为
|
||||
assert public_user.storage == 1024
|
||||
# 密码不应该在公开数据中
|
||||
assert not hasattr(public_user, 'password')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -102,7 +96,6 @@ async def test_user_group_relationship(db_session: AsyncSession):
|
||||
# 创建用户
|
||||
user = User(
|
||||
email="vipuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -126,7 +119,6 @@ async def test_user_status_default(db_session: AsyncSession):
|
||||
|
||||
user = User(
|
||||
email="defaultuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -142,7 +134,6 @@ async def test_user_storage_default(db_session: AsyncSession):
|
||||
|
||||
user = User(
|
||||
email="storageuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -159,7 +150,6 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
# 测试默认值
|
||||
user1 = User(
|
||||
email="user1@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user1 = await user1.save(db_session)
|
||||
@@ -168,7 +158,6 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
# 测试设置为 LIGHT
|
||||
user2 = User(
|
||||
email="user2@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.LIGHT,
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -178,9 +167,40 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
# 测试设置为 DARK
|
||||
user3 = User(
|
||||
email="user3@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.DARK,
|
||||
group_id=group.id
|
||||
)
|
||||
user3 = await user3.save(db_session)
|
||||
assert user3.theme == ThemeType.DARK
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_email_optional(db_session: AsyncSession):
|
||||
"""测试 email 可以为空(支持社交登录用户)"""
|
||||
group = Group(name="默认组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
nickname="社交用户",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.email is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_phone_field(db_session: AsyncSession):
|
||||
"""测试 phone 字段"""
|
||||
group = Group(name="默认组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
email="phoneuser@test.local",
|
||||
phone="13800138000",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.phone == "13800138000"
|
||||
|
||||
@@ -1,78 +1,154 @@
|
||||
"""
|
||||
Login 服务的单元测试
|
||||
|
||||
测试 unified_login() 各 provider 路径。
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
from service.user.login import login
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
from sqlmodels.user import User, UnifiedLoginRequest, TokenResponse, UserStatus
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
from service.user.login import unified_login
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_user(db_session: AsyncSession):
|
||||
"""创建测试用户"""
|
||||
async def setup_auth_settings(db_session: AsyncSession):
|
||||
"""创建认证相关的 Setting 配置"""
|
||||
settings = [
|
||||
Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"),
|
||||
Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"),
|
||||
Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"),
|
||||
Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"),
|
||||
]
|
||||
for s in settings:
|
||||
await s.save(db_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_user(db_session: AsyncSession, setup_auth_settings):
|
||||
"""创建测试用户和邮箱密码认证身份"""
|
||||
# 创建用户组
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 创建用户组选项
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
)
|
||||
await group_options.save(db_session)
|
||||
|
||||
# 创建正常用户
|
||||
plain_password = "secure_password_123"
|
||||
user = User(
|
||||
email="loginuser@test.local",
|
||||
password=Password.hash(plain_password),
|
||||
status=UserStatus.ACTIVE,
|
||||
group_id=group.id
|
||||
group_id=group.id,
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
# 创建邮箱密码认证身份
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
credential=Password.hash(plain_password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"password": plain_password,
|
||||
"group_id": group.id
|
||||
"group_id": group.id,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_banned_user(db_session: AsyncSession):
|
||||
async def setup_banned_user(db_session: AsyncSession, setup_auth_settings):
|
||||
"""创建被封禁的用户"""
|
||||
group = Group(name="测试组2")
|
||||
group = await group.save(db_session)
|
||||
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
)
|
||||
await group_options.save(db_session)
|
||||
|
||||
user = User(
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=UserStatus.ADMIN_BANNED, # 封禁状态
|
||||
group_id=group.id
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
group_id=group.id,
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="banneduser@test.local",
|
||||
credential=Password.hash("password"),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_2fa_user(db_session: AsyncSession):
|
||||
async def setup_2fa_user(db_session: AsyncSession, setup_auth_settings):
|
||||
"""创建启用了两步验证的用户"""
|
||||
import pyotp
|
||||
|
||||
group = Group(name="测试组3")
|
||||
group = await group.save(db_session)
|
||||
|
||||
group_options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=True,
|
||||
share_free=False,
|
||||
relocate=False,
|
||||
)
|
||||
await group_options.save(db_session)
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
user = User(
|
||||
email="2fauser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=UserStatus.ACTIVE,
|
||||
two_factor=secret,
|
||||
group_id=group.id
|
||||
group_id=group.id,
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
# 创建带 2FA secret 的邮箱密码认证身份
|
||||
import orjson
|
||||
extra_data = orjson.dumps({"two_factor": secret}).decode('utf-8')
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=Password.hash("password"),
|
||||
extra_data=extra_data,
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(db_session)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"secret": secret,
|
||||
"password": "password"
|
||||
"password": "password",
|
||||
}
|
||||
|
||||
|
||||
@@ -81,12 +157,13 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
"""测试正常登录"""
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
credential=user_data["password"],
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
result = await unified_login(db_session, request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
assert result.access_token is not None
|
||||
@@ -96,42 +173,48 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_not_found(db_session: AsyncSession):
|
||||
async def test_login_user_not_found(db_session: AsyncSession, setup_user):
|
||||
"""测试用户不存在"""
|
||||
login_request = LoginRequest(
|
||||
email="nonexistent@test.local",
|
||||
password="any_password"
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="nonexistent@test.local",
|
||||
credential="any_password",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result is None
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
"""测试密码错误"""
|
||||
login_request = LoginRequest(
|
||||
email="loginuser@test.local",
|
||||
password="wrong_password"
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
credential="wrong_password",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result is None
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
|
||||
"""测试用户被封禁"""
|
||||
login_request = LoginRequest(
|
||||
email="banneduser@test.local",
|
||||
password="password"
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="banneduser@test.local",
|
||||
credential="password",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result is False
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -139,15 +222,17 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
|
||||
"""测试需要 2FA"""
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"]
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=user_data["password"],
|
||||
# 未提供 two_fa_code
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result == "2fa_required"
|
||||
assert exc_info.value.status_code == 428
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -155,15 +240,17 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
|
||||
"""测试 2FA 错误"""
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code="000000" # 错误的验证码
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=user_data["password"],
|
||||
two_fa_code="000000",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert result == "2fa_invalid"
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -178,56 +265,44 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
|
||||
totp = pyotp.TOTP(secret)
|
||||
valid_code = totp.now()
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code=valid_code
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="2fauser@test.local",
|
||||
credential=user_data["password"],
|
||||
two_fa_code=valid_code,
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
result = await unified_login(db_session, request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
assert result.access_token is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
"""测试返回的令牌可以被解码"""
|
||||
import jwt as pyjwt
|
||||
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
async def test_login_provider_disabled(db_session: AsyncSession, setup_user):
|
||||
"""测试未启用的 provider"""
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.PHONE_SMS,
|
||||
identifier="13800138000",
|
||||
credential="123456",
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
assert isinstance(result, TokenResponse)
|
||||
|
||||
# 注意: 实际项目中需要使用正确的 SECRET_KEY
|
||||
# 这里假设测试环境已经设置了 SECRET_KEY
|
||||
# decoded = pyjwt.decode(
|
||||
# result.access_token,
|
||||
# SECRET_KEY,
|
||||
# algorithms=["HS256"]
|
||||
# )
|
||||
# assert decoded["sub"] == "loginuser"
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user):
|
||||
"""测试邮箱大小写敏感"""
|
||||
user_data = setup_user
|
||||
|
||||
# 使用大写邮箱登录
|
||||
login_request = LoginRequest(
|
||||
email="LOGINUSER@TEST.LOCAL",
|
||||
password=user_data["password"]
|
||||
async def test_login_missing_password(db_session: AsyncSession, setup_user):
|
||||
"""测试邮箱密码登录缺少密码"""
|
||||
request = UnifiedLoginRequest(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier="loginuser@test.local",
|
||||
# 未提供 credential
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await unified_login(db_session, request)
|
||||
|
||||
# 应该失败,因为邮箱大小写不匹配
|
||||
assert result is None
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
Reference in New Issue
Block a user