Refactor JWT utilities and download token logic

Merged JWT utility functions into utils/JWT/__init__.py and removed utils/JWT/JWT.py. Refactored download token creation and verification to use new functions, replacing DownloadTokenManager with create_download_token and verify_download_token. Updated imports across the codebase to reflect the new JWT utility structure. Improved download file logic to use physical file storage path and added a dedicated response model for download tokens.
This commit is contained in:
2025-12-26 17:47:51 +08:00
parent 54784eea3b
commit 3088a9d548
8 changed files with 168 additions and 151 deletions

View File

@@ -7,7 +7,7 @@ from utils.http.http_exceptions import raise_internal_error
from utils.lifespan import lifespan from utils.lifespan import lifespan
from models.database import init_db from models.database import init_db
from models.migration import migration from models.migration import migration
from utils.JWT import JWT from utils import JWT
from routers import router from routers import router
from loguru import logger as l from loguru import logger as l

View File

@@ -5,7 +5,7 @@ from fastapi import Depends
import jwt import jwt
from models.user import User from models.user import User
from utils.JWT import JWT from utils import JWT
from .dependencies import SessionDep from .dependencies import SessionDep
from utils import http_exceptions from utils import http_exceptions
@@ -48,3 +48,19 @@ async def admin_required(
if group.admin: if group.admin:
return user return user
raise http_exceptions.raise_forbidden("Admin Required") raise http_exceptions.raise_forbidden("Admin Required")
def verify_download_token(token: str) -> tuple[UUID, UUID] | None:
"""
验证下载令牌并返回 (file_id, owner_id)。
:param token: JWT 令牌字符串
:return: (file_id, owner_id) 或 None验证失败
"""
try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
if payload.get("type") != "download":
return None
return UUID(payload["file_id"]), UUID(payload["owner_id"])
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None

View File

@@ -8,16 +8,15 @@
- /file/upload - 上传相关操作 - /file/upload - 上传相关操作
- /file/download - 下载相关操作 - /file/download - 下载相关操作
""" """
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
import jwt
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from loguru import logger as l from loguru import logger as l
from middleware.auth import auth_required from middleware.auth import auth_required, verify_download_token
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ( from models import (
CreateFileRequest, CreateFileRequest,
@@ -34,43 +33,20 @@ from models import (
User, User,
) )
from service.storage import LocalStorageService from service.storage import LocalStorageService
from utils.JWT import SECRET_KEY from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
from utils import http_exceptions from utils import http_exceptions
# ==================== 下载令牌管理 ==================== # DTO
class DownloadTokenManager: class DownloadTokenModel(ResponseBase):
"""下载令牌管理器JWT 无状态)""" """下载Token响应模型"""
_ttl: timedelta = timedelta(hours=1) access_token: str
"""JWT 令牌"""
@classmethod
def create(cls, file_id: UUID, owner_id: int) -> str:
"""创建下载令牌"""
payload = {
"file_id": str(file_id),
"owner_id": owner_id,
"exp": datetime.now(timezone.utc) + cls._ttl,
"type": "download",
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
@classmethod
def verify(cls, token: str) -> tuple[UUID, int] | None:
"""
验证令牌并返回 (file_id, owner_id)
:return: (file_id, owner_id) 或 None验证失败
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
if payload.get("type") != "download":
return None
return UUID(payload["file_id"]), payload["owner_id"]
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None
expires_in: int
"""过期时间(秒)"""
# ==================== 主路由 ==================== # ==================== 主路由 ====================
@@ -367,11 +343,11 @@ _download_router = APIRouter(prefix="/download")
summary='创建下载令牌', summary='创建下载令牌',
description='为指定文件创建下载令牌JWT有效期1小时。', description='为指定文件创建下载令牌JWT有效期1小时。',
) )
async def create_download_token( async def create_download_token_endpoint(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
file_id: UUID, file_id: UUID,
) -> ResponseBase: ) -> DownloadTokenModel:
""" """
创建下载令牌端点 创建下载令牌端点
@@ -384,11 +360,11 @@ async def create_download_token(
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")
token = DownloadTokenManager.create(file_id, user.id) token = create_download_token(file_id, user.id)
l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}") l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}")
return ResponseBase(data={"token": token, "expires_in": 3600}) return DownloadTokenModel(access_token=token, expires_in=int(DOWNLOAD_TOKEN_TTL.total_seconds()))
@_download_router.get( @_download_router.get(
@@ -406,7 +382,7 @@ async def download_file(
验证 JWT 令牌后返回文件内容。 验证 JWT 令牌后返回文件内容。
""" """
# 验证令牌 # 验证令牌
result = DownloadTokenManager.verify(token) result = verify_download_token(token)
if not result: if not result:
raise HTTPException(status_code=401, detail="下载令牌无效或已过期") raise HTTPException(status_code=401, detail="下载令牌无效或已过期")
@@ -420,9 +396,13 @@ async def download_file(
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")
if not file_obj.source_name: # 预加载 physical_file 关系以获取存储路径
physical_file = await file_obj.awaitable_attrs.physical_file
if not physical_file or not physical_file.storage_path:
raise HTTPException(status_code=500, detail="文件存储路径丢失") raise HTTPException(status_code=500, detail="文件存储路径丢失")
storage_path = physical_file.storage_path
# 获取策略 # 获取策略
policy = await Policy.get(session, Policy.id == file_obj.policy_id) policy = await Policy.get(session, Policy.id == file_obj.policy_id)
if not policy: if not policy:
@@ -430,11 +410,11 @@ async def download_file(
if policy.type == PolicyType.LOCAL: if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy) storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(file_obj.source_name): if not await storage_service.file_exists(storage_path):
raise HTTPException(status_code=404, detail="物理文件不存在") raise HTTPException(status_code=404, detail="物理文件不存在")
return FileResponse( return FileResponse(
path=file_obj.source_name, path=storage_path,
filename=file_obj.name, filename=file_obj.name,
media_type="application/octet-stream", media_type="application/octet-stream",
) )

View File

@@ -13,7 +13,7 @@ import models
import service import service
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from utils.JWT.JWT import SECRET_KEY from utils.JWT import SECRET_KEY
from utils import Password, http_exceptions from utils import Password, http_exceptions
user_router = APIRouter( user_router = APIRouter(

View File

@@ -5,7 +5,7 @@ from loguru import logger
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import LoginRequest, TokenResponse, User from models import LoginRequest, TokenResponse, User
from utils import http_exceptions from utils import http_exceptions
from utils.JWT.JWT import create_access_token, create_refresh_token from utils.JWT import create_access_token, create_refresh_token
from utils.password.pwd import Password, PasswordStatus from utils.password.pwd import Password, PasswordStatus

View File

@@ -1,94 +0,0 @@
from datetime import datetime, timedelta, timezone
import jwt
from fastapi.security import OAuth2PasswordBearer
from models import AccessTokenBase, RefreshTokenBase
oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌',
description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
tokenUrl="/api/v1/user/session",
refreshUrl="/api/v1/user/session/refresh",
)
SECRET_KEY = ''
async def load_secret_key() -> None:
"""
从数据库读取 JWT 的密钥。
"""
# 延迟导入以避免循环依赖
from models.database import get_session
from models.setting import Setting
global SECRET_KEY
async for session in get_session():
setting = await Setting.get(
session,
(Setting.type == "auth") & (Setting.name == "secret_key")
)
if setting:
SECRET_KEY = setting.value
def build_token_payload(
data: dict,
is_refresh: bool,
algorithm: str,
expires_delta: timedelta | None = None,
) -> tuple[str, datetime]:
"""构建令牌"""
to_encode = data.copy()
if is_refresh:
to_encode.update({"token_type": "refresh"})
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
elif is_refresh:
expire = datetime.now(timezone.utc) + timedelta(days=30)
else:
expire = datetime.now(timezone.utc) + timedelta(hours=3)
to_encode.update({
"iat": int(datetime.now(timezone.utc).timestamp()),
"exp": int(expire.timestamp())
})
return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm), expire
# 访问令牌
def create_access_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> AccessTokenBase:
"""
生成访问令牌,默认有效期 3 小时。
:param data: 需要放进 JWT Payload 的字段。
:param expires_delta: 过期时间, 缺省时为 3 小时。
:param algorithm: JWT 密钥强度,缺省时为 HS256
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
"""
access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta)
return AccessTokenBase(
access_token=access_token,
access_expires=expire_at,
)
# 刷新令牌
def create_refresh_token(data: dict, expires_delta: timedelta | None = None, algorithm: str = "HS256") -> RefreshTokenBase:
"""
生成刷新令牌,默认有效期 30 天。
:param data: 需要放进 JWT Payload 的字段。
:param expires_delta: 过期时间, 缺省时为 30 天。
:param algorithm: JWT 密钥强度,缺省时为 HS256
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
"""
refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta)
return RefreshTokenBase(
refresh_token=refresh_token,
refresh_expires=expire_at,
)

View File

@@ -1,8 +1,123 @@
from . import JWT from datetime import datetime, timedelta, timezone
from .JWT import ( from uuid import UUID
create_access_token,
create_refresh_token, import jwt
load_secret_key, from fastapi.security import OAuth2PasswordBearer
oauth2_scheme,
SECRET_KEY, from models import AccessTokenBase, RefreshTokenBase
oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌',
description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
tokenUrl="/api/v1/user/session",
refreshUrl="/api/v1/user/session/refresh",
) )
SECRET_KEY = ''
async def load_secret_key() -> None:
"""
从数据库读取 JWT 的密钥。
"""
# 延迟导入以避免循环依赖
from models.database import get_session
from models.setting import Setting
global SECRET_KEY
async for session in get_session():
setting = await Setting.get(
session,
(Setting.type == "auth") & (Setting.name == "secret_key")
)
if setting:
SECRET_KEY = setting.value
def build_token_payload(
data: dict,
is_refresh: bool,
algorithm: str,
expires_delta: timedelta | None = None,
) -> tuple[str, datetime]:
"""构建令牌"""
to_encode = data.copy()
if is_refresh:
to_encode.update({"token_type": "refresh"})
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
elif is_refresh:
expire = datetime.now(timezone.utc) + timedelta(days=30)
else:
expire = datetime.now(timezone.utc) + timedelta(hours=3)
to_encode.update({
"iat": int(datetime.now(timezone.utc).timestamp()),
"exp": int(expire.timestamp())
})
return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm), expire
# 访问令牌
def create_access_token(data: dict, expires_delta: timedelta | None = None,
algorithm: str = "HS256") -> AccessTokenBase:
"""
生成访问令牌,默认有效期 3 小时。
:param data: 需要放进 JWT Payload 的字段。
:param expires_delta: 过期时间, 缺省时为 3 小时。
:param algorithm: JWT 密钥强度,缺省时为 HS256
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
"""
access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta)
return AccessTokenBase(
access_token=access_token,
access_expires=expire_at,
)
# 刷新令牌
def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
algorithm: str = "HS256") -> RefreshTokenBase:
"""
生成刷新令牌,默认有效期 30 天。
:param data: 需要放进 JWT Payload 的字段。
:param expires_delta: 过期时间, 缺省时为 30 天。
:param algorithm: JWT 密钥强度,缺省时为 HS256
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
"""
refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta)
return RefreshTokenBase(
refresh_token=refresh_token,
refresh_expires=expire_at,
)
# ==================== 下载令牌 ====================
DOWNLOAD_TOKEN_TTL = timedelta(hours=1)
"""下载令牌有效期"""
def create_download_token(file_id: UUID, owner_id: UUID) -> str:
"""
创建文件下载令牌。
:param file_id: 文件 ID
:param owner_id: 文件所有者 ID
:return: JWT 令牌字符串
"""
payload = {
"file_id": str(file_id),
"owner_id": str(owner_id),
"exp": datetime.now(timezone.utc) + DOWNLOAD_TOKEN_TTL,
"type": "download",
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")

View File

@@ -8,7 +8,7 @@ import pyotp
from itsdangerous import URLSafeTimedSerializer from itsdangerous import URLSafeTimedSerializer
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from utils.JWT.JWT import SECRET_KEY from utils.JWT import SECRET_KEY
from utils.conf import appmeta from utils.conf import appmeta
_ph = PasswordHasher() _ph = PasswordHasher()