feat: add multi-provider auth via AuthIdentity and extend site config
- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link) - Remove password field from User model, credentials now stored in AuthIdentity - Refactor unified login/register to use AuthIdentity-based provider checking - Add site config fields: footer_code, tos_url, privacy_url, auth_methods - Add auth settings defaults in migration (email_password enabled by default) - Update admin user creation to create AuthIdentity records - Update all tests to use AuthIdentity model Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1 +1 @@
|
||||
from .login import login
|
||||
from .login import unified_login
|
||||
|
||||
@@ -1,83 +1,417 @@
|
||||
from uuid import uuid4
|
||||
"""
|
||||
统一登录服务
|
||||
|
||||
from loguru import logger
|
||||
支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。
|
||||
"""
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import LoginRequest, TokenResponse, User
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.group import GroupClaims, GroupOptions
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import http_exceptions
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
from sqlmodels.user import TokenResponse, UnifiedLoginRequest, User, UserStatus
|
||||
from utils import JWT, http_exceptions
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
async def login(
|
||||
session: SessionDep,
|
||||
login_request: LoginRequest,
|
||||
async def unified_login(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
根据账号密码进行登录。
|
||||
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
|
||||
统一登录入口,根据 provider 分发到不同的登录逻辑。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param login_request: 登录请求
|
||||
|
||||
:return: TokenResponse 对象或状态码或 None
|
||||
:param request: 统一登录请求
|
||||
:return: TokenResponse
|
||||
"""
|
||||
# 获取用户信息(预加载 group 关系)
|
||||
current_user: User = await User.get(
|
||||
await _check_provider_enabled(session, request.provider)
|
||||
|
||||
match request.provider:
|
||||
case AuthProviderType.EMAIL_PASSWORD:
|
||||
user = await _login_email_password(session, request)
|
||||
case AuthProviderType.GITHUB:
|
||||
user = await _login_oauth(session, request, AuthProviderType.GITHUB)
|
||||
case AuthProviderType.QQ:
|
||||
user = await _login_oauth(session, request, AuthProviderType.QQ)
|
||||
case AuthProviderType.PASSKEY:
|
||||
user = await _login_passkey(session, request)
|
||||
case AuthProviderType.MAGIC_LINK:
|
||||
user = await _login_magic_link(session, request)
|
||||
case AuthProviderType.PHONE_SMS:
|
||||
http_exceptions.raise_not_implemented("短信登录暂未开放")
|
||||
case _:
|
||||
http_exceptions.raise_bad_request(f"不支持的登录方式: {request.provider}")
|
||||
|
||||
return await _issue_tokens(session, user)
|
||||
|
||||
|
||||
async def _check_provider_enabled(session: AsyncSession, provider: AuthProviderType) -> None:
|
||||
"""检查认证方式是否已被站长启用"""
|
||||
# OAuth 类型从 OAUTH 设置中查询
|
||||
if provider in (AuthProviderType.GITHUB, AuthProviderType.QQ):
|
||||
setting_name = f"{provider.value}_enabled"
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == setting_name),
|
||||
)
|
||||
if not setting or setting.value != "1":
|
||||
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||
return
|
||||
|
||||
# 其他类型从 AUTH 设置中查询
|
||||
setting_name = f"auth_{provider.value}_enabled"
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
User.email == login_request.email,
|
||||
fetch_mode="first",
|
||||
load=User.group,
|
||||
) #type: ignore
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == setting_name),
|
||||
)
|
||||
if not setting or setting.value != "1":
|
||||
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
logger.debug(f"Cannot find user with email: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
|
||||
# 验证密码是否正确
|
||||
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
|
||||
logger.debug(f"Password verification failed for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
async def _login_email_password(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""邮箱+密码登录"""
|
||||
if not request.credential:
|
||||
http_exceptions.raise_bad_request("密码不能为空")
|
||||
|
||||
# 验证用户是否可登录(修复:显式枚举比较,StrEnum 永远 truthy)
|
||||
if current_user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("Your account is disabled")
|
||||
# 查找 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if not identity:
|
||||
l.debug(f"未找到邮箱密码身份: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 检查两步验证
|
||||
if current_user.two_factor:
|
||||
# 用户已启用两步验证
|
||||
if not login_request.two_fa_code:
|
||||
logger.debug(f"2FA required for user: {login_request.email}")
|
||||
http_exceptions.raise_precondition_required("2FA required")
|
||||
# 验证密码
|
||||
if not identity.credential:
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 验证 OTP 码
|
||||
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
||||
if Password.verify(identity.credential, request.credential) != PasswordStatus.VALID:
|
||||
l.debug(f"密码验证失败: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 加载用户
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
|
||||
# 验证用户状态
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 检查两步验证(从 AuthIdentity.extra_data 中读取 2FA secret)
|
||||
if identity.extra_data:
|
||||
import orjson
|
||||
extra: dict = orjson.loads(identity.extra_data)
|
||||
two_factor_secret: str | None = extra.get("two_factor")
|
||||
if two_factor_secret:
|
||||
if not request.two_fa_code:
|
||||
l.debug(f"需要两步验证: {request.identifier}")
|
||||
http_exceptions.raise_precondition_required("需要两步验证")
|
||||
if Password.verify_totp(two_factor_secret, request.two_fa_code) != PasswordStatus.VALID:
|
||||
l.debug(f"两步验证失败: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("两步验证码错误")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _login_oauth(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
provider: AuthProviderType,
|
||||
) -> User:
|
||||
"""
|
||||
OAuth 登录(GitHub / QQ)
|
||||
|
||||
identifier 为 OAuth authorization code,后端换取 access_token 再获取用户信息。
|
||||
"""
|
||||
# 读取 OAuth 配置
|
||||
client_id_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_id"),
|
||||
)
|
||||
client_secret_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_secret"),
|
||||
)
|
||||
if not client_id_setting or not client_secret_setting:
|
||||
http_exceptions.raise_bad_request(f"{provider.value} OAuth 未配置")
|
||||
|
||||
client_id = client_id_setting.value or ""
|
||||
client_secret = client_secret_setting.value or ""
|
||||
|
||||
# 根据 provider 创建对应的 OAuth 客户端
|
||||
if provider == AuthProviderType.GITHUB:
|
||||
from service.oauth import GithubOAuth
|
||||
oauth_client = GithubOAuth(client_id, client_secret)
|
||||
token_resp = await oauth_client.get_access_token(code=request.identifier)
|
||||
user_info_resp = await oauth_client.get_user_info(token_resp)
|
||||
openid = str(user_info_resp.user_data.id)
|
||||
nickname = user_info_resp.user_data.name or user_info_resp.user_data.login
|
||||
avatar_url = user_info_resp.user_data.avatar_url
|
||||
email = user_info_resp.user_data.email
|
||||
elif provider == AuthProviderType.QQ:
|
||||
from service.oauth import QQOAuth
|
||||
oauth_client = QQOAuth(client_id, client_secret)
|
||||
token_resp = await oauth_client.get_access_token(
|
||||
code=request.identifier,
|
||||
redirect_uri=request.redirect_uri or "",
|
||||
)
|
||||
openid_resp = await oauth_client.get_openid(token_resp.access_token)
|
||||
user_info_resp = await oauth_client.get_user_info(
|
||||
token_resp,
|
||||
app_id=client_id,
|
||||
openid=openid_resp.openid,
|
||||
)
|
||||
openid = openid_resp.openid
|
||||
nickname = user_info_resp.user_data.nickname
|
||||
avatar_url = user_info_resp.user_data.figureurl_qq_2 or user_info_resp.user_data.figureurl_2
|
||||
email = None
|
||||
else:
|
||||
http_exceptions.raise_bad_request(f"不支持的 OAuth 提供者: {provider.value}")
|
||||
|
||||
# 查找已有 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == provider) & (AuthIdentity.identifier == openid),
|
||||
)
|
||||
|
||||
if identity:
|
||||
# 已绑定 → 更新 OAuth 信息并返回关联用户
|
||||
identity.display_name = nickname
|
||||
identity.avatar_url = avatar_url
|
||||
await identity.save(session)
|
||||
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
return user
|
||||
|
||||
# 未绑定 → 自动注册
|
||||
user = await _auto_register_oauth_user(
|
||||
session,
|
||||
provider=provider,
|
||||
openid=openid,
|
||||
nickname=nickname,
|
||||
avatar_url=avatar_url,
|
||||
email=email,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def _auto_register_oauth_user(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
provider: AuthProviderType,
|
||||
openid: str,
|
||||
nickname: str | None,
|
||||
avatar_url: str | None,
|
||||
email: str | None,
|
||||
) -> User:
|
||||
"""OAuth 自动注册用户"""
|
||||
# 获取默认用户组
|
||||
default_group_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.REGISTER) & (Setting.name == "default_group"),
|
||||
)
|
||||
if not default_group_setting or not default_group_setting.value:
|
||||
l.error("默认用户组未配置")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
|
||||
# 创建用户
|
||||
new_user = User(
|
||||
email=email,
|
||||
nickname=nickname,
|
||||
avatar=avatar_url or "default",
|
||||
group_id=default_group_id,
|
||||
)
|
||||
new_user_id = new_user.id
|
||||
new_user = await new_user.save(session)
|
||||
|
||||
# 创建 AuthIdentity
|
||||
identity = AuthIdentity(
|
||||
provider=provider,
|
||||
identifier=openid,
|
||||
display_name=nickname,
|
||||
avatar_url=avatar_url,
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
await identity.save(session)
|
||||
|
||||
# 创建用户根目录
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
if default_policy:
|
||||
await Object(
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=new_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
# 重新加载用户(含 group 关系)
|
||||
user: User = await User.get(session, User.id == new_user_id, load=User.group)
|
||||
l.info(f"OAuth 自动注册用户: provider={provider.value}, openid={openid}")
|
||||
return user
|
||||
|
||||
|
||||
async def _login_passkey(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""
|
||||
Passkey/WebAuthn 登录
|
||||
|
||||
identifier 为 credential_id,credential 为 JSON 格式的 authenticator assertion response。
|
||||
"""
|
||||
from webauthn import verify_authentication_response
|
||||
from webauthn.helpers.structs import AuthenticationCredential
|
||||
|
||||
if not request.credential:
|
||||
http_exceptions.raise_bad_request("WebAuthn assertion response 不能为空")
|
||||
|
||||
# 查找 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.PASSKEY)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_unauthorized("Passkey 凭证未注册")
|
||||
|
||||
# 加载对应的 UserAuthn 记录
|
||||
from sqlmodels.user_authn import UserAuthn
|
||||
authn: UserAuthn | None = await UserAuthn.get(
|
||||
session,
|
||||
UserAuthn.credential_id == request.identifier,
|
||||
)
|
||||
if not authn:
|
||||
http_exceptions.raise_unauthorized("Passkey 凭证数据不存在")
|
||||
|
||||
# 获取 RP ID
|
||||
site_url_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||
)
|
||||
rp_id = site_url_setting.value if site_url_setting else "localhost"
|
||||
|
||||
# 验证 WebAuthn assertion
|
||||
import orjson
|
||||
credential = AuthenticationCredential.model_validate(orjson.loads(request.credential))
|
||||
|
||||
try:
|
||||
verification = verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=f"https://{rp_id}",
|
||||
expected_challenge=b"", # TODO: 从 session/cache 中获取 challenge
|
||||
credential_public_key=bytes.fromhex(authn.credential_public_key),
|
||||
credential_current_sign_count=authn.sign_count,
|
||||
)
|
||||
except Exception as e:
|
||||
l.warning(f"WebAuthn 验证失败: {e}")
|
||||
http_exceptions.raise_unauthorized("Passkey 验证失败")
|
||||
|
||||
# 更新签名计数
|
||||
authn.sign_count = verification.new_sign_count
|
||||
await authn.save(session)
|
||||
|
||||
# 加载用户
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _login_magic_link(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""
|
||||
Magic Link 登录
|
||||
|
||||
identifier 为签名 token,由 itsdangerous 生成。
|
||||
"""
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
try:
|
||||
email = serializer.loads(request.identifier, salt="magic-link-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
http_exceptions.raise_unauthorized("Magic Link 已过期")
|
||||
except BadSignature:
|
||||
http_exceptions.raise_unauthorized("Magic Link 无效")
|
||||
|
||||
# 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link)
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.identifier == email)
|
||||
& (
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||
),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_unauthorized("该邮箱未注册")
|
||||
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 标记邮箱已验证
|
||||
if not identity.is_verified:
|
||||
identity.is_verified = True
|
||||
await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _issue_tokens(session: AsyncSession, user: User) -> TokenResponse:
|
||||
"""
|
||||
签发 JWT 双令牌(access + refresh)
|
||||
|
||||
提取自原 login.py 的签发逻辑,供所有 provider 共用。
|
||||
"""
|
||||
# 加载 GroupOptions
|
||||
group_options: GroupOptions | None = await GroupOptions.get(
|
||||
session,
|
||||
GroupOptions.group_id == current_user.group_id,
|
||||
GroupOptions.group_id == user.group_id,
|
||||
)
|
||||
|
||||
# 构建权限快照
|
||||
current_user.group.options = group_options
|
||||
group_claims = GroupClaims.from_group(current_user.group)
|
||||
user.group.options = group_options
|
||||
group_claims = GroupClaims.from_group(user.group)
|
||||
|
||||
# 创建令牌
|
||||
access_token = create_access_token(
|
||||
sub=current_user.id,
|
||||
access_token = JWT.create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=current_user.status.value,
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
sub=current_user.id,
|
||||
jti=uuid4()
|
||||
refresh_token = JWT.create_refresh_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
|
||||
Reference in New Issue
Block a user