diff --git a/middleware/admin.py b/middleware/admin.py index df04c8b..125efd5 100644 --- a/middleware/admin.py +++ b/middleware/admin.py @@ -1,17 +1,15 @@ from typing import Annotated from fastapi import Depends -from sqlmodel.ext.asyncio.session import AsyncSession from model.user import UserTypeEnum from .user import get_current_user from pkg import utils from model import User -from model import database +from middleware.dependencies import SessionDep # 验证是否为管理员 async def is_admin( - token: Annotated[str, Depends(get_current_user)], - session: Annotated[AsyncSession, Depends(database.Database.get_session)], + user: Annotated[User, Depends(get_current_user)], ) -> User: ''' 验证是否为管理员。 @@ -19,16 +17,14 @@ async def is_admin( 使用方法: >>> APIRouter(dependencies=[Depends(is_admin)]) ''' - - user = await get_current_user(token, session) + if user.role == UserTypeEnum.normal_user: utils.raise_forbidden("Admin access required") else: return user async def is_super_admin( - token: Annotated[str, Depends(is_admin)], - session: Annotated[AsyncSession, Depends(database.Database.get_session)], + user: Annotated[User, Depends(is_admin)], ) -> User: ''' 验证是否为超级管理员。 @@ -37,7 +33,6 @@ async def is_super_admin( >>> APIRouter(dependencies=[Depends(is_super_admin)]) ''' - user = await get_current_user(token, session) if user.role != UserTypeEnum.super_admin: utils.raise_forbidden("Super admin access required") else: diff --git a/middleware/user.py b/middleware/user.py index 3ed2dc5..1c962e8 100644 --- a/middleware/user.py +++ b/middleware/user.py @@ -3,16 +3,16 @@ from typing import Annotated import jwt from fastapi import Depends from jwt import InvalidTokenError -from sqlmodel.ext.asyncio.session import AsyncSession +from loguru import logger as l import JWT from model import User -from model.database import Database from pkg import utils +from middleware.dependencies import SessionDep async def get_current_user( token: Annotated[str, Depends(JWT.oauth2_scheme)], - session: Annotated[AsyncSession, Depends(Database.get_session)], + session: SessionDep, ) -> User: """ 验证用户身份并返回当前用户信息。 @@ -20,9 +20,13 @@ async def get_current_user( try: payload = jwt.decode(token, await JWT.get_secret_key(), algorithms=[JWT.ALGORITHM]) - username = payload.get("sub") - stored_account = await User.get(session, User.email == username) - if username is None or stored_account.email != username: + email = payload.get("sub") + stored_account = await User.get(session, User.email == email) + if stored_account is None: + l.warning("Account not found") + utils.raise_unauthorized("Login required") + elif stored_account.email != email: + l.warning("Email mismatch") utils.raise_unauthorized("Login required") return stored_account except InvalidTokenError: diff --git a/services/admin.py b/services/admin.py index 0618377..3ad143a 100644 --- a/services/admin.py +++ b/services/admin.py @@ -180,7 +180,7 @@ async def list_firmwares( if conditions: results = await Firmware.get(session, and_(*conditions), fetch_mode="all") else: - results = await Firmware.get(session, fetch_mode="all") + results = await Firmware.get(session, None, fetch_mode="all") if not results: return []