feat: 添加两步验证功能,优化用户登录逻辑,更新相关模型和依赖
This commit is contained in:
@@ -25,7 +25,7 @@ class GroupOptionsBase(SQLModelBase):
|
|||||||
"""是否允许分享下载"""
|
"""是否允许分享下载"""
|
||||||
|
|
||||||
share_free: bool = False
|
share_free: bool = False
|
||||||
"""是否免积分分享"""
|
"""是否免积分获取需要积分的内容"""
|
||||||
|
|
||||||
relocate: bool = False
|
relocate: bool = False
|
||||||
"""是否允许文件重定位"""
|
"""是否允许文件重定位"""
|
||||||
@@ -136,3 +136,22 @@ class Group(GroupBase, TableBase, table=True):
|
|||||||
back_populates="previous_group",
|
back_populates="previous_group",
|
||||||
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
|
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_response(self) -> "GroupResponse":
|
||||||
|
"""转换为响应 DTO"""
|
||||||
|
opts = self.options
|
||||||
|
return GroupResponse(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
allow_share=self.share_enabled,
|
||||||
|
webdav=self.web_dav_enabled,
|
||||||
|
share_download=opts.share_download if opts else False,
|
||||||
|
share_free=opts.share_free if opts else False,
|
||||||
|
relocate=opts.relocate if opts else False,
|
||||||
|
source_batch=opts.source_batch if opts else 0,
|
||||||
|
select_node=opts.select_node if opts else False,
|
||||||
|
advance_delete=opts.advance_delete if opts else False,
|
||||||
|
allow_remote_download=opts.aria2 if opts else False,
|
||||||
|
allow_archive_download=opts.archive_download if opts else False,
|
||||||
|
allow_webdav_proxy=opts.webdav_proxy if opts else False,
|
||||||
|
)
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ async def migration() -> None:
|
|||||||
log.info('开始进行数据库初始化...')
|
log.info('开始进行数据库初始化...')
|
||||||
|
|
||||||
await init_default_settings()
|
await init_default_settings()
|
||||||
await init_default_group()
|
|
||||||
await init_default_policy()
|
await init_default_policy()
|
||||||
|
await init_default_group()
|
||||||
await init_default_user()
|
await init_default_user()
|
||||||
|
|
||||||
log.info('数据库初始化结束')
|
log.info('数据库初始化结束')
|
||||||
@@ -147,6 +147,7 @@ async def init_default_group() -> None:
|
|||||||
if not await Group.get(session, Group.id == 1):
|
if not await Group.get(session, Group.id == 1):
|
||||||
admin_group = await Group(
|
admin_group = await Group(
|
||||||
name="管理员",
|
name="管理员",
|
||||||
|
policies="1",
|
||||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||||
share_enabled=True,
|
share_enabled=True,
|
||||||
web_dav_enabled=True,
|
web_dav_enabled=True,
|
||||||
@@ -158,7 +159,10 @@ async def init_default_group() -> None:
|
|||||||
archive_download=True,
|
archive_download=True,
|
||||||
archive_task=True,
|
archive_task=True,
|
||||||
share_download=True,
|
share_download=True,
|
||||||
|
share_free=True,
|
||||||
aria2=True,
|
aria2=True,
|
||||||
|
select_node=True,
|
||||||
|
advance_delete=True,
|
||||||
).save(session)
|
).save(session)
|
||||||
|
|
||||||
# 未找到初始注册会员时,则创建
|
# 未找到初始注册会员时,则创建
|
||||||
|
|||||||
@@ -84,8 +84,8 @@ class PolicyResponse(SQLModelBase):
|
|||||||
max_size: int = 0
|
max_size: int = 0
|
||||||
"""单文件最大限制,单位字节,0表示不限制"""
|
"""单文件最大限制,单位字节,0表示不限制"""
|
||||||
|
|
||||||
file_type: list[str] = []
|
file_type: list[str] | None = None
|
||||||
"""允许的文件类型列表,空列表表示不限制"""
|
"""允许的文件类型列表,None 表示不限制"""
|
||||||
|
|
||||||
|
|
||||||
class DirectoryResponse(SQLModelBase):
|
class DirectoryResponse(SQLModelBase):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Literal, Optional, TYPE_CHECKING
|
from typing import Literal, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from sqlmodel import Field, Relationship
|
from sqlmodel import Field, Relationship
|
||||||
@@ -26,6 +27,13 @@ Option 需求
|
|||||||
- 切换到不同存储策略是否提醒
|
- 切换到不同存储策略是否提醒
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class AvatarType(StrEnum):
|
||||||
|
"""头像类型枚举"""
|
||||||
|
|
||||||
|
DEFAULT = "default"
|
||||||
|
GRAVATAR = "gravatar"
|
||||||
|
FILE = "file"
|
||||||
|
|
||||||
|
|
||||||
# ==================== Base 模型 ====================
|
# ==================== Base 模型 ====================
|
||||||
|
|
||||||
@@ -227,7 +235,7 @@ class User(UserBase, TableBase, table=True):
|
|||||||
username: str = Field(max_length=50, unique=True, index=True)
|
username: str = Field(max_length=50, unique=True, index=True)
|
||||||
"""用户名,唯一,一经注册不可更改"""
|
"""用户名,唯一,一经注册不可更改"""
|
||||||
|
|
||||||
nick: str | None = Field(default=None, max_length=50)
|
nickname: str | None = Field(default=None, max_length=50)
|
||||||
"""用于公开展示的名字,可使用真实姓名或昵称"""
|
"""用于公开展示的名字,可使用真实姓名或昵称"""
|
||||||
|
|
||||||
password: str = Field(max_length=255)
|
password: str = Field(max_length=255)
|
||||||
@@ -242,7 +250,7 @@ class User(UserBase, TableBase, table=True):
|
|||||||
two_factor: str | None = Field(default=None, min_length=32, max_length=32)
|
two_factor: str | None = Field(default=None, min_length=32, max_length=32)
|
||||||
"""两步验证密钥"""
|
"""两步验证密钥"""
|
||||||
|
|
||||||
avatar: str | None = Field(default=None, max_length=255)
|
avatar: str = Field(default="default", max_length=255)
|
||||||
"""头像地址"""
|
"""头像地址"""
|
||||||
|
|
||||||
options: str | None = None
|
options: str | None = None
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
# 延迟导入以避免循环依赖
|
|
||||||
# JWT 和 lifespan 应在需要时直接从子模块导入
|
|
||||||
# from .JWT import JWT
|
|
||||||
# from .lifespan import lifespan
|
|
||||||
@@ -9,8 +9,10 @@ dependencies = [
|
|||||||
"aiosqlite>=0.21.0",
|
"aiosqlite>=0.21.0",
|
||||||
"argon2-cffi>=25.1.0",
|
"argon2-cffi>=25.1.0",
|
||||||
"fastapi[standard]>=0.122.0",
|
"fastapi[standard]>=0.122.0",
|
||||||
|
"itsdangerous>=2.2.0",
|
||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
"pyjwt>=2.10.1",
|
"pyjwt>=2.10.1",
|
||||||
|
"pyotp>=2.9.0",
|
||||||
"python-dotenv>=1.2.1",
|
"python-dotenv>=1.2.1",
|
||||||
"python-multipart>=0.0.20",
|
"python-multipart>=0.0.20",
|
||||||
"sqlalchemy>=2.0.44",
|
"sqlalchemy>=2.0.44",
|
||||||
|
|||||||
@@ -28,13 +28,13 @@ async def router_directory_get(
|
|||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(AuthRequired)],
|
user: Annotated[User, Depends(AuthRequired)],
|
||||||
path: str = ""
|
path: str = ""
|
||||||
) -> response.ResponseModel:
|
) -> DirectoryResponse:
|
||||||
"""
|
"""
|
||||||
获取目录内容
|
获取目录内容
|
||||||
|
|
||||||
:param session: 数据库会话
|
:param session: 数据库会话
|
||||||
:param user: 当前登录用户
|
:param user: 当前登录用户
|
||||||
:param path: 目录路径,空或 "/" 表示根目录
|
:param path: 目录路径, "~" 表示根目录
|
||||||
:return: 目录内容
|
:return: 目录内容
|
||||||
"""
|
"""
|
||||||
folder = await Object.get_by_path(session, user.id, path or "/")
|
folder = await Object.get_by_path(session, user.id, path or "/")
|
||||||
@@ -45,6 +45,9 @@ async def router_directory_get(
|
|||||||
if not folder.is_folder:
|
if not folder.is_folder:
|
||||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||||
|
|
||||||
|
if path != "~":
|
||||||
|
path = path.lstrip("~")
|
||||||
|
|
||||||
children = await Object.get_children(session, user.id, folder.id)
|
children = await Object.get_children(session, user.id, folder.id)
|
||||||
policy = await folder.awaitable_attrs.policy
|
policy = await folder.awaitable_attrs.policy
|
||||||
|
|
||||||
@@ -55,7 +58,7 @@ async def router_directory_get(
|
|||||||
path=f"/{child.name}", # TODO: 完整路径
|
path=f"/{child.name}", # TODO: 完整路径
|
||||||
thumb=False,
|
thumb=False,
|
||||||
size=child.size,
|
size=child.size,
|
||||||
type="folder" if child.is_folder else "file",
|
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
|
||||||
date=child.updated_at,
|
date=child.updated_at,
|
||||||
create_date=child.created_at,
|
create_date=child.created_at,
|
||||||
source_enabled=False,
|
source_enabled=False,
|
||||||
@@ -63,18 +66,17 @@ async def router_directory_get(
|
|||||||
for child in children
|
for child in children
|
||||||
]
|
]
|
||||||
|
|
||||||
return response.ResponseModel(
|
policy=PolicyResponse(
|
||||||
data=DirectoryResponse(
|
id=str(policy.id),
|
||||||
parent=str(folder.parent_id) if folder.parent_id else None,
|
name=policy.name,
|
||||||
objects=objects,
|
type=policy.type.value,
|
||||||
policy=PolicyResponse(
|
max_size=policy.max_size,
|
||||||
id=str(policy.id),
|
)
|
||||||
name=policy.name,
|
|
||||||
type=policy.type.value,
|
return DirectoryResponse(
|
||||||
max_size=policy.max_size,
|
parent=str(folder.parent_id) if folder.parent_id else None,
|
||||||
file_type=[],
|
objects=objects,
|
||||||
),
|
policy=policy,
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
from webauthn import generate_registration_options
|
from webauthn import generate_registration_options
|
||||||
from webauthn.helpers import options_to_json_dict
|
from webauthn.helpers import options_to_json_dict
|
||||||
|
import pyotp
|
||||||
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
|
|
||||||
import models
|
import models
|
||||||
import service
|
import service
|
||||||
from middleware.auth import AuthRequired
|
from middleware.auth import AuthRequired
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
|
from pkg.JWT.JWT import SECRET_KEY
|
||||||
|
|
||||||
user_router = APIRouter(
|
user_router = APIRouter(
|
||||||
prefix="/user",
|
prefix="/user",
|
||||||
@@ -25,18 +28,46 @@ user_settings_router = APIRouter(
|
|||||||
@user_router.post(
|
@user_router.post(
|
||||||
path='/session',
|
path='/session',
|
||||||
summary='用户登录',
|
summary='用户登录',
|
||||||
description='User login endpoint.',
|
description='User login endpoint. 当用户启用两步验证时,需要传入 otp 参数。',
|
||||||
)
|
)
|
||||||
async def router_user_session(
|
async def router_user_session(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
) -> models.TokenResponse:
|
) -> models.TokenResponse:
|
||||||
|
"""
|
||||||
|
用户登录端点。
|
||||||
|
|
||||||
|
根据 OAuth2.1 规范,使用 password grant type 进行登录。
|
||||||
|
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。
|
||||||
|
|
||||||
|
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
|
||||||
|
|
||||||
|
:raises HTTPException 401: 用户名或密码错误
|
||||||
|
:raises HTTPException 403: 用户账号被封禁或未完成注册
|
||||||
|
:raises HTTPException 428: 需要两步验证但未提供验证码
|
||||||
|
:raises HTTPException 400: 两步验证码无效
|
||||||
|
"""
|
||||||
username = form_data.username
|
username = form_data.username
|
||||||
password = form_data.password
|
password = form_data.password
|
||||||
|
|
||||||
|
# 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式)
|
||||||
|
# scopes 格式可以是 ["otp:123456"] 或 ["123456"]
|
||||||
|
otp_code: str | None = None
|
||||||
|
for scope in form_data.scopes:
|
||||||
|
if scope.startswith("otp:"):
|
||||||
|
otp_code = scope[4:]
|
||||||
|
break
|
||||||
|
elif scope.isdigit() and len(scope) == 6:
|
||||||
|
otp_code = scope
|
||||||
|
break
|
||||||
|
|
||||||
result = await service.user.Login(
|
result = await service.user.Login(
|
||||||
session,
|
session,
|
||||||
models.LoginRequest(username=username, password=password),
|
models.LoginRequest(
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
two_fa_code=otp_code,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(result, models.TokenResponse):
|
if isinstance(result, models.TokenResponse):
|
||||||
@@ -45,6 +76,14 @@ async def router_user_session(
|
|||||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||||
elif result is False:
|
elif result is False:
|
||||||
raise HTTPException(status_code=403, detail="User account is banned or not fully registered")
|
raise HTTPException(status_code=403, detail="User account is banned or not fully registered")
|
||||||
|
elif result == "2fa_required":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=428,
|
||||||
|
detail="Two-factor authentication required",
|
||||||
|
headers={"X-2FA-Required": "true"},
|
||||||
|
)
|
||||||
|
elif result == "2fa_invalid":
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid two-factor authentication code")
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Internal server error during login")
|
raise HTTPException(status_code=500, detail="Internal server error during login")
|
||||||
|
|
||||||
@@ -62,26 +101,14 @@ def router_user_register() -> models.response.ResponseModel:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@user_router.post(
|
|
||||||
path='/2fa',
|
|
||||||
summary='用两步验证登录',
|
|
||||||
description='Two-factor authentication login endpoint.',
|
|
||||||
)
|
|
||||||
def router_user_2fa() -> models.response.ResponseModel:
|
|
||||||
"""
|
|
||||||
Two-factor authentication login endpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary containing two-factor authentication information.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@user_router.post(
|
@user_router.post(
|
||||||
path='/code',
|
path='/code',
|
||||||
summary='发送验证码邮件',
|
summary='发送验证码邮件',
|
||||||
description='Send a verification code email.',
|
description='Send a verification code email.',
|
||||||
)
|
)
|
||||||
def router_user_email_code() -> models.response.ResponseModel:
|
def router_user_email_code(
|
||||||
|
reason: Literal['register', 'reset'] = 'register',
|
||||||
|
) -> models.response.ResponseModel:
|
||||||
"""
|
"""
|
||||||
Send a verification code email.
|
Send a verification code email.
|
||||||
|
|
||||||
@@ -90,21 +117,6 @@ def router_user_email_code() -> models.response.ResponseModel:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@user_router.patch(
|
|
||||||
path='/reset',
|
|
||||||
summary='通过邮件里的链接重设密码',
|
|
||||||
description='Reset password via email link.',
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
def router_user_reset_patch() -> models.response.ResponseModel:
|
|
||||||
"""
|
|
||||||
Reset password via email link.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary containing information about the password reset.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@user_router.get(
|
@user_router.get(
|
||||||
path='/qq',
|
path='/qq',
|
||||||
summary='初始化QQ登录',
|
summary='初始化QQ登录',
|
||||||
@@ -193,7 +205,7 @@ def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseMode
|
|||||||
)
|
)
|
||||||
async def router_user_me(
|
async def router_user_me(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
user: Annotated[models.User, Depends(AuthRequired)],
|
||||||
) -> models.response.ResponseModel:
|
) -> models.response.ResponseModel:
|
||||||
"""
|
"""
|
||||||
获取用户信息.
|
获取用户信息.
|
||||||
@@ -201,25 +213,32 @@ async def router_user_me(
|
|||||||
:return: response.ResponseModel containing user information.
|
:return: response.ResponseModel containing user information.
|
||||||
:rtype: response.ResponseModel
|
:rtype: response.ResponseModel
|
||||||
"""
|
"""
|
||||||
group = await models.Group.get(session, models.Group.id == user.group_id)
|
# 加载 group 及其 options 关系
|
||||||
|
group = await models.Group.get(
|
||||||
user_group = models.GroupResponse(
|
session,
|
||||||
id=group.id,
|
models.Group.id == user.group_id,
|
||||||
name=group.name,
|
load=models.Group.options
|
||||||
allow_share=group.share_enabled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
users = models.UserResponse(
|
# 构建 GroupResponse
|
||||||
|
group_response = group.to_response() if group else None
|
||||||
|
|
||||||
|
# 异步加载 tags 关系
|
||||||
|
user_tags = await user.awaitable_attrs.tags
|
||||||
|
|
||||||
|
user_response = models.UserResponse(
|
||||||
id=user.id,
|
id=user.id,
|
||||||
username=user.username,
|
username=user.username,
|
||||||
nickname=user.nick,
|
|
||||||
status=user.status,
|
status=user.status,
|
||||||
created_at=user.created_at,
|
|
||||||
score=user.score,
|
score=user.score,
|
||||||
group=user_group,
|
nickname=user.nickname,
|
||||||
).model_dump()
|
avatar=user.avatar,
|
||||||
|
created_at=user.created_at,
|
||||||
|
group=group_response,
|
||||||
|
tags=[tag.name for tag in user_tags] if user_tags else [],
|
||||||
|
)
|
||||||
|
|
||||||
return models.response.ResponseModel(data=users)
|
return models.response.ResponseModel(data=user_response.model_dump())
|
||||||
|
|
||||||
@user_router.get(
|
@user_router.get(
|
||||||
path='/storage',
|
path='/storage',
|
||||||
@@ -425,11 +444,77 @@ def router_user_settings_patch(option: str) -> models.response.ResponseModel:
|
|||||||
description='Get two-factor authentication initialization information.',
|
description='Get two-factor authentication initialization information.',
|
||||||
dependencies=[Depends(AuthRequired)],
|
dependencies=[Depends(AuthRequired)],
|
||||||
)
|
)
|
||||||
def router_user_settings_2fa() -> models.response.ResponseModel:
|
async def router_user_settings_2fa(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||||
|
) -> models.response.ResponseModel:
|
||||||
"""
|
"""
|
||||||
Get two-factor authentication initialization information.
|
Get two-factor authentication initialization information.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing two-factor authentication setup information.
|
dict: A dictionary containing two-factor authentication setup information.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||||
|
|
||||||
|
secret = pyotp.random_base32()
|
||||||
|
|
||||||
|
setup_token = serializer.dumps(
|
||||||
|
secret,
|
||||||
|
salt="2fa-setup-salt"
|
||||||
|
)
|
||||||
|
|
||||||
|
site_Name = await models.Setting.get(session, (models.Setting.type == models.SettingsType.BASIC) & (models.Setting.name == "siteName"))
|
||||||
|
|
||||||
|
otp_uri = pyotp.totp.TOTP(secret).provisioning_uri(
|
||||||
|
name=user.username,
|
||||||
|
issuer_name=site_Name.value
|
||||||
|
)
|
||||||
|
|
||||||
|
return models.response.ResponseModel(
|
||||||
|
data={
|
||||||
|
"setup_token": setup_token,
|
||||||
|
"otp_uri": otp_uri,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@user_settings_router.post(
|
||||||
|
path='/2fa',
|
||||||
|
summary='启用两步验证',
|
||||||
|
description='Enable two-factor authentication.',
|
||||||
|
dependencies=[Depends(AuthRequired)],
|
||||||
|
)
|
||||||
|
async def router_user_settings_2fa_enable(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||||
|
setup_token: str,
|
||||||
|
code: str,
|
||||||
|
) -> models.response.ResponseModel:
|
||||||
|
"""
|
||||||
|
Enable two-factor authentication for the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||||
|
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||||
|
except SignatureExpired:
|
||||||
|
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||||
|
except BadSignature:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid token")
|
||||||
|
|
||||||
|
# 2. 验证用户输入的 6 位验证码
|
||||||
|
if not service.user.verify_totp(secret, code):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||||
|
|
||||||
|
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||||
|
user.two_factor = secret
|
||||||
|
user = await user.save(session)
|
||||||
|
|
||||||
|
return models.response.ResponseModel(
|
||||||
|
data={"message": "Two-factor authentication enabled successfully"}
|
||||||
|
)
|
||||||
@@ -2,4 +2,4 @@
|
|||||||
服务层
|
服务层
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .user import login
|
from . import user
|
||||||
@@ -1 +1,2 @@
|
|||||||
from .login import Login
|
from .login import Login
|
||||||
|
from .totp import verify_totp
|
||||||
@@ -1,17 +1,25 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
from loguru import logger as log
|
from loguru import logger as log
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from models import LoginRequest, TokenResponse, User
|
from models import LoginRequest, TokenResponse, User
|
||||||
from pkg.JWT.JWT import create_access_token, create_refresh_token
|
from pkg.JWT.JWT import create_access_token, create_refresh_token
|
||||||
|
from .totp import verify_totp
|
||||||
|
|
||||||
|
|
||||||
async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenResponse | bool | None:
|
async def Login(
|
||||||
|
session: AsyncSession,
|
||||||
|
login_request: LoginRequest,
|
||||||
|
) -> TokenResponse | bool | Literal["2fa_required", "2fa_invalid"] | None:
|
||||||
"""
|
"""
|
||||||
根据账号密码进行登录。
|
根据账号密码进行登录。
|
||||||
|
|
||||||
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
|
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
|
||||||
如果登录异常,返回 `False`(未完成注册或账号被封禁)。
|
如果登录异常,返回 `False`(未完成注册或账号被封禁)。
|
||||||
如果登录失败,返回 `None`。
|
如果登录失败,返回 `None`。
|
||||||
|
如果需要两步验证但未提供验证码,返回 `"2fa_required"`。
|
||||||
|
如果两步验证码无效,返回 `"2fa_invalid"`。
|
||||||
|
|
||||||
:param session: 数据库会话
|
:param session: 数据库会话
|
||||||
:param login_request: 登录请求
|
:param login_request: 登录请求
|
||||||
@@ -45,6 +53,18 @@ async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenResp
|
|||||||
# 未完成注册 or 账号已被封禁
|
# 未完成注册 or 账号已被封禁
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# 检查两步验证
|
||||||
|
if current_user.two_factor:
|
||||||
|
# 用户已启用两步验证
|
||||||
|
if not login_request.two_fa_code:
|
||||||
|
log.debug(f"2FA required for user: {login_request.username}")
|
||||||
|
return "2fa_required"
|
||||||
|
|
||||||
|
# 验证 OTP 码
|
||||||
|
if not verify_totp(current_user.two_factor, login_request.two_fa_code):
|
||||||
|
log.debug(f"Invalid 2FA code for user: {login_request.username}")
|
||||||
|
return "2fa_invalid"
|
||||||
|
|
||||||
# 创建令牌
|
# 创建令牌
|
||||||
access_token, access_expire = create_access_token(data={'sub': current_user.username})
|
access_token, access_expire = create_access_token(data={'sub': current_user.username})
|
||||||
refresh_token, refresh_expire = create_refresh_token(data={'sub': current_user.username})
|
refresh_token, refresh_expire = create_refresh_token(data={'sub': current_user.username})
|
||||||
|
|||||||
13
service/user/totp.py
Normal file
13
service/user/totp.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import pyotp
|
||||||
|
|
||||||
|
|
||||||
|
def verify_totp(secret: str, code: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证 TOTP 验证码。
|
||||||
|
|
||||||
|
:param secret: TOTP 密钥(Base32 编码)
|
||||||
|
:param code: 用户输入的 6 位验证码
|
||||||
|
:return: 验证是否成功
|
||||||
|
"""
|
||||||
|
totp = pyotp.TOTP(secret)
|
||||||
|
return totp.verify(code)
|
||||||
22
uv.lock
generated
22
uv.lock
generated
@@ -364,8 +364,10 @@ dependencies = [
|
|||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "argon2-cffi" },
|
{ name = "argon2-cffi" },
|
||||||
{ name = "fastapi", extra = ["standard"] },
|
{ name = "fastapi", extra = ["standard"] },
|
||||||
|
{ name = "itsdangerous" },
|
||||||
{ name = "loguru" },
|
{ name = "loguru" },
|
||||||
{ name = "pyjwt" },
|
{ name = "pyjwt" },
|
||||||
|
{ name = "pyotp" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
{ name = "sqlalchemy" },
|
{ name = "sqlalchemy" },
|
||||||
@@ -380,8 +382,10 @@ requires-dist = [
|
|||||||
{ name = "aiosqlite", specifier = ">=0.21.0" },
|
{ name = "aiosqlite", specifier = ">=0.21.0" },
|
||||||
{ name = "argon2-cffi", specifier = ">=25.1.0" },
|
{ name = "argon2-cffi", specifier = ">=25.1.0" },
|
||||||
{ name = "fastapi", extras = ["standard"], specifier = ">=0.122.0" },
|
{ name = "fastapi", extras = ["standard"], specifier = ">=0.122.0" },
|
||||||
|
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||||
{ name = "loguru", specifier = ">=0.7.3" },
|
{ name = "loguru", specifier = ">=0.7.3" },
|
||||||
{ name = "pyjwt", specifier = ">=2.10.1" },
|
{ name = "pyjwt", specifier = ">=2.10.1" },
|
||||||
|
{ name = "pyotp", specifier = ">=2.9.0" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
||||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||||
{ name = "sqlalchemy", specifier = ">=2.0.44" },
|
{ name = "sqlalchemy", specifier = ">=2.0.44" },
|
||||||
@@ -698,6 +702,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itsdangerous"
|
||||||
|
version = "2.2.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jinja2"
|
name = "jinja2"
|
||||||
version = "3.1.6"
|
version = "3.1.6"
|
||||||
@@ -1058,6 +1071,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/d1/81/ef2b1dfd1862567d573a4fdbc9f969067621764fbb74338496840a1d2977/pyopenssl-25.3.0-py3-none-any.whl", hash = "sha256:1fda6fc034d5e3d179d39e59c1895c9faeaf40a79de5fc4cbbfbe0d36f4a77b6", size = 57268, upload-time = "2025-09-17T00:32:19.474Z" },
|
{ url = "https://files.pythonhosted.org/packages/d1/81/ef2b1dfd1862567d573a4fdbc9f969067621764fbb74338496840a1d2977/pyopenssl-25.3.0-py3-none-any.whl", hash = "sha256:1fda6fc034d5e3d179d39e59c1895c9faeaf40a79de5fc4cbbfbe0d36f4a77b6", size = 57268, upload-time = "2025-09-17T00:32:19.474Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyotp"
|
||||||
|
version = "2.9.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f3/b2/1d5994ba2acde054a443bd5e2d384175449c7d2b6d1a0614dbca3a63abfc/pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63", size = 17763, upload-time = "2023-07-27T23:41:03.295Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c3/c0/c33c8792c3e50193ef55adb95c1c3c2786fe281123291c2dbf0eaab95a6f/pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612", size = 13376, upload-time = "2023-07-27T23:41:01.685Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dotenv"
|
name = "python-dotenv"
|
||||||
version = "1.2.1"
|
version = "1.2.1"
|
||||||
|
|||||||
Reference in New Issue
Block a user