feat: 添加两步验证功能,优化用户登录逻辑,更新相关模型和依赖
This commit is contained in:
@@ -28,13 +28,13 @@ async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
path: str = ""
|
||||
) -> response.ResponseModel:
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径,空或 "/" 表示根目录
|
||||
:param path: 目录路径, "~" 表示根目录
|
||||
:return: 目录内容
|
||||
"""
|
||||
folder = await Object.get_by_path(session, user.id, path or "/")
|
||||
@@ -44,6 +44,9 @@ async def router_directory_get(
|
||||
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
if path != "~":
|
||||
path = path.lstrip("~")
|
||||
|
||||
children = await Object.get_children(session, user.id, folder.id)
|
||||
policy = await folder.awaitable_attrs.policy
|
||||
@@ -55,7 +58,7 @@ async def router_directory_get(
|
||||
path=f"/{child.name}", # TODO: 完整路径
|
||||
thumb=False,
|
||||
size=child.size,
|
||||
type="folder" if child.is_folder else "file",
|
||||
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
|
||||
date=child.updated_at,
|
||||
create_date=child.created_at,
|
||||
source_enabled=False,
|
||||
@@ -63,18 +66,17 @@ async def router_directory_get(
|
||||
for child in children
|
||||
]
|
||||
|
||||
return response.ResponseModel(
|
||||
data=DirectoryResponse(
|
||||
parent=str(folder.parent_id) if folder.parent_id else None,
|
||||
objects=objects,
|
||||
policy=PolicyResponse(
|
||||
id=str(policy.id),
|
||||
name=policy.name,
|
||||
type=policy.type.value,
|
||||
max_size=policy.max_size,
|
||||
file_type=[],
|
||||
),
|
||||
)
|
||||
policy=PolicyResponse(
|
||||
id=str(policy.id),
|
||||
name=policy.name,
|
||||
type=policy.type.value,
|
||||
max_size=policy.max_size,
|
||||
)
|
||||
|
||||
return DirectoryResponse(
|
||||
parent=str(folder.parent_id) if folder.parent_id else None,
|
||||
objects=objects,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy import and_
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
import pyotp
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
|
||||
import models
|
||||
import service
|
||||
from middleware.auth import AuthRequired
|
||||
from middleware.dependencies import SessionDep
|
||||
from pkg.JWT.JWT import SECRET_KEY
|
||||
|
||||
user_router = APIRouter(
|
||||
prefix="/user",
|
||||
@@ -25,18 +28,46 @@ user_settings_router = APIRouter(
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
summary='用户登录',
|
||||
description='User login endpoint.',
|
||||
description='User login endpoint. 当用户启用两步验证时,需要传入 otp 参数。',
|
||||
)
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> models.TokenResponse:
|
||||
"""
|
||||
用户登录端点。
|
||||
|
||||
根据 OAuth2.1 规范,使用 password grant type 进行登录。
|
||||
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。
|
||||
|
||||
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
|
||||
|
||||
:raises HTTPException 401: 用户名或密码错误
|
||||
:raises HTTPException 403: 用户账号被封禁或未完成注册
|
||||
:raises HTTPException 428: 需要两步验证但未提供验证码
|
||||
:raises HTTPException 400: 两步验证码无效
|
||||
"""
|
||||
username = form_data.username
|
||||
password = form_data.password
|
||||
|
||||
# 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式)
|
||||
# scopes 格式可以是 ["otp:123456"] 或 ["123456"]
|
||||
otp_code: str | None = None
|
||||
for scope in form_data.scopes:
|
||||
if scope.startswith("otp:"):
|
||||
otp_code = scope[4:]
|
||||
break
|
||||
elif scope.isdigit() and len(scope) == 6:
|
||||
otp_code = scope
|
||||
break
|
||||
|
||||
result = await service.user.Login(
|
||||
session,
|
||||
models.LoginRequest(username=username, password=password),
|
||||
models.LoginRequest(
|
||||
username=username,
|
||||
password=password,
|
||||
two_fa_code=otp_code,
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(result, models.TokenResponse):
|
||||
@@ -45,6 +76,14 @@ async def router_user_session(
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
elif result is False:
|
||||
raise HTTPException(status_code=403, detail="User account is banned or not fully registered")
|
||||
elif result == "2fa_required":
|
||||
raise HTTPException(
|
||||
status_code=428,
|
||||
detail="Two-factor authentication required",
|
||||
headers={"X-2FA-Required": "true"},
|
||||
)
|
||||
elif result == "2fa_invalid":
|
||||
raise HTTPException(status_code=400, detail="Invalid two-factor authentication code")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Internal server error during login")
|
||||
|
||||
@@ -62,26 +101,14 @@ def router_user_register() -> models.response.ResponseModel:
|
||||
"""
|
||||
pass
|
||||
|
||||
@user_router.post(
|
||||
path='/2fa',
|
||||
summary='用两步验证登录',
|
||||
description='Two-factor authentication login endpoint.',
|
||||
)
|
||||
def router_user_2fa() -> models.response.ResponseModel:
|
||||
"""
|
||||
Two-factor authentication login endpoint.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
summary='发送验证码邮件',
|
||||
description='Send a verification code email.',
|
||||
)
|
||||
def router_user_email_code() -> models.response.ResponseModel:
|
||||
def router_user_email_code(
|
||||
reason: Literal['register', 'reset'] = 'register',
|
||||
) -> models.response.ResponseModel:
|
||||
"""
|
||||
Send a verification code email.
|
||||
|
||||
@@ -90,21 +117,6 @@ def router_user_email_code() -> models.response.ResponseModel:
|
||||
"""
|
||||
pass
|
||||
|
||||
@user_router.patch(
|
||||
path='/reset',
|
||||
summary='通过邮件里的链接重设密码',
|
||||
description='Reset password via email link.',
|
||||
deprecated=True,
|
||||
)
|
||||
def router_user_reset_patch() -> models.response.ResponseModel:
|
||||
"""
|
||||
Reset password via email link.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing information about the password reset.
|
||||
"""
|
||||
pass
|
||||
|
||||
@user_router.get(
|
||||
path='/qq',
|
||||
summary='初始化QQ登录',
|
||||
@@ -193,7 +205,7 @@ def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseMode
|
||||
)
|
||||
async def router_user_me(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||
user: Annotated[models.User, Depends(AuthRequired)],
|
||||
) -> models.response.ResponseModel:
|
||||
"""
|
||||
获取用户信息.
|
||||
@@ -201,25 +213,32 @@ async def router_user_me(
|
||||
:return: response.ResponseModel containing user information.
|
||||
:rtype: response.ResponseModel
|
||||
"""
|
||||
group = await models.Group.get(session, models.Group.id == user.group_id)
|
||||
|
||||
user_group = models.GroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
allow_share=group.share_enabled,
|
||||
# 加载 group 及其 options 关系
|
||||
group = await models.Group.get(
|
||||
session,
|
||||
models.Group.id == user.group_id,
|
||||
load=models.Group.options
|
||||
)
|
||||
|
||||
users = models.UserResponse(
|
||||
# 构建 GroupResponse
|
||||
group_response = group.to_response() if group else None
|
||||
|
||||
# 异步加载 tags 关系
|
||||
user_tags = await user.awaitable_attrs.tags
|
||||
|
||||
user_response = models.UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
nickname=user.nick,
|
||||
status=user.status,
|
||||
created_at=user.created_at,
|
||||
score=user.score,
|
||||
group=user_group,
|
||||
).model_dump()
|
||||
nickname=user.nickname,
|
||||
avatar=user.avatar,
|
||||
created_at=user.created_at,
|
||||
group=group_response,
|
||||
tags=[tag.name for tag in user_tags] if user_tags else [],
|
||||
)
|
||||
|
||||
return models.response.ResponseModel(data=users)
|
||||
return models.response.ResponseModel(data=user_response.model_dump())
|
||||
|
||||
@user_router.get(
|
||||
path='/storage',
|
||||
@@ -425,11 +444,77 @@ def router_user_settings_patch(option: str) -> models.response.ResponseModel:
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
)
|
||||
def router_user_settings_2fa() -> models.response.ResponseModel:
|
||||
async def router_user_settings_2fa(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||
) -> models.response.ResponseModel:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication setup information.
|
||||
"""
|
||||
pass
|
||||
|
||||
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
|
||||
setup_token = serializer.dumps(
|
||||
secret,
|
||||
salt="2fa-setup-salt"
|
||||
)
|
||||
|
||||
site_Name = await models.Setting.get(session, (models.Setting.type == models.SettingsType.BASIC) & (models.Setting.name == "siteName"))
|
||||
|
||||
otp_uri = pyotp.totp.TOTP(secret).provisioning_uri(
|
||||
name=user.username,
|
||||
issuer_name=site_Name.value
|
||||
)
|
||||
|
||||
return models.response.ResponseModel(
|
||||
data={
|
||||
"setup_token": setup_token,
|
||||
"otp_uri": otp_uri,
|
||||
}
|
||||
)
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(AuthRequired)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> models.response.ResponseModel:
|
||||
"""
|
||||
Enable two-factor authentication for the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||
"""
|
||||
|
||||
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||
|
||||
try:
|
||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
# 2. 验证用户输入的 6 位验证码
|
||||
if not service.user.verify_totp(secret, code):
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
|
||||
return models.response.ResponseModel(
|
||||
data={"message": "Two-factor authentication enabled successfully"}
|
||||
)
|
||||
Reference in New Issue
Block a user