优化令牌的生成逻辑

This commit is contained in:
2025-12-26 10:58:20 +08:00
parent abd85e2290
commit a716b2b0db
4 changed files with 81 additions and 23 deletions

View File

@@ -1,6 +1,8 @@
from .user import ( from .user import (
LoginRequest, LoginRequest,
RegisterRequest, RegisterRequest,
AccessTokenBase,
RefreshTokenBase,
TokenResponse, TokenResponse,
User, User,
UserBase, UserBase,

View File

@@ -4,6 +4,7 @@ from typing import Literal, TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from pydantic import BaseModel
from .base import SQLModelBase from .base import SQLModelBase
from .model_base import ResponseBase from .model_base import ResponseBase
@@ -110,8 +111,7 @@ class WebAuthnInfo(SQLModelBase):
transports: list[str] transports: list[str]
"""支持的传输方式""" """支持的传输方式"""
class AccessTokenBase(BaseModel):
class TokenResponse(ResponseBase):
"""访问令牌响应 DTO""" """访问令牌响应 DTO"""
access_expires: datetime access_expires: datetime
@@ -120,6 +120,9 @@ class TokenResponse(ResponseBase):
access_token: str access_token: str
"""访问令牌""" """访问令牌"""
class RefreshTokenBase(BaseModel):
"""刷新令牌响应DTO"""
refresh_expires: datetime refresh_expires: datetime
"""刷新令牌过期时间""" """刷新令牌过期时间"""
@@ -127,6 +130,10 @@ class TokenResponse(ResponseBase):
"""刷新令牌""" """刷新令牌"""
class TokenResponse(ResponseBase, AccessTokenBase, RefreshTokenBase):
"""令牌响应 DTO"""
class UserResponse(ResponseBase): class UserResponse(ResponseBase):
"""用户响应 DTO""" """用户响应 DTO"""

View File

@@ -1,3 +1,5 @@
from uuid import uuid4
from loguru import logger from loguru import logger
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
@@ -57,12 +59,18 @@ async def login(
http_exceptions.raise_unauthorized("Invalid 2FA code") http_exceptions.raise_unauthorized("Invalid 2FA code")
# 创建令牌 # 创建令牌
access_token, access_expire = create_access_token(data={'sub': current_user.username}) access_token = create_access_token(data={
refresh_token, refresh_expire = create_refresh_token(data={'sub': current_user.username}) 'sub': str(current_user.id),
'jti': str(uuid4())
})
refresh_token = create_refresh_token(data={
'sub': str(current_user.id),
'jti': str(uuid4())
})
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token.access_token,
access_expires=access_expire, access_expires=access_token.access_expires,
refresh_token=refresh_token, refresh_token=refresh_token.refresh_token,
refresh_expires=refresh_expire, refresh_expires=refresh_token.refresh_expires,
) )

View File

@@ -3,6 +3,8 @@ from datetime import datetime, timedelta, timezone
import jwt import jwt
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from models import AccessTokenBase, RefreshTokenBase
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌', scheme_name='获取 JWT Bearer 令牌',
description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
@@ -30,24 +32,63 @@ async def load_secret_key() -> None:
if setting: if setting:
SECRET_KEY = setting.value SECRET_KEY = setting.value
# 访问令牌 def build_token_payload(
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]: data: dict,
is_refresh: bool,
algorithm: str,
expires_delta: timedelta | None = None,
) -> tuple[str, datetime]:
"""构建令牌"""
to_encode = data.copy() to_encode = data.copy()
if is_refresh:
to_encode.update({"token_type": "refresh"})
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(timezone.utc) + expires_delta
elif is_refresh:
expire = datetime.now(timezone.utc) + timedelta(days=30)
else: else:
expire = datetime.now(timezone.utc) + timedelta(hours=3) expire = datetime.now(timezone.utc) + timedelta(hours=3)
to_encode.update({"exp": expire}) to_encode.update({
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256') "iat": int(datetime.now(timezone.utc).timestamp()),
return encoded_jwt, expire "exp": int(expire.timestamp())
})
return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm), expire
# 访问令牌
def create_access_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> AccessTokenBase:
"""
生成访问令牌,默认有效期 3 小时。
:param data: 需要放进 JWT Payload 的字段。
:param expires_delta: 过期时间, 缺省时为 3 小时。
:param algorithm: JWT 密钥强度,缺省时为 HS256
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
"""
access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta)
return AccessTokenBase(
access_token=access_token,
access_expires=expire_at,
)
# 刷新令牌 # 刷新令牌
def create_refresh_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]: def create_refresh_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> RefreshTokenBase:
to_encode = data.copy() """
if expires_delta: 生成刷新令牌,默认有效期 30 天。
expire = datetime.now(timezone.utc) + expires_delta
else: :param data: 需要放进 JWT Payload 的字段。
expire = datetime.now(timezone.utc) + timedelta(days=30) :param expires_delta: 过期时间, 缺省时为 30 天。
to_encode.update({"exp": expire, "token_type": "refresh"}) :param algorithm: JWT 密钥强度,缺省时为 HS256
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256')
return encoded_jwt, expire :return: 包含密钥本身和过期时间的 `RefreshTokenBase`
"""
refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta)
return RefreshTokenBase(
refresh_token=refresh_token,
refresh_expires=expire_at,
)