feat: add multi-provider auth via AuthIdentity and extend site config
- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link) - Remove password field from User model, credentials now stored in AuthIdentity - Refactor unified login/register to use AuthIdentity-based provider checking - Add site config fields: footer_code, tos_url, privacy_url, auth_methods - Add auth settings defaults in migration (email_password enabled by default) - Update admin user creation to create AuthIdentity records - Update all tests to use AuthIdentity model Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from sqlmodels import (
|
||||
Group, Object, ObjectType, Setting, SettingsType,
|
||||
BatchDeleteRequest,
|
||||
)
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import (
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
|
||||
)
|
||||
@@ -83,13 +84,26 @@ async def router_admin_create_user(
|
||||
"""
|
||||
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
||||
|
||||
管理员创建用户时,若提供了 email + password,
|
||||
会同时创建 AuthIdentity(provider=email_password)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 创建用户请求 DTO
|
||||
:return: 创建结果
|
||||
"""
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
# 如果提供了邮箱,检查唯一性(User 表和 AuthIdentity 表)
|
||||
if request.email:
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
existing_identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
& (AuthIdentity.identifier == request.email),
|
||||
)
|
||||
if existing_identity:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被绑定")
|
||||
|
||||
# 验证用户组存在
|
||||
group = await Group.get(session, Group.id == request.group_id)
|
||||
@@ -98,12 +112,24 @@ async def router_admin_create_user(
|
||||
|
||||
user = User(
|
||||
email=request.email,
|
||||
password=Password.hash(request.password),
|
||||
nickname=request.nickname,
|
||||
group_id=request.group_id,
|
||||
status=request.status,
|
||||
)
|
||||
user = await user.save(session)
|
||||
|
||||
# 如果提供了邮箱和密码,创建邮箱密码认证身份
|
||||
if request.email and request.password:
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=request.email,
|
||||
credential=Password.hash(request.password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@@ -148,17 +174,7 @@ async def router_admin_update_user(
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||
|
||||
# 如果更新密码,需要加密
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if 'password' in update_data and update_data['password']:
|
||||
update_data['password'] = Password.hash(update_data['password'])
|
||||
elif 'password' in update_data:
|
||||
del update_data['password'] # 空密码不更新
|
||||
|
||||
# 验证两步验证密钥格式(如果提供了值且不为 None,长度必须为 32)
|
||||
if 'two_factor' in update_data and update_data['two_factor'] is not None:
|
||||
if len(update_data['two_factor']) != 32:
|
||||
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
|
||||
|
||||
# 记录旧 status 以便检测变更
|
||||
old_status = user.status
|
||||
@@ -175,7 +191,7 @@ async def router_admin_update_user(
|
||||
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
|
||||
await UserBanStore.unban(str(user_id))
|
||||
|
||||
l.info(f"管理员更新了用户: {request.email}")
|
||||
l.info(f"管理员更新了用户: {user.email}")
|
||||
|
||||
|
||||
@admin_user_router.delete(
|
||||
|
||||
@@ -4,7 +4,9 @@ from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
ResponseBase, Setting, SettingsType, SiteConfigResponse,
|
||||
ThemePreset, ThemePresetResponse, ThemePresetListResponse,
|
||||
AuthMethodConfig,
|
||||
)
|
||||
from sqlmodels.auth_identity import AuthProviderType
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
@@ -70,7 +72,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
获取站点全局配置
|
||||
|
||||
无需认证。前端在初始化时调用此端点获取验证码类型、
|
||||
登录/注册/找回密码是否需要验证码等配置。
|
||||
登录/注册/找回密码是否需要验证码、可用的认证方式等配置。
|
||||
"""
|
||||
# 批量查询所需设置
|
||||
settings: list[Setting] = await Setting.get(
|
||||
@@ -78,7 +80,9 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
(Setting.type == SettingsType.BASIC) |
|
||||
(Setting.type == SettingsType.LOGIN) |
|
||||
(Setting.type == SettingsType.REGISTER) |
|
||||
(Setting.type == SettingsType.CAPTCHA),
|
||||
(Setting.type == SettingsType.CAPTCHA) |
|
||||
(Setting.type == SettingsType.AUTH) |
|
||||
(Setting.type == SettingsType.OAUTH),
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
@@ -94,6 +98,16 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
captcha_key = s.get("captcha_CloudflareKey") or None
|
||||
|
||||
# 构建认证方式列表
|
||||
auth_methods: list[AuthMethodConfig] = [
|
||||
AuthMethodConfig(provider=AuthProviderType.EMAIL_PASSWORD, is_enabled=s.get("auth_email_password_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.PHONE_SMS, is_enabled=s.get("auth_phone_sms_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.GITHUB, is_enabled=s.get("github_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.QQ, is_enabled=s.get("qq_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.PASSKEY, is_enabled=s.get("auth_passkey_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.MAGIC_LINK, is_enabled=s.get("auth_magic_link_enabled") == "1"),
|
||||
]
|
||||
|
||||
return SiteConfigResponse(
|
||||
title=s.get("siteName") or "DiskNext",
|
||||
register_enabled=s.get("register_enabled") == "1",
|
||||
@@ -102,4 +116,11 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
forget_captcha=s.get("forget_captcha") == "1",
|
||||
captcha_type=captcha_type,
|
||||
captcha_key=captcha_key,
|
||||
auth_methods=auth_methods,
|
||||
password_required=s.get("auth_password_required") == "1",
|
||||
phone_binding_required=s.get("auth_phone_binding_required") == "1",
|
||||
email_binding_required=s.get("auth_email_binding_required") == "1",
|
||||
footer_code=s.get("footer_code"),
|
||||
tos_url=s.get("tos_url"),
|
||||
privacy_url=s.get("privacy_url"),
|
||||
)
|
||||
@@ -2,7 +2,8 @@ from typing import Annotated, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from itsdangerous import URLSafeTimedSerializer
|
||||
from loguru import logger
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
@@ -12,6 +13,7 @@ import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep, require_captcha
|
||||
from service.captcha import CaptchaScene
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from .settings import user_settings_router
|
||||
@@ -23,59 +25,36 @@ user_router = APIRouter(
|
||||
|
||||
user_router.include_router(user_settings_router)
|
||||
|
||||
class OAuth2PasswordWithExtrasForm:
|
||||
"""
|
||||
扩展 OAuth2 密码表单。
|
||||
|
||||
在标准 username/password 基础上添加 otp_code 字段。
|
||||
captcha_code 由 require_captcha 依赖注入单独处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
username: Annotated[str, Form()],
|
||||
password: Annotated[str, Form()],
|
||||
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
|
||||
):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.otp_code = otp_code
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
summary='用户登录',
|
||||
description='用户登录端点,支持验证码校验和两步验证。',
|
||||
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
|
||||
summary='用户登录(统一入口)',
|
||||
description='统一登录端点,支持多种认证方式。',
|
||||
)
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
|
||||
request: sqlmodels.UnifiedLoginRequest,
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
用户登录端点
|
||||
统一登录端点
|
||||
|
||||
表单字段:
|
||||
- username: 用户邮箱
|
||||
- password: 用户密码
|
||||
- captcha_code: 验证码 token(可选,由 require_captcha 依赖校验)
|
||||
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
|
||||
请求体:
|
||||
- provider: 登录方式(email_password / github / qq / passkey / magic_link)
|
||||
- identifier: 标识符(邮箱 / OAuth code / credential_id / magic link token)
|
||||
- credential: 凭证(密码 / WebAuthn assertion 等)
|
||||
- two_fa_code: 两步验证码(可选)
|
||||
- redirect_uri: OAuth 回调地址(可选)
|
||||
- captcha: 验证码(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: 需要验证码但未提供
|
||||
- 401: 邮箱/密码错误,或 2FA 验证码错误
|
||||
- 403: 账户已禁用 / 验证码验证失败
|
||||
- 428: 需要两步验证但未提供 otp_code
|
||||
- 400: 登录方式未启用 / 参数错误
|
||||
- 401: 凭证错误
|
||||
- 403: 账户已禁用
|
||||
- 428: 需要两步验证
|
||||
- 501: 暂未实现的登录方式
|
||||
"""
|
||||
return await service.user.login(
|
||||
session,
|
||||
sqlmodels.LoginRequest(
|
||||
email=form_data.username,
|
||||
password=form_data.password,
|
||||
two_fa_code=form_data.otp_code,
|
||||
),
|
||||
)
|
||||
return await service.user.unified_login(session, request)
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session/refresh',
|
||||
@@ -150,41 +129,82 @@ async def router_user_session_refresh(
|
||||
|
||||
@user_router.post(
|
||||
path='/',
|
||||
summary='用户注册',
|
||||
summary='用户注册(统一入口)',
|
||||
description='User registration endpoint.',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_register(
|
||||
session: SessionDep,
|
||||
request: sqlmodels.RegisterRequest,
|
||||
request: sqlmodels.UnifiedRegisterRequest,
|
||||
) -> None:
|
||||
"""
|
||||
用户注册端点
|
||||
统一注册端点
|
||||
|
||||
流程:
|
||||
1. 验证用户名唯一性
|
||||
2. 获取默认用户组
|
||||
3. 创建用户记录
|
||||
4. 创建用户根目录(name="/")
|
||||
1. 检查注册开关
|
||||
2. 检查 provider 启用
|
||||
3. 验证 identifier 唯一性(AuthIdentity 表)
|
||||
4. 创建 User + AuthIdentity + 根目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 注册请求
|
||||
:return: 注册结果
|
||||
:raises HTTPException 400: 用户名已存在
|
||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||||
请求体:
|
||||
- provider: 注册方式(email_password / phone_sms)
|
||||
- identifier: 标识符(邮箱 / 手机号)
|
||||
- credential: 凭证(密码 / 短信验证码)
|
||||
- nickname: 昵称(可选)
|
||||
- captcha: 验证码(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: 注册未开放 / 参数错误
|
||||
- 409: 邮箱或手机号已存在
|
||||
- 501: 暂未实现的注册方式
|
||||
"""
|
||||
# 1. 验证邮箱唯一性
|
||||
# 1. 检查注册开关
|
||||
register_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||
& (sqlmodels.Setting.name == "register_enabled"),
|
||||
)
|
||||
if not register_setting or register_setting.value != "1":
|
||||
http_exceptions.raise_bad_request("注册功能未开放")
|
||||
|
||||
# 2. 目前只支持 email_password 注册
|
||||
if request.provider == AuthProviderType.PHONE_SMS:
|
||||
http_exceptions.raise_not_implemented("短信注册暂未开放")
|
||||
elif request.provider != AuthProviderType.EMAIL_PASSWORD:
|
||||
http_exceptions.raise_bad_request("不支持的注册方式")
|
||||
|
||||
# 3. 检查密码是否必填
|
||||
password_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_password_required"),
|
||||
)
|
||||
is_password_required = not password_required_setting or password_required_setting.value != "0"
|
||||
if is_password_required and not request.credential:
|
||||
http_exceptions.raise_bad_request("密码不能为空")
|
||||
|
||||
# 4. 验证 identifier 唯一性(AuthIdentity 表)
|
||||
existing_identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == request.provider)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if existing_identity:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 同时检查 User.email 唯一性(防止旧数据冲突)
|
||||
existing_user = await sqlmodels.User.get(
|
||||
session,
|
||||
sqlmodels.User.email == request.email
|
||||
sqlmodels.User.email == request.identifier,
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
||||
# 5. 获取默认用户组
|
||||
default_group_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||
& (sqlmodels.Setting.name == "default_group"),
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
logger.error("默认用户组不存在")
|
||||
@@ -196,17 +216,28 @@ async def router_user_register(
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
# 6. 创建用户
|
||||
new_user = sqlmodels.User(
|
||||
email=request.email,
|
||||
password=hashed_password,
|
||||
email=request.identifier,
|
||||
nickname=request.nickname,
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id
|
||||
await new_user.save(session)
|
||||
|
||||
# 4. 创建用户根目录
|
||||
# 7. 创建 AuthIdentity
|
||||
hashed_password = Password.hash(request.credential) if request.credential else None
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=request.identifier,
|
||||
credential=hashed_password,
|
||||
is_primary=True,
|
||||
is_verified=False,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
# 8. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
@@ -220,6 +251,66 @@ async def router_user_register(
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/magic-link',
|
||||
summary='发送 Magic Link 邮件',
|
||||
description='生成 Magic Link token 并发送到指定邮箱。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_magic_link(
|
||||
session: SessionDep,
|
||||
request: sqlmodels.MagicLinkRequest,
|
||||
) -> None:
|
||||
"""
|
||||
发送 Magic Link 邮件
|
||||
|
||||
流程:
|
||||
1. 验证邮箱对应的 AuthIdentity 存在
|
||||
2. 生成签名 token
|
||||
3. 发送邮件(包含带 token 的链接)
|
||||
|
||||
错误处理:
|
||||
- 400: Magic Link 未启用
|
||||
- 404: 邮箱未注册
|
||||
"""
|
||||
# 检查 magic_link 是否启用
|
||||
magic_link_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_magic_link_enabled"),
|
||||
)
|
||||
if not magic_link_setting or magic_link_setting.value != "1":
|
||||
http_exceptions.raise_bad_request("Magic Link 登录未启用")
|
||||
|
||||
# 验证邮箱存在
|
||||
identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.identifier == request.email)
|
||||
& (
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||
),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_not_found("该邮箱未注册")
|
||||
|
||||
# 生成签名 token
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
token = serializer.dumps(request.email, salt="magic-link-salt")
|
||||
|
||||
# 获取站点 URL
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.BASIC)
|
||||
& (sqlmodels.Setting.name == "siteURL"),
|
||||
)
|
||||
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
||||
|
||||
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token})
|
||||
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)")
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
summary='发送验证码邮件',
|
||||
@@ -230,52 +321,12 @@ def router_user_email_code(
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Send a verification code email.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing information about the password reset email.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/qq',
|
||||
summary='初始化QQ登录',
|
||||
description='Initialize QQ login for a user.',
|
||||
)
|
||||
def router_user_qq() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize QQ login for a user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing QQ login initialization information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='authn/{username}',
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
)
|
||||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
||||
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.post(
|
||||
path='authn/finish/{username}',
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
)
|
||||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/profile/{id}',
|
||||
summary='获取用户主页展示用分享',
|
||||
@@ -284,10 +335,10 @@ def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user profile for display.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The user ID.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing user profile information.
|
||||
"""
|
||||
@@ -301,11 +352,11 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user avatar by ID and size.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The user ID.
|
||||
size (int): The size of the avatar image.
|
||||
|
||||
|
||||
Returns:
|
||||
str: A Base64 encoded string of the user avatar image.
|
||||
"""
|
||||
@@ -348,8 +399,6 @@ async def router_user_me(
|
||||
return sqlmodels.UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
status=user.status,
|
||||
score=user.score,
|
||||
nickname=user.nickname,
|
||||
avatar=user.avatar,
|
||||
created_at=user.created_at,
|
||||
@@ -374,9 +423,9 @@ async def router_user_storage(
|
||||
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
|
||||
# [TODO] 总空间加上用户购买的额外空间
|
||||
|
||||
|
||||
total: int = group.max_storage
|
||||
used: int = user.storage
|
||||
free: int = max(0, total - used)
|
||||
@@ -389,8 +438,8 @@ async def router_user_storage(
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/start',
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
summary='注册 Passkey 凭证(初始化)',
|
||||
description='Initialize Passkey registration for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_authn_start(
|
||||
@@ -398,18 +447,19 @@ async def router_user_authn_start(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize WebAuthn login for a user.
|
||||
Passkey 注册初始化(需要登录)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn initialization information.
|
||||
返回 WebAuthn registration options,前端使用 navigator.credentials.create() 处理。
|
||||
|
||||
错误处理:
|
||||
- 400: Passkey 未启用
|
||||
"""
|
||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||
authn_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
||||
)
|
||||
if not authn_setting or authn_setting.value != "1":
|
||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||
raise HTTPException(status_code=400, detail="Passkey 未启用")
|
||||
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
@@ -423,23 +473,26 @@ async def router_user_authn_start(
|
||||
options = generate_registration_options(
|
||||
rp_id=site_url_setting.value if site_url_setting else "",
|
||||
rp_name=site_title_setting.value if site_title_setting else "",
|
||||
user_name=user.email,
|
||||
user_display_name=user.nickname or user.email,
|
||||
user_name=user.email or str(user.id),
|
||||
user_display_name=user.nickname or user.email or str(user.id),
|
||||
)
|
||||
|
||||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/finish',
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
summary='注册 Passkey 凭证(完成)',
|
||||
description='Finish Passkey registration for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Passkey 注册完成(需要登录)
|
||||
|
||||
接收前端 navigator.credentials.create() 返回的凭证数据,
|
||||
创建 UserAuthn 行 + AuthIdentity(provider=passkey)。
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
dict: A dictionary containing Passkey registration information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
@@ -9,6 +10,7 @@ from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest,
|
||||
SettingOption, UserSettingUpdateRequest,
|
||||
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
||||
)
|
||||
from sqlmodels.color import ThemeColorsBase
|
||||
from utils import JWT, Password, http_exceptions
|
||||
@@ -117,16 +119,29 @@ async def router_user_settings(
|
||||
else:
|
||||
theme_colors = BUILTIN_DEFAULT_COLORS
|
||||
|
||||
# 检查是否启用了两步验证(从 email_password AuthIdentity 的 extra_data 中读取)
|
||||
has_two_factor = False
|
||||
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.user_id == user.id)
|
||||
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||
)
|
||||
if email_identity and email_identity.extra_data:
|
||||
import orjson
|
||||
extra: dict = orjson.loads(email_identity.extra_data)
|
||||
has_two_factor = bool(extra.get("two_factor"))
|
||||
|
||||
return sqlmodels.UserSettingResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
phone=user.phone,
|
||||
nickname=user.nickname,
|
||||
created_at=user.created_at,
|
||||
group_name=user.group.name,
|
||||
language=user.language,
|
||||
timezone=user.timezone,
|
||||
group_expires=user.group_expires,
|
||||
two_factor=user.two_factor is not None,
|
||||
two_factor=has_two_factor,
|
||||
theme_preset_id=user.theme_preset_id,
|
||||
theme_colors=theme_colors,
|
||||
)
|
||||
@@ -255,7 +270,7 @@ async def router_user_settings_2fa(
|
||||
|
||||
返回 setup_token(用于后续验证请求)和 uri(用于生成二维码)。
|
||||
"""
|
||||
return await Password.generate_totp(name=user.email)
|
||||
return await Password.generate_totp(name=user.email or str(user.id))
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
@@ -273,7 +288,7 @@ async def router_user_settings_2fa_enable(
|
||||
"""
|
||||
启用两步验证
|
||||
|
||||
请求体包含 setup_token(GET /2fa 返回的令牌)和 code(6 位验证码)。
|
||||
将 2FA secret 存储到 email_password AuthIdentity 的 extra_data 中。
|
||||
"""
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
@@ -287,6 +302,150 @@ async def router_user_settings_2fa_enable(
|
||||
if Password.verify_totp(secret, request.code) != PasswordStatus.VALID:
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
# 将 secret 存储到 AuthIdentity.extra_data 中
|
||||
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.user_id == user.id)
|
||||
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||
)
|
||||
if not email_identity:
|
||||
raise HTTPException(status_code=400, detail="未找到邮箱密码认证身份")
|
||||
|
||||
import orjson
|
||||
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
|
||||
extra["two_factor"] = secret
|
||||
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
|
||||
await email_identity.save(session)
|
||||
|
||||
|
||||
# ==================== 认证身份管理 ====================
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/identities',
|
||||
summary='列出已绑定的认证身份',
|
||||
)
|
||||
async def router_user_settings_identities(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> list[AuthIdentityResponse]:
|
||||
"""
|
||||
列出当前用户已绑定的所有认证身份
|
||||
|
||||
返回:
|
||||
- 认证身份列表,包含 provider、identifier、display_name 等
|
||||
"""
|
||||
identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||
session,
|
||||
AuthIdentity.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
return [identity.to_response() for identity in identities]
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/identity',
|
||||
summary='绑定新的认证身份',
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def router_user_settings_bind_identity(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
request: BindIdentityRequest,
|
||||
) -> AuthIdentityResponse:
|
||||
"""
|
||||
绑定新的登录方式
|
||||
|
||||
请求体:
|
||||
- provider: 提供者类型
|
||||
- identifier: 标识符(邮箱 / 手机号 / OAuth code)
|
||||
- credential: 凭证(密码、验证码等)
|
||||
- redirect_uri: OAuth 回调地址(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: provider 未启用
|
||||
- 409: 该身份已被其他用户绑定
|
||||
"""
|
||||
# 检查是否已被绑定
|
||||
existing = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == request.provider)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="该身份已被绑定")
|
||||
|
||||
# 处理密码类型的凭证
|
||||
credential: str | None = None
|
||||
if request.provider == AuthProviderType.EMAIL_PASSWORD and request.credential:
|
||||
credential = Password.hash(request.credential)
|
||||
|
||||
identity = AuthIdentity(
|
||||
provider=request.provider,
|
||||
identifier=request.identifier,
|
||||
credential=credential,
|
||||
is_primary=False,
|
||||
is_verified=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
identity = await identity.save(session)
|
||||
return identity.to_response()
|
||||
|
||||
|
||||
@user_settings_router.delete(
|
||||
path='/identity/{identity_id}',
|
||||
summary='解绑认证身份',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_user_settings_unbind_identity(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
identity_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
解绑一个认证身份
|
||||
|
||||
约束:
|
||||
- 不能解绑最后一个身份
|
||||
- 站长配置强制绑定邮箱/手机号时,不能解绑对应身份
|
||||
|
||||
错误处理:
|
||||
- 404: 身份不存在或不属于当前用户
|
||||
- 400: 不能解绑最后一个身份 / 不能解绑强制绑定的身份
|
||||
"""
|
||||
# 查找目标身份
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.id == identity_id) & (AuthIdentity.user_id == user.id),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_not_found("认证身份不存在")
|
||||
|
||||
# 检查是否为最后一个身份
|
||||
all_identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||
session,
|
||||
AuthIdentity.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
if len(all_identities) <= 1:
|
||||
http_exceptions.raise_bad_request("不能解绑最后一个认证身份")
|
||||
|
||||
# 检查强制绑定约束
|
||||
if identity.provider == AuthProviderType.EMAIL_PASSWORD:
|
||||
email_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_email_binding_required"),
|
||||
)
|
||||
if email_required_setting and email_required_setting.value == "1":
|
||||
http_exceptions.raise_bad_request("站长要求必须绑定邮箱,不能解绑")
|
||||
|
||||
if identity.provider == AuthProviderType.PHONE_SMS:
|
||||
phone_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_phone_binding_required"),
|
||||
)
|
||||
if phone_required_setting and phone_required_setting.value == "1":
|
||||
http_exceptions.raise_bad_request("站长要求必须绑定手机号,不能解绑")
|
||||
|
||||
await AuthIdentity.delete(session, identity)
|
||||
|
||||
Reference in New Issue
Block a user