feat: 添加两步验证功能,优化用户登录逻辑,更新相关模型和依赖

This commit is contained in:
2025-12-19 14:11:24 +08:00
parent 89e837d91c
commit b7c5d5aec7
13 changed files with 248 additions and 76 deletions

View File

@@ -25,7 +25,7 @@ class GroupOptionsBase(SQLModelBase):
"""是否允许分享下载"""
share_free: bool = False
"""是否免积分分享"""
"""是否免积分获取需要积分的内容"""
relocate: bool = False
"""是否允许文件重定位"""
@@ -136,3 +136,22 @@ class Group(GroupBase, TableBase, table=True):
back_populates="previous_group",
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
)
def to_response(self) -> "GroupResponse":
"""转换为响应 DTO"""
opts = self.options
return GroupResponse(
id=self.id,
name=self.name,
allow_share=self.share_enabled,
webdav=self.web_dav_enabled,
share_download=opts.share_download if opts else False,
share_free=opts.share_free if opts else False,
relocate=opts.relocate if opts else False,
source_batch=opts.source_batch if opts else 0,
select_node=opts.select_node if opts else False,
advance_delete=opts.advance_delete if opts else False,
allow_remote_download=opts.aria2 if opts else False,
allow_archive_download=opts.archive_download if opts else False,
allow_webdav_proxy=opts.webdav_proxy if opts else False,
)

View File

@@ -15,8 +15,8 @@ async def migration() -> None:
log.info('开始进行数据库初始化...')
await init_default_settings()
await init_default_group()
await init_default_policy()
await init_default_group()
await init_default_user()
log.info('数据库初始化结束')
@@ -147,6 +147,7 @@ async def init_default_group() -> None:
if not await Group.get(session, Group.id == 1):
admin_group = await Group(
name="管理员",
policies="1",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
@@ -158,7 +159,10 @@ async def init_default_group() -> None:
archive_download=True,
archive_task=True,
share_download=True,
share_free=True,
aria2=True,
select_node=True,
advance_delete=True,
).save(session)
# 未找到初始注册会员时,则创建

View File

