修复登录后仍旧提示需要登录
This commit is contained in:
@@ -1,17 +1,15 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from model.user import UserTypeEnum
|
from model.user import UserTypeEnum
|
||||||
from .user import get_current_user
|
from .user import get_current_user
|
||||||
from pkg import utils
|
from pkg import utils
|
||||||
from model import User
|
from model import User
|
||||||
from model import database
|
from middleware.dependencies import SessionDep
|
||||||
|
|
||||||
# 验证是否为管理员
|
# 验证是否为管理员
|
||||||
async def is_admin(
|
async def is_admin(
|
||||||
token: Annotated[str, Depends(get_current_user)],
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
|
||||||
) -> User:
|
) -> User:
|
||||||
'''
|
'''
|
||||||
验证是否为管理员。
|
验证是否为管理员。
|
||||||
@@ -19,16 +17,14 @@ async def is_admin(
|
|||||||
使用方法:
|
使用方法:
|
||||||
>>> APIRouter(dependencies=[Depends(is_admin)])
|
>>> APIRouter(dependencies=[Depends(is_admin)])
|
||||||
'''
|
'''
|
||||||
|
|
||||||
user = await get_current_user(token, session)
|
|
||||||
if user.role == UserTypeEnum.normal_user:
|
if user.role == UserTypeEnum.normal_user:
|
||||||
utils.raise_forbidden("Admin access required")
|
utils.raise_forbidden("Admin access required")
|
||||||
else:
|
else:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def is_super_admin(
|
async def is_super_admin(
|
||||||
token: Annotated[str, Depends(is_admin)],
|
user: Annotated[User, Depends(is_admin)],
|
||||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
|
||||||
) -> User:
|
) -> User:
|
||||||
'''
|
'''
|
||||||
验证是否为超级管理员。
|
验证是否为超级管理员。
|
||||||
@@ -37,7 +33,6 @@ async def is_super_admin(
|
|||||||
>>> APIRouter(dependencies=[Depends(is_super_admin)])
|
>>> APIRouter(dependencies=[Depends(is_super_admin)])
|
||||||
'''
|
'''
|
||||||
|
|
||||||
user = await get_current_user(token, session)
|
|
||||||
if user.role != UserTypeEnum.super_admin:
|
if user.role != UserTypeEnum.super_admin:
|
||||||
utils.raise_forbidden("Super admin access required")
|
utils.raise_forbidden("Super admin access required")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ from typing import Annotated
|
|||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from jwt import InvalidTokenError
|
from jwt import InvalidTokenError
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from loguru import logger as l
|
||||||
|
|
||||||
import JWT
|
import JWT
|
||||||
from model import User
|
from model import User
|
||||||
from model.database import Database
|
|
||||||
from pkg import utils
|
from pkg import utils
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
||||||
session: Annotated[AsyncSession, Depends(Database.get_session)],
|
session: SessionDep,
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
验证用户身份并返回当前用户信息。
|
验证用户身份并返回当前用户信息。
|
||||||
@@ -20,9 +20,13 @@ async def get_current_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, await JWT.get_secret_key(), algorithms=[JWT.ALGORITHM])
|
payload = jwt.decode(token, await JWT.get_secret_key(), algorithms=[JWT.ALGORITHM])
|
||||||
username = payload.get("sub")
|
email = payload.get("sub")
|
||||||
stored_account = await User.get(session, User.email == username)
|
stored_account = await User.get(session, User.email == email)
|
||||||
if username is None or stored_account.email != username:
|
if stored_account is None:
|
||||||
|
l.warning("Account not found")
|
||||||
|
utils.raise_unauthorized("Login required")
|
||||||
|
elif stored_account.email != email:
|
||||||
|
l.warning("Email mismatch")
|
||||||
utils.raise_unauthorized("Login required")
|
utils.raise_unauthorized("Login required")
|
||||||
return stored_account
|
return stored_account
|
||||||
except InvalidTokenError:
|
except InvalidTokenError:
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ async def list_firmwares(
|
|||||||
if conditions:
|
if conditions:
|
||||||
results = await Firmware.get(session, and_(*conditions), fetch_mode="all")
|
results = await Firmware.get(session, and_(*conditions), fetch_mode="all")
|
||||||
else:
|
else:
|
||||||
results = await Firmware.get(session, fetch_mode="all")
|
results = await Firmware.get(session, None, fetch_mode="all")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return []
|
return []
|
||||||
|
|||||||
Reference in New Issue
Block a user