From cd35c6fbedbf51fe3c167b5ab2d7da68c2c66943 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Sun, 5 Oct 2025 18:58:46 +0800 Subject: [PATCH] Refactor models and routes for item management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reorganized model structure by replacing 'object' and 'items' with a unified 'item' model using UUIDs, and moved base model logic into separate files. Updated routes to use the new item model and improved request/response handling. Enhanced user and setting models, added utility functions, and improved error handling throughout the codebase. Also added initial .idea project files and minor admin API improvements. Co-Authored-By: 砂糖橘 <54745033+Foxerine@users.noreply.github.com> --- .idea/.gitignore | 8 + .idea/Findreve.iml | 17 ++ .idea/copilot.data.migration.agent.xml | 6 + .idea/copilot.data.migration.ask.xml | 6 + .idea/copilot.data.migration.edit.xml | 6 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/material_theme_project_new.xml | 17 ++ .idea/misc.xml | 7 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + app.py | 31 ++- main.py | 2 +- middleware/user.py | 19 +- model/__init__.py | 9 +- model/base.py | 151 ----------- model/base/__init__.py | 3 + model/base/sqlmodel_base.py | 5 + model/base/table_base.py | 200 ++++++++++++++ model/database.py | 2 +- model/item.py | 77 ++++++ model/items.py | 14 - model/migration.py | 56 ++-- model/object.py | 45 ---- model/response.py | 23 +- model/setting.py | 21 +- model/token.py | 9 - model/user.py | 89 ++++++- pkg/__init__.py | 3 +- pkg/sms/smsbao.py | 2 + pkg/utils.py | 75 ++++++ routes/admin.py | 57 +++- routes/object.py | 247 ++++++------------ routes/session.py | 26 +- routes/site.py | 20 ++ 34 files changed, 782 insertions(+), 491 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/Findreve.iml create mode 100644 .idea/copilot.data.migration.agent.xml create mode 100644 .idea/copilot.data.migration.ask.xml create mode 100644 .idea/copilot.data.migration.edit.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/material_theme_project_new.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml delete mode 100644 model/base.py create mode 100644 model/base/__init__.py create mode 100644 model/base/sqlmodel_base.py create mode 100644 model/base/table_base.py create mode 100644 model/item.py delete mode 100644 model/items.py delete mode 100644 model/object.py delete mode 100644 model/token.py create mode 100644 pkg/sms/smsbao.py create mode 100644 pkg/utils.py create mode 100644 routes/site.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/Findreve.iml b/.idea/Findreve.iml new file mode 100644 index 0000000..916c239 --- /dev/null +++ b/.idea/Findreve.iml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/copilot.data.migration.agent.xml b/.idea/copilot.data.migration.agent.xml new file mode 100644 index 0000000..4ea72a9 --- /dev/null +++ b/.idea/copilot.data.migration.agent.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/copilot.data.migration.ask.xml b/.idea/copilot.data.migration.ask.xml new file mode 100644 index 0000000..7ef04e2 --- /dev/null +++ b/.idea/copilot.data.migration.ask.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/copilot.data.migration.edit.xml b/.idea/copilot.data.migration.edit.xml new file mode 100644 index 0000000..8648f94 --- /dev/null +++ b/.idea/copilot.data.migration.edit.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml new file mode 100644 index 0000000..d508618 --- /dev/null +++ b/.idea/material_theme_project_new.xml @@ -0,0 +1,17 @@ + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..82554e2 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..cd62433 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/app.py b/app.py index 93eddb2..add63b9 100644 --- a/app.py +++ b/app.py @@ -5,18 +5,21 @@ from contextlib import asynccontextmanager from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded + +from pkg.utils import raise_internal_error from routes import (session, admin, object) -import model.database -import os, asyncio +from model.database import Database +import os import pkg.conf -# 初始化数据库 -asyncio.run(model.database.Database().init_db()) +from loguru import logger + +Router = [admin, session, object] # Findreve 的生命周期 @asynccontextmanager async def lifespan(app: FastAPI): - await model.database.Database().init_db() + await Database().init_db() yield # 定义 Findreve 服务器 @@ -28,10 +31,22 @@ app = FastAPI( lifespan=lifespan ) +@app.exception_handler(Exception) +async def handle_unexpected_exceptions(request: Request, exc: Exception): + """ + 捕获所有未经处理的异常,防止敏感信息泄露。 + """ + # 1. 为开发人员记录详细的、包含完整堆栈跟踪的错误日志 + logger.exception( + f"An unhandled exception occurred for request: {request.method} {request.url.path}" + ) + + raise_internal_error() + + # 挂载后端路由 -app.include_router(admin.Router) -app.include_router(session.Router) -app.include_router(object.Router) +for router in Router: + app.include_router(router.Router) # 挂载Slowapi限流中间件 limiter = Limiter(key_func=get_remote_address) diff --git a/main.py b/main.py index 89cf230..9d5fa2a 100644 --- a/main.py +++ b/main.py @@ -32,4 +32,4 @@ if __name__ == '__main__': port=port, log_config=None, # 禁用 uvicorn 默认的日志配置,使用 loguru reload=debug, # 调试模式下启用热重载 - ) \ No newline at end of file + ) diff --git a/middleware/user.py b/middleware/user.py index c9e2486..861c068 100644 --- a/middleware/user.py +++ b/middleware/user.py @@ -1,21 +1,24 @@ -from typing import Annotated, Literal +from typing import Annotated + +import jwt from fastapi import Depends from fastapi import HTTPException -import JWT -import jwt from jwt import InvalidTokenError -from model import database from sqlmodel.ext.asyncio.session import AsyncSession + +import JWT from model import User +from model.database import Database + # 验证是否为管理员 async def get_current_user( token: Annotated[str, Depends(JWT.oauth2_scheme)], - session: Annotated[AsyncSession, Depends(database.Database.get_session)], + session: Annotated[AsyncSession, Depends(Database.get_session)], ) -> User: - ''' + """ 验证用户身份并返回当前用户信息。 - ''' + """ not_login_exception = HTTPException( status_code=401, detail="Login required", @@ -26,7 +29,7 @@ async def get_current_user( payload = jwt.decode(token, await JWT.get_secret_key(), algorithms=[JWT.ALGORITHM]) username = payload.get("sub") stored_account = await User.get(session, User.email == username) - if username is None or not stored_account.email == username: + if username is None or stored_account.email != username: raise not_login_exception return stored_account except InvalidTokenError: diff --git a/model/__init__.py b/model/__init__.py index 5b4e3be..58046c2 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,4 +1,5 @@ -from . import token -from .setting import Setting -from .object import Object -from .user import User \ No newline at end of file +from .response import DefaultResponse, TokenResponse, TokenData +from .setting import Setting, SettingResponse +from .item import Item, ItemDataResponse, ItemTypeEnum, ItemStatusEnum +from .user import User, UserTypeEnum +from .database import Database diff --git a/model/base.py b/model/base.py deleted file mode 100644 index 830a432..0000000 --- a/model/base.py +++ /dev/null @@ -1,151 +0,0 @@ -# model/base.py -from datetime import datetime, timezone -from typing import Type, TypeVar, Union, Literal, List - -from sqlalchemy import DateTime, BinaryExpression, ClauseElement -from sqlalchemy.orm import selectinload -from sqlalchemy.ext.asyncio.session import AsyncSession -from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlmodel import SQLModel, Field, select, Relationship -from sqlalchemy.sql._typing import _OnClauseArgument - -B = TypeVar('B', bound='TableBase') -M = TypeVar('M', bound='SQLModel') - -utcnow = lambda: datetime.now(tz=timezone.utc) - -class TableBase(AsyncAttrs, SQLModel): - __abstract__ = True - - created_at: datetime = Field( - default_factory=utcnow, - description="创建时间", - ) - updated_at: datetime = Field( - sa_type=DateTime, - description="更新时间", - sa_column_kwargs={"default": utcnow, "onupdate": utcnow}, - default_factory=utcnow - ) - deleted_at: datetime | None = Field( - default=None, - description="删除时间", - sa_column={"nullable": True} - ) - - @classmethod - async def add( - cls: Type[B], - session: AsyncSession, - instances: B | List[B], - refresh: bool = True - ) -> B | List[B]: - is_list = isinstance(instances, list) - if is_list: - session.add_all(instances) - else: - session.add(instances) - await session.commit() - if refresh: - if is_list: - for i in instances: - await session.refresh(i) - else: - await session.refresh(instances) - return instances - - async def save( - self: B, - session: AsyncSession, - load: Union[Relationship, None] = None, # 设默认值,避免必须传 - ): - session.add(self) - await session.commit() - if load is not None: - cls = type(self) - return await cls.get(session, cls.id == self.id, load=load) # 若该模型没有 id,请别用 load 模式 - else: - await session.refresh(self) - return self - - async def update( - self: B, - session: AsyncSession, - other: M, - extra_data: dict = None, - exclude_unset: bool = True, - ) -> B: - self.sqlmodel_update( - other.model_dump(exclude_unset=exclude_unset), - update=extra_data - ) - session.add(self) - await session.commit() - await session.refresh(self) - return self - - @classmethod - async def delete( - cls: Type[B], - session: AsyncSession, - instance: B | list[B], - ) -> None: - if isinstance(instance, list): - for inst in instance: - await session.delete(inst) - else: - await session.delete(instance) - await session.commit() - - @classmethod - async def get( - cls: Type[B], - session: AsyncSession, - condition: BinaryExpression | ClauseElement | None, - *, - offset: int | None = None, - limit: int | None = None, - fetch_mode: Literal["one", "first", "all"] = "first", - join: Type[B] | tuple[Type[B], _OnClauseArgument] | None = None, - options: list | None = None, - load: Union[Relationship, None] = None, - order_by: list[ClauseElement] | None = None - ) -> B | List[B] | None: - statement = select(cls) - if condition is not None: - statement = statement.where(condition) - if join is not None: - statement = statement.join(*join) - if options: - statement = statement.options(*options) - if load: - statement = statement.options(selectinload(load)) - if order_by is not None: - statement = statement.order_by(*order_by) - if offset: - statement = statement.offset(offset) - if limit: - statement = statement.limit(limit) - - result = await session.exec(statement) - if fetch_mode == "one": - return result.one() - elif fetch_mode == "first": - return result.first() - elif fetch_mode == "all": - return list(result.all()) - else: - raise ValueError(f"无效的 fetch_mode: {fetch_mode}") - - @classmethod - async def get_exist_one(cls: Type[B], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> B: - instance = await cls.get(session, cls.id == id, load=load) - if not instance: - from fastapi import HTTPException - raise HTTPException(status_code=404, detail="Not found") - return instance - - -# 需要“自增 id 主键”的模型才混入它;Setting 不混入 -class IdMixin(SQLModel): - id: int | None = Field(default=None, primary_key=True, description="主键ID") \ No newline at end of file diff --git a/model/base/__init__.py b/model/base/__init__.py new file mode 100644 index 0000000..5f464b4 --- /dev/null +++ b/model/base/__init__.py @@ -0,0 +1,3 @@ +from .sqlmodel_base import SQLModelBase +from .table_base import TableBase, UUIDTableBase + diff --git a/model/base/sqlmodel_base.py b/model/base/sqlmodel_base.py new file mode 100644 index 0000000..0a42a35 --- /dev/null +++ b/model/base/sqlmodel_base.py @@ -0,0 +1,5 @@ +from pydantic import ConfigDict +from sqlmodel import SQLModel + +class SQLModelBase(SQLModel): + model_config = ConfigDict(use_attribute_docstrings=True) diff --git a/model/base/table_base.py b/model/base/table_base.py new file mode 100644 index 0000000..5478ff0 --- /dev/null +++ b/model/base/table_base.py @@ -0,0 +1,200 @@ +import uuid +from datetime import datetime, timezone +from typing import Union, TypeVar, Type, Literal, override, Optional, Any + +from fastapi import HTTPException +from sqlalchemy import DateTime, BinaryExpression, ClauseElement +from sqlalchemy.orm import selectinload +from sqlmodel import Field, select, Relationship, SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlalchemy.sql._typing import _OnClauseArgument +from sqlalchemy.ext.asyncio import AsyncAttrs + +T = TypeVar("T", bound="TableBase") +M = TypeVar("M", bound="SQLModel") + +now = lambda: datetime.now() +now_date = lambda: datetime.now().date() + +class TableBase(AsyncAttrs): + id: int | None = Field(default=None, primary_key=True) + + created_at: datetime = Field(default_factory=now) + updated_at: datetime = Field( + sa_type=DateTime, + sa_column_kwargs={"default": now, "onupdate": now}, + default_factory=now + ) + + @classmethod + async def add(cls: Type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]: + """ + 新增一条记录 + :param session: 数据库会话 + :param instances: + :param refresh: + :return: 新增的实例对象 + + usage: + item1 = Item(...) + item2 = Item(...) + + Item.add(session, [item1, item2]) + + item1_id = item1.id + """ + is_list = False + if isinstance(instances, list): + is_list = True + session.add_all(instances) + else: + session.add(instances) + + await session.commit() + + if refresh: + if is_list: + for instance in instances: + await session.refresh(instance) + else: + await session.refresh(instances) + + return instances + + async def save(self: T, session: AsyncSession, load: Optional[Relationship] = None) -> T: + session.add(self) + await session.commit() + + if load is not None: + cls = type(self) + return await cls.get(session, cls.id == self.id, load=load) + else: + await session.refresh(self) + return self + + async def update( + self: T, + session: AsyncSession, + other: M, + extra_data: dict = None, + exclude_unset: bool = True + ) -> T: + """ + 更新记录 + :param session: 数据库会话 + :param other: + :param extra_data: + :param exclude_unset: + :return: + """ + self.sqlmodel_update(other.model_dump(exclude_unset=exclude_unset), update=extra_data) + + session.add(self) + + await session.commit() + await session.refresh(self) + + return self + + @classmethod + async def delete(cls: Type[T], session: AsyncSession, instances: T | list[T]) -> None: + """ + 删除一些记录 + :param session: 数据库会话 + :param instances: + :return: None + + usage: + item1 = Item.get(...) + item2 = Item.get(...) + + Item.delete(session, [item1, item2]) + + """ + if isinstance(instances, list): + for instance in instances: + await session.delete(instance) + else: + await session.delete(instances) + + await session.commit() + + @classmethod + async def get( + cls: Type[T], + session: AsyncSession, + condition: BinaryExpression | ClauseElement | None, + *, + offset: int | None = None, + limit: int | None = None, + fetch_mode: Literal["one", "first", "all"] = "first", + join: Type[T] | tuple[Type[T], _OnClauseArgument] | None = None, + options: list | None = None, + load: Union[Relationship, None] = None, + order_by: list[ClauseElement] | None = None + ) -> T | list[T] | None: + """ + 异步获取模型实例 + + 参数: + session: 异步数据库会话 + condition: SQLAlchemy查询条件,如Model.id == 1 + offset: 结果偏移量 + limit: 结果数量限制 + options: 查询选项,如selectinload(Model.relation),异步访问关系属性必备,不然会报错 + fetch_mode: 获取模式 - "one"/"all"/"first" + join: 要联接的模型类 + + 返回: + 根据fetch_mode返回相应的查询结果 + """ + statement = select(cls) + + if condition is not None: + statement = statement.where(condition) + + if join is not None: + statement = statement.join(*join) + + if options: + statement = statement.options(*options) + + if load: + statement = statement.options(selectinload(load)) + + if order_by is not None: + statement = statement.order_by(*order_by) + + if offset: + statement = statement.offset(offset) + + if limit: + statement = statement.limit(limit) + + result = await session.exec(statement) + + if fetch_mode == "one": + return result.one() + elif fetch_mode == "first": + return result.first() + elif fetch_mode == "all": + return list(result.all()) + else: + raise ValueError(f"无效的 fetch_mode: {fetch_mode}") + + @classmethod + async def get_exist_one(cls: Type[T], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> T: + """此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None, + 而是会抛出fastapi 404 异常""" + instance = await cls.get(session, cls.id == id, load=load) + if not instance: + raise HTTPException(status_code=404, detail="Not found") + return instance + +class UUIDTableBase(TableBase): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + """override""" + + @override + async def get_exist_one(cls: Type[T], session: AsyncSession, id: uuid.UUID, load: Union[Relationship, None] = None) -> T: + return super().get_exist_one(session, id, load) # type: ignore diff --git a/model/database.py b/model/database.py index fcaf9f9..96f8b6d 100644 --- a/model/database.py +++ b/model/database.py @@ -71,4 +71,4 @@ class Database: # For internal use, create a temporary context manager async with self.session_context() as session: - await migration(session) # 执行迁移脚本 \ No newline at end of file + await migration(session) # 执行迁移脚本 diff --git a/model/item.py b/model/item.py new file mode 100644 index 0000000..d828242 --- /dev/null +++ b/model/item.py @@ -0,0 +1,77 @@ +from datetime import datetime +from enum import StrEnum +from typing import TYPE_CHECKING, Self, Optional +from uuid import UUID + +from sqlmodel import Field, Relationship + +from .base import SQLModelBase, UUIDTableBase + +if TYPE_CHECKING: + from .user import User + +class ItemTypeEnum(StrEnum): + normal = 'normal' + car = 'car' + +class ItemStatusEnum(StrEnum): + ok = 'ok' + lost = 'lost' + +class ItemBase(SQLModelBase): + type: ItemTypeEnum = ItemTypeEnum.normal + """物品的类型""" + + name: str + """物品名称""" + + icon: str | None = None + """物品图标""" + + status: ItemStatusEnum = ItemStatusEnum.ok + """物品状态""" + + phone: str | None = None + """联系电话""" + + description: str | None = None + """物品描述""" + +class Item(ItemBase, UUIDTableBase, table=True): + expires_at: datetime | None = None + """物品过期时间""" + + lost_at: datetime | None = None + """物品丢失的时间""" + + find_ip: str | None = None + """最后一次发现的IP地址""" + + user_id: UUID = Field(foreign_key='user.id', ondelete='CASCADE') + """所属用户ID""" + + user: 'User' = Relationship(back_populates='items') + + parent_item_id: UUID | None = Field(foreign_key='item.id', ondelete='RESTRICT') + parent_item: Optional['Item'] = Relationship(back_populates='sub_items', sa_relationship_kwargs={'remote_side': 'Item.id'}) + sub_items: list['Item'] = Relationship(back_populates='parent_item', passive_deletes='all') + +class ItemDataUpdateRequest(ItemBase): + pass + +class ItemDataResponse(ItemBase): + expires_at: datetime | None = None + """物品过期时间""" + + lost_at: datetime | None = None + """物品丢失的时间""" + +class ItemDataResponseAdmin(ItemBase): + expires_at: datetime | None = None + """物品过期时间""" + + lost_at: datetime | None = None + """物品丢失的时间""" + + user_id: UUID = Field(foreign_key='user.id') + """所属用户ID""" diff --git a/model/items.py b/model/items.py deleted file mode 100644 index 47d9d18..0000000 --- a/model/items.py +++ /dev/null @@ -1,14 +0,0 @@ -from pydantic import BaseModel - -class Item(BaseModel): - id: int - type: str - key: str - name: str - icon: str - status: str - phone: int - lost_description: str | None - find_ip: str | None - create_time: str - lost_time: str | None diff --git a/model/migration.py b/model/migration.py index eb711ad..451a8a8 100644 --- a/model/migration.py +++ b/model/migration.py @@ -1,13 +1,12 @@ from loguru import logger -from sqlmodel import select from .setting import Setting -from .user import User +from .user import User, UserTypeEnum from pkg import Password default_settings: list[Setting] = [ - Setting(type='string', name='version', value='1.0.0'), - Setting(type='int', name='jwt_token_exp', value='30'), - Setting(type='string', name='server_chan_key', value=''), + Setting(type='string', name='version', value='2.0.0'), # 版本号,用于考虑是否需要数据迁移 + Setting(type='int', name='jwt_token_exp', value='30'), # JWT Token 访问令牌 + Setting(type='string', name='server_chan_key', value=''), # Server 酱推送密钥 ] async def migration(session): @@ -24,39 +23,32 @@ async def migration(session): names = [s.name for s in settings] existed_settings = await Setting.get( session, - Setting.name.in_(names), - fetch_mode="all" + Setting.name in names, + fetch_mode='all' ) existed: set[str] = {s.name for s in (existed_settings or [])} to_insert = [s for s in settings if s.name not in existed] if to_insert: - await Setting.add(session, to_insert, refresh=False) - - if await User.get(session, User.id == 1): - # 已有超级管理员用户,说明不是第一次运行 + await Setting.add(session, to_insert) - # 修复数据库id为1的用户不是管理员的问题 - admin_user = await User.get(session, User.id == 1) - if admin_user and not admin_user.is_admin: - admin_user.is_admin = True - await User.update(session, admin_user, refresh=False) - - # 已有用户,直接返回 - return + if not await User.get(session, User.role == UserTypeEnum.super_admin): + # 生成初始密码与密钥 + admin_password = Password.generate() + logger.warning("当前无管理员用户,已自动创建初始管理员用户:") + logger.warning("邮箱: admin@yxqi.cn") + logger.warning(f"密码: {admin_password}") - # 生成初始密码与密钥 - admin_password = Password.generate() - logger.warning("当前无管理员用户,已自动创建初始管理员用户:") - logger.warning("邮箱: admin@yxqi.cn") - logger.warning(f"密码: {admin_password}") + User._initializing = True - admin_user = User( - id=1, - email='admin@yxqi.cn', - username='Admin', - password=Password.hash(admin_password), - is_admin=True - ) + admin_user = User( + email='admin@yxqi.cn', + username='Admin', + password=Password.hash(admin_password), + role=UserTypeEnum.super_admin, + _initializing=True + ) - await User.add(session, admin_user, refresh=False) \ No newline at end of file + await User.add(session, admin_user) + + User._initializing = False diff --git a/model/object.py b/model/object.py deleted file mode 100644 index c836530..0000000 --- a/model/object.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Literal, TYPE_CHECKING -from sqlmodel import Field, Column, String, DateTime, Relationship -from .base import TableBase, IdMixin -from datetime import datetime - -if TYPE_CHECKING: - from .user import User - -class Object(IdMixin, TableBase, table=True): - - user_id: int = Field(foreign_key="user.id", index=True, nullable=False, description="所属用户ID") - key: str = Field(index=True, nullable=False, unique=True, description="物品外部ID") - type: Literal['normal', 'car'] = Field( - default='normal', - description="物品类型", - sa_column=Column( - String, - default='normal', - nullable=False - ) - ) - name: str = Field(nullable=False, description="物品名称") - icon: str | None = Field(default=None, description="物品图标") - status: Literal['ok', 'lost'] = Field( - default='ok', - description="物品状态", - sa_column=Column( - String, - default='ok', - nullable=False - ) - ) - phone: str | None = Field(default=None, description="联系电话") - description: str | None = Field(default=None, description="物品描述") - find_ip: str | None = Field(default=None, description="最后一次发现的IP地址") - lost_at: datetime | None = Field( - default=None, - description="物品标记为丢失的时间", - sa_column=Column( - DateTime, - nullable=True - ) - ) - - user: "User" = Relationship(back_populates="objects") \ No newline at end of file diff --git a/model/response.py b/model/response.py index 2b99bfe..b54e8ac 100644 --- a/model/response.py +++ b/model/response.py @@ -1,20 +1,15 @@ +from datetime import datetime + from pydantic import BaseModel -from typing import Literal class DefaultResponse(BaseModel): code: int = 0 - data: dict | list | bool | None = None + data: dict | list | bool | None msg: str = "" -class ObjectData(BaseModel): - id: int - type: Literal['normal', 'car'] - key: str - name: str - icon: str - status: Literal['ok', 'lost'] - phone: str - context: str | None = None - lost_description: str | None = None - create_time: str - lost_time: str | None = None \ No newline at end of file +# FastAPI 鉴权返回模型 +class TokenResponse(BaseModel): + access_token: str + +class TokenData(BaseModel): + username: str | None = None diff --git a/model/setting.py b/model/setting.py index ea43faf..54a86da 100644 --- a/model/setting.py +++ b/model/setting.py @@ -1,8 +1,19 @@ from sqlmodel import Field -from .base import TableBase +from .base import TableBase, SQLModelBase -class Setting(TableBase, table=True): - type: str = Field(index=True, nullable=False, description="设置类型") - name: str = Field(primary_key=True, nullable=False, description="设置名称") # name 为唯一主键 - value: str | None = Field(description="设置值") +class SettingBase(SQLModelBase): + type: str = Field(index=True) + """设置类型""" + + name: str = Field(index=True, unique=True) # name 为唯一主键 + """设置名称""" + + value: str | None + """设置值""" + +class Setting(SettingBase, TableBase, table=True): + pass + +class SettingResponse(SettingBase): + pass diff --git a/model/token.py b/model/token.py deleted file mode 100644 index 72f1b45..0000000 --- a/model/token.py +++ /dev/null @@ -1,9 +0,0 @@ -from pydantic import BaseModel - -# FastAPI 鉴权模型 -class Token(BaseModel): - access_token: str - token_type: str - -class TokenData(BaseModel): - username: str | None = None \ No newline at end of file diff --git a/model/user.py b/model/user.py index ca54bff..dd8f60c 100644 --- a/model/user.py +++ b/model/user.py @@ -1,16 +1,81 @@ -from typing import TYPE_CHECKING -from sqlmodel import Field, Column, String, Boolean, Relationship -from .base import TableBase, IdMixin +from enum import StrEnum +from typing import ClassVar -if TYPE_CHECKING: - from .object import Object +import sqlalchemy as sa +from pydantic import EmailStr +from sqlalchemy import event +from sqlalchemy.orm.session import Session as SessionClass +from sqlmodel import Field, Relationship -class User(IdMixin, TableBase, table=True): +from .base import SQLModelBase, UUIDTableBase +from .item import Item - email: str = Field(sa_column=Column(String(100), index=True, unique=True)) - username: str = Field(sa_column=Column(String(50), index=True, unique=True)) - password: str = Field(sa_column=Column(String(100))) - is_admin: bool = Field(default=False, sa_column=Column(Boolean, default=False)) - - objects: list["Object"] = Relationship(back_populates="user") \ No newline at end of file +class UserTypeEnum(StrEnum): + normal_user = 'normal_user' + admin = 'admin' + super_admin = 'super_admin' + +class UserBase(SQLModelBase): + pass + +class User(UserBase, UUIDTableBase, table=True): + email: EmailStr = Field(index=True, unique=True) + """邮箱""" + + username: str = Field(index=True, unique=True) + """用户名""" + + password: str + """Argon2算法哈希后的密码""" + + two_factor_secret: str | None = None + """两步验证的密钥""" + + role: UserTypeEnum = Field(default=UserTypeEnum.normal_user, index=True) + """用户的权限等级""" + + items: list[Item] = Relationship(back_populates='user', cascade_delete=True) + """物品关系""" + + _initializing: ClassVar[bool] = False + """标记当前是否处于初始化阶段,初始化阶段允许创建 super_admin""" + +@event.listens_for(SessionClass, "before_flush") +def check_super_admin_immutability(session, flush_context, instances): + """ + 在事务刷新到数据库前,集中检查所有关于 super_admin 的不合法操作。 + 此监听器确保超级管理员的角色和存在性是不可变的。 + """ + # 检查1: 禁止创建新的 super_admin + for obj in session.new: + if isinstance(obj, User) and obj.role == UserTypeEnum.super_admin and not User._initializing: + raise ValueError("业务规则:不允许创建新的超级管理员。") + + # 检查2: 禁止删除已存在的 super_admin + for obj in session.deleted: + if isinstance(obj, User): + state = sa.inspect(obj) + # 直接从对象被删除前的状态获取角色,避免不必要的 lazy load + original_role = state.committed_state.get('role') + if original_role == UserTypeEnum.super_admin: + username = state.committed_state.get('username', f'(ID: {obj.id})') + raise ValueError(f"业务规则:不允许删除超级管理员 '{username}'。") + + # 检查3: 禁止与 super_admin 相关的角色变更 + for obj in session.dirty: + if isinstance(obj, User): + state = sa.inspect(obj) + # 仅在 'role' 字段确实被修改时才进行检查 + if "role" in state.committed_state: + history = state.attrs.role.history + original_role = history.deleted[0] + new_role = history.added[0] + + # 场景 a: 禁止将 super_admin 降级 + if original_role == UserTypeEnum.super_admin: + raise ValueError(f"业务规则:不允许将超级管理员 '{obj.username}' 的角色降级。") + + # 场景 b: 禁止将任何用户提升为 super_admin + if new_role == UserTypeEnum.super_admin: + raise ValueError(f"业务规则:不允许将用户 '{obj.username}' 提升为超级管理员。") diff --git a/pkg/__init__.py b/pkg/__init__.py index 5fd47c0..1ce733e 100644 --- a/pkg/__init__.py +++ b/pkg/__init__.py @@ -1 +1,2 @@ -from .password import Password \ No newline at end of file +from .password import Password + diff --git a/pkg/sms/smsbao.py b/pkg/sms/smsbao.py new file mode 100644 index 0000000..f309297 --- /dev/null +++ b/pkg/sms/smsbao.py @@ -0,0 +1,2 @@ +class SmsBao(): + async def get \ No newline at end of file diff --git a/pkg/utils.py b/pkg/utils.py new file mode 100644 index 0000000..919e123 --- /dev/null +++ b/pkg/utils.py @@ -0,0 +1,75 @@ +from typing import Any, NoReturn, TYPE_CHECKING + +from fastapi import HTTPException + +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_401_UNAUTHORIZED, + HTTP_403_FORBIDDEN, + HTTP_404_NOT_FOUND, + HTTP_409_CONFLICT, + HTTP_429_TOO_MANY_REQUESTS, + HTTP_500_INTERNAL_SERVER_ERROR, + HTTP_501_NOT_IMPLEMENTED, + HTTP_503_SERVICE_UNAVAILABLE, + HTTP_504_GATEWAY_TIMEOUT, HTTP_402_PAYMENT_REQUIRED, +) + +if TYPE_CHECKING: + from sqlmodel.ext.asyncio.session import AsyncSession + + +# --- Request and Response Helpers --- + +def ensure_request_param(to_check: Any, detail: str) -> None: + """ + Ensures a parameter exists. If not, raises a 400 Bad Request. + This function returns None if the check passes. + """ + if not to_check: + raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail) + +def raise_bad_request(detail: str = '') -> NoReturn: + """Raises an HTTP 400 Bad Request exception.""" + raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail) + +def raise_not_found(detail: str) -> NoReturn: + """Raises an HTTP 404 Not Found exception.""" + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail) + +def raise_internal_error(detail: str = "服务器出现故障,请稍后再试或联系管理员") -> NoReturn: + """Raises an HTTP 500 Internal Server Error exception.""" + raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail) + +def raise_forbidden(detail: str) -> NoReturn: + """Raises an HTTP 403 Forbidden exception.""" + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail) + +def raise_unauthorized(detail: str) -> NoReturn: + """Raises an HTTP 401 Unauthorized exception.""" + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail) + +def raise_conflict(detail: str) -> NoReturn: + """Raises an HTTP 409 Conflict exception.""" + raise HTTPException(status_code=HTTP_409_CONFLICT, detail=detail) + +def raise_too_many_requests(detail: str) -> NoReturn: + """Raises an HTTP 429 Too Many Requests exception.""" + raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=detail) + +def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn: + """Raises an HTTP 501 Not Implemented exception.""" + raise HTTPException(status_code=HTTP_501_NOT_IMPLEMENTED, detail=detail) + +def raise_service_unavailable(detail: str) -> NoReturn: + """Raises an HTTP 503 Service Unavailable exception.""" + raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail=detail) + +def raise_gateway_timeout(detail: str) -> NoReturn: + """Raises an HTTP 504 Gateway Timeout exception.""" + raise HTTPException(status_code=HTTP_504_GATEWAY_TIMEOUT, detail=detail) + +def raise_insufficient_quota(detail: str = "积分不足,请充值") -> NoReturn: + raise HTTPException(status_code=HTTP_402_PAYMENT_REQUIRED, detail=detail) + +# --- End of Request and Response Helpers --- diff --git a/routes/admin.py b/routes/admin.py index 60f7cfd..134e6a3 100644 --- a/routes/admin.py +++ b/routes/admin.py @@ -1,9 +1,12 @@ -from fastapi import APIRouter +from typing import Annotated + +from fastapi import APIRouter, HTTPException from fastapi import Depends -from model.response import DefaultResponse +from sqlalchemy.ext.asyncio import AsyncSession from middleware.admin import is_admin - +from model import database, Setting, SettingResponse +from model.response import DefaultResponse Router = APIRouter( prefix='/api/admin', @@ -25,4 +28,52 @@ async def verity_admin() -> DefaultResponse: - 若为管理员,返回 `True` - 若不是管理员,抛出 `401` 错误 ''' + return DefaultResponse(data=True) + +@Router.get( + path='api/admin/settings', + summary='获取设置项', + description='获取设置项, 留空则获取所有', + response_model=DefaultResponse, + response_description='设置项列表' +) +async def get_settings( + session: Annotated[AsyncSession, Depends(database.Database.get_session)], + name: str | None = None +) -> DefaultResponse: + data = [] + + if name: + setting = await Setting.get(session, Setting.name == name) + if setting: + data.append(SettingResponse.model_validate(setting)) + else: + raise HTTPException(404, detail="Setting not found") + else: + settings = await Setting.get(session, fetch_mode="all") + if settings: + data = [SettingResponse.model_validate(s) for s in settings] + + return DefaultResponse(data=data) + + +@Router.put( + path='api/admin/settings', + summary='更新设置项', + description='更新设置项', + response_model=DefaultResponse, + response_description='更新结果' +) +async def update_settings( + session: Annotated[AsyncSession, Depends(database.Database.get_session)], + name: str, + value: str +) -> DefaultResponse: + setting = await Setting.get(session, Setting.name == name) + if not setting: + raise HTTPException(404, detail="Setting not found") + + setting.value = value + await Setting.save(session) + return DefaultResponse(data=True) \ No newline at end of file diff --git a/routes/object.py b/routes/object.py index cbb5961..3168500 100644 --- a/routes/object.py +++ b/routes/object.py @@ -1,16 +1,17 @@ -import random +from typing import Annotated, Literal +from uuid import UUID + from fastapi import APIRouter, Request, Query, HTTPException from fastapi.responses import JSONResponse +from loguru import logger from slowapi import Limiter from slowapi.util import get_remote_address -from model import database, Object, Setting -from model import User -from model.items import Item -from middleware.user import get_current_user -from loguru import logger -from model.response import DefaultResponse, ObjectData from sqlalchemy.ext.asyncio import AsyncSession -from typing import Annotated, Literal + +from middleware.user import get_current_user +from model import DefaultResponse, ItemDataResponse, User, database, Setting, Item +from model.item import ItemDataUpdateRequest +from pkg.utils import raise_not_found, raise_bad_request, raise_internal_error, raise_service_unavailable limiter = Limiter(key_func=get_remote_address) @@ -32,21 +33,21 @@ async def get_items( token: Annotated[User, Depends(get_current_user)], id: int | None = Query(default=None, ge=1, description='物品ID'), key: str | None = Query(default=None, description='物品序列号')): - ''' + """ 获得物品信息。 - + 不传参数返回所有信息,否则可传入 `id` 或 `key` 进行筛选。 - ''' + """ # 根据条件查询物品,只获取当前用户的物品 if id is not None: - results = await Object.get(session, (Object.id == id) & (Object.user_id == token.id)) + results = await Item.get(session, (Item.id == id) & (Item.user_id == token.id)) results = [results] if results else [] elif key is not None: - results = await Object.get(session, (Object.key == key) & (Object.user_id == token.id)) + results = await Item.get(session, (Item.key == key) & (Item.user_id == token.id)) results = [results] if results else [] else: - results = await Object.get(session, Object.user_id == token.id, fetch_mode="all") + results = await Item.get(session, Item.user_id == token.id, fetch_mode="all") if results: items = [] @@ -54,7 +55,7 @@ async def get_items( items.append(Item( id=obj.id, type=obj.type, - key=obj.key, + key=obj.id, name=obj.name, icon=obj.icon or "", status=obj.status or "", @@ -77,35 +78,22 @@ async def get_items( ) async def add_items( session: Annotated[AsyncSession, Depends(database.Database.get_session)], - token: Annotated[User, Depends(get_current_user)], - key: str, - type: Literal['normal', 'car'], - name: str, - icon: str, - phone: str + user: Annotated[User, Depends(get_current_user)], + request: ItemDataUpdateRequest ) -> DefaultResponse: - ''' + """ 添加物品信息。 - + - **key**: 物品的关键字 - **type**: 物品的类型 - **name**: 物品的名称 - **icon**: 物品的图标 - **phone**: 联系电话 - ''' + """ try: # 创建新物品对象,关联当前用户 - new_object = Object( - key=key, - type=type, - name=name, - icon=icon, - phone=phone, - user_id=token.id - ) - # 使用 base.py 中的 add 方法 - await Object.add(session, new_object) + await Item.add(session, Item.model_validate(request)) except Exception as e: logger.error(f"Failed to add item: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -113,7 +101,7 @@ async def add_items( return DefaultResponse(data=True) @Router.patch( - path='/items', + path='/items/{item_id}', summary='更新物品信息', description='更新现有物品的信息', response_model=DefaultResponse, @@ -121,22 +109,15 @@ async def add_items( ) async def update_items( session: Annotated[AsyncSession, Depends(database.Database.get_session)], - token: Annotated[User, Depends(get_current_user)], - id: int = Query(ge=1), - key: str | None = None, - name: str | None = None, - icon: str | None = None, - status: str | None = None, - phone: int | None = None, - lost_description: str | None = None, - find_ip: str | None = None, - lost_time: str | None = None - ) -> DefaultResponse: - ''' + user: Annotated[User, Depends(get_current_user)], + item_id: UUID, + request: ItemDataUpdateRequest +) -> DefaultResponse: + """ 更新物品信息。 - + 只有 `id` 是必填参数,其余参数都是可选的,在不传入任何值的时候将不做任何更改。 - + - **id**: 物品的ID - **key**: 物品的序列号 **不建议修改此项,这样会导致生成的物品二维码直接失效** - **name**: 物品的名称 @@ -146,41 +127,16 @@ async def update_items( - **lost_description**: 物品丢失描述 - **find_ip**: 找到物品的IP - **lost_time**: 物品丢失时间 - - ''' - try: - # 获取现有物品,验证归属权 - obj = await Object.get(session, (Object.id == id) & (Object.user_id == token.id)) - if not obj: - raise HTTPException(status_code=404, detail="Item not found or access denied") - - # 更新字段 - if key is not None: - obj.key = key - if name is not None: - obj.name = name - if icon is not None: - obj.icon = icon - if status is not None: - obj.status = status - if phone is not None: - obj.phone = str(phone) - if lost_description is not None: - obj.context = lost_description - if find_ip is not None: - obj.find_ip = find_ip - if lost_time is not None: - from datetime import datetime - obj.lost_at = datetime.fromisoformat(lost_time) - - # 保存更新 - await obj.save(session) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - else: - return DefaultResponse(data=True) + + """ + # 获取现有物品,验证归属权 + obj = await Item.get(session, (Item.id == item_id) & (Item.user_id == user.id)) + if not obj: + raise_not_found("Item not found or access denied") + + await obj.update(session, request) + + return DefaultResponse(data=True) @Router.delete( path='/items', @@ -191,27 +147,21 @@ async def update_items( ) async def delete_items( session: Annotated[AsyncSession, Depends(database.Database.get_session)], - token: Annotated[User, Depends(get_current_user)], + user: Annotated[User, Depends(get_current_user)], id: int ) -> DefaultResponse: - ''' + """ 删除物品信息。 - + - **id**: 物品的ID - ''' - try: - # 获取现有物品,验证归属权 - obj = await Object.get(session, (Object.id == id) & (Object.user_id == token.id)) - if not obj: - raise HTTPException(status_code=404, detail="Item not found or access denied") - # 使用 base.py 中的 delete 方法 - await Object.delete(session, obj) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - else: - return DefaultResponse(data=True) + """ + # 获取现有物品,验证归属权 + obj = await Item.get(session, (Item.id == id) & (Item.user_id == user.id)) + if not obj: + raise_not_found("Item not found or access denied") + await Item.delete(session, obj) + + return DefaultResponse(data=True) @Router.get( path='/{item_key}', @@ -224,47 +174,21 @@ async def get_object( session: Annotated[AsyncSession, Depends(database.Database.get_session)], item_key: str, request: Request -): +) -> DefaultResponse: """ 获取物品信息 / Get object information """ - - object_data = await Object.get(session, Object.key == item_key) + object_data = await Item.get(session, Item.key == item_key) if object_data: if object_data.status == 'lost': # 物品已标记为丢失,更新IP地址 - await Object.update( - session, - id=object_data.id, - find_ip=str(request.client.host) - ) + object_data.find_ip = str(request.client.host) + object_data = await object_data.save(session) - # 添加一些随机延迟,类似JWT身份验证时根据延迟爆破引发的问题 - await asyncio.sleep(random.uniform(0.10, 0.30)) - - print(object_data) - return DefaultResponse( - data=ObjectData( - id=object_data.id, - type=object_data.type, - key=object_data.key, - name=object_data.name, - icon=object_data.icon, - status=object_data.status, - phone=object_data.phone, - lost_description=object_data.lost_description, - create_time=object_data.create_time, - lost_time=object_data.lost_time - ).model_dump() - ) - else: return JSONResponse( - status_code=404, - content=DefaultResponse( - code=404, - msg='物品不存在或出现异常' - ).model_dump() - ) + return DefaultResponse(data=ItemDataResponse.model_validate(object_data)) + else: + raise_not_found('物品不存在或出现异常') @Router.put( path='/{item_id}', @@ -274,7 +198,7 @@ async def get_object( response_description="挪车通知结果" ) @limiter.limit( - limit_value="1/30minute", # 每30分钟允许1次请求 + limit_value="1/5minute", # 每5分钟允许1次请求 error_message="小主已经通知过车主了,请稍安勿躁~" ) async def notify_move_car( @@ -283,7 +207,8 @@ async def notify_move_car( item_id: int, phone: str = None, ): - """通知车主进行挪车 / Notify car owner to move the car + """ + 通知车主进行挪车 / Notify car owner to move the car Args: item_id (int): 物品ID / Item ID @@ -291,36 +216,18 @@ async def notify_move_car( """ # 检查是否存在该物品 - object_data = await Object.get(session, Object.id == item_id) + object_data = await Item.get(session, Item.id == item_id) if not object_data: - return JSONResponse( - status_code=404, - content=DefaultResponse( - code=404, - msg='物品不存在或出现异常' - ).model_dump() - ) + raise_not_found() # 检查物品类型是否为车辆 if object_data.type != 'car': - return JSONResponse( - status_code=400, - content=DefaultResponse( - code=400, - msg='该物品不是车辆,无法发送挪车通知' - ).model_dump() - ) + raise_bad_request("Item is not car") # 发起挪车通知(目前仅适配Server酱) server_chan_key = await Setting.get(session, Setting.name == 'server_chan_key') if not server_chan_key: - return JSONResponse( - status_code=500, - content=DefaultResponse( - code=500, - msg='未配置Server酱,无法发送挪车通知' - ).model_dump() - ) + raise_internal_error('未配置Server酱,无法发送挪车通知') title = "挪车通知 - Findreve" description = f"您的车辆“{object_data.name}”被请求挪车。\n\n" @@ -342,21 +249,15 @@ async def notify_move_car( return DefaultResponse(msg='挪车通知发送成功') else: error_msg = resp_json.get('message') - logger.error(f"Failed to send notification via Server Chan: error_code={resp_json.get('code')}, error_message={error_msg}, item_id={item_id}, response={resp_json}") - return JSONResponse( - status_code=500, - content=DefaultResponse( - code=500, - msg=f"挪车通知发送失败,Server酱返回错误:{error_msg}" - ).model_dump() + logger.error( + f"Failed to send notification via Server Chan: error_code={resp_json.get('code')}, " + f"error_message={error_msg}, item_id={item_id}, response={resp_json}" ) + raise_service_unavailable('Server酱出现问题,发送失败') else: response_text = await resp.text() - logger.error(f"Failed to send notification via Server Chan: http_status={resp.status}, item_id={item_id}, response_body={response_text}, url={resp.url}") - return JSONResponse( - status_code=500, - content=DefaultResponse( - code=500, - msg=f"挪车通知发送失败,HTTP状态码:{resp.status}" - ).model_dump() - ) \ No newline at end of file + logger.error( + f"Failed to send notification via Server Chan: http_status={resp.status}, item_id={item_id}, " + f"response_body={response_text}, url={resp.url}" + ) + raise_internal_error('挪车通知发送失败') diff --git a/routes/session.py b/routes/session.py index 9e68fbd..9ef8d8d 100644 --- a/routes/session.py +++ b/routes/session.py @@ -9,18 +9,23 @@ from sqlmodel.ext.asyncio.session import AsyncSession from pkg import Password from loguru import logger -from model.token import Token from model import Setting, User, database +from model.response import TokenResponse Router = APIRouter(tags=["令牌 session"]) -# 创建令牌 -async def create_access_token(session: AsyncSession, data: dict, expires_delta: timedelta | None = None): +# 创建访问令牌 +async def create_access_token( + session: AsyncSession, + data: dict, + expires_delta: timedelta | None = None +): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(minutes=await Setting.get(session, 'jwt_token_exp')) + jwt_exp_setting = await Setting.get(session, Setting.name == 'jwt_token_exp') + expire = datetime.now(timezone.utc) + timedelta(int(jwt_exp_setting.value)) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, key=await JWT.get_secret_key(), algorithm='HS256') return encoded_jwt @@ -45,13 +50,13 @@ async def authenticate_user(session: AsyncSession, username: str, password: str) path="/api/token", summary="获取访问令牌", description="使用用户名和密码获取访问令牌", - response_model=Token, + response_model=TokenResponse, response_description="访问令牌" ) async def login_for_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], session: Annotated[AsyncSession, Depends(database.Database.get_session)], -) -> Token: +) -> TokenResponse: user = await authenticate_user( session=session, username=form_data.username, @@ -63,10 +68,11 @@ async def login_for_access_token( detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(hours=1) access_token = await create_access_token( session=session, - data={"sub": form_data.username}, - expires_delta=access_token_expires + data={"sub": user.email}, ) - return Token(access_token=access_token, token_type="bearer") \ No newline at end of file + + return TokenResponse( + access_token=access_token, + ) \ No newline at end of file diff --git a/routes/site.py b/routes/site.py new file mode 100644 index 0000000..88e8a86 --- /dev/null +++ b/routes/site.py @@ -0,0 +1,20 @@ +from fastapi import APIRouter +from model.response import DefaultResponse +from pkg import conf + +Router = APIRouter(prefix='/api/site', tags=['站点 Site']) + +@Router.get( + path='/ping', + summary='站点健康检查', + description='检查站点是否在线', + response_model=DefaultResponse, + response_description='站点在线' +) +async def ping(): + """ + 站点健康检查接口。 + + :return: Findreve 版本号 + """ + return DefaultResponse(data=conf.VERSION) \ No newline at end of file