diff --git a/routers/api/v1/admin/user/__init__.py b/routers/api/v1/admin/user/__init__.py index 8f20806..f26df50 100644 --- a/routers/api/v1/admin/user/__init__.py +++ b/routers/api/v1/admin/user/__init__.py @@ -12,6 +12,7 @@ from sqlmodels import ( Group, Object, ObjectType, Setting, SettingsType, BatchDeleteRequest, ) +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType from sqlmodels.user import ( UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus, ) @@ -83,13 +84,26 @@ async def router_admin_create_user( """ 创建一个新的用户,设置邮箱、密码、用户组等信息。 + 管理员创建用户时,若提供了 email + password, + 会同时创建 AuthIdentity(provider=email_password)。 + :param session: 数据库会话 :param request: 创建用户请求 DTO :return: 创建结果 """ - existing_user = await User.get(session, User.email == request.email) - if existing_user: - raise HTTPException(status_code=409, detail="该邮箱已被注册") + # 如果提供了邮箱,检查唯一性(User 表和 AuthIdentity 表) + if request.email: + existing_user = await User.get(session, User.email == request.email) + if existing_user: + raise HTTPException(status_code=409, detail="该邮箱已被注册") + + existing_identity = await AuthIdentity.get( + session, + (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD) + & (AuthIdentity.identifier == request.email), + ) + if existing_identity: + raise HTTPException(status_code=409, detail="该邮箱已被绑定") # 验证用户组存在 group = await Group.get(session, Group.id == request.group_id) @@ -98,12 +112,24 @@ async def router_admin_create_user( user = User( email=request.email, - password=Password.hash(request.password), nickname=request.nickname, group_id=request.group_id, status=request.status, ) user = await user.save(session) + + # 如果提供了邮箱和密码,创建邮箱密码认证身份 + if request.email and request.password: + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier=request.email, + credential=Password.hash(request.password), + is_primary=True, + is_verified=True, + user_id=user.id, + ) + await identity.save(session) + return user.to_public() @@ -148,17 +174,7 @@ async def router_admin_update_user( if not group: raise HTTPException(status_code=400, detail="目标用户组不存在") - # 如果更新密码,需要加密 update_data = request.model_dump(exclude_unset=True) - if 'password' in update_data and update_data['password']: - update_data['password'] = Password.hash(update_data['password']) - elif 'password' in update_data: - del update_data['password'] # 空密码不更新 - - # 验证两步验证密钥格式(如果提供了值且不为 None,长度必须为 32) - if 'two_factor' in update_data and update_data['two_factor'] is not None: - if len(update_data['two_factor']) != 32: - raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串") # 记录旧 status 以便检测变更 old_status = user.status @@ -175,7 +191,7 @@ async def router_admin_update_user( elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE: await UserBanStore.unban(str(user_id)) - l.info(f"管理员更新了用户: {request.email}") + l.info(f"管理员更新了用户: {user.email}") @admin_user_router.delete( diff --git a/routers/api/v1/site/__init__.py b/routers/api/v1/site/__init__.py index 1ec446d..3e5206d 100644 --- a/routers/api/v1/site/__init__.py +++ b/routers/api/v1/site/__init__.py @@ -4,7 +4,9 @@ from middleware.dependencies import SessionDep from sqlmodels import ( ResponseBase, Setting, SettingsType, SiteConfigResponse, ThemePreset, ThemePresetResponse, ThemePresetListResponse, + AuthMethodConfig, ) +from sqlmodels.auth_identity import AuthProviderType from sqlmodels.setting import CaptchaType from utils import http_exceptions @@ -70,7 +72,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse: 获取站点全局配置 无需认证。前端在初始化时调用此端点获取验证码类型、 - 登录/注册/找回密码是否需要验证码等配置。 + 登录/注册/找回密码是否需要验证码、可用的认证方式等配置。 """ # 批量查询所需设置 settings: list[Setting] = await Setting.get( @@ -78,7 +80,9 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse: (Setting.type == SettingsType.BASIC) | (Setting.type == SettingsType.LOGIN) | (Setting.type == SettingsType.REGISTER) | - (Setting.type == SettingsType.CAPTCHA), + (Setting.type == SettingsType.CAPTCHA) | + (Setting.type == SettingsType.AUTH) | + (Setting.type == SettingsType.OAUTH), fetch_mode="all", ) @@ -94,6 +98,16 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse: elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE: captcha_key = s.get("captcha_CloudflareKey") or None + # 构建认证方式列表 + auth_methods: list[AuthMethodConfig] = [ + AuthMethodConfig(provider=AuthProviderType.EMAIL_PASSWORD, is_enabled=s.get("auth_email_password_enabled") == "1"), + AuthMethodConfig(provider=AuthProviderType.PHONE_SMS, is_enabled=s.get("auth_phone_sms_enabled") == "1"), + AuthMethodConfig(provider=AuthProviderType.GITHUB, is_enabled=s.get("github_enabled") == "1"), + AuthMethodConfig(provider=AuthProviderType.QQ, is_enabled=s.get("qq_enabled") == "1"), + AuthMethodConfig(provider=AuthProviderType.PASSKEY, is_enabled=s.get("auth_passkey_enabled") == "1"), + AuthMethodConfig(provider=AuthProviderType.MAGIC_LINK, is_enabled=s.get("auth_magic_link_enabled") == "1"), + ] + return SiteConfigResponse( title=s.get("siteName") or "DiskNext", register_enabled=s.get("register_enabled") == "1", @@ -102,4 +116,11 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse: forget_captcha=s.get("forget_captcha") == "1", captcha_type=captcha_type, captcha_key=captcha_key, + auth_methods=auth_methods, + password_required=s.get("auth_password_required") == "1", + phone_binding_required=s.get("auth_phone_binding_required") == "1", + email_binding_required=s.get("auth_email_binding_required") == "1", + footer_code=s.get("footer_code"), + tos_url=s.get("tos_url"), + privacy_url=s.get("privacy_url"), ) \ No newline at end of file diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py index 44abda3..cb89015 100644 --- a/routers/api/v1/user/__init__.py +++ b/routers/api/v1/user/__init__.py @@ -2,7 +2,8 @@ from typing import Annotated, Literal from uuid import UUID, uuid4 import jwt -from fastapi import APIRouter, Depends, Form, HTTPException +from fastapi import APIRouter, Depends, HTTPException +from itsdangerous import URLSafeTimedSerializer from loguru import logger from webauthn import generate_registration_options from webauthn.helpers import options_to_json_dict @@ -12,6 +13,7 @@ import sqlmodels from middleware.auth import auth_required from middleware.dependencies import SessionDep, require_captcha from service.captcha import CaptchaScene +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType from sqlmodels.user import UserStatus from utils import JWT, Password, http_exceptions from .settings import user_settings_router @@ -23,59 +25,36 @@ user_router = APIRouter( user_router.include_router(user_settings_router) -class OAuth2PasswordWithExtrasForm: - """ - 扩展 OAuth2 密码表单。 - - 在标准 username/password 基础上添加 otp_code 字段。 - captcha_code 由 require_captcha 依赖注入单独处理。 - """ - - def __init__( - self, - *, - username: Annotated[str, Form()], - password: Annotated[str, Form()], - otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None, - ): - self.username = username - self.password = password - self.otp_code = otp_code - @user_router.post( path='/session', - summary='用户登录', - description='用户登录端点,支持验证码校验和两步验证。', - dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))], + summary='用户登录(统一入口)', + description='统一登录端点,支持多种认证方式。', ) async def router_user_session( session: SessionDep, - form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()], + request: sqlmodels.UnifiedLoginRequest, ) -> sqlmodels.TokenResponse: """ - 用户登录端点 + 统一登录端点 - 表单字段: - - username: 用户邮箱 - - password: 用户密码 - - captcha_code: 验证码 token(可选,由 require_captcha 依赖校验) - - otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要) + 请求体: + - provider: 登录方式(email_password / github / qq / passkey / magic_link) + - identifier: 标识符(邮箱 / OAuth code / credential_id / magic link token) + - credential: 凭证(密码 / WebAuthn assertion 等) + - two_fa_code: 两步验证码(可选) + - redirect_uri: OAuth 回调地址(可选) + - captcha: 验证码(可选) 错误处理: - - 400: 需要验证码但未提供 - - 401: 邮箱/密码错误,或 2FA 验证码错误 - - 403: 账户已禁用 / 验证码验证失败 - - 428: 需要两步验证但未提供 otp_code + - 400: 登录方式未启用 / 参数错误 + - 401: 凭证错误 + - 403: 账户已禁用 + - 428: 需要两步验证 + - 501: 暂未实现的登录方式 """ - return await service.user.login( - session, - sqlmodels.LoginRequest( - email=form_data.username, - password=form_data.password, - two_fa_code=form_data.otp_code, - ), - ) + return await service.user.unified_login(session, request) + @user_router.post( path='/session/refresh', @@ -150,41 +129,82 @@ async def router_user_session_refresh( @user_router.post( path='/', - summary='用户注册', + summary='用户注册(统一入口)', description='User registration endpoint.', status_code=204, ) async def router_user_register( session: SessionDep, - request: sqlmodels.RegisterRequest, + request: sqlmodels.UnifiedRegisterRequest, ) -> None: """ - 用户注册端点 + 统一注册端点 流程: - 1. 验证用户名唯一性 - 2. 获取默认用户组 - 3. 创建用户记录 - 4. 创建用户根目录(name="/") + 1. 检查注册开关 + 2. 检查 provider 启用 + 3. 验证 identifier 唯一性(AuthIdentity 表) + 4. 创建 User + AuthIdentity + 根目录 - :param session: 数据库会话 - :param request: 注册请求 - :return: 注册结果 - :raises HTTPException 400: 用户名已存在 - :raises HTTPException 500: 默认用户组或存储策略不存在 + 请求体: + - provider: 注册方式(email_password / phone_sms) + - identifier: 标识符(邮箱 / 手机号) + - credential: 凭证(密码 / 短信验证码) + - nickname: 昵称(可选) + - captcha: 验证码(可选) + + 错误处理: + - 400: 注册未开放 / 参数错误 + - 409: 邮箱或手机号已存在 + - 501: 暂未实现的注册方式 """ - # 1. 验证邮箱唯一性 + # 1. 检查注册开关 + register_setting = await sqlmodels.Setting.get( + session, + (sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) + & (sqlmodels.Setting.name == "register_enabled"), + ) + if not register_setting or register_setting.value != "1": + http_exceptions.raise_bad_request("注册功能未开放") + + # 2. 目前只支持 email_password 注册 + if request.provider == AuthProviderType.PHONE_SMS: + http_exceptions.raise_not_implemented("短信注册暂未开放") + elif request.provider != AuthProviderType.EMAIL_PASSWORD: + http_exceptions.raise_bad_request("不支持的注册方式") + + # 3. 检查密码是否必填 + password_required_setting = await sqlmodels.Setting.get( + session, + (sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH) + & (sqlmodels.Setting.name == "auth_password_required"), + ) + is_password_required = not password_required_setting or password_required_setting.value != "0" + if is_password_required and not request.credential: + http_exceptions.raise_bad_request("密码不能为空") + + # 4. 验证 identifier 唯一性(AuthIdentity 表) + existing_identity = await AuthIdentity.get( + session, + (AuthIdentity.provider == request.provider) + & (AuthIdentity.identifier == request.identifier), + ) + if existing_identity: + raise HTTPException(status_code=409, detail="该邮箱已被注册") + + # 同时检查 User.email 唯一性(防止旧数据冲突) existing_user = await sqlmodels.User.get( session, - sqlmodels.User.email == request.email + sqlmodels.User.email == request.identifier, ) if existing_user: - raise HTTPException(status_code=400, detail="邮箱已存在") + raise HTTPException(status_code=409, detail="该邮箱已被注册") - # 2. 获取默认用户组(从设置中读取 UUID) - default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get( + # 5. 获取默认用户组 + default_group_setting = await sqlmodels.Setting.get( session, - (sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group") + (sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) + & (sqlmodels.Setting.name == "default_group"), ) if default_group_setting is None or not default_group_setting.value: logger.error("默认用户组不存在") @@ -196,17 +216,28 @@ async def router_user_register( logger.error("默认用户组不存在") http_exceptions.raise_internal_error() - # 3. 创建用户 - hashed_password = Password.hash(request.password) + # 6. 创建用户 new_user = sqlmodels.User( - email=request.email, - password=hashed_password, + email=request.identifier, + nickname=request.nickname, group_id=default_group.id, ) new_user_id = new_user.id await new_user.save(session) - # 4. 创建用户根目录 + # 7. 创建 AuthIdentity + hashed_password = Password.hash(request.credential) if request.credential else None + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier=request.identifier, + credential=hashed_password, + is_primary=True, + is_verified=False, + user_id=new_user_id, + ) + await identity.save(session) + + # 8. 创建用户根目录 default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储") if not default_policy: logger.error("默认存储策略不存在") @@ -220,6 +251,66 @@ async def router_user_register( policy_id=default_policy.id, ).save(session) + +@user_router.post( + path='/magic-link', + summary='发送 Magic Link 邮件', + description='生成 Magic Link token 并发送到指定邮箱。', + status_code=204, +) +async def router_user_magic_link( + session: SessionDep, + request: sqlmodels.MagicLinkRequest, +) -> None: + """ + 发送 Magic Link 邮件 + + 流程: + 1. 验证邮箱对应的 AuthIdentity 存在 + 2. 生成签名 token + 3. 发送邮件(包含带 token 的链接) + + 错误处理: + - 400: Magic Link 未启用 + - 404: 邮箱未注册 + """ + # 检查 magic_link 是否启用 + magic_link_setting = await sqlmodels.Setting.get( + session, + (sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH) + & (sqlmodels.Setting.name == "auth_magic_link_enabled"), + ) + if not magic_link_setting or magic_link_setting.value != "1": + http_exceptions.raise_bad_request("Magic Link 登录未启用") + + # 验证邮箱存在 + identity = await AuthIdentity.get( + session, + (AuthIdentity.identifier == request.email) + & ( + (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD) + | (AuthIdentity.provider == AuthProviderType.MAGIC_LINK) + ), + ) + if not identity: + http_exceptions.raise_not_found("该邮箱未注册") + + # 生成签名 token + serializer = URLSafeTimedSerializer(JWT.SECRET_KEY) + token = serializer.dumps(request.email, salt="magic-link-salt") + + # 获取站点 URL + site_url_setting = await sqlmodels.Setting.get( + session, + (sqlmodels.Setting.type == sqlmodels.SettingsType.BASIC) + & (sqlmodels.Setting.name == "siteURL"), + ) + site_url = site_url_setting.value if site_url_setting else "http://localhost" + + # TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token}) + logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)") + + @user_router.post( path='/code', summary='发送验证码邮件', @@ -230,52 +321,12 @@ def router_user_email_code( ) -> sqlmodels.ResponseBase: """ Send a verification code email. - + Returns: dict: A dictionary containing information about the password reset email. """ http_exceptions.raise_not_implemented() -@user_router.get( - path='/qq', - summary='初始化QQ登录', - description='Initialize QQ login for a user.', -) -def router_user_qq() -> sqlmodels.ResponseBase: - """ - Initialize QQ login for a user. - - Returns: - dict: A dictionary containing QQ login initialization information. - """ - http_exceptions.raise_not_implemented() - -@user_router.get( - path='authn/{username}', - summary='WebAuthn登录初始化', - description='Initialize WebAuthn login for a user.', -) -async def router_user_authn(username: str) -> sqlmodels.ResponseBase: - - http_exceptions.raise_not_implemented() - -@user_router.post( - path='authn/finish/{username}', - summary='WebAuthn登录', - description='Finish WebAuthn login for a user.', -) -def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase: - """ - Finish WebAuthn login for a user. - - Args: - username (str): The username of the user. - - Returns: - dict: A dictionary containing WebAuthn login information. - """ - http_exceptions.raise_not_implemented() - @user_router.get( path='/profile/{id}', summary='获取用户主页展示用分享', @@ -284,10 +335,10 @@ def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase: def router_user_profile(id: str) -> sqlmodels.ResponseBase: """ Get user profile for display. - + Args: id (str): The user ID. - + Returns: dict: A dictionary containing user profile information. """ @@ -301,11 +352,11 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase: def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase: """ Get user avatar by ID and size. - + Args: id (str): The user ID. size (int): The size of the avatar image. - + Returns: str: A Base64 encoded string of the user avatar image. """ @@ -348,8 +399,6 @@ async def router_user_me( return sqlmodels.UserResponse( id=user.id, email=user.email, - status=user.status, - score=user.score, nickname=user.nickname, avatar=user.avatar, created_at=user.created_at, @@ -374,9 +423,9 @@ async def router_user_storage( group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id) if not group: raise HTTPException(status_code=404, detail="用户组不存在") - + # [TODO] 总空间加上用户购买的额外空间 - + total: int = group.max_storage used: int = user.storage free: int = max(0, total - used) @@ -389,8 +438,8 @@ async def router_user_storage( @user_router.put( path='/authn/start', - summary='WebAuthn登录初始化', - description='Initialize WebAuthn login for a user.', + summary='注册 Passkey 凭证(初始化)', + description='Initialize Passkey registration for a user.', dependencies=[Depends(auth_required)], ) async def router_user_authn_start( @@ -398,18 +447,19 @@ async def router_user_authn_start( user: Annotated[sqlmodels.user.User, Depends(auth_required)], ) -> sqlmodels.ResponseBase: """ - Initialize WebAuthn login for a user. + Passkey 注册初始化(需要登录) - Returns: - dict: A dictionary containing WebAuthn initialization information. + 返回 WebAuthn registration options,前端使用 navigator.credentials.create() 处理。 + + 错误处理: + - 400: Passkey 未启用 """ - # TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等 authn_setting = await sqlmodels.Setting.get( session, (sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled") ) if not authn_setting or authn_setting.value != "1": - raise HTTPException(status_code=400, detail="WebAuthn is not enabled") + raise HTTPException(status_code=400, detail="Passkey 未启用") site_url_setting = await sqlmodels.Setting.get( session, @@ -423,23 +473,26 @@ async def router_user_authn_start( options = generate_registration_options( rp_id=site_url_setting.value if site_url_setting else "", rp_name=site_title_setting.value if site_title_setting else "", - user_name=user.email, - user_display_name=user.nickname or user.email, + user_name=user.email or str(user.id), + user_display_name=user.nickname or user.email or str(user.id), ) return sqlmodels.ResponseBase(data=options_to_json_dict(options)) @user_router.put( path='/authn/finish', - summary='WebAuthn登录', - description='Finish WebAuthn login for a user.', + summary='注册 Passkey 凭证(完成)', + description='Finish Passkey registration for a user.', dependencies=[Depends(auth_required)], ) def router_user_authn_finish() -> sqlmodels.ResponseBase: """ - Finish WebAuthn login for a user. - + Passkey 注册完成(需要登录) + + 接收前端 navigator.credentials.create() 返回的凭证数据, + 创建 UserAuthn 行 + AuthIdentity(provider=passkey)。 + Returns: - dict: A dictionary containing WebAuthn login information. + dict: A dictionary containing Passkey registration information. """ - http_exceptions.raise_not_implemented() \ No newline at end of file + http_exceptions.raise_not_implemented() diff --git a/routers/api/v1/user/settings/__init__.py b/routers/api/v1/user/settings/__init__.py index 75b5738..534b6bb 100644 --- a/routers/api/v1/user/settings/__init__.py +++ b/routers/api/v1/user/settings/__init__.py @@ -1,4 +1,5 @@ from typing import Annotated +from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired @@ -9,6 +10,7 @@ from middleware.dependencies import SessionDep from sqlmodels import ( BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest, SettingOption, UserSettingUpdateRequest, + AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest, ) from sqlmodels.color import ThemeColorsBase from utils import JWT, Password, http_exceptions @@ -117,16 +119,29 @@ async def router_user_settings( else: theme_colors = BUILTIN_DEFAULT_COLORS + # 检查是否启用了两步验证(从 email_password AuthIdentity 的 extra_data 中读取) + has_two_factor = False + email_identity: AuthIdentity | None = await AuthIdentity.get( + session, + (AuthIdentity.user_id == user.id) + & (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD), + ) + if email_identity and email_identity.extra_data: + import orjson + extra: dict = orjson.loads(email_identity.extra_data) + has_two_factor = bool(extra.get("two_factor")) + return sqlmodels.UserSettingResponse( id=user.id, email=user.email, + phone=user.phone, nickname=user.nickname, created_at=user.created_at, group_name=user.group.name, language=user.language, timezone=user.timezone, group_expires=user.group_expires, - two_factor=user.two_factor is not None, + two_factor=has_two_factor, theme_preset_id=user.theme_preset_id, theme_colors=theme_colors, ) @@ -255,7 +270,7 @@ async def router_user_settings_2fa( 返回 setup_token(用于后续验证请求)和 uri(用于生成二维码)。 """ - return await Password.generate_totp(name=user.email) + return await Password.generate_totp(name=user.email or str(user.id)) @user_settings_router.post( @@ -273,7 +288,7 @@ async def router_user_settings_2fa_enable( """ 启用两步验证 - 请求体包含 setup_token(GET /2fa 返回的令牌)和 code(6 位验证码)。 + 将 2FA secret 存储到 email_password AuthIdentity 的 extra_data 中。 """ serializer = URLSafeTimedSerializer(JWT.SECRET_KEY) @@ -287,6 +302,150 @@ async def router_user_settings_2fa_enable( if Password.verify_totp(secret, request.code) != PasswordStatus.VALID: raise HTTPException(status_code=400, detail="Invalid OTP code") - # 3. 将 secret 存储到用户的数据库记录中,启用 2FA - user.two_factor = secret - user = await user.save(session) \ No newline at end of file + # 将 secret 存储到 AuthIdentity.extra_data 中 + email_identity: AuthIdentity | None = await AuthIdentity.get( + session, + (AuthIdentity.user_id == user.id) + & (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD), + ) + if not email_identity: + raise HTTPException(status_code=400, detail="未找到邮箱密码认证身份") + + import orjson + extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {} + extra["two_factor"] = secret + email_identity.extra_data = orjson.dumps(extra).decode('utf-8') + await email_identity.save(session) + + +# ==================== 认证身份管理 ==================== + +@user_settings_router.get( + path='/identities', + summary='列出已绑定的认证身份', +) +async def router_user_settings_identities( + session: SessionDep, + user: Annotated[sqlmodels.user.User, Depends(auth_required)], +) -> list[AuthIdentityResponse]: + """ + 列出当前用户已绑定的所有认证身份 + + 返回: + - 认证身份列表,包含 provider、identifier、display_name 等 + """ + identities: list[AuthIdentity] = await AuthIdentity.get( + session, + AuthIdentity.user_id == user.id, + fetch_mode="all", + ) + return [identity.to_response() for identity in identities] + + +@user_settings_router.post( + path='/identity', + summary='绑定新的认证身份', + status_code=status.HTTP_201_CREATED, +) +async def router_user_settings_bind_identity( + session: SessionDep, + user: Annotated[sqlmodels.user.User, Depends(auth_required)], + request: BindIdentityRequest, +) -> AuthIdentityResponse: + """ + 绑定新的登录方式 + + 请求体: + - provider: 提供者类型 + - identifier: 标识符(邮箱 / 手机号 / OAuth code) + - credential: 凭证(密码、验证码等) + - redirect_uri: OAuth 回调地址(可选) + + 错误处理: + - 400: provider 未启用 + - 409: 该身份已被其他用户绑定 + """ + # 检查是否已被绑定 + existing = await AuthIdentity.get( + session, + (AuthIdentity.provider == request.provider) + & (AuthIdentity.identifier == request.identifier), + ) + if existing: + raise HTTPException(status_code=409, detail="该身份已被绑定") + + # 处理密码类型的凭证 + credential: str | None = None + if request.provider == AuthProviderType.EMAIL_PASSWORD and request.credential: + credential = Password.hash(request.credential) + + identity = AuthIdentity( + provider=request.provider, + identifier=request.identifier, + credential=credential, + is_primary=False, + is_verified=False, + user_id=user.id, + ) + identity = await identity.save(session) + return identity.to_response() + + +@user_settings_router.delete( + path='/identity/{identity_id}', + summary='解绑认证身份', + status_code=status.HTTP_204_NO_CONTENT, +) +async def router_user_settings_unbind_identity( + session: SessionDep, + user: Annotated[sqlmodels.user.User, Depends(auth_required)], + identity_id: UUID, +) -> None: + """ + 解绑一个认证身份 + + 约束: + - 不能解绑最后一个身份 + - 站长配置强制绑定邮箱/手机号时,不能解绑对应身份 + + 错误处理: + - 404: 身份不存在或不属于当前用户 + - 400: 不能解绑最后一个身份 / 不能解绑强制绑定的身份 + """ + # 查找目标身份 + identity: AuthIdentity | None = await AuthIdentity.get( + session, + (AuthIdentity.id == identity_id) & (AuthIdentity.user_id == user.id), + ) + if not identity: + http_exceptions.raise_not_found("认证身份不存在") + + # 检查是否为最后一个身份 + all_identities: list[AuthIdentity] = await AuthIdentity.get( + session, + AuthIdentity.user_id == user.id, + fetch_mode="all", + ) + if len(all_identities) <= 1: + http_exceptions.raise_bad_request("不能解绑最后一个认证身份") + + # 检查强制绑定约束 + if identity.provider == AuthProviderType.EMAIL_PASSWORD: + email_required_setting = await sqlmodels.Setting.get( + session, + (sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH) + & (sqlmodels.Setting.name == "auth_email_binding_required"), + ) + if email_required_setting and email_required_setting.value == "1": + http_exceptions.raise_bad_request("站长要求必须绑定邮箱,不能解绑") + + if identity.provider == AuthProviderType.PHONE_SMS: + phone_required_setting = await sqlmodels.Setting.get( + session, + (sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH) + & (sqlmodels.Setting.name == "auth_phone_binding_required"), + ) + if phone_required_setting and phone_required_setting.value == "1": + http_exceptions.raise_bad_request("站长要求必须绑定手机号,不能解绑") + + await AuthIdentity.delete(session, identity) diff --git a/service/user/__init__.py b/service/user/__init__.py index ee9b18b..5c7bd9a 100644 --- a/service/user/__init__.py +++ b/service/user/__init__.py @@ -1 +1 @@ -from .login import login \ No newline at end of file +from .login import unified_login diff --git a/service/user/login.py b/service/user/login.py index ee0ad68..7515f3b 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -1,83 +1,417 @@ -from uuid import uuid4 +""" +统一登录服务 -from loguru import logger +支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。 +""" +from uuid import UUID, uuid4 -from middleware.dependencies import SessionDep -from sqlmodels import LoginRequest, TokenResponse, User +from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired +from loguru import logger as l +from sqlmodel.ext.asyncio.session import AsyncSession + +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType from sqlmodels.group import GroupClaims, GroupOptions -from sqlmodels.user import UserStatus -from utils import http_exceptions -from utils.JWT import create_access_token, create_refresh_token +from sqlmodels.object import Object, ObjectType +from sqlmodels.policy import Policy +from sqlmodels.setting import Setting, SettingsType +from sqlmodels.user import TokenResponse, UnifiedLoginRequest, User, UserStatus +from utils import JWT, http_exceptions from utils.password.pwd import Password, PasswordStatus -async def login( - session: SessionDep, - login_request: LoginRequest, +async def unified_login( + session: AsyncSession, + request: UnifiedLoginRequest, ) -> TokenResponse: """ - 根据账号密码进行登录。 - 如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。 + 统一登录入口,根据 provider 分发到不同的登录逻辑。 :param session: 数据库会话 - :param login_request: 登录请求 - - :return: TokenResponse 对象或状态码或 None + :param request: 统一登录请求 + :return: TokenResponse """ - # 获取用户信息(预加载 group 关系) - current_user: User = await User.get( + await _check_provider_enabled(session, request.provider) + + match request.provider: + case AuthProviderType.EMAIL_PASSWORD: + user = await _login_email_password(session, request) + case AuthProviderType.GITHUB: + user = await _login_oauth(session, request, AuthProviderType.GITHUB) + case AuthProviderType.QQ: + user = await _login_oauth(session, request, AuthProviderType.QQ) + case AuthProviderType.PASSKEY: + user = await _login_passkey(session, request) + case AuthProviderType.MAGIC_LINK: + user = await _login_magic_link(session, request) + case AuthProviderType.PHONE_SMS: + http_exceptions.raise_not_implemented("短信登录暂未开放") + case _: + http_exceptions.raise_bad_request(f"不支持的登录方式: {request.provider}") + + return await _issue_tokens(session, user) + + +async def _check_provider_enabled(session: AsyncSession, provider: AuthProviderType) -> None: + """检查认证方式是否已被站长启用""" + # OAuth 类型从 OAUTH 设置中查询 + if provider in (AuthProviderType.GITHUB, AuthProviderType.QQ): + setting_name = f"{provider.value}_enabled" + setting = await Setting.get( + session, + (Setting.type == SettingsType.OAUTH) & (Setting.name == setting_name), + ) + if not setting or setting.value != "1": + http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用") + return + + # 其他类型从 AUTH 设置中查询 + setting_name = f"auth_{provider.value}_enabled" + setting = await Setting.get( session, - User.email == login_request.email, - fetch_mode="first", - load=User.group, - ) #type: ignore + (Setting.type == SettingsType.AUTH) & (Setting.name == setting_name), + ) + if not setting or setting.value != "1": + http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用") - # 验证用户是否存在 - if not current_user: - logger.debug(f"Cannot find user with email: {login_request.email}") - http_exceptions.raise_unauthorized("Invalid email or password") - # 验证密码是否正确 - if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID: - logger.debug(f"Password verification failed for user: {login_request.email}") - http_exceptions.raise_unauthorized("Invalid email or password") +async def _login_email_password( + session: AsyncSession, + request: UnifiedLoginRequest, +) -> User: + """邮箱+密码登录""" + if not request.credential: + http_exceptions.raise_bad_request("密码不能为空") - # 验证用户是否可登录(修复:显式枚举比较,StrEnum 永远 truthy) - if current_user.status != UserStatus.ACTIVE: - http_exceptions.raise_forbidden("Your account is disabled") + # 查找 AuthIdentity + identity: AuthIdentity | None = await AuthIdentity.get( + session, + (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD) + & (AuthIdentity.identifier == request.identifier), + ) + if not identity: + l.debug(f"未找到邮箱密码身份: {request.identifier}") + http_exceptions.raise_unauthorized("邮箱或密码错误") - # 检查两步验证 - if current_user.two_factor: - # 用户已启用两步验证 - if not login_request.two_fa_code: - logger.debug(f"2FA required for user: {login_request.email}") - http_exceptions.raise_precondition_required("2FA required") + # 验证密码 + if not identity.credential: + http_exceptions.raise_unauthorized("邮箱或密码错误") - # 验证 OTP 码 - if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID: - logger.debug(f"Invalid 2FA code for user: {login_request.email}") - http_exceptions.raise_unauthorized("Invalid 2FA code") + if Password.verify(identity.credential, request.credential) != PasswordStatus.VALID: + l.debug(f"密码验证失败: {request.identifier}") + http_exceptions.raise_unauthorized("邮箱或密码错误") + # 加载用户 + user: User = await User.get(session, User.id == identity.user_id, load=User.group) + if not user: + http_exceptions.raise_unauthorized("用户不存在") + + # 验证用户状态 + if user.status != UserStatus.ACTIVE: + http_exceptions.raise_forbidden("账户已被禁用") + + # 检查两步验证(从 AuthIdentity.extra_data 中读取 2FA secret) + if identity.extra_data: + import orjson + extra: dict = orjson.loads(identity.extra_data) + two_factor_secret: str | None = extra.get("two_factor") + if two_factor_secret: + if not request.two_fa_code: + l.debug(f"需要两步验证: {request.identifier}") + http_exceptions.raise_precondition_required("需要两步验证") + if Password.verify_totp(two_factor_secret, request.two_fa_code) != PasswordStatus.VALID: + l.debug(f"两步验证失败: {request.identifier}") + http_exceptions.raise_unauthorized("两步验证码错误") + + return user + + +async def _login_oauth( + session: AsyncSession, + request: UnifiedLoginRequest, + provider: AuthProviderType, +) -> User: + """ + OAuth 登录(GitHub / QQ) + + identifier 为 OAuth authorization code,后端换取 access_token 再获取用户信息。 + """ + # 读取 OAuth 配置 + client_id_setting = await Setting.get( + session, + (Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_id"), + ) + client_secret_setting = await Setting.get( + session, + (Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_secret"), + ) + if not client_id_setting or not client_secret_setting: + http_exceptions.raise_bad_request(f"{provider.value} OAuth 未配置") + + client_id = client_id_setting.value or "" + client_secret = client_secret_setting.value or "" + + # 根据 provider 创建对应的 OAuth 客户端 + if provider == AuthProviderType.GITHUB: + from service.oauth import GithubOAuth + oauth_client = GithubOAuth(client_id, client_secret) + token_resp = await oauth_client.get_access_token(code=request.identifier) + user_info_resp = await oauth_client.get_user_info(token_resp) + openid = str(user_info_resp.user_data.id) + nickname = user_info_resp.user_data.name or user_info_resp.user_data.login + avatar_url = user_info_resp.user_data.avatar_url + email = user_info_resp.user_data.email + elif provider == AuthProviderType.QQ: + from service.oauth import QQOAuth + oauth_client = QQOAuth(client_id, client_secret) + token_resp = await oauth_client.get_access_token( + code=request.identifier, + redirect_uri=request.redirect_uri or "", + ) + openid_resp = await oauth_client.get_openid(token_resp.access_token) + user_info_resp = await oauth_client.get_user_info( + token_resp, + app_id=client_id, + openid=openid_resp.openid, + ) + openid = openid_resp.openid + nickname = user_info_resp.user_data.nickname + avatar_url = user_info_resp.user_data.figureurl_qq_2 or user_info_resp.user_data.figureurl_2 + email = None + else: + http_exceptions.raise_bad_request(f"不支持的 OAuth 提供者: {provider.value}") + + # 查找已有 AuthIdentity + identity: AuthIdentity | None = await AuthIdentity.get( + session, + (AuthIdentity.provider == provider) & (AuthIdentity.identifier == openid), + ) + + if identity: + # 已绑定 → 更新 OAuth 信息并返回关联用户 + identity.display_name = nickname + identity.avatar_url = avatar_url + await identity.save(session) + + user: User = await User.get(session, User.id == identity.user_id, load=User.group) + if not user: + http_exceptions.raise_unauthorized("用户不存在") + if user.status != UserStatus.ACTIVE: + http_exceptions.raise_forbidden("账户已被禁用") + return user + + # 未绑定 → 自动注册 + user = await _auto_register_oauth_user( + session, + provider=provider, + openid=openid, + nickname=nickname, + avatar_url=avatar_url, + email=email, + ) + return user + + +async def _auto_register_oauth_user( + session: AsyncSession, + *, + provider: AuthProviderType, + openid: str, + nickname: str | None, + avatar_url: str | None, + email: str | None, +) -> User: + """OAuth 自动注册用户""" + # 获取默认用户组 + default_group_setting = await Setting.get( + session, + (Setting.type == SettingsType.REGISTER) & (Setting.name == "default_group"), + ) + if not default_group_setting or not default_group_setting.value: + l.error("默认用户组未配置") + http_exceptions.raise_internal_error() + + default_group_id = UUID(default_group_setting.value) + + # 创建用户 + new_user = User( + email=email, + nickname=nickname, + avatar=avatar_url or "default", + group_id=default_group_id, + ) + new_user_id = new_user.id + new_user = await new_user.save(session) + + # 创建 AuthIdentity + identity = AuthIdentity( + provider=provider, + identifier=openid, + display_name=nickname, + avatar_url=avatar_url, + is_primary=True, + is_verified=True, + user_id=new_user_id, + ) + await identity.save(session) + + # 创建用户根目录 + default_policy = await Policy.get(session, Policy.name == "本地存储") + if default_policy: + await Object( + name="/", + type=ObjectType.FOLDER, + owner_id=new_user_id, + parent_id=None, + policy_id=default_policy.id, + ).save(session) + + # 重新加载用户(含 group 关系) + user: User = await User.get(session, User.id == new_user_id, load=User.group) + l.info(f"OAuth 自动注册用户: provider={provider.value}, openid={openid}") + return user + + +async def _login_passkey( + session: AsyncSession, + request: UnifiedLoginRequest, +) -> User: + """ + Passkey/WebAuthn 登录 + + identifier 为 credential_id,credential 为 JSON 格式的 authenticator assertion response。 + """ + from webauthn import verify_authentication_response + from webauthn.helpers.structs import AuthenticationCredential + + if not request.credential: + http_exceptions.raise_bad_request("WebAuthn assertion response 不能为空") + + # 查找 AuthIdentity + identity: AuthIdentity | None = await AuthIdentity.get( + session, + (AuthIdentity.provider == AuthProviderType.PASSKEY) + & (AuthIdentity.identifier == request.identifier), + ) + if not identity: + http_exceptions.raise_unauthorized("Passkey 凭证未注册") + + # 加载对应的 UserAuthn 记录 + from sqlmodels.user_authn import UserAuthn + authn: UserAuthn | None = await UserAuthn.get( + session, + UserAuthn.credential_id == request.identifier, + ) + if not authn: + http_exceptions.raise_unauthorized("Passkey 凭证数据不存在") + + # 获取 RP ID + site_url_setting = await Setting.get( + session, + (Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"), + ) + rp_id = site_url_setting.value if site_url_setting else "localhost" + + # 验证 WebAuthn assertion + import orjson + credential = AuthenticationCredential.model_validate(orjson.loads(request.credential)) + + try: + verification = verify_authentication_response( + credential=credential, + expected_rp_id=rp_id, + expected_origin=f"https://{rp_id}", + expected_challenge=b"", # TODO: 从 session/cache 中获取 challenge + credential_public_key=bytes.fromhex(authn.credential_public_key), + credential_current_sign_count=authn.sign_count, + ) + except Exception as e: + l.warning(f"WebAuthn 验证失败: {e}") + http_exceptions.raise_unauthorized("Passkey 验证失败") + + # 更新签名计数 + authn.sign_count = verification.new_sign_count + await authn.save(session) + + # 加载用户 + user: User = await User.get(session, User.id == identity.user_id, load=User.group) + if not user: + http_exceptions.raise_unauthorized("用户不存在") + if user.status != UserStatus.ACTIVE: + http_exceptions.raise_forbidden("账户已被禁用") + + return user + + +async def _login_magic_link( + session: AsyncSession, + request: UnifiedLoginRequest, +) -> User: + """ + Magic Link 登录 + + identifier 为签名 token,由 itsdangerous 生成。 + """ + serializer = URLSafeTimedSerializer(JWT.SECRET_KEY) + + try: + email = serializer.loads(request.identifier, salt="magic-link-salt", max_age=600) + except SignatureExpired: + http_exceptions.raise_unauthorized("Magic Link 已过期") + except BadSignature: + http_exceptions.raise_unauthorized("Magic Link 无效") + + # 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link) + identity: AuthIdentity | None = await AuthIdentity.get( + session, + (AuthIdentity.identifier == email) + & ( + (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD) + | (AuthIdentity.provider == AuthProviderType.MAGIC_LINK) + ), + ) + if not identity: + http_exceptions.raise_unauthorized("该邮箱未注册") + + user: User = await User.get(session, User.id == identity.user_id, load=User.group) + if not user: + http_exceptions.raise_unauthorized("用户不存在") + if user.status != UserStatus.ACTIVE: + http_exceptions.raise_forbidden("账户已被禁用") + + # 标记邮箱已验证 + if not identity.is_verified: + identity.is_verified = True + await identity.save(session) + + return user + + +async def _issue_tokens(session: AsyncSession, user: User) -> TokenResponse: + """ + 签发 JWT 双令牌(access + refresh) + + 提取自原 login.py 的签发逻辑,供所有 provider 共用。 + """ # 加载 GroupOptions group_options: GroupOptions | None = await GroupOptions.get( session, - GroupOptions.group_id == current_user.group_id, + GroupOptions.group_id == user.group_id, ) # 构建权限快照 - current_user.group.options = group_options - group_claims = GroupClaims.from_group(current_user.group) + user.group.options = group_options + group_claims = GroupClaims.from_group(user.group) # 创建令牌 - access_token = create_access_token( - sub=current_user.id, + access_token = JWT.create_access_token( + sub=user.id, jti=uuid4(), - status=current_user.status.value, + status=user.status.value, group=group_claims, ) - refresh_token = create_refresh_token( - sub=current_user.id, - jti=uuid4() + refresh_token = JWT.create_refresh_token( + sub=user.id, + jti=uuid4(), ) return TokenResponse( diff --git a/sqlmodels/__init__.py b/sqlmodels/__init__.py index fd5749d..c0b1246 100644 --- a/sqlmodels/__init__.py +++ b/sqlmodels/__init__.py @@ -1,9 +1,16 @@ +from .auth_identity import ( + AuthIdentity, + AuthIdentityResponse, + AuthProviderType, + BindIdentityRequest, +) from .user import ( BatchDeleteRequest, JWTPayload, - LoginRequest, + MagicLinkRequest, + UnifiedLoginRequest, + UnifiedRegisterRequest, RefreshTokenRequest, - RegisterRequest, AccessTokenBase, RefreshTokenBase, TokenResponse, @@ -89,7 +96,7 @@ from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, Policy from .redeem import Redeem, RedeemType from .report import Report, ReportReason from .setting import ( - Setting, SettingsType, SiteConfigResponse, + Setting, SettingsType, SiteConfigResponse, AuthMethodConfig, # 管理员DTO SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse, ) @@ -120,4 +127,4 @@ from .model_base import ( ) # mixin 中的通用分页模型 -from .mixin import ListResponse \ No newline at end of file +from .mixin import ListResponse diff --git a/sqlmodels/auth_identity.py b/sqlmodels/auth_identity.py new file mode 100644 index 0000000..5649f43 --- /dev/null +++ b/sqlmodels/auth_identity.py @@ -0,0 +1,139 @@ +""" +认证身份模块 + +一个用户可拥有多种登录方式(邮箱密码、OAuth、Passkey、Magic Link 等)。 +AuthIdentity 表存储每种认证方式的凭证信息。 +""" +from enum import StrEnum +from typing import TYPE_CHECKING +from uuid import UUID + +from sqlmodel import Field, Relationship, UniqueConstraint + +from .base import SQLModelBase +from .mixin import UUIDTableBaseMixin + +if TYPE_CHECKING: + from .user import User + + +class AuthProviderType(StrEnum): + """认证提供者类型""" + + EMAIL_PASSWORD = "email_password" + """邮箱+密码""" + + PHONE_SMS = "phone_sms" + """手机号+短信验证码(预留)""" + + GITHUB = "github" + """GitHub OAuth""" + + QQ = "qq" + """QQ OAuth""" + + PASSKEY = "passkey" + """Passkey/WebAuthn""" + + MAGIC_LINK = "magic_link" + """邮箱 Magic Link""" + + +# ==================== DTO 模型 ==================== + +class AuthIdentityResponse(SQLModelBase): + """认证身份响应 DTO(列表展示用)""" + + id: UUID + """身份UUID""" + + provider: AuthProviderType + """提供者类型""" + + identifier: str + """标识符(邮箱/手机号/OAuth openid)""" + + display_name: str | None = None + """显示名称(OAuth 昵称等)""" + + avatar_url: str | None = None + """头像 URL""" + + is_primary: bool = False + """是否主要身份""" + + is_verified: bool = False + """是否已验证""" + + +class BindIdentityRequest(SQLModelBase): + """绑定认证身份请求 DTO""" + + provider: AuthProviderType + """提供者类型""" + + identifier: str + """标识符(邮箱/手机号/OAuth code)""" + + credential: str | None = None + """凭证(密码、验证码等)""" + + redirect_uri: str | None = None + """OAuth 回调地址""" + + +# ==================== 数据库模型 ==================== + +class AuthIdentity(SQLModelBase, UUIDTableBaseMixin): + """用户认证身份 — 一个用户可以有多种登录方式""" + + __table_args__ = ( + UniqueConstraint("provider", "identifier", name="uq_auth_identity_provider_identifier"), + ) + + provider: AuthProviderType = Field(index=True) + """提供者类型""" + + identifier: str = Field(max_length=255, index=True) + """标识符(邮箱/手机号/OAuth openid)""" + + credential: str | None = Field(default=None, max_length=1024) + """凭证(Argon2 哈希密码 / null)""" + + display_name: str | None = Field(default=None, max_length=100) + """OAuth 昵称""" + + avatar_url: str | None = Field(default=None, max_length=512) + """OAuth 头像 URL""" + + extra_data: str | None = None + """JSON 附加数据(2FA secret、OAuth refresh_token 等)""" + + is_primary: bool = False + """是否主要身份""" + + is_verified: bool = False + """是否已验证""" + + # 外键 + user_id: UUID = Field( + foreign_key="user.id", + index=True, + ondelete="CASCADE", + ) + """所属用户UUID""" + + # 关系 + user: "User" = Relationship(back_populates="auth_identities") + + def to_response(self) -> AuthIdentityResponse: + """转换为响应 DTO""" + return AuthIdentityResponse( + id=self.id, + provider=self.provider, + identifier=self.identifier, + display_name=self.display_name, + avatar_url=self.avatar_url, + is_primary=self.is_primary, + is_verified=self.is_verified, + ) diff --git a/sqlmodels/migration.py b/sqlmodels/migration.py index 2a49509..7715265 100644 --- a/sqlmodels/migration.py +++ b/sqlmodels/migration.py @@ -28,7 +28,8 @@ default_settings: list[Setting] = [ Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC), Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC), Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC), - Setting(name="site_notice", value="", type=SettingsType.BASIC), + Setting(name="site_notice_public", value="", type=SettingsType.BASIC), + Setting(name="site_notice_user", value="", type=SettingsType.BASIC), Setting(name="footer_code", value="", type=SettingsType.BASIC), Setting(name="tos_url", value="", type=SettingsType.BASIC), Setting(name="privacy_url", value="", type=SettingsType.BASIC), @@ -58,7 +59,7 @@ default_settings: list[Setting] = [ Setting(name="login_captcha", value="0", type=SettingsType.LOGIN), Setting(name="reg_captcha", value="0", type=SettingsType.LOGIN), Setting(name="reg_email_captcha", value="0", type=SettingsType.LOGIN), - Setting(name="email_active", value="0", type=SettingsType.REGISTER), + Setting(name="require_active", value="0", type=SettingsType.REGISTER), Setting(name="mail_activation_template", value="""验证码
{% if logo_url %}{{ site_name }} {% else %}{{ site_name }} {% endif %}
 

验证您的邮箱

 
感谢您注册{{ site_name }},您的验证码是:
 
{{ verify_code }}
 

该验证码{{ valid_minutes }} 分钟内有效。

为保障您的账户安全,请勿将验证码告诉他人。

 
 

此邮件由系统自动发送,请勿直接回复。

© {{ current_year }} {{ site_name }}. 保留所有权利。

                                           
""", type=SettingsType.MAIL_TEMPLATE), Setting(name="mail_reset_pwd_template", value="""重置密码
{% if logo_url %}{{ site_name }} {% else %}{{ site_name }} {% endif %}
 

重置密码

 
您正在申请重置{{ site_name }} 的登录密码。若确认是您本人操作,请使用下方验证码:
 
{{ verify_code }}
 

该验证码{{ valid_minutes }} 分钟内有效。

如果您没有请求重置密码,请忽略此邮件,您的账户依然安全。

 
 

此邮件由系统自动发送,请勿直接回复。

© {{ current_year }} {{ site_name }}. 保留所有权利。

                                           
""", type=SettingsType.MAIL_TEMPLATE), Setting(name="forget_captcha", value="0", type=SettingsType.LOGIN), @@ -107,6 +108,25 @@ default_settings: list[Setting] = [ Setting(name="pwa_display", value="standalone", type=SettingsType.PWA), Setting(name="pwa_theme_color", value="#000000", type=SettingsType.PWA), Setting(name="pwa_background_color", value="#ffffff", type=SettingsType.PWA), + # ==================== 认证方式配置 ==================== + Setting(name="auth_email_password_enabled", value="1", type=SettingsType.AUTH), + Setting(name="auth_phone_sms_enabled", value="0", type=SettingsType.AUTH), + Setting(name="auth_passkey_enabled", value="0", type=SettingsType.AUTH), + Setting(name="auth_magic_link_enabled", value="0", type=SettingsType.AUTH), + Setting(name="auth_password_required", value="1", type=SettingsType.AUTH), + Setting(name="auth_phone_binding_required", value="0", type=SettingsType.AUTH), + Setting(name="auth_email_binding_required", value="1", type=SettingsType.AUTH), + # ==================== OAuth 配置 ==================== + Setting(name="github_enabled", value="0", type=SettingsType.OAUTH), + Setting(name="github_client_id", value="", type=SettingsType.OAUTH), + Setting(name="github_client_secret", value="", type=SettingsType.OAUTH), + Setting(name="qq_enabled", value="0", type=SettingsType.OAUTH), + Setting(name="qq_client_id", value="", type=SettingsType.OAUTH), + Setting(name="qq_client_secret", value="", type=SettingsType.OAUTH), + # ==================== 短信服务配置(预留) ==================== + Setting(name="sms_provider", value="", type=SettingsType.MOBILE), + Setting(name="sms_access_key", value="", type=SettingsType.MOBILE), + Setting(name="sms_secret_key", value="", type=SettingsType.MOBILE), ] async def init_default_settings() -> None: @@ -219,6 +239,7 @@ async def init_default_group() -> None: # 游客组不关联存储策略(无法上传) async def init_default_user() -> None: + from .auth_identity import AuthIdentity, AuthProviderType from .user import User from .group import Group from .object import Object, ObjectType @@ -258,11 +279,20 @@ async def init_default_user() -> None: email="admin@disknext.local", nickname="admin", group_id=admin_group.id, - password=hashed_admin_password, ) admin_user_id = admin_user.id # 在 save 前保存 UUID await admin_user.save(session) + # 创建 AuthIdentity(邮箱密码身份) + await AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="admin@disknext.local", + credential=hashed_admin_password, + is_primary=True, + is_verified=True, + user_id=admin_user_id, + ).save(session) + # 记录默认管理员 ID 到 Setting await Setting( name="default_admin_id", @@ -341,4 +371,4 @@ async def init_default_theme_presets() -> None: neutral=NeutralColor.ZINC, ) await default_preset.save(session) - log.info('已创建默认主题预设') \ No newline at end of file + log.info('已创建默认主题预设') diff --git a/sqlmodels/setting.py b/sqlmodels/setting.py index a0d0a1c..b4375c9 100644 --- a/sqlmodels/setting.py +++ b/sqlmodels/setting.py @@ -2,6 +2,7 @@ from enum import StrEnum from sqlmodel import UniqueConstraint +from .auth_identity import AuthProviderType from .base import SQLModelBase from .mixin import TableBaseMixin from .user import UserResponse @@ -12,6 +13,19 @@ class CaptchaType(StrEnum): GCAPTCHA = "gcaptcha" CLOUD_FLARE_TURNSTILE = "cloudflare turnstile" + +# ==================== Auth 配置 DTO ==================== + +class AuthMethodConfig(SQLModelBase): + """认证方式配置 DTO""" + + provider: AuthProviderType + """提供者类型""" + + is_enabled: bool + """是否启用""" + + # ==================== DTO 模型 ==================== class SiteConfigResponse(SQLModelBase): @@ -50,6 +64,27 @@ class SiteConfigResponse(SQLModelBase): captcha_key: str | None = None """验证码 public key(DEFAULT 类型时为 None)""" + auth_methods: list[AuthMethodConfig] = [] + """可用的登录方式列表""" + + password_required: bool = True + """注册时是否必须设置密码""" + + phone_binding_required: bool = False + """是否强制绑定手机号""" + + email_binding_required: bool = True + """是否强制绑定邮箱""" + + footer_code: str | None = None + """自定义页脚代码""" + + tos_url: str | None = None + """服务条款 URL""" + + privacy_url: str | None = None + """隐私政策 URL""" + # ==================== 管理员设置 DTO ==================== @@ -133,4 +168,4 @@ class Setting(SettingItem, TableBaseMixin): __table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),) type: SettingsType - """设置类型/分组(覆盖基类的 str 类型为枚举类型)""" \ No newline at end of file + """设置类型/分组(覆盖基类的 str 类型为枚举类型)""" diff --git a/sqlmodels/user.py b/sqlmodels/user.py index 437f9b1..9eacde9 100644 --- a/sqlmodels/user.py +++ b/sqlmodels/user.py @@ -9,6 +9,7 @@ from sqlmodel import Field, Relationship from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.main import RelationshipInfo +from .auth_identity import AuthProviderType from .base import SQLModelBase from .color import ChromaticColor, NeutralColor, ThemeColorsBase from .model_base import ResponseBase @@ -17,6 +18,7 @@ from .mixin import UUIDTableBaseMixin, TableViewRequest, ListResponse T = TypeVar("T", bound="User") if TYPE_CHECKING: + from .auth_identity import AuthIdentity from .group import Group from .download import Download from .object import Object @@ -30,7 +32,7 @@ if TYPE_CHECKING: class AvatarType(StrEnum): """头像类型枚举""" - + DEFAULT = "default" GRAVATAR = "gravatar" FILE = "file" @@ -69,8 +71,8 @@ class UserFilterParams(SQLModelBase): class UserBase(SQLModelBase): """用户基础字段,供数据库模型和 DTO 共享""" - email: str - """用户邮箱""" + email: str | None = None + """用户邮箱(社交登录用户可能没有邮箱)""" status: UserStatus = UserStatus.ACTIVE """用户状态""" @@ -81,30 +83,42 @@ class UserBase(SQLModelBase): # ==================== DTO 模型 ==================== -class LoginRequest(SQLModelBase): - """登录请求 DTO""" +class UnifiedLoginRequest(SQLModelBase): + """统一登录请求 DTO""" - email: str - """用户邮箱""" + provider: AuthProviderType + """登录方式""" - password: str - """用户密码""" + identifier: str + """标识符(邮箱 / OAuth code / Magic Link token)""" - captcha: str | None = None - """验证码""" + credential: str | None = None + """凭证(密码,provider=email_password 时必填)""" two_fa_code: str | None = Field(default=None, min_length=6, max_length=6) """两步验证代码""" + redirect_uri: str | None = None + """OAuth 回调地址""" -class RegisterRequest(SQLModelBase): - """注册请求 DTO""" + captcha: str | None = None + """验证码""" - email: str - """用户邮箱,唯一""" - password: str - """用户密码""" +class UnifiedRegisterRequest(SQLModelBase): + """统一注册请求 DTO""" + + provider: AuthProviderType + """注册方式(email_password / phone_sms)""" + + identifier: str + """标识符(邮箱 / 手机号)""" + + credential: str | None = None + """凭证(密码 / 短信验证码)""" + + nickname: str | None = Field(default=None, max_length=50) + """昵称""" captcha: str | None = None """验证码""" @@ -190,7 +204,7 @@ class UserResponse(ResponseBase): id: UUID """用户UUID""" - email: str + email: str | None = None """用户邮箱""" nickname: str | None = None @@ -216,10 +230,10 @@ class UserStorageResponse(SQLModelBase): used: int """已用存储空间(字节)""" - + free: int """剩余存储空间(字节)""" - + total: int """总存储空间(字节)""" @@ -248,9 +262,6 @@ class UserPublic(UserBase): group_name: str | None = None """用户组名称""" - two_factor: str | None = None - """两步验证密钥(32位字符串,null 表示未启用)""" - created_at: datetime | None = None """创建时间""" @@ -264,21 +275,24 @@ class UserSettingResponse(SQLModelBase): id: UUID """用户UUID""" - email: str + email: str | None = None """用户邮箱""" + phone: str | None = None + """手机号""" + nickname: str | None = None """昵称""" - + created_at: datetime """用户注册时间""" group_name: str """用户所属用户组名称""" - + language: str """语言偏好""" - + timezone: int """时区""" @@ -341,16 +355,26 @@ class UserTwoFactorResponse(SQLModelBase): """两步验证密钥""" +class MagicLinkRequest(SQLModelBase): + """Magic Link 请求 DTO""" + + email: str + """接收 Magic Link 的邮箱""" + + captcha: str | None = None + """验证码""" + + # ==================== 管理员用户管理 DTO ==================== class UserAdminCreateRequest(SQLModelBase): """管理员创建用户请求 DTO""" - email: str = Field(max_length=50) + email: str | None = Field(default=None, max_length=50) """用户邮箱""" - password: str - """用户密码(明文,由服务端加密)""" + password: str | None = None + """用户密码(明文,由服务端加密;为空则不创建邮箱密码身份)""" nickname: str | None = Field(default=None, max_length=50) """昵称""" @@ -364,15 +388,15 @@ class UserAdminCreateRequest(SQLModelBase): class UserAdminUpdateRequest(SQLModelBase): """管理员更新用户请求 DTO""" - - email: str = Field(max_length=50) + + email: str | None = Field(default=None, max_length=50) """邮箱""" nickname: str | None = Field(default=None, max_length=50) """昵称""" - password: str | None = None - """新密码(为空则不修改)""" + phone: str | None = None + """手机号""" group_id: UUID | None = None """用户组UUID""" @@ -389,9 +413,6 @@ class UserAdminUpdateRequest(SQLModelBase): group_expires: datetime | None = None """用户组过期时间""" - two_factor: str | None = None - """两步验证密钥(32位字符串,传 null 可清除,不传则不修改)""" - class UserCalibrateResponse(SQLModelBase): """用户存储校准响应 DTO""" @@ -415,9 +436,6 @@ class UserCalibrateResponse(SQLModelBase): class UserAdminDetailResponse(UserPublic): """管理员用户详情响应 DTO""" - two_factor_enabled: bool = False - """是否启用两步验证""" - file_count: int = 0 """文件数量""" @@ -443,14 +461,14 @@ UserSettingResponse.model_rebuild() class User(UserBase, UUIDTableBaseMixin): """用户模型""" - email: str = Field(max_length=50, unique=True, index=True) - """用户邮箱,唯一""" + email: str | None = Field(default=None, max_length=50, unique=True, index=True) + """用户邮箱(社交登录用户可能没有邮箱)""" nickname: str | None = Field(default=None, max_length=50) """用于公开展示的名字,可使用真实姓名或昵称""" - password: str = Field(max_length=255) - """用户密码(加密后)""" + phone: str | None = Field(default=None, max_length=20, unique=True, index=True) + """手机号(预留)""" status: UserStatus = UserStatus.ACTIVE """用户状态""" @@ -458,9 +476,6 @@ class User(UserBase, UUIDTableBaseMixin): storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0) """已用存储空间(字节)""" - two_factor: str | None = Field(default=None, min_length=32, max_length=32) - """两步验证密钥""" - avatar: str = Field(default="default", max_length=255) """头像地址""" @@ -533,6 +548,12 @@ class User(UserBase, UUIDTableBaseMixin): } ) + auth_identities: list["AuthIdentity"] = Relationship( + back_populates="user", + sa_relationship_kwargs={"cascade": "all, delete-orphan"} + ) + """用户的认证身份列表""" + downloads: list["Download"] = Relationship( back_populates="user", sa_relationship_kwargs={"cascade": "all, delete-orphan"} @@ -634,4 +655,3 @@ class User(UserBase, UUIDTableBaseMixin): filter=filter, table_view=table_view, ) - \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 574a4ec..e61c162 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..') from main import app from sqlmodels.database import get_session +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType from sqlmodels.group import Group, GroupClaims, GroupOptions from sqlmodels.migration import migration from sqlmodels.object import Object, ObjectType @@ -192,7 +193,6 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: user = User( email="testuser@test.local", nickname="测试用户", - password=Password.hash(password), status=UserStatus.ACTIVE, storage=0, score=100, @@ -200,6 +200,17 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: ) user = await user.save(db_session) + # 创建邮箱密码认证身份 + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="testuser@test.local", + credential=Password.hash(password), + is_primary=True, + is_verified=True, + user_id=user.id, + ) + await identity.save(db_session) + # 创建用户根目录 root_folder = Object( name="/", @@ -279,7 +290,6 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]: admin = User( email="admin@disknext.local", nickname="管理员", - password=Password.hash(password), status=UserStatus.ACTIVE, storage=0, score=9999, @@ -287,6 +297,17 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]: ) admin = await admin.save(db_session) + # 创建管理员邮箱密码认证身份 + admin_identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="admin@disknext.local", + credential=Password.hash(password), + is_primary=True, + is_verified=True, + user_id=admin.id, + ) + await admin_identity.save(db_session) + # 创建管理员根目录 root_folder = Object( name="/", diff --git a/tests/fixtures/users.py b/tests/fixtures/users.py index e4dfa20..0fefd57 100644 --- a/tests/fixtures/users.py +++ b/tests/fixtures/users.py @@ -2,12 +2,14 @@ 用户测试数据工厂 提供创建测试用户的便捷方法。 +用户密码凭证通过 AuthIdentity 管理,不再存储在 User 表中。 """ from uuid import UUID from sqlmodel.ext.asyncio.session import AsyncSession -from sqlmodels.user import User +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType +from sqlmodels.user import User, UserStatus from utils.password.pwd import Password @@ -20,7 +22,7 @@ class UserFactory: group_id: UUID, email: str | None = None, password: str | None = None, - **kwargs + **kwargs, ) -> User: """ 创建普通用户 @@ -29,7 +31,7 @@ class UserFactory: session: 数据库会话 group_id: 用户组UUID email: 用户邮箱(默认: test_user_{随机}@test.local) - password: 明文密码(默认: password123) + password: 明文密码(默认: password123),若提供则同时创建 AuthIdentity **kwargs: 其他用户字段 返回: @@ -46,12 +48,10 @@ class UserFactory: user = User( email=email, nickname=kwargs.get("nickname", email), - password=Password.hash(password), - status=kwargs.get("status", True), + status=kwargs.get("status", UserStatus.ACTIVE), storage=kwargs.get("storage", 0), score=kwargs.get("score", 100), group_id=group_id, - two_factor=kwargs.get("two_factor"), avatar=kwargs.get("avatar", "default"), group_expires=kwargs.get("group_expires"), theme=kwargs.get("theme", "system"), @@ -61,6 +61,18 @@ class UserFactory: ) user = await user.save(session) + + # 创建邮箱密码认证身份 + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier=email, + credential=Password.hash(password), + is_primary=True, + is_verified=True, + user_id=user.id, + ) + await identity.save(session) + return user @staticmethod @@ -68,7 +80,7 @@ class UserFactory: session: AsyncSession, admin_group_id: UUID, email: str | None = None, - password: str | None = None + password: str | None = None, ) -> User: """ 创建管理员用户 @@ -93,8 +105,7 @@ class UserFactory: admin = User( email=email, nickname=f"管理员 {email}", - password=Password.hash(password), - status=True, + status=UserStatus.ACTIVE, storage=0, score=9999, group_id=admin_group_id, @@ -102,13 +113,25 @@ class UserFactory: ) admin = await admin.save(session) + + # 创建邮箱密码认证身份 + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier=email, + credential=Password.hash(password), + is_primary=True, + is_verified=True, + user_id=admin.id, + ) + await identity.save(session) + return admin @staticmethod async def create_banned( session: AsyncSession, group_id: UUID, - email: str | None = None + email: str | None = None, ) -> User: """ 创建被封禁用户 @@ -129,8 +152,7 @@ class UserFactory: banned_user = User( email=email, nickname=f"封禁用户 {email}", - password=Password.hash("banned_password"), - status=False, # 封禁状态 + status=UserStatus.ADMIN_BANNED, storage=0, score=0, group_id=group_id, @@ -138,6 +160,18 @@ class UserFactory: ) banned_user = await banned_user.save(session) + + # 创建邮箱密码认证身份 + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier=email, + credential=Password.hash("banned_password"), + is_primary=True, + is_verified=True, + user_id=banned_user.id, + ) + await identity.save(session) + return banned_user @staticmethod @@ -145,7 +179,7 @@ class UserFactory: session: AsyncSession, group_id: UUID, storage_bytes: int, - email: str | None = None + email: str | None = None, ) -> User: """ 创建已使用指定存储空间的用户 @@ -167,8 +201,7 @@ class UserFactory: user = User( email=email, nickname=email, - password=Password.hash("password123"), - status=True, + status=UserStatus.ACTIVE, storage=storage_bytes, score=100, group_id=group_id, @@ -176,4 +209,16 @@ class UserFactory: ) user = await user.save(session) + + # 创建邮箱密码认证身份 + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier=email, + credential=Password.hash("password123"), + is_primary=True, + is_verified=True, + user_id=user.id, + ) + await identity.save(session) + return user diff --git a/tests/integration/api/test_site.py b/tests/integration/api/test_site.py index ccb6b6f..62b10ac 100644 --- a/tests/integration/api/test_site.py +++ b/tests/integration/api/test_site.py @@ -83,6 +83,24 @@ async def test_site_config_captcha_settings(async_client: AsyncClient): assert "forgetCaptcha" in config +@pytest.mark.asyncio +async def test_site_config_auth_methods(async_client: AsyncClient): + """测试配置包含认证方式列表""" + response = await async_client.get("/api/site/config") + assert response.status_code == 200 + + data = response.json() + config = data["data"] + assert "authMethods" in config + assert isinstance(config["authMethods"], list) + assert len(config["authMethods"]) > 0 + + # 每个认证方式应包含 provider 和 isEnabled + for method in config["authMethods"]: + assert "provider" in method + assert "isEnabled" in method + + @pytest.mark.asyncio async def test_site_captcha_endpoint_exists(async_client: AsyncClient): """测试验证码端点存在(即使未实现也应返回有效响应)""" diff --git a/tests/integration/api/test_user.py b/tests/integration/api/test_user.py index c851a8d..7eb234f 100644 --- a/tests/integration/api/test_user.py +++ b/tests/integration/api/test_user.py @@ -15,9 +15,10 @@ async def test_user_login_success( """测试成功登录""" response = await async_client.post( "/api/user/session", - data={ - "username": test_user_info["email"], - "password": test_user_info["password"], + json={ + "provider": "email_password", + "identifier": test_user_info["email"], + "credential": test_user_info["password"], } ) assert response.status_code == 200 @@ -37,9 +38,10 @@ async def test_user_login_wrong_password( """测试密码错误返回 401""" response = await async_client.post( "/api/user/session", - data={ - "username": test_user_info["email"], - "password": "wrongpassword", + json={ + "provider": "email_password", + "identifier": test_user_info["email"], + "credential": "wrongpassword", } ) assert response.status_code == 401 @@ -50,9 +52,10 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient): """测试不存在的用户返回 401""" response = await async_client.post( "/api/user/session", - data={ - "username": "nonexistent@test.local", - "password": "anypassword", + json={ + "provider": "email_password", + "identifier": "nonexistent@test.local", + "credential": "anypassword", } ) assert response.status_code == 401 @@ -66,9 +69,10 @@ async def test_user_login_user_banned( """测试封禁用户返回 403""" response = await async_client.post( "/api/user/session", - data={ - "username": banned_user_info["email"], - "password": banned_user_info["password"], + json={ + "provider": "email_password", + "identifier": banned_user_info["email"], + "credential": banned_user_info["password"], } ) assert response.status_code == 403 @@ -82,8 +86,9 @@ async def test_user_register_success(async_client: AsyncClient): response = await async_client.post( "/api/user/", json={ - "email": "newuser@test.local", - "password": "newpass123", + "provider": "email_password", + "identifier": "newuser@test.local", + "credential": "newpass123", } ) assert response.status_code == 200 @@ -104,8 +109,9 @@ async def test_user_register_duplicate_email( response = await async_client.post( "/api/user/", json={ - "email": test_user_info["email"], - "password": "anypassword", + "provider": "email_password", + "identifier": test_user_info["email"], + "credential": "anypassword", } ) assert response.status_code == 400 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 360b1e0..bd907f9 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -23,6 +23,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../. from main import app from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType from sqlmodels.user import UserStatus from utils import Password from utils.JWT import create_access_token @@ -98,6 +99,15 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""), Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"), Setting(type=SettingsType.AUTH, name="secret_key", value="test_secret_key_for_jwt_token_generation"), + Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"), + Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"), + Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"), + Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"), + Setting(type=SettingsType.AUTH, name="auth_password_required", value="1"), + Setting(type=SettingsType.AUTH, name="auth_phone_binding_required", value="0"), + Setting(type=SettingsType.AUTH, name="auth_email_binding_required", value="1"), + Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"), + Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"), ] for setting in settings: test_session.add(setting) @@ -183,7 +193,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: test_user = User( id=uuid4(), email="testuser@test.local", - password=Password.hash("testpass123"), nickname="测试用户", status=UserStatus.ACTIVE, storage=0, @@ -196,7 +205,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: admin_user = User( id=uuid4(), email="admin@disknext.local", - password=Password.hash("adminpass123"), nickname="管理员", status=UserStatus.ACTIVE, storage=0, @@ -209,7 +217,6 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: banned_user = User( id=uuid4(), email="banneduser@test.local", - password=Password.hash("banned123"), nickname="封禁用户", status=UserStatus.ADMIN_BANNED, storage=0, @@ -226,7 +233,40 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: await test_session.refresh(admin_user) await test_session.refresh(banned_user) - # 7. 创建用户根目录 + # 7. 创建认证身份 + test_user_identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="testuser@test.local", + credential=Password.hash("testpass123"), + is_primary=True, + is_verified=True, + user_id=test_user.id, + ) + test_session.add(test_user_identity) + + admin_user_identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="admin@disknext.local", + credential=Password.hash("adminpass123"), + is_primary=True, + is_verified=True, + user_id=admin_user.id, + ) + test_session.add(admin_user_identity) + + banned_user_identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="banneduser@test.local", + credential=Password.hash("banned123"), + is_primary=True, + is_verified=True, + user_id=banned_user.id, + ) + test_session.add(banned_user_identity) + + await test_session.commit() + + # 8. 创建用户根目录 test_user_root = Object( id=uuid4(), name="/", @@ -251,7 +291,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: await test_session.commit() - # 8. 设置JWT密钥(从数据库加载) + # 9. 设置JWT密钥(从数据库加载) JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation" # 刷新 group options diff --git a/tests/test_db_user.py b/tests/test_db_user.py index 0b6a728..aaefa49 100644 --- a/tests/test_db_user.py +++ b/tests/test_db_user.py @@ -18,7 +18,6 @@ async def test_user_curd(): test_user = User( email='test_user@test.local', - password='test_password', group_id=created_group.id ) @@ -28,7 +27,6 @@ async def test_user_curd(): # 验证用户是否存在 assert created_user.id is not None assert created_user.email == 'test_user@test.local' - assert created_user.password == 'test_password' assert created_user.group_id == created_group.id # 测试查 Read @@ -36,18 +34,16 @@ async def test_user_curd(): assert fetched_user is not None assert fetched_user.email == 'test_user@test.local' - assert fetched_user.password == 'test_password' assert fetched_user.group_id == created_group.id # 测试改 Update updated_user = await fetched_user.update( session, - {"email": "updated_user@test.local", "password": "updated_password"} + {"email": "updated_user@test.local"} ) assert updated_user is not None assert updated_user.email == 'updated_user@test.local' - assert updated_user.password == 'updated_password' # 测试删除 Delete await updated_user.delete(session) diff --git a/tests/unit/models/test_object.py b/tests/unit/models/test_object.py index 928b95f..16a9bca 100644 --- a/tests/unit/models/test_object.py +++ b/tests/unit/models/test_object.py @@ -19,7 +19,7 @@ async def test_object_create_folder(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="testuser", password="password", group_id=group.id) + user = User(email="testuser", group_id=group.id) user = await user.save(db_session) policy = Policy( @@ -53,7 +53,7 @@ async def test_object_create_file(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="testuser", password="password", group_id=group.id) + user = User(email="testuser", group_id=group.id) user = await user.save(db_session) policy = Policy( @@ -98,7 +98,7 @@ async def test_object_is_file_property(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="testuser", password="password", group_id=group.id) + user = User(email="testuser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -125,7 +125,7 @@ async def test_object_is_folder_property(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="testuser", password="password", group_id=group.id) + user = User(email="testuser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -151,7 +151,7 @@ async def test_object_get_root(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="rootuser", password="password", group_id=group.id) + user = User(email="rootuser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -183,7 +183,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="pathuser", password="password", group_id=group.id) + user = User(email="pathuser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -214,7 +214,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="nesteduser", password="password", group_id=group.id) + user = User(email="nesteduser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -277,7 +277,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="notfounduser", password="password", group_id=group.id) + user = User(email="notfounduser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -311,7 +311,7 @@ async def test_object_get_children(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="childrenuser", password="password", group_id=group.id) + user = User(email="childrenuser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -363,7 +363,7 @@ async def test_object_parent_child_relationship(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="reluser", password="password", group_id=group.id) + user = User(email="reluser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -408,7 +408,7 @@ async def test_object_unique_constraint(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="uniqueuser", password="password", group_id=group.id) + user = User(email="uniqueuser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -456,7 +456,7 @@ async def test_object_get_full_path(db_session: AsyncSession): group = Group(name="测试组") group = await group.save(db_session) - user = User(email="pathuser", password="password", group_id=group.id) + user = User(email="pathuser", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") diff --git a/tests/unit/models/test_user.py b/tests/unit/models/test_user.py index c490fc0..01b419e 100644 --- a/tests/unit/models/test_user.py +++ b/tests/unit/models/test_user.py @@ -20,7 +20,6 @@ async def test_user_create(db_session: AsyncSession): user = User( email="testuser@test.local", nickname="测试用户", - password="hashed_password", group_id=group.id ) user = await user.save(db_session) @@ -43,7 +42,6 @@ async def test_user_unique_email(db_session: AsyncSession): # 创建第一个用户 user1 = User( email="duplicate@test.local", - password="password1", group_id=group.id ) await user1.save(db_session) @@ -51,7 +49,6 @@ async def test_user_unique_email(db_session: AsyncSession): # 尝试创建同名用户 user2 = User( email="duplicate@test.local", - password="password2", group_id=group.id ) @@ -70,7 +67,6 @@ async def test_user_to_public(db_session: AsyncSession): user = User( email="publicuser@test.local", nickname="公开用户", - password="secret_password", storage=1024, avatar="avatar.jpg", group_id=group.id @@ -88,8 +84,6 @@ async def test_user_to_public(db_session: AsyncSession): # 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段 assert public_user.nick is None # 实际行为 assert public_user.storage == 1024 - # 密码不应该在公开数据中 - assert not hasattr(public_user, 'password') @pytest.mark.asyncio @@ -102,7 +96,6 @@ async def test_user_group_relationship(db_session: AsyncSession): # 创建用户 user = User( email="vipuser@test.local", - password="password", group_id=group.id ) user = await user.save(db_session) @@ -126,7 +119,6 @@ async def test_user_status_default(db_session: AsyncSession): user = User( email="defaultuser@test.local", - password="password", group_id=group.id ) user = await user.save(db_session) @@ -142,7 +134,6 @@ async def test_user_storage_default(db_session: AsyncSession): user = User( email="storageuser@test.local", - password="password", group_id=group.id ) user = await user.save(db_session) @@ -159,7 +150,6 @@ async def test_user_theme_enum(db_session: AsyncSession): # 测试默认值 user1 = User( email="user1@test.local", - password="password", group_id=group.id ) user1 = await user1.save(db_session) @@ -168,7 +158,6 @@ async def test_user_theme_enum(db_session: AsyncSession): # 测试设置为 LIGHT user2 = User( email="user2@test.local", - password="password", theme=ThemeType.LIGHT, group_id=group.id ) @@ -178,9 +167,40 @@ async def test_user_theme_enum(db_session: AsyncSession): # 测试设置为 DARK user3 = User( email="user3@test.local", - password="password", theme=ThemeType.DARK, group_id=group.id ) user3 = await user3.save(db_session) assert user3.theme == ThemeType.DARK + + +@pytest.mark.asyncio +async def test_user_email_optional(db_session: AsyncSession): + """测试 email 可以为空(支持社交登录用户)""" + group = Group(name="默认组") + group = await group.save(db_session) + + user = User( + nickname="社交用户", + group_id=group.id + ) + user = await user.save(db_session) + + assert user.id is not None + assert user.email is None + + +@pytest.mark.asyncio +async def test_user_phone_field(db_session: AsyncSession): + """测试 phone 字段""" + group = Group(name="默认组") + group = await group.save(db_session) + + user = User( + email="phoneuser@test.local", + phone="13800138000", + group_id=group.id + ) + user = await user.save(db_session) + + assert user.phone == "13800138000" diff --git a/tests/unit/service/test_login.py b/tests/unit/service/test_login.py index 9270f77..821569e 100644 --- a/tests/unit/service/test_login.py +++ b/tests/unit/service/test_login.py @@ -1,78 +1,154 @@ """ Login 服务的单元测试 + +测试 unified_login() 各 provider 路径。 """ import pytest +from fastapi import HTTPException from sqlmodel.ext.asyncio.session import AsyncSession -from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus -from sqlmodels.group import Group -from service.user.login import login +from sqlmodels.auth_identity import AuthIdentity, AuthProviderType +from sqlmodels.setting import Setting, SettingsType +from sqlmodels.user import User, UnifiedLoginRequest, TokenResponse, UserStatus +from sqlmodels.group import Group, GroupOptions +from service.user.login import unified_login from utils.password.pwd import Password @pytest.fixture -async def setup_user(db_session: AsyncSession): - """创建测试用户""" +async def setup_auth_settings(db_session: AsyncSession): + """创建认证相关的 Setting 配置""" + settings = [ + Setting(type=SettingsType.AUTH, name="auth_email_password_enabled", value="1"), + Setting(type=SettingsType.AUTH, name="auth_phone_sms_enabled", value="0"), + Setting(type=SettingsType.AUTH, name="auth_passkey_enabled", value="0"), + Setting(type=SettingsType.AUTH, name="auth_magic_link_enabled", value="0"), + Setting(type=SettingsType.OAUTH, name="github_enabled", value="0"), + Setting(type=SettingsType.OAUTH, name="qq_enabled", value="0"), + ] + for s in settings: + await s.save(db_session) + + +@pytest.fixture +async def setup_user(db_session: AsyncSession, setup_auth_settings): + """创建测试用户和邮箱密码认证身份""" # 创建用户组 group = Group(name="测试组") group = await group.save(db_session) + # 创建用户组选项 + group_options = GroupOptions( + group_id=group.id, + share_download=True, + share_free=False, + relocate=False, + ) + await group_options.save(db_session) + # 创建正常用户 plain_password = "secure_password_123" user = User( email="loginuser@test.local", - password=Password.hash(plain_password), status=UserStatus.ACTIVE, - group_id=group.id + group_id=group.id, ) user = await user.save(db_session) + # 创建邮箱密码认证身份 + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="loginuser@test.local", + credential=Password.hash(plain_password), + is_primary=True, + is_verified=True, + user_id=user.id, + ) + await identity.save(db_session) + return { "user": user, "password": plain_password, - "group_id": group.id + "group_id": group.id, } @pytest.fixture -async def setup_banned_user(db_session: AsyncSession): +async def setup_banned_user(db_session: AsyncSession, setup_auth_settings): """创建被封禁的用户""" group = Group(name="测试组2") group = await group.save(db_session) + group_options = GroupOptions( + group_id=group.id, + share_download=True, + share_free=False, + relocate=False, + ) + await group_options.save(db_session) + user = User( email="banneduser@test.local", - password=Password.hash("password"), - status=UserStatus.ADMIN_BANNED, # 封禁状态 - group_id=group.id + status=UserStatus.ADMIN_BANNED, + group_id=group.id, ) user = await user.save(db_session) + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="banneduser@test.local", + credential=Password.hash("password"), + is_primary=True, + is_verified=True, + user_id=user.id, + ) + await identity.save(db_session) + return user @pytest.fixture -async def setup_2fa_user(db_session: AsyncSession): +async def setup_2fa_user(db_session: AsyncSession, setup_auth_settings): """创建启用了两步验证的用户""" import pyotp group = Group(name="测试组3") group = await group.save(db_session) + group_options = GroupOptions( + group_id=group.id, + share_download=True, + share_free=False, + relocate=False, + ) + await group_options.save(db_session) + secret = pyotp.random_base32() user = User( email="2fauser@test.local", - password=Password.hash("password"), status=UserStatus.ACTIVE, - two_factor=secret, - group_id=group.id + group_id=group.id, ) user = await user.save(db_session) + # 创建带 2FA secret 的邮箱密码认证身份 + import orjson + extra_data = orjson.dumps({"two_factor": secret}).decode('utf-8') + identity = AuthIdentity( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="2fauser@test.local", + credential=Password.hash("password"), + extra_data=extra_data, + is_primary=True, + is_verified=True, + user_id=user.id, + ) + await identity.save(db_session) + return { "user": user, "secret": secret, - "password": "password" + "password": "password", } @@ -81,12 +157,13 @@ async def test_login_success(db_session: AsyncSession, setup_user): """测试正常登录""" user_data = setup_user - login_request = LoginRequest( - email="loginuser@test.local", - password=user_data["password"] + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="loginuser@test.local", + credential=user_data["password"], ) - result = await login(db_session, login_request) + result = await unified_login(db_session, request) assert isinstance(result, TokenResponse) assert result.access_token is not None @@ -96,42 +173,48 @@ async def test_login_success(db_session: AsyncSession, setup_user): @pytest.mark.asyncio -async def test_login_user_not_found(db_session: AsyncSession): +async def test_login_user_not_found(db_session: AsyncSession, setup_user): """测试用户不存在""" - login_request = LoginRequest( - email="nonexistent@test.local", - password="any_password" + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="nonexistent@test.local", + credential="any_password", ) - result = await login(db_session, login_request) + with pytest.raises(HTTPException) as exc_info: + await unified_login(db_session, request) - assert result is None + assert exc_info.value.status_code == 401 @pytest.mark.asyncio async def test_login_wrong_password(db_session: AsyncSession, setup_user): """测试密码错误""" - login_request = LoginRequest( - email="loginuser@test.local", - password="wrong_password" + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="loginuser@test.local", + credential="wrong_password", ) - result = await login(db_session, login_request) + with pytest.raises(HTTPException) as exc_info: + await unified_login(db_session, request) - assert result is None + assert exc_info.value.status_code == 401 @pytest.mark.asyncio async def test_login_user_banned(db_session: AsyncSession, setup_banned_user): """测试用户被封禁""" - login_request = LoginRequest( - email="banneduser@test.local", - password="password" + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="banneduser@test.local", + credential="password", ) - result = await login(db_session, login_request) + with pytest.raises(HTTPException) as exc_info: + await unified_login(db_session, request) - assert result is False + assert exc_info.value.status_code == 403 @pytest.mark.asyncio @@ -139,15 +222,17 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user): """测试需要 2FA""" user_data = setup_2fa_user - login_request = LoginRequest( - email="2fauser@test.local", - password=user_data["password"] + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="2fauser@test.local", + credential=user_data["password"], # 未提供 two_fa_code ) - result = await login(db_session, login_request) + with pytest.raises(HTTPException) as exc_info: + await unified_login(db_session, request) - assert result == "2fa_required" + assert exc_info.value.status_code == 428 @pytest.mark.asyncio @@ -155,15 +240,17 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user): """测试 2FA 错误""" user_data = setup_2fa_user - login_request = LoginRequest( - email="2fauser@test.local", - password=user_data["password"], - two_fa_code="000000" # 错误的验证码 + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="2fauser@test.local", + credential=user_data["password"], + two_fa_code="000000", ) - result = await login(db_session, login_request) + with pytest.raises(HTTPException) as exc_info: + await unified_login(db_session, request) - assert result == "2fa_invalid" + assert exc_info.value.status_code == 401 @pytest.mark.asyncio @@ -178,56 +265,44 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user): totp = pyotp.TOTP(secret) valid_code = totp.now() - login_request = LoginRequest( - email="2fauser@test.local", - password=user_data["password"], - two_fa_code=valid_code + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="2fauser@test.local", + credential=user_data["password"], + two_fa_code=valid_code, ) - result = await login(db_session, login_request) + result = await unified_login(db_session, request) assert isinstance(result, TokenResponse) assert result.access_token is not None @pytest.mark.asyncio -async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user): - """测试返回的令牌可以被解码""" - import jwt as pyjwt - - user_data = setup_user - - login_request = LoginRequest( - email="loginuser@test.local", - password=user_data["password"] +async def test_login_provider_disabled(db_session: AsyncSession, setup_user): + """测试未启用的 provider""" + request = UnifiedLoginRequest( + provider=AuthProviderType.PHONE_SMS, + identifier="13800138000", + credential="123456", ) - result = await login(db_session, login_request) + with pytest.raises(HTTPException) as exc_info: + await unified_login(db_session, request) - assert isinstance(result, TokenResponse) - - # 注意: 实际项目中需要使用正确的 SECRET_KEY - # 这里假设测试环境已经设置了 SECRET_KEY - # decoded = pyjwt.decode( - # result.access_token, - # SECRET_KEY, - # algorithms=["HS256"] - # ) - # assert decoded["sub"] == "loginuser" + assert exc_info.value.status_code == 400 @pytest.mark.asyncio -async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user): - """测试邮箱大小写敏感""" - user_data = setup_user - - # 使用大写邮箱登录 - login_request = LoginRequest( - email="LOGINUSER@TEST.LOCAL", - password=user_data["password"] +async def test_login_missing_password(db_session: AsyncSession, setup_user): + """测试邮箱密码登录缺少密码""" + request = UnifiedLoginRequest( + provider=AuthProviderType.EMAIL_PASSWORD, + identifier="loginuser@test.local", + # 未提供 credential ) - result = await login(db_session, login_request) + with pytest.raises(HTTPException) as exc_info: + await unified_login(db_session, request) - # 应该失败,因为邮箱大小写不匹配 - assert result is None + assert exc_info.value.status_code == 400