From a716b2b0dbbbce594c1fb20c7a49f6d31bf30dec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Fri, 26 Dec 2025 10:58:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A4=E7=89=8C=E7=9A=84?= =?UTF-8?q?=E7=94=9F=E6=88=90=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/__init__.py | 2 ++ models/user.py | 11 +++++-- service/user/login.py | 20 ++++++++---- utils/JWT/JWT.py | 71 ++++++++++++++++++++++++++++++++++--------- 4 files changed, 81 insertions(+), 23 deletions(-) diff --git a/models/__init__.py b/models/__init__.py index 790334d..a1b5367 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,6 +1,8 @@ from .user import ( LoginRequest, RegisterRequest, + AccessTokenBase, + RefreshTokenBase, TokenResponse, User, UserBase, diff --git a/models/user.py b/models/user.py index d21183b..4146717 100644 --- a/models/user.py +++ b/models/user.py @@ -4,6 +4,7 @@ from typing import Literal, TYPE_CHECKING from uuid import UUID from sqlmodel import Field, Relationship +from pydantic import BaseModel from .base import SQLModelBase from .model_base import ResponseBase @@ -110,8 +111,7 @@ class WebAuthnInfo(SQLModelBase): transports: list[str] """支持的传输方式""" - -class TokenResponse(ResponseBase): +class AccessTokenBase(BaseModel): """访问令牌响应 DTO""" access_expires: datetime @@ -120,6 +120,9 @@ class TokenResponse(ResponseBase): access_token: str """访问令牌""" +class RefreshTokenBase(BaseModel): + """刷新令牌响应DTO""" + refresh_expires: datetime """刷新令牌过期时间""" @@ -127,6 +130,10 @@ class TokenResponse(ResponseBase): """刷新令牌""" +class TokenResponse(ResponseBase, AccessTokenBase, RefreshTokenBase): + """令牌响应 DTO""" + + class UserResponse(ResponseBase): """用户响应 DTO""" diff --git a/service/user/login.py b/service/user/login.py index 3f9e7e5..1d57d8f 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + from loguru import logger from middleware.dependencies import SessionDep @@ -57,12 +59,18 @@ async def login( http_exceptions.raise_unauthorized("Invalid 2FA code") # 创建令牌 - access_token, access_expire = create_access_token(data={'sub': current_user.username}) - refresh_token, refresh_expire = create_refresh_token(data={'sub': current_user.username}) + access_token = create_access_token(data={ + 'sub': str(current_user.id), + 'jti': str(uuid4()) + }) + refresh_token = create_refresh_token(data={ + 'sub': str(current_user.id), + 'jti': str(uuid4()) + }) return TokenResponse( - access_token=access_token, - access_expires=access_expire, - refresh_token=refresh_token, - refresh_expires=refresh_expire, + access_token=access_token.access_token, + access_expires=access_token.access_expires, + refresh_token=refresh_token.refresh_token, + refresh_expires=refresh_token.refresh_expires, ) \ No newline at end of file diff --git a/utils/JWT/JWT.py b/utils/JWT/JWT.py index 67a7c9c..b5fb5ca 100644 --- a/utils/JWT/JWT.py +++ b/utils/JWT/JWT.py @@ -3,6 +3,8 @@ from datetime import datetime, timedelta, timezone import jwt from fastapi.security import OAuth2PasswordBearer +from models import AccessTokenBase, RefreshTokenBase + oauth2_scheme = OAuth2PasswordBearer( scheme_name='获取 JWT Bearer 令牌', description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', @@ -29,25 +31,64 @@ async def load_secret_key() -> None: ) if setting: SECRET_KEY = setting.value - -# 访问令牌 -def create_access_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]: + +def build_token_payload( + data: dict, + is_refresh: bool, + algorithm: str, + expires_delta: timedelta | None = None, +) -> tuple[str, datetime]: + """构建令牌""" + to_encode = data.copy() + + if is_refresh: + to_encode.update({"token_type": "refresh"}) + if expires_delta: expire = datetime.now(timezone.utc) + expires_delta + elif is_refresh: + expire = datetime.now(timezone.utc) + timedelta(days=30) else: expire = datetime.now(timezone.utc) + timedelta(hours=3) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256') - return encoded_jwt, expire + to_encode.update({ + "iat": int(datetime.now(timezone.utc).timestamp()), + "exp": int(expire.timestamp()) + }) + return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm), expire + +# 访问令牌 +def create_access_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> AccessTokenBase: + """ + 生成访问令牌,默认有效期 3 小时。 + + :param data: 需要放进 JWT Payload 的字段。 + :param expires_delta: 过期时间, 缺省时为 3 小时。 + :param algorithm: JWT 密钥强度,缺省时为 HS256 + + :return: 包含密钥本身和过期时间的 `AccessTokenBase` + """ + + access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta) + return AccessTokenBase( + access_token=access_token, + access_expires=expire_at, + ) # 刷新令牌 -def create_refresh_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]: - to_encode = data.copy() - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta(days=30) - to_encode.update({"exp": expire, "token_type": "refresh"}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256') - return encoded_jwt, expire \ No newline at end of file +def create_refresh_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> RefreshTokenBase: + """ + 生成刷新令牌,默认有效期 30 天。 + + :param data: 需要放进 JWT Payload 的字段。 + :param expires_delta: 过期时间, 缺省时为 30 天。 + :param algorithm: JWT 密钥强度,缺省时为 HS256 + + :return: 包含密钥本身和过期时间的 `RefreshTokenBase` + """ + + refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta) + return RefreshTokenBase( + refresh_token=refresh_token, + refresh_expires=expire_at, + ) \ No newline at end of file