@@ -84,8 +84,8 @@ class PolicyResponse(SQLModelBase):
max_size: int = 0
"""单文件最大限制单位字节0表示不限制"""
file_type: list[str] = []
"""允许的文件类型列表,空列表表示不限制"""
file_type: list[str] | None = None
"""允许的文件类型列表,None 表示不限制"""
class DirectoryResponse(SQLModelBase):

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from enum import StrEnum
from typing import Literal, Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship
@@ -26,6 +27,13 @@ Option 需求
- 切换到不同存储策略是否提醒
"""
class AvatarType(StrEnum):
"""头像类型枚举"""
DEFAULT = "default"
GRAVATAR = "gravatar"
FILE = "file"
# ==================== Base 模型 ====================
@@ -227,7 +235,7 @@ class User(UserBase, TableBase, table=True):
username: str = Field(max_length=50, unique=True, index=True)
"""用户名,唯一,一经注册不可更改"""
nick: str | None = Field(default=None, max_length=50)
nickname: str | None = Field(default=None, max_length=50)
"""用于公开展示的名字,可使用真实姓名或昵称"""
password: str = Field(max_length=255)
@@ -242,7 +250,7 @@ class User(UserBase, TableBase, table=True):
two_factor: str | None = Field(default=None, min_length=32, max_length=32)
"""两步验证密钥"""
avatar: str | None = Field(default=None, max_length=255)
avatar: str = Field(default="default", max_length=255)
"""头像地址"""
options: str | None = None

View File

@@ -1,4 +0,0 @@
# 延迟导入以避免循环依赖
# JWT 和 lifespan 应在需要时直接从子模块导入
# from .JWT import JWT
# from .lifespan import lifespan

View File

@@ -9,8 +9,10 @@ dependencies = [
"aiosqlite>=0.21.0",
"argon2-cffi>=25.1.0",
"fastapi[standard]>=0.122.0",
"itsdangerous>=2.2.0",
"loguru>=0.7.3",
"pyjwt>=2.10.1",
"pyotp>=2.9.0",
"python-dotenv>=1.2.1",
"python-multipart>=0.0.20",
"sqlalchemy>=2.0.44",

View File

@@ -28,13 +28,13 @@ async def router_directory_get(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
path: str = ""
) -> response.ResponseModel:
) -> DirectoryResponse:
"""
获取目录内容
:param session: 数据库会话
:param user: 当前登录用户
:param path: 目录路径,空或 "/" 表示根目录
:param path: 目录路径, "~" 表示根目录
:return: 目录内容
"""
folder = await Object.get_by_path(session, user.id, path or "/")
@@ -45,6 +45,9 @@ async def router_directory_get(
if not folder.is_folder:
raise HTTPException(status_code=400, detail="指定路径不是目录")
if path != "~":
path = path.lstrip("~")
children = await Object.get_children(session, user.id, folder.id)
policy = await folder.awaitable_attrs.policy
@@ -55,7 +58,7 @@ async def router_directory_get(
path=f"/{child.name}", # TODO: 完整路径
thumb=False,
size=child.size,
type="folder" if child.is_folder else "file",
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
date=child.updated_at,
create_date=child.created_at,
source_enabled=False,
@@ -63,18 +66,17 @@ async def router_directory_get(
for child in children
]
return response.ResponseModel(
data=DirectoryResponse(
parent=str(folder.parent_id) if folder.parent_id else None,
objects=objects,
policy=PolicyResponse(
id=str(policy.id),
name=policy.name,
type=policy.type.value,
max_size=policy.max_size,
file_type=[],
),
)
return DirectoryResponse(
parent=str(folder.parent_id) if folder.parent_id else None,
objects=objects,
policy=policy,
)

View File

@@ -1,15 +1,18 @@
from typing import Annotated
from typing import Annotated, Literal
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy import and_
from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict
import pyotp
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
import models
import service
from middleware.auth import AuthRequired
from middleware.dependencies import SessionDep
from pkg.JWT.JWT import SECRET_KEY
user_router = APIRouter(
prefix="/user",
@@ -25,18 +28,46 @@ user_settings_router = APIRouter(
@user_router.post(
path='/session',
summary='用户登录',
description='User login endpoint.',
description='User login endpoint. 当用户启用两步验证时,需要传入 otp 参数。',
)
async def router_user_session(
session: SessionDep,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> models.TokenResponse:
"""
用户登录端点。
根据 OAuth2.1 规范,使用 password grant type 进行登录。
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
:raises HTTPException 401: 用户名或密码错误
:raises HTTPException 403: 用户账号被封禁或未完成注册
:raises HTTPException 428: 需要两步验证但未提供验证码
:raises HTTPException 400: 两步验证码无效
"""
username = form_data.username
password = form_data.password
# 从 scopes 中提取 OTP 验证码OAuth2.1 扩展方式)
# scopes 格式可以是 ["otp:123456"] 或 ["123456"]
otp_code: str | None = None
for scope in form_data.scopes:
if scope.startswith("otp:"):
otp_code = scope[4:]
break
elif scope.isdigit() and len(scope) == 6:
otp_code = scope
break
result = await service.user.Login(
session,
models.LoginRequest(username=username, password=password),
models.LoginRequest(
username=username,
password=password,
two_fa_code=otp_code,
),
)
if isinstance(result, models.TokenResponse):
@@ -45,6 +76,14 @@ async def router_user_session(
raise HTTPException(status_code=401, detail="Invalid username or password")
elif result is False:
raise HTTPException(status_code=403, detail="User account is banned or not fully registered")
elif result == "2fa_required":
raise HTTPException(
status_code=428,
detail="Two-factor authentication required",
headers={"X-2FA-Required": "true"},
)
elif result == "2fa_invalid":
raise HTTPException(status_code=400, detail="Invalid two-factor authentication code")
else:
raise HTTPException(status_code=500, detail="Internal server error during login")
@@ -62,26 +101,14 @@ def router_user_register() -> models.response.ResponseModel:
"""
pass
@user_router.post(
path='/2fa',
summary='用两步验证登录',
description='Two-factor authentication login endpoint.',
)
def router_user_2fa() -> models.response.ResponseModel:
"""
Two-factor authentication login endpoint.
Returns:
dict: A dictionary containing two-factor authentication information.
"""
pass
@user_router.post(
path='/code',
summary='发送验证码邮件',
description='Send a verification code email.',
)
def router_user_email_code() -> models.response.ResponseModel:
def router_user_email_code(
reason: Literal['register', 'reset'] = 'register',
) -> models.response.ResponseModel:
"""
Send a verification code email.
@@ -90,21 +117,6 @@ def router_user_email_code() -> models.response.ResponseModel:
"""
pass
@user_router.patch(
path='/reset',
summary='通过邮件里的链接重设密码',
description='Reset password via email link.',
deprecated=True,
)
def router_user_reset_patch() -> models.response.ResponseModel:
"""
Reset password via email link.
Returns:
dict: A dictionary containing information about the password reset.
"""
pass
@user_router.get(
path='/qq',
summary='初始化QQ登录',
@@ -193,7 +205,7 @@ def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseMode
)
async def router_user_me(
session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)],
user: Annotated[models.User, Depends(AuthRequired)],
) -> models.response.ResponseModel:
"""
获取用户信息.
@@ -201,25 +213,32 @@ async def router_user_me(
:return: response.ResponseModel containing user information.
:rtype: response.ResponseModel
"""
group = await models.Group.get(session, models.Group.id == user.group_id)
user_group = models.GroupResponse(
id=group.id,
name=group.name,
allow_share=group.share_enabled,
# 加载 group 及其 options 关系
group = await models.Group.get(
session,
models.Group.id == user.group_id,
load=models.Group.options
)
users = models.UserResponse(
# 构建 GroupResponse
group_response = group.to_response() if group else None
# 异步加载 tags 关系
user_tags = await user.awaitable_attrs.tags
user_response = models.UserResponse(
id=user.id,
username=user.username,
nickname=user.nick,
status=user.status,
created_at=user.created_at,
score=user.score,
group=user_group,
).model_dump()
nickname=user.nickname,
avatar=user.avatar,
created_at=user.created_at,
group=group_response,
tags=[tag.name for tag in user_tags] if user_tags else [],
)
return models.response.ResponseModel(data=users)
return models.response.ResponseModel(data=user_response.model_dump())
@user_router.get(
path='/storage',
@@ -425,11 +444,77 @@ def router_user_settings_patch(option: str) -> models.response.ResponseModel:
description='Get two-factor authentication initialization information.',
dependencies=[Depends(AuthRequired)],
)
def router_user_settings_2fa() -> models.response.ResponseModel:
async def router_user_settings_2fa(
session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)],
) -> models.response.ResponseModel:
"""
Get two-factor authentication initialization information.
Returns:
dict: A dictionary containing two-factor authentication setup information.
"""
pass
serializer = URLSafeTimedSerializer(SECRET_KEY)
secret = pyotp.random_base32()
setup_token = serializer.dumps(
secret,
salt="2fa-setup-salt"
)
site_Name = await models.Setting.get(session, (models.Setting.type == models.SettingsType.BASIC) & (models.Setting.name == "siteName"))
otp_uri = pyotp.totp.TOTP(secret).provisioning_uri(
name=user.username,
issuer_name=site_Name.value
)
return models.response.ResponseModel(
data={
"setup_token": setup_token,
"otp_uri": otp_uri,
}
)
@user_settings_router.post(
path='/2fa',
summary='启用两步验证',
description='Enable two-factor authentication.',
dependencies=[Depends(AuthRequired)],
)
async def router_user_settings_2fa_enable(
session: SessionDep,
user: Annotated[models.user.User, Depends(AuthRequired)],
setup_token: str,
code: str,
) -> models.response.ResponseModel:
"""
Enable two-factor authentication for the user.
Returns:
dict: A dictionary containing the result of enabling two-factor authentication.
"""
serializer = URLSafeTimedSerializer(SECRET_KEY)
try:
# 1. 解包 Token设置有效期例如 600秒
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
except SignatureExpired:
raise HTTPException(status_code=400, detail="Setup session expired")
except BadSignature:
raise HTTPException(status_code=400, detail="Invalid token")
# 2. 验证用户输入的 6 位验证码
if not service.user.verify_totp(secret, code):
raise HTTPException(status_code=400, detail="Invalid OTP code")
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
user.two_factor = secret
user = await user.save(session)
return models.response.ResponseModel(
data={"message": "Two-factor authentication enabled successfully"}
)

