From 3088a9d5481cb98439dbed6e38e39eee0fd43e32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Fri, 26 Dec 2025 17:47:51 +0800 Subject: [PATCH] Refactor JWT utilities and download token logic Merged JWT utility functions into utils/JWT/__init__.py and removed utils/JWT/JWT.py. Refactored download token creation and verification to use new functions, replacing DownloadTokenManager with create_download_token and verify_download_token. Updated imports across the codebase to reflect the new JWT utility structure. Improved download file logic to use physical file storage path and added a dedicated response model for download tokens. --- main.py | 2 +- middleware/auth.py | 20 ++++- routers/api/v1/file/__init__.py | 68 ++++++----------- routers/api/v1/user/__init__.py | 2 +- service/user/login.py | 2 +- utils/JWT/JWT.py | 94 ----------------------- utils/JWT/__init__.py | 129 ++++++++++++++++++++++++++++++-- utils/password/pwd.py | 2 +- 8 files changed, 168 insertions(+), 151 deletions(-) delete mode 100644 utils/JWT/JWT.py diff --git a/main.py b/main.py index 781fd0c..61c05a9 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from utils.http.http_exceptions import raise_internal_error from utils.lifespan import lifespan from models.database import init_db from models.migration import migration -from utils.JWT import JWT +from utils import JWT from routers import router from loguru import logger as l diff --git a/middleware/auth.py b/middleware/auth.py index 128f377..5ce7f42 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -5,7 +5,7 @@ from fastapi import Depends import jwt from models.user import User -from utils.JWT import JWT +from utils import JWT from .dependencies import SessionDep from utils import http_exceptions @@ -47,4 +47,20 @@ async def admin_required( group = await user.awaitable_attrs.group if group.admin: return user - raise http_exceptions.raise_forbidden("Admin Required") \ No newline at end of file + raise http_exceptions.raise_forbidden("Admin Required") + + +def verify_download_token(token: str) -> tuple[UUID, UUID] | None: + """ + 验证下载令牌并返回 (file_id, owner_id)。 + + :param token: JWT 令牌字符串 + :return: (file_id, owner_id) 或 None(验证失败) + """ + try: + payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"]) + if payload.get("type") != "download": + return None + return UUID(payload["file_id"]), UUID(payload["owner_id"]) + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): + return None \ No newline at end of file diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py index 57acc5e..380a90a 100644 --- a/routers/api/v1/file/__init__.py +++ b/routers/api/v1/file/__init__.py @@ -8,16 +8,15 @@ - /file/upload - 上传相关操作 - /file/download - 下载相关操作 """ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from typing import Annotated from uuid import UUID -import jwt from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi.responses import FileResponse from loguru import logger as l -from middleware.auth import auth_required +from middleware.auth import auth_required, verify_download_token from middleware.dependencies import SessionDep from models import ( CreateFileRequest, @@ -34,43 +33,20 @@ from models import ( User, ) from service.storage import LocalStorageService -from utils.JWT import SECRET_KEY +from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL from utils import http_exceptions -# ==================== 下载令牌管理 ==================== - -class DownloadTokenManager: - """下载令牌管理器(JWT 无状态)""" - - _ttl: timedelta = timedelta(hours=1) - - @classmethod - def create(cls, file_id: UUID, owner_id: int) -> str: - """创建下载令牌""" - payload = { - "file_id": str(file_id), - "owner_id": owner_id, - "exp": datetime.now(timezone.utc) + cls._ttl, - "type": "download", - } - return jwt.encode(payload, SECRET_KEY, algorithm="HS256") - - @classmethod - def verify(cls, token: str) -> tuple[UUID, int] | None: - """ - 验证令牌并返回 (file_id, owner_id) - - :return: (file_id, owner_id) 或 None(验证失败) - """ - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) - if payload.get("type") != "download": - return None - return UUID(payload["file_id"]), payload["owner_id"] - except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): - return None +# DTO +class DownloadTokenModel(ResponseBase): + """下载Token响应模型""" + + access_token: str + """JWT 令牌""" + + expires_in: int + """过期时间(秒)""" # ==================== 主路由 ==================== @@ -367,11 +343,11 @@ _download_router = APIRouter(prefix="/download") summary='创建下载令牌', description='为指定文件创建下载令牌(JWT),有效期1小时。', ) -async def create_download_token( +async def create_download_token_endpoint( session: SessionDep, user: Annotated[User, Depends(auth_required)], file_id: UUID, -) -> ResponseBase: +) -> DownloadTokenModel: """ 创建下载令牌端点 @@ -384,11 +360,11 @@ async def create_download_token( if not file_obj.is_file: raise HTTPException(status_code=400, detail="对象不是文件") - token = DownloadTokenManager.create(file_id, user.id) + token = create_download_token(file_id, user.id) l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}") - return ResponseBase(data={"token": token, "expires_in": 3600}) + return DownloadTokenModel(access_token=token, expires_in=int(DOWNLOAD_TOKEN_TTL.total_seconds())) @_download_router.get( @@ -406,7 +382,7 @@ async def download_file( 验证 JWT 令牌后返回文件内容。 """ # 验证令牌 - result = DownloadTokenManager.verify(token) + result = verify_download_token(token) if not result: raise HTTPException(status_code=401, detail="下载令牌无效或已过期") @@ -420,9 +396,13 @@ async def download_file( if not file_obj.is_file: raise HTTPException(status_code=400, detail="对象不是文件") - if not file_obj.source_name: + # 预加载 physical_file 关系以获取存储路径 + physical_file = await file_obj.awaitable_attrs.physical_file + if not physical_file or not physical_file.storage_path: raise HTTPException(status_code=500, detail="文件存储路径丢失") + storage_path = physical_file.storage_path + # 获取策略 policy = await Policy.get(session, Policy.id == file_obj.policy_id) if not policy: @@ -430,11 +410,11 @@ async def download_file( if policy.type == PolicyType.LOCAL: storage_service = LocalStorageService(policy) - if not await storage_service.file_exists(file_obj.source_name): + if not await storage_service.file_exists(storage_path): raise HTTPException(status_code=404, detail="物理文件不存在") return FileResponse( - path=file_obj.source_name, + path=storage_path, filename=file_obj.name, media_type="application/octet-stream", ) diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py index 0892ae6..5b07530 100644 --- a/routers/api/v1/user/__init__.py +++ b/routers/api/v1/user/__init__.py @@ -13,7 +13,7 @@ import models import service from middleware.auth import auth_required from middleware.dependencies import SessionDep -from utils.JWT.JWT import SECRET_KEY +from utils.JWT import SECRET_KEY from utils import Password, http_exceptions user_router = APIRouter( diff --git a/service/user/login.py b/service/user/login.py index 1d57d8f..88f808f 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -5,7 +5,7 @@ from loguru import logger from middleware.dependencies import SessionDep from models import LoginRequest, TokenResponse, User from utils import http_exceptions -from utils.JWT.JWT import create_access_token, create_refresh_token +from utils.JWT import create_access_token, create_refresh_token from utils.password.pwd import Password, PasswordStatus diff --git a/utils/JWT/JWT.py b/utils/JWT/JWT.py deleted file mode 100644 index b5fb5ca..0000000 --- a/utils/JWT/JWT.py +++ /dev/null @@ -1,94 +0,0 @@ -from datetime import datetime, timedelta, timezone - -import jwt -from fastapi.security import OAuth2PasswordBearer - -from models import AccessTokenBase, RefreshTokenBase - -oauth2_scheme = OAuth2PasswordBearer( - scheme_name='获取 JWT Bearer 令牌', - description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', - tokenUrl="/api/v1/user/session", - refreshUrl="/api/v1/user/session/refresh", -) - -SECRET_KEY = '' - - -async def load_secret_key() -> None: - """ - 从数据库读取 JWT 的密钥。 - """ - # 延迟导入以避免循环依赖 - from models.database import get_session - from models.setting import Setting - - global SECRET_KEY - async for session in get_session(): - setting = await Setting.get( - session, - (Setting.type == "auth") & (Setting.name == "secret_key") - ) - if setting: - SECRET_KEY = setting.value - -def build_token_payload( - data: dict, - is_refresh: bool, - algorithm: str, - expires_delta: timedelta | None = None, -) -> tuple[str, datetime]: - """构建令牌""" - - to_encode = data.copy() - - if is_refresh: - to_encode.update({"token_type": "refresh"}) - - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - elif is_refresh: - expire = datetime.now(timezone.utc) + timedelta(days=30) - else: - expire = datetime.now(timezone.utc) + timedelta(hours=3) - to_encode.update({ - "iat": int(datetime.now(timezone.utc).timestamp()), - "exp": int(expire.timestamp()) - }) - return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm), expire - -# 访问令牌 -def create_access_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> AccessTokenBase: - """ - 生成访问令牌,默认有效期 3 小时。 - - :param data: 需要放进 JWT Payload 的字段。 - :param expires_delta: 过期时间, 缺省时为 3 小时。 - :param algorithm: JWT 密钥强度,缺省时为 HS256 - - :return: 包含密钥本身和过期时间的 `AccessTokenBase` - """ - - access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta) - return AccessTokenBase( - access_token=access_token, - access_expires=expire_at, - ) - -# 刷新令牌 -def create_refresh_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> RefreshTokenBase: - """ - 生成刷新令牌,默认有效期 30 天。 - - :param data: 需要放进 JWT Payload 的字段。 - :param expires_delta: 过期时间, 缺省时为 30 天。 - :param algorithm: JWT 密钥强度,缺省时为 HS256 - - :return: 包含密钥本身和过期时间的 `RefreshTokenBase` - """ - - refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta) - return RefreshTokenBase( - refresh_token=refresh_token, - refresh_expires=expire_at, - ) \ No newline at end of file diff --git a/utils/JWT/__init__.py b/utils/JWT/__init__.py index 2e58c01..41c2d6b 100644 --- a/utils/JWT/__init__.py +++ b/utils/JWT/__init__.py @@ -1,8 +1,123 @@ -from . import JWT -from .JWT import ( - create_access_token, - create_refresh_token, - load_secret_key, - oauth2_scheme, - SECRET_KEY, +from datetime import datetime, timedelta, timezone +from uuid import UUID + +import jwt +from fastapi.security import OAuth2PasswordBearer + +from models import AccessTokenBase, RefreshTokenBase + +oauth2_scheme = OAuth2PasswordBearer( + scheme_name='获取 JWT Bearer 令牌', + description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', + tokenUrl="/api/v1/user/session", + refreshUrl="/api/v1/user/session/refresh", ) + +SECRET_KEY = '' + + +async def load_secret_key() -> None: + """ + 从数据库读取 JWT 的密钥。 + """ + # 延迟导入以避免循环依赖 + from models.database import get_session + from models.setting import Setting + + global SECRET_KEY + async for session in get_session(): + setting = await Setting.get( + session, + (Setting.type == "auth") & (Setting.name == "secret_key") + ) + if setting: + SECRET_KEY = setting.value + + +def build_token_payload( + data: dict, + is_refresh: bool, + algorithm: str, + expires_delta: timedelta | None = None, +) -> tuple[str, datetime]: + """构建令牌""" + + to_encode = data.copy() + + if is_refresh: + to_encode.update({"token_type": "refresh"}) + + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + elif is_refresh: + expire = datetime.now(timezone.utc) + timedelta(days=30) + else: + expire = datetime.now(timezone.utc) + timedelta(hours=3) + to_encode.update({ + "iat": int(datetime.now(timezone.utc).timestamp()), + "exp": int(expire.timestamp()) + }) + return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm), expire + + +# 访问令牌 +def create_access_token(data: dict, expires_delta: timedelta | None = None, + algorithm: str = "HS256") -> AccessTokenBase: + """ + 生成访问令牌,默认有效期 3 小时。 + + :param data: 需要放进 JWT Payload 的字段。 + :param expires_delta: 过期时间, 缺省时为 3 小时。 + :param algorithm: JWT 密钥强度,缺省时为 HS256 + + :return: 包含密钥本身和过期时间的 `AccessTokenBase` + """ + + access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta) + return AccessTokenBase( + access_token=access_token, + access_expires=expire_at, + ) + + +# 刷新令牌 +def create_refresh_token(data: dict, expires_delta: timedelta | None = None, + algorithm: str = "HS256") -> RefreshTokenBase: + """ + 生成刷新令牌,默认有效期 30 天。 + + :param data: 需要放进 JWT Payload 的字段。 + :param expires_delta: 过期时间, 缺省时为 30 天。 + :param algorithm: JWT 密钥强度,缺省时为 HS256 + + :return: 包含密钥本身和过期时间的 `RefreshTokenBase` + """ + + refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta) + return RefreshTokenBase( + refresh_token=refresh_token, + refresh_expires=expire_at, + ) + + +# ==================== 下载令牌 ==================== + +DOWNLOAD_TOKEN_TTL = timedelta(hours=1) +"""下载令牌有效期""" + + +def create_download_token(file_id: UUID, owner_id: UUID) -> str: + """ + 创建文件下载令牌。 + + :param file_id: 文件 ID + :param owner_id: 文件所有者 ID + :return: JWT 令牌字符串 + """ + payload = { + "file_id": str(file_id), + "owner_id": str(owner_id), + "exp": datetime.now(timezone.utc) + DOWNLOAD_TOKEN_TTL, + "type": "download", + } + return jwt.encode(payload, SECRET_KEY, algorithm="HS256") \ No newline at end of file diff --git a/utils/password/pwd.py b/utils/password/pwd.py index 282bb30..76397e1 100644 --- a/utils/password/pwd.py +++ b/utils/password/pwd.py @@ -8,7 +8,7 @@ import pyotp from itsdangerous import URLSafeTimedSerializer from pydantic import BaseModel, Field -from utils.JWT.JWT import SECRET_KEY +from utils.JWT import SECRET_KEY from utils.conf import appmeta _ph = PasswordHasher()