diff --git a/.idea/.gitignore b/.idea/.gitignore
deleted file mode 100644
index 35410ca..0000000
--- a/.idea/.gitignore
+++ /dev/null
@@ -1,8 +0,0 @@
-# 默认忽略的文件
-/shelf/
-/workspace.xml
-# 基于编辑器的 HTTP 客户端请求
-/httpRequests/
-# Datasource local storage ignored files
-/dataSources/
-/dataSources.local.xml
diff --git a/.idea/.name b/.idea/.name
deleted file mode 100644
index 233deed..0000000
--- a/.idea/.name
+++ /dev/null
@@ -1 +0,0 @@
-password.py
\ No newline at end of file
diff --git a/.idea/Findreve.iml b/.idea/Findreve.iml
deleted file mode 100644
index 916c239..0000000
--- a/.idea/Findreve.iml
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/copilot.data.migration.agent.xml b/.idea/copilot.data.migration.agent.xml
deleted file mode 100644
index 4ea72a9..0000000
--- a/.idea/copilot.data.migration.agent.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/copilot.data.migration.ask.xml b/.idea/copilot.data.migration.ask.xml
deleted file mode 100644
index 7ef04e2..0000000
--- a/.idea/copilot.data.migration.ask.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/copilot.data.migration.ask2agent.xml b/.idea/copilot.data.migration.ask2agent.xml
deleted file mode 100644
index 1f2ea11..0000000
--- a/.idea/copilot.data.migration.ask2agent.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/copilot.data.migration.edit.xml b/.idea/copilot.data.migration.edit.xml
deleted file mode 100644
index 8648f94..0000000
--- a/.idea/copilot.data.migration.edit.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index 105ce2d..0000000
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml
deleted file mode 100644
index d508618..0000000
--- a/.idea/material_theme_project_new.xml
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index 82554e2..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,7 +0,0 @@
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index cd62433..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 35eb1dd..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/AGENTS.md b/AGENTS.md
index 4a23f1a..f087fea 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -34,7 +34,7 @@ flowchart TD
```
## 编码风格与命名
-- 统一使用 Python 3.8+、四空格缩进,并在公共接口添加类型注解;仅对复杂逻辑补充文档字符串。
+- 统一使用 Python 3.13+、四空格缩进,并在公共接口添加类型注解;仅对复杂逻辑补充文档字符串。
- 函数使用 `snake_case`,数据模型使用 `PascalCase`,配置与日志归于 `pkg/`(`pkg/logger.py` 封装`loguru`)。
- 所有代码、注释、提交信息与评审讨论均使用简体中文。
diff --git a/README.md b/README.md
index 1fba819..adbaa69 100644
--- a/README.md
+++ b/README.md
@@ -61,18 +61,16 @@ chmod +x ./findreve
启动后, Findreve 会在程序的根目录自动创建 SQLite 数据库,并在
终端显示管理员账号密码。请注意,账号密码仅显示一次,请注意保管。
-账号默认为 `admin@yuxiaoqiu.cn`
+账号默认为 `admin@yxqi.cn`
Upon launch, Findreve will create a SQLite database in the project's root directory and
display the administrator's account and password in the console.
## 构建
-> 当前版本的 Findreve Core 无法正常工作,因为我们正在尝试[重构数据库组件以使用ORM](https://github.com/Findreve/Findreve/issues/8)
+你需要安装Python 3.13 以上的版本。然后,clone 本仓库到您的服务器并解压,然后安装下面的依赖:
-你需要安装Python 3.8 以上的版本。然后,clone 本仓库到您的服务器并解压,然后安装下面的依赖:
-
-You need to have Python 3.8 or higher installed on your server. Then, clone this repository
+You need to have Python 3.13 or higher installed on your server. Then, clone this repository
to your server and install the required dependencies:
> `pip install -r requirements.txt`
diff --git a/app.py b/app.py
index add63b9..614920a 100644
--- a/app.py
+++ b/app.py
@@ -1,6 +1,6 @@
from fastapi import FastAPI
from fastapi.responses import FileResponse
-from fastapi import Request, HTTPException
+from fastapi import Request
from contextlib import asynccontextmanager
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
@@ -11,6 +11,7 @@ from routes import (session, admin, object)
from model.database import Database
import os
import pkg.conf
+from pkg import utils
from loguru import logger
@@ -54,21 +55,21 @@ app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@app.get("/")
-def read_root():
+async def frontend_index():
if not os.path.exists("dist/index.html"):
- raise HTTPException(status_code=404)
+ utils.raise_not_found("Index not found")
return FileResponse("dist/index.html")
# 回退路由
@app.get("/{path:path}")
-async def serve_spa(request: Request, path: str):
+async def frontend_path(path: str):
if not os.path.exists("dist/index.html"):
- raise HTTPException(status_code=404)
-
+ utils.raise_not_found("Index not found, please build frontend first.")
+
# 排除API路由
if path.startswith("api/"):
- raise HTTPException(status_code=404)
-
+ utils.raise_not_found("API route not found")
+
# 检查是否是静态资源请求
if path.startswith("assets/") and os.path.exists(f"dist/{path}"):
return FileResponse(f"dist/{path}")
diff --git a/middleware/admin.py b/middleware/admin.py
index 7c344ac..df04c8b 100644
--- a/middleware/admin.py
+++ b/middleware/admin.py
@@ -1,19 +1,18 @@
-from typing import Annotated, Literal
+from typing import Annotated
from fastapi import Depends
-from fastapi import HTTPException
-import JWT
-import jwt
-from jwt import InvalidTokenError
-from model import database
from sqlmodel.ext.asyncio.session import AsyncSession
-from model import User
+
+from model.user import UserTypeEnum
from .user import get_current_user
+from pkg import utils
+from model import User
+from model import database
# 验证是否为管理员
async def is_admin(
token: Annotated[str, Depends(get_current_user)],
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
-) -> Literal[True]:
+) -> User:
'''
验证是否为管理员。
@@ -21,14 +20,25 @@ async def is_admin(
>>> APIRouter(dependencies=[Depends(is_admin)])
'''
- not_admin_exception = HTTPException(
- status_code=403,
- detail="Admin access required",
- headers={"WWW-Authenticate": "Bearer"},
- )
+ user = await get_current_user(token, session)
+ if user.role == UserTypeEnum.normal_user:
+ utils.raise_forbidden("Admin access required")
+ else:
+ return user
+
+async def is_super_admin(
+ token: Annotated[str, Depends(is_admin)],
+ session: Annotated[AsyncSession, Depends(database.Database.get_session)],
+) -> User:
+ '''
+ 验证是否为超级管理员。
+
+ 使用方法:
+ >>> APIRouter(dependencies=[Depends(is_super_admin)])
+ '''
user = await get_current_user(token, session)
- if not user.is_admin:
- raise not_admin_exception
+ if user.role != UserTypeEnum.super_admin:
+ utils.raise_forbidden("Super admin access required")
else:
- return True
\ No newline at end of file
+ return user
\ No newline at end of file
diff --git a/middleware/user.py b/middleware/user.py
index 861c068..3ed2dc5 100644
--- a/middleware/user.py
+++ b/middleware/user.py
@@ -2,16 +2,14 @@ from typing import Annotated
import jwt
from fastapi import Depends
-from fastapi import HTTPException
from jwt import InvalidTokenError
from sqlmodel.ext.asyncio.session import AsyncSession
import JWT
from model import User
from model.database import Database
+from pkg import utils
-
-# 验证是否为管理员
async def get_current_user(
token: Annotated[str, Depends(JWT.oauth2_scheme)],
session: Annotated[AsyncSession, Depends(Database.get_session)],
@@ -19,18 +17,13 @@ async def get_current_user(
"""
验证用户身份并返回当前用户信息。
"""
- not_login_exception = HTTPException(
- status_code=401,
- detail="Login required",
- headers={"WWW-Authenticate": "Bearer"},
- )
try:
payload = jwt.decode(token, await JWT.get_secret_key(), algorithms=[JWT.ALGORITHM])
username = payload.get("sub")
stored_account = await User.get(session, User.email == username)
if username is None or stored_account.email != username:
- raise not_login_exception
+ utils.raise_unauthorized("Login required")
return stored_account
except InvalidTokenError:
- raise not_login_exception
\ No newline at end of file
+ utils.raise_unauthorized("Login required")
\ No newline at end of file
diff --git a/model/database.py b/model/database.py
index 96f8b6d..ec22950 100644
--- a/model/database.py
+++ b/model/database.py
@@ -3,7 +3,6 @@ from contextlib import asynccontextmanager
from typing import AsyncGenerator
import os
from dotenv import load_dotenv
-
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
diff --git a/model/item.py b/model/item.py
index f09d1af..fe4e061 100644
--- a/model/item.py
+++ b/model/item.py
@@ -2,7 +2,6 @@ from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Optional
from uuid import UUID
-
from sqlmodel import Field, Relationship
from .base import SQLModelBase, UUIDTableBase
diff --git a/model/response.py b/model/response.py
index f4ea8e3..48aee67 100644
--- a/model/response.py
+++ b/model/response.py
@@ -2,7 +2,16 @@ from pydantic import BaseModel
from model.base import SQLModelBase
+"""
+[TODO] 弃用,改成 ResponseBase:
+class ResponseBase(BaseModel):
+ code: int = 0
+ msg: str = ""
+ request_id: UUID
+
+再根据需要继承
+"""
class DefaultResponse(BaseModel):
code: int = 0
data: dict | list | bool | SQLModelBase | None = None
diff --git a/model/user.py b/model/user.py
index dd8f60c..95beb17 100644
--- a/model/user.py
+++ b/model/user.py
@@ -23,8 +23,8 @@ class User(UserBase, UUIDTableBase, table=True):
email: EmailStr = Field(index=True, unique=True)
"""邮箱"""
- username: str = Field(index=True, unique=True)
- """用户名"""
+ nickname: str
+ """昵称"""
password: str
"""Argon2算法哈希后的密码"""
diff --git a/pkg/password.py b/pkg/password.py
index 4f302e2..e4bdc9e 100644
--- a/pkg/password.py
+++ b/pkg/password.py
@@ -2,18 +2,32 @@ import secrets
from loguru import logger
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
+from enum import StrEnum
_ph = PasswordHasher()
-class Password():
+class PasswordStatus(StrEnum):
+ """密码校验状态枚举"""
+
+ VALID = "valid"
+ """密码校验通过"""
+
+ INVALID = "invalid"
+ """密码校验失败"""
+
+ EXPIRED = "expired"
+ """密码哈希已过时,建议重新哈希"""
+
+class Password:
+ """密码处理工具类,包含密码生成、哈希和验证功能"""
@staticmethod
def generate(
- length: int = 8
+ length: int = 8
) -> str:
"""
生成指定长度的随机密码。
-
+
:param length: 密码长度
:type length: int
:return: 随机密码
@@ -23,7 +37,7 @@ class Password():
@staticmethod
def hash(
- password: str
+ password: str
) -> str:
"""
使用 Argon2 生成密码的哈希值。
@@ -37,38 +51,29 @@ class Password():
@staticmethod
def verify(
- stored_password: str,
- provided_password: str,
- debug: bool = False
- ) -> bool:
+ hash: str,
+ password: str
+ ) -> PasswordStatus:
"""
验证存储的 Argon2 哈希值与用户提供的密码是否匹配。
- :param stored_password: 数据库中存储的 Argon2 哈希字符串
- :param provided_password: 用户本次提供的密码
- :param debug: 是否输出调试信息
+ :param hash: 数据库中存储的 Argon2 哈希字符串
+ :param password: 用户本次提供的密码
:return: 如果密码匹配返回 True, 否则返回 False
"""
- if debug:
- logger.info(f"验证密码: (哈希) {stored_password}")
-
try:
# verify 函数会自动解析 stored_password 中的盐和参数
- _ph.verify(stored_password, provided_password)
+ _ph.verify(hash, password)
# 检查哈希参数是否已过时。如果返回True,
# 意味着你应该使用新的参数重新哈希密码并更新存储。
# 这是一个很好的实践,可以随着时间推移增强安全性。
- if _ph.check_needs_rehash(stored_password):
+ if _ph.check_needs_rehash(hash):
logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
+ return PasswordStatus.EXPIRED
- return True
+ return PasswordStatus.VALID
except VerifyMismatchError:
# 这是预期的异常,当密码不匹配时触发。
- if debug:
- logger.info("密码不匹配")
- return False
- except Exception as e:
- # 捕获其他可能的错误
- logger.error(f"密码验证过程中发生未知错误: {e}")
- return False
\ No newline at end of file
+ return PasswordStatus.INVALID
+ # 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
diff --git a/pkg/utils.py b/pkg/utils.py
index 47a5571..8d42137 100644
--- a/pkg/utils.py
+++ b/pkg/utils.py
@@ -16,7 +16,7 @@ from starlette.status import (
HTTP_504_GATEWAY_TIMEOUT,
)
-# --- Request and Response Helpers ---
+# --- 400 ---
def ensure_request_param(to_check: Any, detail: str) -> None:
"""
@@ -30,21 +30,21 @@ def raise_bad_request(detail: str = '') -> NoReturn:
"""Raises an HTTP 400 Bad Request exception."""
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
-def raise_not_found(detail: str) -> NoReturn:
- """Raises an HTTP 404 Not Found exception."""
- raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
+def raise_unauthorized(detail: str) -> NoReturn:
+ """Raises an HTTP 401 Unauthorized exception."""
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
-def raise_internal_error(detail: str = "服务器出现故障,请稍后再试或联系管理员") -> NoReturn:
- """Raises an HTTP 500 Internal Server Error exception."""
- raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
+def raise_insufficient_quota(detail: str = "积分不足,请充值") -> NoReturn:
+ """Raises an HTTP 402 Payment Required exception."""
+ raise HTTPException(status_code=HTTP_402_PAYMENT_REQUIRED, detail=detail)
def raise_forbidden(detail: str) -> NoReturn:
"""Raises an HTTP 403 Forbidden exception."""
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
-def raise_unauthorized(detail: str) -> NoReturn:
- """Raises an HTTP 401 Unauthorized exception."""
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
+def raise_not_found(detail: str) -> NoReturn:
+ """Raises an HTTP 404 Not Found exception."""
+ raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
def raise_conflict(detail: str) -> NoReturn:
"""Raises an HTTP 409 Conflict exception."""
@@ -54,6 +54,12 @@ def raise_too_many_requests(detail: str) -> NoReturn:
"""Raises an HTTP 429 Too Many Requests exception."""
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=detail)
+# --- 500 ---
+
+def raise_internal_error(detail: str = "服务器出现故障,请稍后再试或联系管理员") -> NoReturn:
+ """Raises an HTTP 500 Internal Server Error exception."""
+ raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
+
def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
"""Raises an HTTP 501 Not Implemented exception."""
raise HTTPException(status_code=HTTP_501_NOT_IMPLEMENTED, detail=detail)
@@ -65,8 +71,3 @@ def raise_service_unavailable(detail: str) -> NoReturn:
def raise_gateway_timeout(detail: str) -> NoReturn:
"""Raises an HTTP 504 Gateway Timeout exception."""
raise HTTPException(status_code=HTTP_504_GATEWAY_TIMEOUT, detail=detail)
-
-def raise_insufficient_quota(detail: str = "积分不足,请充值") -> NoReturn:
- raise HTTPException(status_code=HTTP_402_PAYMENT_REQUIRED, detail=detail)
-
-# --- End of Request and Response Helpers ---
diff --git a/routes/session.py b/routes/session.py
index 63bfb70..2db1e6c 100644
--- a/routes/session.py
+++ b/routes/session.py
@@ -1,13 +1,13 @@
# 导入库
from typing import Annotated
-
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel.ext.asyncio.session import AsyncSession
from model import database
from model.response import TokenResponse
from services import session as session_service
+from pkg import utils
Router = APIRouter(tags=["令牌 session"])
@@ -29,10 +29,6 @@ async def login_for_access_token(
password=form_data.password,
)
if not token_response:
- raise HTTPException(
- status_code=401,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Bearer"},
- )
+ utils.raise_unauthorized("Incorrect username or password")
return token_response
diff --git a/services/admin.py b/services/admin.py
index c71c9bc..4115964 100644
--- a/services/admin.py
+++ b/services/admin.py
@@ -4,11 +4,11 @@
from typing import Iterable, List
-from fastapi import HTTPException
from sqlmodel.ext.asyncio.session import AsyncSession
from model import Setting
from model import SettingResponse
+from pkg import utils
async def fetch_settings(
@@ -25,7 +25,7 @@ async def fetch_settings(
if setting:
data.append(SettingResponse.model_validate(setting))
else:
- raise HTTPException(404, detail="Setting not found")
+ utils.raise_not_found("Setting not found")
else:
settings: Iterable[Setting] | None = await Setting.get(session, fetch_mode="all")
if settings:
@@ -44,7 +44,7 @@ async def update_setting_value(
"""
setting = await Setting.get(session, Setting.name == name)
if not setting:
- raise HTTPException(404, detail="Setting not found")
+ utils.raise_not_found("Setting not found")
setting.value = value
await Setting.save(session)
diff --git a/services/object.py b/services/object.py
index 1ff7bdb..935fbec 100644
--- a/services/object.py
+++ b/services/object.py
@@ -5,15 +5,14 @@
from typing import List
from uuid import UUID
-from fastapi import HTTPException
+from fastapi import status
from loguru import logger
from sqlmodel.ext.asyncio.session import AsyncSession
from model import Item, ItemDataResponse, Setting, User
from model.item import ItemDataUpdateRequest, ItemTypeEnum
from pkg.sender import ServerChatBot, WeChatBot
-from pkg.utils import raise_bad_request, raise_internal_error, raise_not_found
-from starlette.status import HTTP_204_NO_CONTENT
+from pkg import utils
async def list_items(
@@ -72,7 +71,7 @@ async def create_item(
await Item.add(session, Item.model_validate(request_dict))
except Exception as exc: # noqa: BLE001
logger.error(f"Failed to add item: {exc}")
- raise HTTPException(status_code=500, detail=str(exc)) from exc
+ utils.raise_internal_error(str(exc))
async def update_item(
@@ -86,7 +85,7 @@ async def update_item(
"""
obj = await Item.get(session, (Item.id == item_id) & (Item.user_id == user.id))
if not obj:
- raise_not_found("Item not found or access denied")
+ utils.raise_not_found("Item not found or access denied")
await obj.update(session, request, exclude_unset=True)
@@ -101,7 +100,7 @@ async def delete_item(
"""
obj = await Item.get(session, (Item.id == item_id) & (Item.user_id == user.id))
if not obj:
- raise_not_found("Item not found or access denied")
+ utils.raise_not_found("Item not found or access denied")
await Item.delete(session, obj)
@@ -116,7 +115,7 @@ async def retrieve_object(
object_data = await Item.get(session, Item.id == item_id)
if not object_data:
- raise_not_found("物品不存在或出现异常")
+ utils.raise_not_found("物品不存在或出现异常")
if object_data.status == "lost":
object_data.find_ip = client_host
@@ -136,12 +135,12 @@ async def notify_move_car(
item_data = await Item.get_exist_one(session=session, id=item_id)
if item_data.type != ItemTypeEnum.car:
- raise_bad_request("Item is not car")
+ utils.raise_bad_request("Item is not car")
server_chan_key = await Setting.get(session, Setting.name == "server_chan_key")
wechat_bot_key = await Setting.get(session, Setting.name == "wechat_bot_key")
if not (server_chan_key.value or wechat_bot_key.value):
- raise_internal_error("未配置Server酱,无法发送挪车通知")
+ utils.raise_internal_error("未配置Server酱,无法发送挪车通知")
title = "挪车通知 - Findreve"
description = (
@@ -161,4 +160,4 @@ async def notify_move_car(
version="v1",
)
- return HTTP_204_NO_CONTENT
+ return status.HTTP_204_NO_CONTENT
diff --git a/services/session.py b/services/session.py
index ce98c23..1fb6fe4 100644
--- a/services/session.py
+++ b/services/session.py
@@ -3,32 +3,25 @@
"""
from datetime import datetime, timedelta, timezone
-from typing import Any
-
-import JWT
-import jwt
-from loguru import logger
from sqlmodel.ext.asyncio.session import AsyncSession
+from typing import Any
+import jwt
from model import Setting, User
from model.response import TokenResponse
-from pkg import Password
-
+from pkg import Password, utils
+import JWT
async def create_access_token(
session: AsyncSession,
data: dict[str, Any],
- expires_delta: timedelta | None = None,
) -> str:
"""
创建访问令牌。
"""
to_encode = data.copy()
- if expires_delta:
- expire = datetime.now(timezone.utc) + expires_delta
- else:
- jwt_exp_setting = await Setting.get(session, Setting.name == "jwt_token_exp")
- expire = datetime.now(timezone.utc) + timedelta(int(jwt_exp_setting.value))
+ jwt_exp_setting = await Setting.get(session, Setting.name == "jwt_token_exp")
+ expire = datetime.now(timezone.utc) + timedelta(int(jwt_exp_setting.value))
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, key=await JWT.get_secret_key(), algorithm="HS256")
return encoded_jwt
@@ -38,19 +31,14 @@ async def authenticate_user(
session: AsyncSession,
username: str,
password: str,
-) -> User | None:
+) -> User:
"""
验证用户名和密码,返回认证后的用户。
"""
account = await User.get(session, User.email == username)
- if not account:
- logger.error("Account or password not set in settings.")
- return None
-
- if account.email != username or not Password.verify(account.password, password):
- logger.error("Invalid username or password.")
- return None
+ if not account or account.email != username or not Password.verify(account.password, password):
+ utils.raise_unauthorized("Account or password is incorrect")
return account
@@ -59,13 +47,11 @@ async def login_for_access_token(
session: AsyncSession,
username: str,
password: str,
-) -> TokenResponse | None:
+) -> TokenResponse:
"""
登录并生成访问令牌。
"""
user = await authenticate_user(session=session, username=username, password=password)
- if not user:
- return None
access_token = await create_access_token(
session=session,