View File

@@ -2,4 +2,4 @@
服务层
"""
from .user import login
from . import user

View File

@@ -1 +1,2 @@
from .login import Login
from .totp import verify_totp

View File

@@ -1,17 +1,25 @@
from typing import Literal
from loguru import logger as log
from sqlmodel.ext.asyncio.session import AsyncSession
from models import LoginRequest, TokenResponse, User
from pkg.JWT.JWT import create_access_token, create_refresh_token
from .totp import verify_totp
async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenResponse | bool | None:
async def Login(
session: AsyncSession,
login_request: LoginRequest,
) -> TokenResponse | bool | Literal["2fa_required", "2fa_invalid"] | None:
"""
根据账号密码进行登录。
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
如果登录异常,返回 `False`(未完成注册或账号被封禁)。
如果登录失败,返回 `None`。
如果需要两步验证但未提供验证码,返回 `"2fa_required"`。
如果两步验证码无效,返回 `"2fa_invalid"`。
:param session: 数据库会话
:param login_request: 登录请求
@@ -45,6 +53,18 @@ async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenResp
# 未完成注册 or 账号已被封禁
return False
# 检查两步验证
if current_user.two_factor:
# 用户已启用两步验证
if not login_request.two_fa_code:
log.debug(f"2FA required for user: {login_request.username}")
return "2fa_required"
# 验证 OTP 码
if not verify_totp(current_user.two_factor, login_request.two_fa_code):
log.debug(f"Invalid 2FA code for user: {login_request.username}")
return "2fa_invalid"
# 创建令牌
access_token, access_expire = create_access_token(data={'sub': current_user.username})
refresh_token, refresh_expire = create_refresh_token(data={'sub': current_user.username})

