feat: 添加两步验证功能,重构相关逻辑,移除冗余代码
This commit is contained in:
@@ -307,5 +307,4 @@ class User(UserBase, TableBase, table=True):
|
|||||||
|
|
||||||
def to_public(self) -> "UserPublic":
|
def to_public(self) -> "UserPublic":
|
||||||
"""转换为公开 DTO,排除敏感字段"""
|
"""转换为公开 DTO,排除敏感字段"""
|
||||||
return UserPublic.model_validate(self)
|
return UserPublic.model_validate(self)
|
||||||
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from .password.pwd import Password, PasswordStatus
|
||||||
@@ -3,6 +3,12 @@ from loguru import logger
|
|||||||
from argon2 import PasswordHasher
|
from argon2 import PasswordHasher
|
||||||
from argon2.exceptions import VerifyMismatchError
|
from argon2.exceptions import VerifyMismatchError
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
import pyotp
|
||||||
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from pkg.JWT.JWT import SECRET_KEY
|
||||||
|
from pkg.conf import appmeta
|
||||||
|
|
||||||
_ph = PasswordHasher()
|
_ph = PasswordHasher()
|
||||||
|
|
||||||
@@ -18,6 +24,24 @@ class PasswordStatus(StrEnum):
|
|||||||
EXPIRED = "expired"
|
EXPIRED = "expired"
|
||||||
"""密码哈希已过时,建议重新哈希"""
|
"""密码哈希已过时,建议重新哈希"""
|
||||||
|
|
||||||
|
class TwoFactorBase(BaseModel):
|
||||||
|
"""两步验证请求 DTO"""
|
||||||
|
|
||||||
|
setup_token: str
|
||||||
|
"""用于验证的令牌"""
|
||||||
|
|
||||||
|
class TwoFactorResponse(TwoFactorBase):
|
||||||
|
"""两步验证-请求启用时的响应 DTO"""
|
||||||
|
|
||||||
|
uri: str
|
||||||
|
"""用于生成二维码的 URI"""
|
||||||
|
|
||||||
|
class TwoFactorVerifyRequest(TwoFactorBase):
|
||||||
|
"""两步验证-验证请求 DTO"""
|
||||||
|
|
||||||
|
code: int = Field(..., ge=100000, le=999999)
|
||||||
|
"""6 位验证码"""
|
||||||
|
|
||||||
class Password:
|
class Password:
|
||||||
"""密码处理工具类,包含密码生成、哈希和验证功能"""
|
"""密码处理工具类,包含密码生成、哈希和验证功能"""
|
||||||
|
|
||||||
@@ -76,4 +100,51 @@ class Password:
|
|||||||
except VerifyMismatchError:
|
except VerifyMismatchError:
|
||||||
# 这是预期的异常,当密码不匹配时触发。
|
# 这是预期的异常,当密码不匹配时触发。
|
||||||
return PasswordStatus.INVALID
|
return PasswordStatus.INVALID
|
||||||
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
|
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def generate_totp(
|
||||||
|
username: str
|
||||||
|
) -> TwoFactorResponse:
|
||||||
|
"""
|
||||||
|
生成 TOTP 密钥和对应的 URI,用于两步验证。
|
||||||
|
|
||||||
|
:return: 包含 TOTP 密钥和 URI 的元组
|
||||||
|
"""
|
||||||
|
|
||||||
|
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||||
|
|
||||||
|
secret = pyotp.random_base32()
|
||||||
|
|
||||||
|
setup_token = serializer.dumps(
|
||||||
|
secret,
|
||||||
|
salt="2fa-setup-salt"
|
||||||
|
)
|
||||||
|
|
||||||
|
otp_uri = pyotp.totp.TOTP(secret).provisioning_uri(
|
||||||
|
name=username,
|
||||||
|
issuer_name=appmeta.APP_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
return TwoFactorResponse(
|
||||||
|
uri=otp_uri,
|
||||||
|
setup_token=setup_token
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_totp(
|
||||||
|
secret: str,
|
||||||
|
code: str
|
||||||
|
) -> PasswordStatus:
|
||||||
|
"""
|
||||||
|
验证 TOTP 验证码。
|
||||||
|
|
||||||
|
:param secret: TOTP 密钥(Base32 编码)
|
||||||
|
:param code: 用户输入的 6 位验证码
|
||||||
|
:return: 验证是否成功
|
||||||
|
"""
|
||||||
|
totp = pyotp.TOTP(secret)
|
||||||
|
if totp.verify(code):
|
||||||
|
return PasswordStatus.VALID
|
||||||
|
else:
|
||||||
|
return PasswordStatus.INVALID
|
||||||
@@ -13,6 +13,7 @@ import service
|
|||||||
from middleware.auth import AuthRequired
|
from middleware.auth import AuthRequired
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
from pkg.JWT.JWT import SECRET_KEY
|
from pkg.JWT.JWT import SECRET_KEY
|
||||||
|
from pkg import Password
|
||||||
|
|
||||||
user_router = APIRouter(
|
user_router = APIRouter(
|
||||||
prefix="/user",
|
prefix="/user",
|
||||||
@@ -445,7 +446,6 @@ def router_user_settings_patch(option: str) -> models.response.ResponseModel:
|
|||||||
dependencies=[Depends(AuthRequired)],
|
dependencies=[Depends(AuthRequired)],
|
||||||
)
|
)
|
||||||
async def router_user_settings_2fa(
|
async def router_user_settings_2fa(
|
||||||
session: SessionDep,
|
|
||||||
user: Annotated[models.user.User, Depends(AuthRequired)],
|
user: Annotated[models.user.User, Depends(AuthRequired)],
|
||||||
) -> models.response.ResponseModel:
|
) -> models.response.ResponseModel:
|
||||||
"""
|
"""
|
||||||
@@ -455,27 +455,8 @@ async def router_user_settings_2fa(
|
|||||||
dict: A dictionary containing two-factor authentication setup information.
|
dict: A dictionary containing two-factor authentication setup information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
|
||||||
|
|
||||||
secret = pyotp.random_base32()
|
|
||||||
|
|
||||||
setup_token = serializer.dumps(
|
|
||||||
secret,
|
|
||||||
salt="2fa-setup-salt"
|
|
||||||
)
|
|
||||||
|
|
||||||
site_Name = await models.Setting.get(session, (models.Setting.type == models.SettingsType.BASIC) & (models.Setting.name == "siteName"))
|
|
||||||
|
|
||||||
otp_uri = pyotp.totp.TOTP(secret).provisioning_uri(
|
|
||||||
name=user.username,
|
|
||||||
issuer_name=site_Name.value
|
|
||||||
)
|
|
||||||
|
|
||||||
return models.response.ResponseModel(
|
return models.response.ResponseModel(
|
||||||
data={
|
data=await Password.generate_totp(user.username)
|
||||||
"setup_token": setup_token,
|
|
||||||
"otp_uri": otp_uri,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@user_settings_router.post(
|
@user_settings_router.post(
|
||||||
@@ -508,7 +489,7 @@ async def router_user_settings_2fa_enable(
|
|||||||
raise HTTPException(status_code=400, detail="Invalid token")
|
raise HTTPException(status_code=400, detail="Invalid token")
|
||||||
|
|
||||||
# 2. 验证用户输入的 6 位验证码
|
# 2. 验证用户输入的 6 位验证码
|
||||||
if not service.user.verify_totp(secret, code):
|
if not Password.verify_totp(secret, code):
|
||||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||||
|
|
||||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||||
|
|||||||
@@ -1,2 +1 @@
|
|||||||
from .login import Login
|
from .login import Login
|
||||||
from .totp import verify_totp
|
|
||||||
@@ -5,7 +5,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
|
|
||||||
from models import LoginRequest, TokenResponse, User
|
from models import LoginRequest, TokenResponse, User
|
||||||
from pkg.JWT.JWT import create_access_token, create_refresh_token
|
from pkg.JWT.JWT import create_access_token, create_refresh_token
|
||||||
from .totp import verify_totp
|
|
||||||
|
|
||||||
|
|
||||||
async def Login(
|
async def Login(
|
||||||
@@ -61,7 +60,7 @@ async def Login(
|
|||||||
return "2fa_required"
|
return "2fa_required"
|
||||||
|
|
||||||
# 验证 OTP 码
|
# 验证 OTP 码
|
||||||
if not verify_totp(current_user.two_factor, login_request.two_fa_code):
|
if not Password.verify_totp(current_user.two_factor, login_request.two_fa_code):
|
||||||
log.debug(f"Invalid 2FA code for user: {login_request.username}")
|
log.debug(f"Invalid 2FA code for user: {login_request.username}")
|
||||||
return "2fa_invalid"
|
return "2fa_invalid"
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
import pyotp
|
|
||||||
|
|
||||||
|
|
||||||
def verify_totp(secret: str, code: str) -> bool:
|
|
||||||
"""
|
|
||||||
验证 TOTP 验证码。
|
|
||||||
|
|
||||||
:param secret: TOTP 密钥(Base32 编码)
|
|
||||||
:param code: 用户输入的 6 位验证码
|
|
||||||
:return: 验证是否成功
|
|
||||||
"""
|
|
||||||
totp = pyotp.TOTP(secret)
|
|
||||||
return totp.verify(code)
|
|
||||||
Reference in New Issue
Block a user