feat: add multi-provider auth via AuthIdentity and extend site config
- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link) - Remove password field from User model, credentials now stored in AuthIdentity - Refactor unified login/register to use AuthIdentity-based provider checking - Add site config fields: footer_code, tos_url, privacy_url, auth_methods - Add auth settings defaults in migration (email_password enabled by default) - Update admin user creation to create AuthIdentity records - Update all tests to use AuthIdentity model Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,7 +2,8 @@ from typing import Annotated, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from itsdangerous import URLSafeTimedSerializer
|
||||
from loguru import logger
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
@@ -12,6 +13,7 @@ import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep, require_captcha
|
||||
from service.captcha import CaptchaScene
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from .settings import user_settings_router
|
||||
@@ -23,59 +25,36 @@ user_router = APIRouter(
|
||||
|
||||
user_router.include_router(user_settings_router)
|
||||
|
||||
class OAuth2PasswordWithExtrasForm:
|
||||
"""
|
||||
扩展 OAuth2 密码表单。
|
||||
|
||||
在标准 username/password 基础上添加 otp_code 字段。
|
||||
captcha_code 由 require_captcha 依赖注入单独处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
username: Annotated[str, Form()],
|
||||
password: Annotated[str, Form()],
|
||||
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
|
||||
):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.otp_code = otp_code
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
summary='用户登录',
|
||||
description='用户登录端点,支持验证码校验和两步验证。',
|
||||
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
|
||||
summary='用户登录(统一入口)',
|
||||
description='统一登录端点,支持多种认证方式。',
|
||||
)
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
|
||||
request: sqlmodels.UnifiedLoginRequest,
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
用户登录端点
|
||||
统一登录端点
|
||||
|
||||
表单字段:
|
||||
- username: 用户邮箱
|
||||
- password: 用户密码
|
||||
- captcha_code: 验证码 token(可选,由 require_captcha 依赖校验)
|
||||
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
|
||||
请求体:
|
||||
- provider: 登录方式(email_password / github / qq / passkey / magic_link)
|
||||
- identifier: 标识符(邮箱 / OAuth code / credential_id / magic link token)
|
||||
- credential: 凭证(密码 / WebAuthn assertion 等)
|
||||
- two_fa_code: 两步验证码(可选)
|
||||
- redirect_uri: OAuth 回调地址(可选)
|
||||
- captcha: 验证码(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: 需要验证码但未提供
|
||||
- 401: 邮箱/密码错误,或 2FA 验证码错误
|
||||
- 403: 账户已禁用 / 验证码验证失败
|
||||
- 428: 需要两步验证但未提供 otp_code
|
||||
- 400: 登录方式未启用 / 参数错误
|
||||
- 401: 凭证错误
|
||||
- 403: 账户已禁用
|
||||
- 428: 需要两步验证
|
||||
- 501: 暂未实现的登录方式
|
||||
"""
|
||||
return await service.user.login(
|
||||
session,
|
||||
sqlmodels.LoginRequest(
|
||||
email=form_data.username,
|
||||
password=form_data.password,
|
||||
two_fa_code=form_data.otp_code,
|
||||
),
|
||||
)
|
||||
return await service.user.unified_login(session, request)
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session/refresh',
|
||||
@@ -150,41 +129,82 @@ async def router_user_session_refresh(
|
||||
|
||||
@user_router.post(
|
||||
path='/',
|
||||
summary='用户注册',
|
||||
summary='用户注册(统一入口)',
|
||||
description='User registration endpoint.',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_register(
|
||||
session: SessionDep,
|
||||
request: sqlmodels.RegisterRequest,
|
||||
request: sqlmodels.UnifiedRegisterRequest,
|
||||
) -> None:
|
||||
"""
|
||||
用户注册端点
|
||||
统一注册端点
|
||||
|
||||
流程:
|
||||
1. 验证用户名唯一性
|
||||
2. 获取默认用户组
|
||||
3. 创建用户记录
|
||||
4. 创建用户根目录(name="/")
|
||||
1. 检查注册开关
|
||||
2. 检查 provider 启用
|
||||
3. 验证 identifier 唯一性(AuthIdentity 表)
|
||||
4. 创建 User + AuthIdentity + 根目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 注册请求
|
||||
:return: 注册结果
|
||||
:raises HTTPException 400: 用户名已存在
|
||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||||
请求体:
|
||||
- provider: 注册方式(email_password / phone_sms)
|
||||
- identifier: 标识符(邮箱 / 手机号)
|
||||
- credential: 凭证(密码 / 短信验证码)
|
||||
- nickname: 昵称(可选)
|
||||
- captcha: 验证码(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: 注册未开放 / 参数错误
|
||||
- 409: 邮箱或手机号已存在
|
||||
- 501: 暂未实现的注册方式
|
||||
"""
|
||||
# 1. 验证邮箱唯一性
|
||||
# 1. 检查注册开关
|
||||
register_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||
& (sqlmodels.Setting.name == "register_enabled"),
|
||||
)
|
||||
if not register_setting or register_setting.value != "1":
|
||||
http_exceptions.raise_bad_request("注册功能未开放")
|
||||
|
||||
# 2. 目前只支持 email_password 注册
|
||||
if request.provider == AuthProviderType.PHONE_SMS:
|
||||
http_exceptions.raise_not_implemented("短信注册暂未开放")
|
||||
elif request.provider != AuthProviderType.EMAIL_PASSWORD:
|
||||
http_exceptions.raise_bad_request("不支持的注册方式")
|
||||
|
||||
# 3. 检查密码是否必填
|
||||
password_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_password_required"),
|
||||
)
|
||||
is_password_required = not password_required_setting or password_required_setting.value != "0"
|
||||
if is_password_required and not request.credential:
|
||||
http_exceptions.raise_bad_request("密码不能为空")
|
||||
|
||||
# 4. 验证 identifier 唯一性(AuthIdentity 表)
|
||||
existing_identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == request.provider)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if existing_identity:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 同时检查 User.email 唯一性(防止旧数据冲突)
|
||||
existing_user = await sqlmodels.User.get(
|
||||
session,
|
||||
sqlmodels.User.email == request.email
|
||||
sqlmodels.User.email == request.identifier,
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
||||
# 5. 获取默认用户组
|
||||
default_group_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||
& (sqlmodels.Setting.name == "default_group"),
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
logger.error("默认用户组不存在")
|
||||
@@ -196,17 +216,28 @@ async def router_user_register(
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
# 6. 创建用户
|
||||
new_user = sqlmodels.User(
|
||||
email=request.email,
|
||||
password=hashed_password,
|
||||
email=request.identifier,
|
||||
nickname=request.nickname,
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id
|
||||
await new_user.save(session)
|
||||
|
||||
# 4. 创建用户根目录
|
||||
# 7. 创建 AuthIdentity
|
||||
hashed_password = Password.hash(request.credential) if request.credential else None
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=request.identifier,
|
||||
credential=hashed_password,
|
||||
is_primary=True,
|
||||
is_verified=False,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
# 8. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
@@ -220,6 +251,66 @@ async def router_user_register(
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/magic-link',
|
||||
summary='发送 Magic Link 邮件',
|
||||
description='生成 Magic Link token 并发送到指定邮箱。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_magic_link(
|
||||
session: SessionDep,
|
||||
request: sqlmodels.MagicLinkRequest,
|
||||
) -> None:
|
||||
"""
|
||||
发送 Magic Link 邮件
|
||||
|
||||
流程:
|
||||
1. 验证邮箱对应的 AuthIdentity 存在
|
||||
2. 生成签名 token
|
||||
3. 发送邮件(包含带 token 的链接)
|
||||
|
||||
错误处理:
|
||||
- 400: Magic Link 未启用
|
||||
- 404: 邮箱未注册
|
||||
"""
|
||||
# 检查 magic_link 是否启用
|
||||
magic_link_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_magic_link_enabled"),
|
||||
)
|
||||
if not magic_link_setting or magic_link_setting.value != "1":
|
||||
http_exceptions.raise_bad_request("Magic Link 登录未启用")
|
||||
|
||||
# 验证邮箱存在
|
||||
identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.identifier == request.email)
|
||||
& (
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||
),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_not_found("该邮箱未注册")
|
||||
|
||||
# 生成签名 token
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
token = serializer.dumps(request.email, salt="magic-link-salt")
|
||||
|
||||
# 获取站点 URL
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.BASIC)
|
||||
& (sqlmodels.Setting.name == "siteURL"),
|
||||
)
|
||||
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
||||
|
||||
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token})
|
||||
logger.info(f"Magic Link token 已生成: {token} (邮件发送待实现)")
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
summary='发送验证码邮件',
|
||||
@@ -230,52 +321,12 @@ def router_user_email_code(
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Send a verification code email.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing information about the password reset email.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/qq',
|
||||
summary='初始化QQ登录',
|
||||
description='Initialize QQ login for a user.',
|
||||
)
|
||||
def router_user_qq() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize QQ login for a user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing QQ login initialization information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='authn/{username}',
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
)
|
||||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
||||
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.post(
|
||||
path='authn/finish/{username}',
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
)
|
||||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_router.get(
|
||||
path='/profile/{id}',
|
||||
summary='获取用户主页展示用分享',
|
||||
@@ -284,10 +335,10 @@ def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user profile for display.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The user ID.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing user profile information.
|
||||
"""
|
||||
@@ -301,11 +352,11 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user avatar by ID and size.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The user ID.
|
||||
size (int): The size of the avatar image.
|
||||
|
||||
|
||||
Returns:
|
||||
str: A Base64 encoded string of the user avatar image.
|
||||
"""
|
||||
@@ -348,8 +399,6 @@ async def router_user_me(
|
||||
return sqlmodels.UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
status=user.status,
|
||||
score=user.score,
|
||||
nickname=user.nickname,
|
||||
avatar=user.avatar,
|
||||
created_at=user.created_at,
|
||||
@@ -374,9 +423,9 @@ async def router_user_storage(
|
||||
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
|
||||
# [TODO] 总空间加上用户购买的额外空间
|
||||
|
||||
|
||||
total: int = group.max_storage
|
||||
used: int = user.storage
|
||||
free: int = max(0, total - used)
|
||||
@@ -389,8 +438,8 @@ async def router_user_storage(
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/start',
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
summary='注册 Passkey 凭证(初始化)',
|
||||
description='Initialize Passkey registration for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_authn_start(
|
||||
@@ -398,18 +447,19 @@ async def router_user_authn_start(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize WebAuthn login for a user.
|
||||
Passkey 注册初始化(需要登录)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn initialization information.
|
||||
返回 WebAuthn registration options,前端使用 navigator.credentials.create() 处理。
|
||||
|
||||
错误处理:
|
||||
- 400: Passkey 未启用
|
||||
"""
|
||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||
authn_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
||||
)
|
||||
if not authn_setting or authn_setting.value != "1":
|
||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||
raise HTTPException(status_code=400, detail="Passkey 未启用")
|
||||
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
@@ -423,23 +473,26 @@ async def router_user_authn_start(
|
||||
options = generate_registration_options(
|
||||
rp_id=site_url_setting.value if site_url_setting else "",
|
||||
rp_name=site_title_setting.value if site_title_setting else "",
|
||||
user_name=user.email,
|
||||
user_display_name=user.nickname or user.email,
|
||||
user_name=user.email or str(user.id),
|
||||
user_display_name=user.nickname or user.email or str(user.id),
|
||||
)
|
||||
|
||||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/finish',
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
summary='注册 Passkey 凭证(完成)',
|
||||
description='Finish Passkey registration for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Passkey 注册完成(需要登录)
|
||||
|
||||
接收前端 navigator.credentials.create() 返回的凭证数据,
|
||||
创建 UserAuthn 行 + AuthIdentity(provider=passkey)。
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
dict: A dictionary containing Passkey registration information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
Reference in New Issue
Block a user