13
service/user/totp.py Normal file
View File

@@ -0,0 +1,13 @@
import pyotp
def verify_totp(secret: str, code: str) -> bool:
"""
验证 TOTP 验证码。
:param secret: TOTP 密钥Base32 编码)
:param code: 用户输入的 6 位验证码
:return: 验证是否成功
"""
totp = pyotp.TOTP(secret)
return totp.verify(code)

22
uv.lock generated
View File

@@ -364,8 +364,10 @@ dependencies = [
{ name = "aiosqlite" },
{ name = "argon2-cffi" },
{ name = "fastapi", extra = ["standard"] },
{ name = "itsdangerous" },
{ name = "loguru" },
{ name = "pyjwt" },
{ name = "pyotp" },
{ name = "python-dotenv" },
{ name = "python-multipart" },
{ name = "sqlalchemy" },
@@ -380,8 +382,10 @@ requires-dist = [
{ name = "aiosqlite", specifier = ">=0.21.0" },
{ name = "argon2-cffi", specifier = ">=25.1.0" },
{ name = "fastapi", extras = ["standard"], specifier = ">=0.122.0" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "pyjwt", specifier = ">=2.10.1" },
{ name = "pyotp", specifier = ">=2.9.0" },
{ name = "python-dotenv", specifier = ">=1.2.1" },
{ name = "python-multipart", specifier = ">=0.0.20" },
{ name = "sqlalchemy", specifier = ">=2.0.44" },
@@ -698,6 +702,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
]
[[package]]
name = "itsdangerous"
version = "2.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" },
]
[[package]]
name = "jinja2"
version = "3.1.6"
@@ -1058,6 +1071,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/81/ef2b1dfd1862567d573a4fdbc9f969067621764fbb74338496840a1d2977/pyopenssl-25.3.0-py3-none-any.whl", hash = "sha256:1fda6fc034d5e3d179d39e59c1895c9faeaf40a79de5fc4cbbfbe0d36f4a77b6", size = 57268, upload-time = "2025-09-17T00:32:19.474Z" },
]
[[package]]
name = "pyotp"
version = "2.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f3/b2/1d5994ba2acde054a443bd5e2d384175449c7d2b6d1a0614dbca3a63abfc/pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63", size = 17763, upload-time = "2023-07-27T23:41:03.295Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/c0/c33c8792c3e50193ef55adb95c1c3c2786fe281123291c2dbf0eaab95a6f/pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612", size = 13376, upload-time = "2023-07-27T23:41:01.685Z" },
]
[[package]]
name = "python-dotenv"
version = "1.2.1"