Compare commits
10 Commits
c1c36c606f
...
67b9aa2bd6
| Author | SHA1 | Date | |
|---|---|---|---|
| 67b9aa2bd6 | |||
| 3f1bd0731b | |||
| 3580717087 | |||
| a0afcbaa90 | |||
| 35efbdf000 | |||
| 8ce34440d8 | |||
| 93830c3d03 | |||
| a71cde7b82 | |||
| cd35c6fbed | |||
| ee684d67cf |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -14,3 +14,4 @@ __pycache__/
|
||||
.VSCodeCounter/
|
||||
|
||||
.env
|
||||
uv.lock
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.13
|
||||
51
AGENTS.md
Normal file
51
AGENTS.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# 仓库指南
|
||||
|
||||
## 项目结构与模块
|
||||
- `main.py` 负责使用 `pkg/*` 配置与日志启动 FastAPI,并加载 `app.py` 中定义的应用。
|
||||
- `routes/` 存放核心路由(`admin.py`、`session.py`、`object.py`、`site.py`),新增模块后请在 `app.py` 注册。
|
||||
- `model/` 汇集 SQLModel 表、数据库工具与响应模型;共享字段请复用 `model/base/` 混入。
|
||||
- `middleware/` 管理认证与限流;公共工具位于 `pkg/`;Vue 构建产物保存在 `dist/`,视觉资产位于 `docs/`。
|
||||
|
||||
## 构建、测试与开发
|
||||
- 创建虚拟环境(`python -m venv .venv`)并激活,随后执行 `pip install -r requirements.txt`。
|
||||
- 通过 `python main.py` 启动后端;流程会生成 `.env`、初始化 SQLite `data.db`,`DEBUG=true` 时开启热重载。
|
||||
- 在前端仓库运行 `yarn install && yarn build`,将生成的 `dist/` 拷贝回项目根目录并刷新服务。
|
||||
- 使用 `pytest` 执行自动化检查;重构期间可通过 `-k` 聚焦相关用例。
|
||||
|
||||
## 入职流程
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
F1[前端开发者入职] --> F2[克隆仓库]
|
||||
F2 --> F3[安装后端依赖<br/>pip install -r requirements.txt]
|
||||
F3 --> F4[构建前端<br/>yarn install && yarn build]
|
||||
F4 --> F5[复制 dist/ 到仓库根目录]
|
||||
F5 --> F6[在浏览器完成冒烟测试]
|
||||
```
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
B1[后端开发者入职] --> B2[克隆仓库]
|
||||
B2 --> B3[创建 .venv 并安装依赖]
|
||||
B3 --> B4[运行 python main.py]
|
||||
B4 --> B5[使用 curl/httpie 访问 /api]
|
||||
B5 --> B6[补充测试并执行 pytest]
|
||||
B6 --> B7[整理发现并提交 PR]
|
||||
```
|
||||
|
||||
## 编码风格与命名
|
||||
- 统一使用 Python 3.13+、四空格缩进,并在公共接口添加类型注解;仅对复杂逻辑补充文档字符串。
|
||||
- 函数使用 `snake_case`,数据模型使用 `PascalCase`,配置与日志归于 `pkg/`(`pkg/logger.py` 封装`loguru`)。
|
||||
- 所有代码、注释、提交信息与评审讨论均使用简体中文。
|
||||
|
||||
## 测试规范
|
||||
- 在 `tests/` 中镜像业务目录(如 `tests/test_session.py`),以 `test_<行为>()` 命名测试函数。
|
||||
- 通过 `pytest` fixture 启动临时 SQLite 数据库,并在 PR 中说明手工验证或覆盖率缺口。
|
||||
|
||||
## 提交与 Pull Request
|
||||
- 提交信息保持简洁的祈使句(例如 `新增通知发送器`),仅在必要时补充作用域。
|
||||
- PR 需关联议题、突出模型或接口变更、列出迁移与测试结果,并附上影响界面的截图或 `curl` 示例。
|
||||
|
||||
## 安全与配置提示
|
||||
- 机密数据仅存放于 `.env`;依赖 `pkg/env.ensure_env_file()` 生成默认值,勿直接修改源代码。
|
||||
- 调整 `middleware/` 或 `routes/` 后须复验认证流程与 SlowAPI 限流,确保防护完整。
|
||||
@@ -61,18 +61,16 @@ chmod +x ./findreve
|
||||
|
||||
启动后, Findreve 会在程序的根目录自动创建 SQLite 数据库,并在
|
||||
终端显示管理员账号密码。请注意,账号密码仅显示一次,请注意保管。
|
||||
账号默认为 `admin@yuxiaoqiu.cn`
|
||||
账号默认为 `admin@yxqi.cn`
|
||||
|
||||
Upon launch, Findreve will create a SQLite database in the project's root directory and
|
||||
display the administrator's account and password in the console.
|
||||
|
||||
## 构建
|
||||
|
||||
> 当前版本的 Findreve Core 无法正常工作,因为我们正在尝试[重构数据库组件以使用ORM](https://github.com/Findreve/Findreve/issues/8)
|
||||
你需要安装Python 3.13 以上的版本。然后,clone 本仓库到您的服务器并解压,然后安装下面的依赖:
|
||||
|
||||
你需要安装Python 3.8 以上的版本。然后,clone 本仓库到您的服务器并解压,然后安装下面的依赖:
|
||||
|
||||
You need to have Python 3.8 or higher installed on your server. Then, clone this repository
|
||||
You need to have Python 3.13 or higher installed on your server. Then, clone this repository
|
||||
to your server and install the required dependencies:
|
||||
|
||||
> `pip install -r requirements.txt`
|
||||
|
||||
56
app.py
56
app.py
@@ -1,19 +1,26 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi import Request
|
||||
from contextlib import asynccontextmanager
|
||||
from routes import (session, admin, object)
|
||||
import model.database
|
||||
import os, asyncio
|
||||
import pkg.conf
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# 初始化数据库
|
||||
asyncio.run(model.database.Database().init_db())
|
||||
from pkg.utils import raise_internal_error
|
||||
from routes import (session, admin, object, ota)
|
||||
from model.database import Database
|
||||
import os
|
||||
import pkg.conf
|
||||
from pkg import utils
|
||||
|
||||
from loguru import logger
|
||||
|
||||
Router = [admin, session, object, ota]
|
||||
|
||||
# Findreve 的生命周期
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await model.database.Database().init_db()
|
||||
await Database().init_db()
|
||||
yield
|
||||
|
||||
# 定义 Findreve 服务器
|
||||
@@ -25,26 +32,43 @@ app = FastAPI(
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def handle_unexpected_exceptions(request: Request, exc: Exception):
|
||||
"""
|
||||
捕获所有未经处理的异常,防止敏感信息泄露。
|
||||
"""
|
||||
# 1. 为开发人员记录详细的、包含完整堆栈跟踪的错误日志
|
||||
logger.exception(
|
||||
f"An unhandled exception occurred for request: {request.method} {request.url.path}"
|
||||
)
|
||||
|
||||
raise_internal_error()
|
||||
|
||||
|
||||
# 挂载后端路由
|
||||
app.include_router(admin.Router)
|
||||
app.include_router(session.Router)
|
||||
app.include_router(object.Router)
|
||||
for router in Router:
|
||||
app.include_router(router.Router)
|
||||
|
||||
# 挂载Slowapi限流中间件
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
async def frontend_index():
|
||||
if not os.path.exists("dist/index.html"):
|
||||
raise HTTPException(status_code=404)
|
||||
utils.raise_not_found("Index not found")
|
||||
return FileResponse("dist/index.html")
|
||||
|
||||
# 回退路由
|
||||
@app.get("/{path:path}")
|
||||
async def serve_spa(request: Request, path: str):
|
||||
async def frontend_path(path: str):
|
||||
if not os.path.exists("dist/index.html"):
|
||||
raise HTTPException(status_code=404)
|
||||
utils.raise_not_found("Index not found, please build frontend first.")
|
||||
|
||||
# 排除API路由
|
||||
if path.startswith("api/"):
|
||||
raise HTTPException(status_code=404)
|
||||
utils.raise_not_found("API route not found")
|
||||
|
||||
# 检查是否是静态资源请求
|
||||
if path.startswith("assets/") and os.path.exists(f"dist/{path}"):
|
||||
|
||||
6
dependencies.py
Normal file
6
dependencies.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from typing import Annotated
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from fastapi import Depends
|
||||
from model import database
|
||||
|
||||
SessionDep = Annotated[AsyncSession, Depends(database.Database.get_session)]
|
||||
39
middleware/admin.py
Normal file
39
middleware/admin.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
|
||||
from model.user import UserTypeEnum
|
||||
from .user import get_current_user
|
||||
from pkg import utils
|
||||
from model import User
|
||||
from middleware.dependencies import SessionDep
|
||||
|
||||
# 验证是否为管理员
|
||||
async def is_admin(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
'''
|
||||
验证是否为管理员。
|
||||
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(is_admin)])
|
||||
'''
|
||||
|
||||
if user.role == UserTypeEnum.normal_user:
|
||||
utils.raise_forbidden("Admin access required")
|
||||
else:
|
||||
return user
|
||||
|
||||
async def is_super_admin(
|
||||
user: Annotated[User, Depends(is_admin)],
|
||||
) -> User:
|
||||
'''
|
||||
验证是否为超级管理员。
|
||||
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(is_super_admin)])
|
||||
'''
|
||||
|
||||
if user.role != UserTypeEnum.super_admin:
|
||||
utils.raise_forbidden("Super admin access required")
|
||||
else:
|
||||
return user
|
||||
62
middleware/dependencies.py
Normal file
62
middleware/dependencies.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Annotated, TypeAlias
|
||||
|
||||
from fastapi import Depends, Request
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from model.database import Database
|
||||
from model.mixin.table import TableViewRequest
|
||||
from model import Item
|
||||
from model.item import ItemTypeEnum
|
||||
from pkg import utils
|
||||
|
||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)]
|
||||
"""数据库会话依赖,用于路由函数中获取数据库会话"""
|
||||
|
||||
# 新增:表格视图请求依赖(用于分页排序)
|
||||
TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends()]
|
||||
"""分页排序请求依赖,用于 LIST 端点"""
|
||||
|
||||
|
||||
async def get_device_from_cert(
|
||||
request: Request,
|
||||
session: SessionDep,
|
||||
) -> Item:
|
||||
"""
|
||||
从 mTLS 客户端证书中提取设备序列号并验证设备。
|
||||
|
||||
客户端证书的 CN (Common Name) 字段应存储设备序列号 (UUID)。
|
||||
反向代理(Nginx/Apache)验证证书后,通过 HTTP Header 将 CN 传递给 FastAPI。
|
||||
|
||||
Nginx 配置示例:
|
||||
proxy_set_header X-Client-CN $ssl_client_s_dn_cn;
|
||||
|
||||
Apache 配置示例:
|
||||
RequestHeader set X-Client-CN "%{SSL_CLIENT_S_DN_CN}s"
|
||||
"""
|
||||
# 从 Header 获取设备序列号(由反向代理注入)
|
||||
serial_number = request.headers.get("X-Client-CN")
|
||||
|
||||
if not serial_number:
|
||||
utils.raise_unauthorized("Device certificate required")
|
||||
|
||||
# 验证 UUID 格式
|
||||
try:
|
||||
from uuid import UUID
|
||||
serial_uuid = UUID(serial_number)
|
||||
except ValueError:
|
||||
utils.raise_unauthorized("Invalid device serial number format")
|
||||
|
||||
# 查找设备
|
||||
device = await Item.get(session, Item.id == serial_uuid)
|
||||
|
||||
if not device:
|
||||
utils.raise_not_found("Device not found")
|
||||
|
||||
if device.type != ItemTypeEnum.esp32:
|
||||
utils.raise_forbidden("Not an ESP device")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
DeviceDep: TypeAlias = Annotated[Item, Depends(get_device_from_cert)]
|
||||
"""设备认证依赖,通过 mTLS 证书验证 ESP 设备"""
|
||||
33
middleware/user.py
Normal file
33
middleware/user.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Annotated
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends
|
||||
from jwt import InvalidTokenError
|
||||
from loguru import logger as l
|
||||
|
||||
import JWT
|
||||
from model import User
|
||||
from pkg import utils
|
||||
from middleware.dependencies import SessionDep
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
||||
session: SessionDep,
|
||||
) -> User:
|
||||
"""
|
||||
验证用户身份并返回当前用户信息。
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, await JWT.get_secret_key(), algorithms=[JWT.ALGORITHM])
|
||||
email = payload.get("sub")
|
||||
stored_account = await User.get(session, User.email == email)
|
||||
if stored_account is None:
|
||||
l.warning("Account not found")
|
||||
utils.raise_unauthorized("Login required")
|
||||
elif stored_account.email != email:
|
||||
l.warning("Email mismatch")
|
||||
utils.raise_unauthorized("Login required")
|
||||
return stored_account
|
||||
except InvalidTokenError:
|
||||
utils.raise_unauthorized("Login required")
|
||||
@@ -1,3 +1,54 @@
|
||||
from . import token
|
||||
from .setting import Setting
|
||||
from .object import Object
|
||||
from .response import DefaultResponse, TokenResponse, TokenData
|
||||
from .setting import Setting, SettingResponse
|
||||
from .item import Item, ItemDataResponse, ItemTypeEnum, ItemStatusEnum
|
||||
from .user import User, UserTypeEnum
|
||||
from .database import Database
|
||||
from .firmware import (
|
||||
Firmware,
|
||||
FirmwareDataResponse,
|
||||
FirmwareDataResponseAdmin,
|
||||
FirmwareUploadRequest,
|
||||
FirmwareCheckUpdateRequest,
|
||||
FirmwareCheckUpdateResponse,
|
||||
ChipTypeEnum,
|
||||
)
|
||||
|
||||
# 新增:从 foxline 项目移植的 Mixin 组件
|
||||
from .mixin.table import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
ListResponse,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
PaginationRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DefaultResponse",
|
||||
"TokenResponse",
|
||||
"TokenData",
|
||||
"Setting",
|
||||
"SettingResponse",
|
||||
"Item",
|
||||
"ItemDataResponse",
|
||||
"ItemTypeEnum",
|
||||
"ItemStatusEnum",
|
||||
"User",
|
||||
"UserTypeEnum",
|
||||
"Database",
|
||||
# 固件相关
|
||||
"Firmware",
|
||||
"FirmwareDataResponse",
|
||||
"FirmwareDataResponseAdmin",
|
||||
"FirmwareUploadRequest",
|
||||
"FirmwareCheckUpdateRequest",
|
||||
"FirmwareCheckUpdateResponse",
|
||||
"ChipTypeEnum",
|
||||
# 新增的 Mixin 组件
|
||||
"TableBaseMixin",
|
||||
"UUIDTableBaseMixin",
|
||||
"ListResponse",
|
||||
"TableViewRequest",
|
||||
"TimeFilterRequest",
|
||||
"PaginationRequest",
|
||||
]
|
||||
|
||||
151
model/base.py
151
model/base.py
@@ -1,151 +0,0 @@
|
||||
# model/base.py
|
||||
from datetime import datetime, timezone
|
||||
from typing import Type, TypeVar, Union, Literal, List
|
||||
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import SQLModel, Field, select, Relationship
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
|
||||
B = TypeVar('B', bound='TableBase')
|
||||
M = TypeVar('M', bound='SQLModel')
|
||||
|
||||
utcnow = lambda: datetime.now(tz=timezone.utc)
|
||||
|
||||
class TableBase(AsyncAttrs, SQLModel):
|
||||
__abstract__ = True
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=utcnow,
|
||||
description="创建时间",
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
description="更新时间",
|
||||
sa_column_kwargs={"default": utcnow, "onupdate": utcnow},
|
||||
default_factory=utcnow
|
||||
)
|
||||
deleted_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="删除时间",
|
||||
sa_column={"nullable": True}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(
|
||||
cls: Type[B],
|
||||
session: AsyncSession,
|
||||
instances: B | List[B],
|
||||
refresh: bool = True
|
||||
) -> B | List[B]:
|
||||
is_list = isinstance(instances, list)
|
||||
if is_list:
|
||||
session.add_all(instances)
|
||||
else:
|
||||
session.add(instances)
|
||||
await session.commit()
|
||||
if refresh:
|
||||
if is_list:
|
||||
for i in instances:
|
||||
await session.refresh(i)
|
||||
else:
|
||||
await session.refresh(instances)
|
||||
return instances
|
||||
|
||||
async def save(
|
||||
self: B,
|
||||
session: AsyncSession,
|
||||
load: Union[Relationship, None] = None, # 设默认值,避免必须传
|
||||
):
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load) # 若该模型没有 id,请别用 load 模式
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
async def update(
|
||||
self: B,
|
||||
session: AsyncSession,
|
||||
other: M,
|
||||
extra_data: dict = None,
|
||||
exclude_unset: bool = True,
|
||||
) -> B:
|
||||
self.sqlmodel_update(
|
||||
other.model_dump(exclude_unset=exclude_unset),
|
||||
update=extra_data
|
||||
)
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: Type[B],
|
||||
session: AsyncSession,
|
||||
instance: B | list[B],
|
||||
) -> None:
|
||||
if isinstance(instance, list):
|
||||
for inst in instance:
|
||||
await session.delete(inst)
|
||||
else:
|
||||
await session.delete(instance)
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: Type[B],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: Type[B] | tuple[Type[B], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: Union[Relationship, None] = None,
|
||||
order_by: list[ClauseElement] | None = None
|
||||
) -> B | List[B] | None:
|
||||
statement = select(cls)
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
if join is not None:
|
||||
statement = statement.join(*join)
|
||||
if options:
|
||||
statement = statement.options(*options)
|
||||
if load:
|
||||
statement = statement.options(selectinload(load))
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await session.exec(statement)
|
||||
if fetch_mode == "one":
|
||||
return result.one()
|
||||
elif fetch_mode == "first":
|
||||
return result.first()
|
||||
elif fetch_mode == "all":
|
||||
return list(result.all())
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: Type[B], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> B:
|
||||
instance = await cls.get(session, cls.id == id, load=load)
|
||||
if not instance:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return instance
|
||||
|
||||
|
||||
# 需要“自增 id 主键”的模型才混入它;Setting 不混入
|
||||
class IdMixin(SQLModel):
|
||||
id: int | None = Field(default=None, primary_key=True, description="主键ID")
|
||||
28
model/base/__init__.py
Normal file
28
model/base/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
from .table_base import TableBase, UUIDTableBase
|
||||
|
||||
# 新的 Mixin 类(从 foxline 项目移植)
|
||||
from ..mixin.table import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
ListResponse,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
PaginationRequest,
|
||||
)
|
||||
|
||||
# 保持向后兼容:TableBase/UUIDTableBase 作为旧名称继续可用
|
||||
# 新代码推荐使用 TableBaseMixin/UUIDTableBaseMixin
|
||||
|
||||
__all__ = [
|
||||
"SQLModelBase",
|
||||
"TableBase",
|
||||
"UUIDTableBase",
|
||||
# 新的 Mixin 类
|
||||
"TableBaseMixin",
|
||||
"UUIDTableBaseMixin",
|
||||
"ListResponse",
|
||||
"TableViewRequest",
|
||||
"TimeFilterRequest",
|
||||
"PaginationRequest",
|
||||
]
|
||||
5
model/base/sqlmodel_base.py
Normal file
5
model/base/sqlmodel_base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import ConfigDict
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
class SQLModelBase(SQLModel):
|
||||
model_config = ConfigDict(use_attribute_docstrings=True)
|
||||
201
model/base/table_base.py
Normal file
201
model/base/table_base.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union, TypeVar, Type, Literal, override, Optional, Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import Field, select, Relationship, SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
|
||||
T = TypeVar("T", bound="TableBase")
|
||||
M = TypeVar("M", bound="SQLModel")
|
||||
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
|
||||
class TableBase(AsyncAttrs):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
created_at: datetime = Field(default_factory=now)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
sa_column_kwargs={"default": now, "onupdate": now},
|
||||
default_factory=now
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: Type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
||||
"""
|
||||
新增一条记录
|
||||
:param session: 数据库会话
|
||||
:param instances:
|
||||
:param refresh:
|
||||
:return: 新增的实例对象
|
||||
|
||||
usage:
|
||||
item1 = Item(...)
|
||||
item2 = Item(...)
|
||||
|
||||
Item.add(session, [item1, item2])
|
||||
|
||||
item1_id = item1.id
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
is_list = True
|
||||
session.add_all(instances)
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
else:
|
||||
await session.refresh(instances)
|
||||
|
||||
return instances
|
||||
|
||||
async def save(self: T, session: AsyncSession, load: Optional[Relationship] = None) -> T:
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
async def update(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
other: M,
|
||||
extra_data: dict = None,
|
||||
exclude_unset: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
更新记录
|
||||
:param session: 数据库会话
|
||||
:param other:
|
||||
:param extra_data:
|
||||
:param exclude_unset:
|
||||
:return:
|
||||
"""
|
||||
self.sqlmodel_update(other.model_dump(exclude_unset=exclude_unset), update=extra_data)
|
||||
|
||||
session.add(self)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(cls: Type[T], session: AsyncSession, instances: T | list[T]) -> None:
|
||||
"""
|
||||
删除一些记录
|
||||
:param session: 数据库会话
|
||||
:param instances:
|
||||
:return: None
|
||||
|
||||
usage:
|
||||
item1 = Item.get(...)
|
||||
item2 = Item.get(...)
|
||||
|
||||
Item.delete(session, [item1, item2])
|
||||
|
||||
"""
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
else:
|
||||
await session.delete(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: Type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: Type[T] | tuple[Type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: Union[Relationship, None] = None,
|
||||
order_by: list[ClauseElement] | None = None
|
||||
) -> T | list[T] | None:
|
||||
"""
|
||||
异步获取模型实例
|
||||
|
||||
参数:
|
||||
session: 异步数据库会话
|
||||
condition: SQLAlchemy查询条件,如Model.id == 1
|
||||
offset: 结果偏移量
|
||||
limit: 结果数量限制
|
||||
options: 查询选项,如selectinload(Model.relation),异步访问关系属性必备,不然会报错
|
||||
fetch_mode: 获取模式 - "one"/"all"/"first"
|
||||
join: 要联接的模型类
|
||||
|
||||
返回:
|
||||
根据fetch_mode返回相应的查询结果
|
||||
"""
|
||||
statement = select(cls)
|
||||
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
if join is not None:
|
||||
statement = statement.join(*join)
|
||||
|
||||
if options:
|
||||
statement = statement.options(*options)
|
||||
|
||||
if load:
|
||||
statement = statement.options(selectinload(load))
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await session.exec(statement)
|
||||
|
||||
if fetch_mode == "one":
|
||||
return result.one()
|
||||
elif fetch_mode == "first":
|
||||
return result.first()
|
||||
elif fetch_mode == "all":
|
||||
return list(result.all())
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: Type[T], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> T:
|
||||
"""此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None,
|
||||
而是会抛出fastapi 404 异常"""
|
||||
instance = await cls.get(session, cls.id == id, load=load)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return instance
|
||||
|
||||
class UUIDTableBase(TableBase):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
"""override"""
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def get_exist_one(cls: Type[T], session: AsyncSession, id: uuid.UUID, load: Union[Relationship, None] = None) -> T:
|
||||
return await super().get_exist_one(session, id, load)
|
||||
@@ -1,9 +1,8 @@
|
||||
# ~/models/database.py
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, ClassVar
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
@@ -30,20 +29,34 @@ engine: AsyncEngine = create_async_engine(
|
||||
# max_overflow=64,
|
||||
)
|
||||
|
||||
_async_session_factory = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
|
||||
|
||||
|
||||
# 数据库类
|
||||
class Database:
|
||||
# Database 初始化方法
|
||||
"""
|
||||
数据库管理类(单例模式)
|
||||
|
||||
从 foxline 项目移植的改进版本,支持:
|
||||
- ClassVar 单例模式
|
||||
- 触发器 SQL 支持
|
||||
- 优雅关闭
|
||||
"""
|
||||
engine: ClassVar[AsyncEngine | None] = None
|
||||
_async_session_factory: ClassVar[sessionmaker | None] = None
|
||||
|
||||
def __init__(
|
||||
self, # self 用于引用类的实例
|
||||
self,
|
||||
db_path: str = "data.db", # db_path 数据库文件路径,默认为 data.db
|
||||
):
|
||||
# 保持向后兼容:实例化时使用全局 engine
|
||||
self.db_path = db_path
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> AsyncEngine:
|
||||
"""获取数据库引擎"""
|
||||
return engine
|
||||
|
||||
@staticmethod
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency to get a database session."""
|
||||
@@ -57,18 +70,46 @@ class Database:
|
||||
提供异步上下文管理器用于直接获取数据库会话
|
||||
|
||||
使用示例:
|
||||
async with Database.session_context() as session:
|
||||
>>> async with Database.session_context() as session:
|
||||
# 执行数据库操作
|
||||
pass
|
||||
"""
|
||||
async with _async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def init_db(self, url: str = ASYNC_DATABASE_URL):
|
||||
"""创建数据库结构"""
|
||||
async def init_db(
|
||||
self,
|
||||
trigger_sqls: list[tuple[str, str, str]] | None = None,
|
||||
):
|
||||
"""
|
||||
创建数据库结构
|
||||
|
||||
Args:
|
||||
trigger_sqls: 触发器SQL语句列表,每个元素为 (function_sql, drop_trigger_sql, create_trigger_sql)
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
# 创建触发器(如果提供)
|
||||
if trigger_sqls:
|
||||
for function_sql, drop_trigger_sql, create_trigger_sql in trigger_sqls:
|
||||
await conn.exec_driver_sql(function_sql)
|
||||
await conn.exec_driver_sql(drop_trigger_sql)
|
||||
await conn.exec_driver_sql(create_trigger_sql)
|
||||
|
||||
# For internal use, create a temporary context manager
|
||||
async with self.session_context() as session:
|
||||
await migration(session) # 执行迁移脚本
|
||||
|
||||
@classmethod
|
||||
async def close(cls):
|
||||
"""
|
||||
优雅地关闭数据库连接引擎。
|
||||
|
||||
仅应在应用结束时调用。
|
||||
|
||||
这会释放引擎维护的所有数据库连接池资源。
|
||||
在应用程序关闭时调用此方法是一个好习惯。
|
||||
"""
|
||||
if engine:
|
||||
await engine.dispose()
|
||||
|
||||
125
model/firmware.py
Normal file
125
model/firmware.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""固件包数据模型,用于 ESP32/8266 OTA 在线升级功能。"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, String, Text
|
||||
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class ChipTypeEnum(StrEnum):
|
||||
"""ESP 芯片类型枚举"""
|
||||
esp32 = 'esp32'
|
||||
esp8266 = 'esp8266'
|
||||
esp32s2 = 'esp32s2'
|
||||
esp32s3 = 'esp32s3'
|
||||
esp32c3 = 'esp32c3'
|
||||
|
||||
|
||||
class FirmwareBase(SQLModelBase):
|
||||
chip_type: ChipTypeEnum = Field(index=True)
|
||||
"""芯片类型"""
|
||||
|
||||
version: str = Field(sa_type=String(64), index=True)
|
||||
"""固件版本号,遵循语义化版本规范"""
|
||||
|
||||
file_path: str
|
||||
"""固件文件存储路径"""
|
||||
|
||||
file_size: int
|
||||
"""固件文件大小(字节)"""
|
||||
|
||||
file_md5: str = Field(max_length=32)
|
||||
"""固件文件 MD5 校验值"""
|
||||
|
||||
description: str | None = Field(default=None, sa_type=Text)
|
||||
"""固件更新说明"""
|
||||
|
||||
is_active: bool = Field(default=True, index=True)
|
||||
"""是否启用该固件版本"""
|
||||
|
||||
|
||||
class Firmware(FirmwareBase, UUIDTableBase, table=True):
|
||||
"""固件包表"""
|
||||
|
||||
uploaded_by_id: UUID = Field(foreign_key='user.id', ondelete='RESTRICT')
|
||||
"""上传者用户ID"""
|
||||
|
||||
downloaded_count: int = Field(default=0)
|
||||
"""下载次数统计"""
|
||||
|
||||
uploaded_at: datetime = Field(default_factory=datetime.now)
|
||||
"""上传时间"""
|
||||
|
||||
uploaded_by: 'User' = Relationship(back_populates='firmwares')
|
||||
|
||||
|
||||
# DTO 定义
|
||||
|
||||
class FirmwareDataResponse(FirmwareBase):
|
||||
"""固件信息响应"""
|
||||
id: UUID
|
||||
"""固件ID"""
|
||||
|
||||
downloaded_count: int
|
||||
"""下载次数"""
|
||||
|
||||
uploaded_at: datetime
|
||||
"""上传时间"""
|
||||
|
||||
download_url: str | None = None
|
||||
"""下载地址"""
|
||||
|
||||
|
||||
class FirmwareDataResponseAdmin(FirmwareDataResponse):
|
||||
"""固件信息响应(管理员)"""
|
||||
uploaded_by_id: UUID
|
||||
"""上传者ID"""
|
||||
|
||||
|
||||
class FirmwareUploadRequest(SQLModelBase):
|
||||
"""固件上传请求"""
|
||||
chip_type: ChipTypeEnum
|
||||
"""芯片类型"""
|
||||
|
||||
version: str
|
||||
"""版本号字符串"""
|
||||
|
||||
description: str | None = None
|
||||
"""更新说明"""
|
||||
|
||||
|
||||
class FirmwareCheckUpdateRequest(SQLModelBase):
|
||||
"""设备检查更新请求"""
|
||||
chip_type: ChipTypeEnum
|
||||
"""芯片类型"""
|
||||
|
||||
current_version: str
|
||||
"""当前版本号"""
|
||||
|
||||
|
||||
class FirmwareCheckUpdateResponse(SQLModelBase):
|
||||
"""检查更新响应"""
|
||||
has_update: bool
|
||||
"""是否有可用更新"""
|
||||
|
||||
latest_version: str | None = None
|
||||
"""最新版本号"""
|
||||
|
||||
download_url: str | None = None
|
||||
"""下载地址"""
|
||||
|
||||
file_size: int | None = None
|
||||
"""文件大小"""
|
||||
|
||||
file_md5: str | None = None
|
||||
"""文件MD5"""
|
||||
|
||||
description: str | None = None
|
||||
"""更新说明"""
|
||||
85
model/item.py
Normal file
85
model/item.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from uuid import UUID
|
||||
from sqlmodel import Field, Relationship, String
|
||||
from pydantic_extra_types.semantic_version import SemanticVersion
|
||||
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
from .firmware import ChipTypeEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class ItemTypeEnum(StrEnum):
|
||||
normal = 'normal'
|
||||
car = 'car'
|
||||
esp32 = 'esp32'
|
||||
|
||||
class ItemStatusEnum(StrEnum):
|
||||
ok = 'ok'
|
||||
lost = 'lost'
|
||||
|
||||
class ItemBase(SQLModelBase):
|
||||
type: ItemTypeEnum = ItemTypeEnum.normal
|
||||
"""物品的类型"""
|
||||
|
||||
name: str
|
||||
"""物品名称"""
|
||||
|
||||
icon: str | None = None
|
||||
"""物品图标"""
|
||||
|
||||
status: ItemStatusEnum = ItemStatusEnum.ok
|
||||
"""物品状态"""
|
||||
|
||||
phone: str | None = None
|
||||
"""联系电话"""
|
||||
|
||||
description: str | None = None
|
||||
"""物品描述"""
|
||||
|
||||
version: SemanticVersion = Field(sa_type=String(64))
|
||||
"""版本号"""
|
||||
|
||||
chip_type: ChipTypeEnum | None = Field(default=None, index=True)
|
||||
"""ESP设备芯片类型,仅当type=esp32时有值"""
|
||||
|
||||
class Item(ItemBase, UUIDTableBase, table=True):
|
||||
expires_at: datetime | None = None
|
||||
"""物品过期时间"""
|
||||
|
||||
lost_at: datetime | None = None
|
||||
"""物品丢失的时间"""
|
||||
|
||||
find_ip: str | None = None
|
||||
"""最后一次发现的IP地址"""
|
||||
|
||||
user_id: UUID = Field(foreign_key='user.id', ondelete='CASCADE')
|
||||
"""所属用户ID"""
|
||||
|
||||
user: 'User' = Relationship(back_populates='items')
|
||||
|
||||
parent_item_id: UUID | None = Field(foreign_key='item.id', ondelete='RESTRICT', default=None)
|
||||
parent_item: Optional['Item'] = Relationship(back_populates='sub_items', sa_relationship_kwargs={'remote_side': 'Item.id'})
|
||||
sub_items: list['Item'] = Relationship(back_populates='parent_item', passive_deletes='all')
|
||||
|
||||
class ItemDataUpdateRequest(ItemBase):
|
||||
pass
|
||||
|
||||
class ItemDataResponse(ItemBase):
|
||||
expires_at: datetime | None = None
|
||||
"""物品过期时间"""
|
||||
|
||||
lost_at: datetime | None = None
|
||||
"""物品丢失的时间"""
|
||||
|
||||
class ItemDataResponseAdmin(ItemBase):
|
||||
expires_at: datetime | None = None
|
||||
"""物品过期时间"""
|
||||
|
||||
lost_at: datetime | None = None
|
||||
"""物品丢失的时间"""
|
||||
|
||||
user_id: UUID = Field(foreign_key='user.id')
|
||||
"""所属用户ID"""
|
||||
@@ -1,14 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Item(BaseModel):
|
||||
id: int
|
||||
type: str
|
||||
key: str
|
||||
name: str
|
||||
icon: str
|
||||
status: str
|
||||
phone: int
|
||||
lost_description: str | None
|
||||
find_ip: str | None
|
||||
create_time: str
|
||||
lost_time: str | None
|
||||
@@ -1,12 +1,14 @@
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
from .setting import Setting
|
||||
from .user import User, UserTypeEnum
|
||||
from pkg import Password
|
||||
|
||||
default_settings: list[Setting] = [
|
||||
Setting(type='string', name='version', value='1.0.0'),
|
||||
Setting(type='int', name='ver', value='1'),
|
||||
Setting(type='string', name='account', value='admin@yuxiaoqiu.cn'),
|
||||
Setting(type='string', name='version', value='2.0.0'), # 版本号,用于考虑是否需要数据迁移
|
||||
Setting(type='int', name='jwt_token_exp', value='30'), # JWT Token 访问令牌
|
||||
Setting(type='int', name='mentioned_channel', value='wechat_bot'), # 通知推送通道
|
||||
Setting(type='string', name='server_chan_key', value=''), # Server 酱推送密钥
|
||||
Setting(type='string', name='wechat_bot_key', value=''), # 企业微信机器人推送密钥
|
||||
]
|
||||
|
||||
async def migration(session):
|
||||
@@ -17,20 +19,38 @@ async def migration(session):
|
||||
# 已有数据,说明不是第一次运行,直接返回
|
||||
return
|
||||
|
||||
# 生成初始密码与密钥
|
||||
admin_password = Password.generate()
|
||||
logger.warning(f"密码(请牢记,后续不再显示): {admin_password}")
|
||||
|
||||
settings.append(Setting(type='string', name='password', value=Password.hash(admin_password)))
|
||||
settings.append(Setting(type='string', name='SECRET_KEY', value=Password.generate(64)))
|
||||
|
||||
# 读取库里已存在的 name,避免主键冲突
|
||||
names = [s.name for s in settings]
|
||||
exist_stmt = select(Setting.name).where(Setting.name.in_(names))
|
||||
exist_rs = await session.exec(exist_stmt)
|
||||
existed: set[str] = set(exist_rs.all())
|
||||
existed_settings = await Setting.get(
|
||||
session,
|
||||
Setting.name in names,
|
||||
fetch_mode='all'
|
||||
)
|
||||
existed: set[str] = {s.name for s in (existed_settings or [])}
|
||||
|
||||
to_insert = [s for s in settings if s.name not in existed]
|
||||
if to_insert:
|
||||
# 使用你写好的通用新增方法(是类方法),并传入会话
|
||||
await Setting.add(session, to_insert, refresh=False)
|
||||
await Setting.add(session, to_insert)
|
||||
|
||||
if not await User.get(session, User.role == UserTypeEnum.super_admin):
|
||||
# 生成初始密码与密钥
|
||||
admin_password = Password.generate()
|
||||
logger.warning("当前无管理员用户,已自动创建初始管理员用户:")
|
||||
logger.warning("邮箱: admin@yxqi.cn")
|
||||
logger.warning(f"密码: {admin_password}")
|
||||
|
||||
User._initializing = True
|
||||
|
||||
admin_user = User(
|
||||
email='admin@yxqi.cn',
|
||||
nickname='Admin',
|
||||
password=Password.hash(admin_password),
|
||||
role=UserTypeEnum.super_admin,
|
||||
_initializing=True
|
||||
)
|
||||
|
||||
await User.add(session, admin_user)
|
||||
|
||||
User._initializing = False
|
||||
|
||||
17
model/mixin/__init__.py
Normal file
17
model/mixin/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .table import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
ListResponse,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
PaginationRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TableBaseMixin",
|
||||
"UUIDTableBaseMixin",
|
||||
"ListResponse",
|
||||
"TableViewRequest",
|
||||
"TimeFilterRequest",
|
||||
"PaginationRequest",
|
||||
]
|
||||
852
model/mixin/table.py
Normal file
852
model/mixin/table.py
Normal file
@@ -0,0 +1,852 @@
|
||||
"""
|
||||
表基类 Mixin
|
||||
|
||||
提供 TableBaseMixin、UUIDTableBaseMixin 和 TableViewRequest。
|
||||
这些类实际上是 Mixin,为 SQLModel 模型提供 CRUD 操作和时间戳字段。
|
||||
|
||||
版本历史:
|
||||
0.1.0 - delete() 方法支持条件删除(condition 参数)
|
||||
"""
|
||||
__version__ = "0.1.0"
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Literal, override, Any, ClassVar, Generic
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, delete as sql_delete
|
||||
from sqlalchemy.orm import selectinload, Relationship
|
||||
from sqlmodel import Field, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from ..base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# Type variables for generic type hints, improving code completion and analysis.
|
||||
T = TypeVar("T", bound="TableBaseMixin")
|
||||
M = TypeVar("M", bound="SQLModelBase")
|
||||
ItemT = TypeVar("ItemT")
|
||||
|
||||
|
||||
class ListResponse(BaseModel, Generic[ItemT]):
|
||||
"""
|
||||
泛型分页响应
|
||||
|
||||
用于所有LIST端点的标准化响应格式,包含记录总数和项目列表。
|
||||
与 TableBaseMixin.get_with_count() 配合使用。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
@router.get("", response_model=ListResponse[CharacterInfoResponse])
|
||||
async def list_characters(...) -> ListResponse[Character]:
|
||||
return await Character.get_with_count(session, table_view=table_view)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
count: 符合条件的记录总数(用于分页计算)
|
||||
items: 当前页的记录列表
|
||||
|
||||
Note:
|
||||
继承BaseModel而非SQLModelBase,因为SQLModel的metaclass与Generic冲突。
|
||||
"""
|
||||
model_config = ConfigDict(use_attribute_docstrings=True)
|
||||
|
||||
count: int
|
||||
"""符合条件的记录总数"""
|
||||
|
||||
items: list[ItemT]
|
||||
"""当前页的记录列表"""
|
||||
|
||||
|
||||
# Lambda functions to get the current time, used as default factories in model fields.
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
|
||||
|
||||
# ==================== 查询参数请求类 ====================
|
||||
|
||||
class TimeFilterRequest(SQLModelBase):
|
||||
"""
|
||||
时间筛选请求参数
|
||||
|
||||
用于 count() 等只需要时间筛选的场景。
|
||||
纯数据类,只负责参数校验和携带,SQL子句构建由 TableBaseMixin 负责。
|
||||
|
||||
Raises:
|
||||
ValueError: 时间范围无效
|
||||
"""
|
||||
created_after_datetime: datetime | None = None
|
||||
"""创建时间起始筛选(created_at >= datetime),如果为None则不限制"""
|
||||
|
||||
created_before_datetime: datetime | None = None
|
||||
"""创建时间结束筛选(created_at < datetime),如果为None则不限制"""
|
||||
|
||||
updated_after_datetime: datetime | None = None
|
||||
"""更新时间起始筛选(updated_at >= datetime),如果为None则不限制"""
|
||||
|
||||
updated_before_datetime: datetime | None = None
|
||||
"""更新时间结束筛选(updated_at < datetime),如果为None则不限制"""
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
验证时间范围有效性
|
||||
|
||||
验证规则:
|
||||
1. 同类型:after 必须小于 before
|
||||
2. 跨类型:created_after 不能大于 updated_before(记录不可能在创建前被更新)
|
||||
"""
|
||||
# 同类型矛盾验证
|
||||
if self.created_after_datetime and self.created_before_datetime:
|
||||
if self.created_after_datetime >= self.created_before_datetime:
|
||||
raise ValueError("created_after_datetime 必须小于 created_before_datetime")
|
||||
if self.updated_after_datetime and self.updated_before_datetime:
|
||||
if self.updated_after_datetime >= self.updated_before_datetime:
|
||||
raise ValueError("updated_after_datetime 必须小于 updated_before_datetime")
|
||||
|
||||
# 跨类型矛盾验证:created_after >= updated_before 意味着要求创建时间晚于或等于更新时间上界,逻辑矛盾
|
||||
if self.created_after_datetime and self.updated_before_datetime:
|
||||
if self.created_after_datetime >= self.updated_before_datetime:
|
||||
raise ValueError(
|
||||
"created_after_datetime 不能大于或等于 updated_before_datetime"
|
||||
"(记录的更新时间不可能早于或等于创建时间)"
|
||||
)
|
||||
|
||||
|
||||
class PaginationRequest(SQLModelBase):
|
||||
"""
|
||||
分页排序请求参数
|
||||
|
||||
用于需要分页和排序的场景。
|
||||
纯数据类,只负责携带参数,SQL子句构建由 TableBaseMixin 负责。
|
||||
"""
|
||||
offset: int | None = Field(default=0, ge=0)
|
||||
"""偏移量(跳过前N条记录),必须为非负整数"""
|
||||
|
||||
limit: int | None = Field(default=50, le=100)
|
||||
"""每页数量(返回最多N条记录),默认50,最大100"""
|
||||
|
||||
desc: bool | None = True
|
||||
"""是否降序排序(True: 降序, False: 升序)"""
|
||||
|
||||
order: Literal["created_at", "updated_at"] | None = "created_at"
|
||||
"""排序字段(created_at: 创建时间, updated_at: 更新时间)"""
|
||||
|
||||
|
||||
class TableViewRequest(TimeFilterRequest, PaginationRequest):
|
||||
"""
|
||||
表格视图请求参数(分页、排序和时间筛选)
|
||||
|
||||
组合继承 TimeFilterRequest 和 PaginationRequest,用于 get() 等需要完整查询参数的场景。
|
||||
纯数据类,SQL子句构建由 TableBaseMixin 负责。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
# 在端点中使用依赖注入
|
||||
@router.get("/list")
|
||||
async def list_items(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
):
|
||||
items = await Item.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view
|
||||
)
|
||||
return items
|
||||
|
||||
# 直接使用
|
||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
||||
items = await Item.get(session, fetch_mode="all", table_view=table_view)
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ==================== TableBaseMixin ====================
|
||||
|
||||
class TableBaseMixin(AsyncAttrs):
|
||||
"""
|
||||
一个异步 CRUD 操作的基础模型类 Mixin.
|
||||
|
||||
此类必须搭配SQLModelBase使用
|
||||
|
||||
此类为所有继承它的 SQLModel 模型提供了通用的数据库操作方法,
|
||||
例如 add, save, update, delete, 和 get. 它还包括自动管理
|
||||
的 `created_at` 和 `updated_at` 时间戳字段.
|
||||
|
||||
Attributes:
|
||||
id (int | None): 整数主键, 自动递增.
|
||||
created_at (datetime): 记录创建时的时间戳, 自动设置.
|
||||
updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
|
||||
"""
|
||||
_is_table_mixin: ClassVar[bool] = True
|
||||
"""标记此类为表混入类的内部属性"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
接受并传递子类定义时的关键字参数
|
||||
|
||||
这允许元类 __DeclarativeMeta 处理的参数(如 table_args)
|
||||
能够正确传递,而不会在 __init_subclass__ 阶段报错。
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
created_at: datetime = Field(default_factory=now)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
sa_column_kwargs={'default': now, 'onupdate': now},
|
||||
default_factory=now
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
||||
"""
|
||||
向数据库中添加一个新的或多个新的记录.
|
||||
|
||||
这个类方法可以接受单个模型实例或一个实例列表,并将它们
|
||||
一次性提交到数据库中。执行后,可以选择性地刷新这些实例以获取
|
||||
数据库生成的值(例如,自动递增的 ID).
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
|
||||
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
|
||||
|
||||
Returns:
|
||||
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
|
||||
|
||||
Usage:
|
||||
item1 = Item(name="Apple")
|
||||
item2 = Item(name="Banana")
|
||||
|
||||
# 添加多个实例
|
||||
added_items = await Item.add(session, [item1, item2])
|
||||
|
||||
# 添加单个实例
|
||||
item3 = Item(name="Cherry")
|
||||
added_item = await Item.add(session, item3)
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
is_list = True
|
||||
session.add_all(instances)
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
else:
|
||||
await session.refresh(instances)
|
||||
|
||||
return instances
|
||||
|
||||
async def save(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
保存(插入或更新)当前模型实例到数据库.
|
||||
|
||||
这是一个实例方法,它将当前对象添加到会话中并提交更改。
|
||||
可以用于创建新记录或更新现有记录。还可以选择在保存后
|
||||
预加载(eager load)一个关联关系.
|
||||
|
||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||
如果需要继续使用该对象,必须使用返回值:
|
||||
|
||||
```python
|
||||
# 正确:需要返回值时
|
||||
client = await client.save(session)
|
||||
return client
|
||||
|
||||
# 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.save(session, refresh=False)
|
||||
|
||||
# 正确:批量操作时延迟提交
|
||||
for item in items:
|
||||
item = await item.save(session, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# 错误:需要返回值但未使用
|
||||
await client.save(session)
|
||||
return client # client 对象已过期
|
||||
```
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
例如 `User.posts`.
|
||||
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否在保存后提交事务。如果为 False,只会 flush 获取 ID
|
||||
但不提交,适用于批量操作场景。默认为 True.
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
"""
|
||||
session.add(self)
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
# 如果指定了 load, 重新获取实例并加载关联关系
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
async def update(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
other: M,
|
||||
extra_data: dict[str, Any] | None = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude: set[str] | None = None,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
使用另一个模型实例或字典中的数据来更新当前实例.
|
||||
|
||||
此方法将 `other` 对象中的数据合并到当前实例中。默认情况下,
|
||||
它只会更新 `other` 中被显式设置的字段.
|
||||
|
||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||
如果需要继续使用该对象,必须使用返回值:
|
||||
|
||||
```python
|
||||
# 正确:需要返回值时
|
||||
client = await client.update(session, update_data)
|
||||
return client
|
||||
|
||||
# 正确:需要返回值且需要加载关系时
|
||||
user = await user.update(session, update_data, load=User.permission)
|
||||
return user
|
||||
|
||||
# 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.update(session, update_data, refresh=False)
|
||||
|
||||
# 正确:批量操作时延迟提交
|
||||
for item in items:
|
||||
item = await item.update(session, data, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# 错误:需要返回值但未使用
|
||||
await client.update(session, update_data)
|
||||
return client # client 对象已过期
|
||||
```
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
other (M): 一个 SQLModel 或 Pydantic 模型实例,其数据将用于更新当前实例.
|
||||
extra_data (dict, optional): 一个额外的字典,用于更新当前实例的特定字段.
|
||||
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
|
||||
的字段将被忽略. 默认为 True.
|
||||
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
|
||||
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
例如 `User.permission`.
|
||||
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否在更新后提交事务。如果为 False,只会 flush
|
||||
但不提交,适用于批量操作场景。默认为 True.
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
"""
|
||||
self.sqlmodel_update(
|
||||
other.model_dump(exclude_unset=exclude_unset, exclude=exclude),
|
||||
update=extra_data
|
||||
)
|
||||
|
||||
session.add(self)
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
instances: T | list[T] | None = None,
|
||||
*,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
从数据库中删除记录,支持实例删除和条件删除两种模式。
|
||||
|
||||
Args:
|
||||
session: 用于数据库操作的异步会话对象
|
||||
instances: 要删除的单个模型实例或模型实例列表(实例删除模式)
|
||||
condition: WHERE 条件表达式(条件删除模式,直接执行 SQL DELETE)
|
||||
commit: 是否在删除后提交事务。默认为 True
|
||||
|
||||
Returns:
|
||||
删除的记录数(条件删除模式返回实际删除数,实例删除模式返回实例数)
|
||||
|
||||
Raises:
|
||||
ValueError: 同时提供 instances 和 condition,或两者都未提供
|
||||
|
||||
Usage:
|
||||
# 实例删除模式
|
||||
item = await Item.get(session, Item.id == 1)
|
||||
if item:
|
||||
await Item.delete(session, item)
|
||||
|
||||
items = await Item.get(session, Item.name.in_(["A", "B"]), fetch_mode="all")
|
||||
if items:
|
||||
await Item.delete(session, items)
|
||||
|
||||
# 条件删除模式(高效批量删除,不加载实例到内存)
|
||||
deleted_count = await Item.delete(
|
||||
session,
|
||||
condition=(Item.user_id == user_id) & (Item.status == "expired"),
|
||||
)
|
||||
"""
|
||||
if instances is not None and condition is not None:
|
||||
raise ValueError("不能同时提供 instances 和 condition 参数")
|
||||
if instances is None and condition is None:
|
||||
raise ValueError("必须提供 instances 或 condition 参数之一")
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
if condition is not None:
|
||||
# 条件删除模式:直接执行 SQL DELETE
|
||||
stmt = sql_delete(cls).where(condition)
|
||||
result = await session.execute(stmt)
|
||||
deleted_count = result.rowcount
|
||||
else:
|
||||
# 实例删除模式
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
deleted_count = len(instances)
|
||||
else:
|
||||
await session.delete(instances)
|
||||
deleted_count = 1
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
return deleted_count
|
||||
|
||||
@classmethod
|
||||
def _build_time_filters(
|
||||
cls: type[T],
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
updated_after_datetime: datetime | None = None,
|
||||
) -> list[BinaryExpression]:
|
||||
"""
|
||||
构建时间筛选条件列表
|
||||
|
||||
Args:
|
||||
created_before_datetime: 筛选 created_at < datetime 的记录
|
||||
created_after_datetime: 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime: 筛选 updated_at < datetime 的记录
|
||||
updated_after_datetime: 筛选 updated_at >= datetime 的记录
|
||||
|
||||
Returns:
|
||||
BinaryExpression 条件列表
|
||||
"""
|
||||
filters: list[BinaryExpression] = []
|
||||
if created_after_datetime is not None:
|
||||
filters.append(cls.created_at >= created_after_datetime)
|
||||
if created_before_datetime is not None:
|
||||
filters.append(cls.created_at < created_before_datetime)
|
||||
if updated_after_datetime is not None:
|
||||
filters.append(cls.updated_at >= updated_after_datetime)
|
||||
if updated_before_datetime is not None:
|
||||
filters.append(cls.updated_at < updated_before_datetime)
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
with_for_update: bool = False,
|
||||
table_view: TableViewRequest | None = None,
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
updated_after_datetime: datetime | None = None,
|
||||
) -> T | list[T] | None:
|
||||
"""
|
||||
根据指定的条件异步地从数据库中获取一个或多个模型实例.
|
||||
|
||||
这是一个功能强大的通用查询方法,支持过滤、排序、分页、连接查询和关联关系预加载.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
condition (BinaryExpression | ClauseElement | None): 主要的查询过滤条件,
|
||||
例如 `User.id == 1`。
|
||||
当为 `None` 时,表示无条件查询(查询所有记录)。
|
||||
offset (int | None): 查询结果的起始偏移量, 用于分页.
|
||||
limit (int | None): 返回记录的最大数量, 用于分页.
|
||||
fetch_mode (Literal["one", "first", "all"]):
|
||||
- "one": 获取唯一的一条记录. 如果找不到或找到多条,会引发异常.
|
||||
- "first": 获取查询结果的第一条记录. 如果找不到,返回 `None`.
|
||||
- "all": 获取所有匹配的记录,返回一个列表.
|
||||
默认为 "first".
|
||||
join (type[T] | tuple[type[T], _OnClauseArgument] | None):
|
||||
要 JOIN 的模型类或一个包含模型类和 ON 子句的元组.
|
||||
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
|
||||
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
|
||||
例如 `[selectinload(User.posts)]`.
|
||||
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
|
||||
可以是单个关系或关系列表。例如 `User.profile` 或 `[User.profile, User.posts]`.
|
||||
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
|
||||
例如 `[User.name.asc(), User.created_at.desc()]`.
|
||||
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
|
||||
|
||||
with_for_update (bool): 如果为 True, 在查询中使用 `FOR UPDATE` 锁定选定的行. 默认为 False.
|
||||
|
||||
table_view (TableViewRequest | None): TableViewRequest对象,如果提供则自动处理分页、排序和时间筛选。
|
||||
会覆盖offset、limit、order_by及时间筛选参数。
|
||||
这是推荐的分页排序方式,统一了所有LIST端点的参数格式。
|
||||
|
||||
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
|
||||
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
|
||||
updated_after_datetime (datetime | None): 筛选 updated_at >= datetime 的记录
|
||||
|
||||
Returns:
|
||||
T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: 如果提供了无效的 `fetch_mode` 值.
|
||||
|
||||
Examples:
|
||||
# 使用table_view参数(推荐)
|
||||
users = await User.get(session, fetch_mode="all", table_view=table_view_args)
|
||||
|
||||
# 传统方式(向后兼容)
|
||||
users = await User.get(session, fetch_mode="all", offset=0, limit=20, order_by=[desc(User.created_at)])
|
||||
"""
|
||||
# 如果提供table_view,作为默认值使用(单独传入的参数优先级更高)
|
||||
if table_view:
|
||||
# 处理时间筛选(TimeFilterRequest 及其子类)
|
||||
if isinstance(table_view, TimeFilterRequest):
|
||||
if created_after_datetime is None and table_view.created_after_datetime is not None:
|
||||
created_after_datetime = table_view.created_after_datetime
|
||||
if created_before_datetime is None and table_view.created_before_datetime is not None:
|
||||
created_before_datetime = table_view.created_before_datetime
|
||||
if updated_after_datetime is None and table_view.updated_after_datetime is not None:
|
||||
updated_after_datetime = table_view.updated_after_datetime
|
||||
if updated_before_datetime is None and table_view.updated_before_datetime is not None:
|
||||
updated_before_datetime = table_view.updated_before_datetime
|
||||
# 处理分页排序(PaginationRequest 及其子类,包括 TableViewRequest)
|
||||
if isinstance(table_view, PaginationRequest):
|
||||
if offset is None:
|
||||
offset = table_view.offset
|
||||
if limit is None:
|
||||
limit = table_view.limit
|
||||
# 仅在未显式传入order_by时,从table_view构建排序子句
|
||||
if order_by is None:
|
||||
order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at
|
||||
order_by = [desc(order_column) if table_view.desc else asc(order_column)]
|
||||
|
||||
statement = select(cls)
|
||||
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
# 应用时间筛选
|
||||
for time_filter in cls._build_time_filters(
|
||||
created_before_datetime, created_after_datetime,
|
||||
updated_before_datetime, updated_after_datetime
|
||||
):
|
||||
statement = statement.where(time_filter)
|
||||
|
||||
if join is not None:
|
||||
# 如果 join 是一个元组,解包它;否则直接使用
|
||||
if isinstance(join, tuple):
|
||||
statement = statement.join(*join)
|
||||
else:
|
||||
statement = statement.join(join)
|
||||
|
||||
if options:
|
||||
statement = statement.options(*options)
|
||||
|
||||
if load:
|
||||
# 标准化为列表
|
||||
load_list = load if isinstance(load, list) else [load]
|
||||
# 为每个关系添加 selectinload
|
||||
for rel in load_list:
|
||||
statement = statement.options(selectinload(rel))
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
if filter:
|
||||
statement = statement.filter(filter)
|
||||
|
||||
if with_for_update:
|
||||
statement = statement.with_for_update()
|
||||
|
||||
result = await session.exec(statement)
|
||||
|
||||
if fetch_mode == "one":
|
||||
return result.one()
|
||||
elif fetch_mode == "first":
|
||||
return result.first()
|
||||
elif fetch_mode == "all":
|
||||
return list(result.all())
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@classmethod
|
||||
async def count(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
*,
|
||||
time_filter: TimeFilterRequest | None = None,
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
updated_after_datetime: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
根据条件统计记录数量(支持时间筛选)
|
||||
|
||||
使用数据库层面的 COUNT() 聚合函数,比 get() + len() 更高效。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
condition: 查询条件,例如 `User.is_active == True`
|
||||
time_filter: TimeFilterRequest 对象(优先级更高)
|
||||
created_before_datetime: 筛选 created_at < datetime 的记录
|
||||
created_after_datetime: 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime: 筛选 updated_at < datetime 的记录
|
||||
updated_after_datetime: 筛选 updated_at >= datetime 的记录
|
||||
|
||||
Returns:
|
||||
符合条件的记录数量
|
||||
|
||||
Examples:
|
||||
# 统计所有用户
|
||||
total = await User.count(session)
|
||||
|
||||
# 统计激活的用户
|
||||
count = await User.count(
|
||||
session,
|
||||
User.is_active == True
|
||||
)
|
||||
|
||||
# 使用 TimeFilterRequest 进行时间筛选
|
||||
count = await User.count(session, time_filter=time_filter_request)
|
||||
|
||||
# 使用独立时间参数
|
||||
count = await User.count(
|
||||
session,
|
||||
created_after_datetime=datetime(2025, 1, 1),
|
||||
created_before_datetime=datetime(2025, 2, 1),
|
||||
)
|
||||
"""
|
||||
# time_filter 的时间筛选优先级更高
|
||||
if isinstance(time_filter, TimeFilterRequest):
|
||||
if time_filter.created_after_datetime is not None:
|
||||
created_after_datetime = time_filter.created_after_datetime
|
||||
if time_filter.created_before_datetime is not None:
|
||||
created_before_datetime = time_filter.created_before_datetime
|
||||
if time_filter.updated_after_datetime is not None:
|
||||
updated_after_datetime = time_filter.updated_after_datetime
|
||||
if time_filter.updated_before_datetime is not None:
|
||||
updated_before_datetime = time_filter.updated_before_datetime
|
||||
|
||||
statement = select(func.count()).select_from(cls)
|
||||
|
||||
# 应用查询条件
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
# 应用时间筛选
|
||||
for time_condition in cls._build_time_filters(
|
||||
created_before_datetime, created_after_datetime,
|
||||
updated_before_datetime, updated_after_datetime
|
||||
):
|
||||
statement = statement.where(time_condition)
|
||||
|
||||
result = await session.scalar(statement)
|
||||
return result or 0
|
||||
|
||||
@classmethod
|
||||
async def get_with_count(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
*,
|
||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
table_view: TableViewRequest | None = None,
|
||||
) -> 'ListResponse[T]':
|
||||
"""
|
||||
获取分页列表及总数,直接返回 ListResponse
|
||||
|
||||
同时返回符合条件的记录列表和总数,用于分页场景。
|
||||
与 get() 方法类似,但固定 fetch_mode="all" 并返回 ListResponse。
|
||||
|
||||
注意:如果子类的 get() 方法支持额外参数(如 filter_params),
|
||||
子类应该覆盖此方法以确保 count 和 items 使用相同的过滤条件。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
condition: 查询条件
|
||||
join: JOIN 的模型类或元组
|
||||
options: SQLAlchemy 查询选项
|
||||
load: selectinload 预加载关系
|
||||
order_by: 排序子句
|
||||
filter: 附加过滤条件
|
||||
table_view: 分页排序参数(推荐使用)
|
||||
|
||||
Returns:
|
||||
ListResponse[T]: 包含 count 和 items 的分页响应
|
||||
|
||||
Examples:
|
||||
```python
|
||||
@router.get("", response_model=ListResponse[CharacterInfoResponse])
|
||||
async def list_characters(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
) -> ListResponse[Character]:
|
||||
return await Character.get_with_count(session, table_view=table_view)
|
||||
```
|
||||
"""
|
||||
# 提取时间筛选参数(用于 count)
|
||||
time_filter: TimeFilterRequest | None = None
|
||||
if table_view is not None:
|
||||
time_filter = TimeFilterRequest(
|
||||
created_after_datetime=table_view.created_after_datetime,
|
||||
created_before_datetime=table_view.created_before_datetime,
|
||||
updated_after_datetime=table_view.updated_after_datetime,
|
||||
updated_before_datetime=table_view.updated_before_datetime,
|
||||
)
|
||||
|
||||
# 获取总数(不带分页限制)
|
||||
total_count = await cls.count(session, condition, time_filter=time_filter)
|
||||
|
||||
# 获取分页数据
|
||||
items = await cls.get(
|
||||
session,
|
||||
condition,
|
||||
fetch_mode="all",
|
||||
join=join,
|
||||
options=options,
|
||||
load=load,
|
||||
order_by=order_by,
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
)
|
||||
|
||||
return ListResponse(count=total_count, items=items)
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
|
||||
"""
|
||||
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
这个方法是对 `get` 方法的封装,专门用于处理那种"记录必须存在"的业务场景。
|
||||
如果记录未找到,它会直接引发 FastAPI 的 `HTTPException`, 而不是返回 `None`.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (int): 要查找的记录的主键 ID.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果 ID 对应的记录不存在,则抛出状态码为 404 的异常.
|
||||
"""
|
||||
instance = await cls.get(session, cls.id == id, load=load)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return instance
|
||||
|
||||
class UUIDTableBaseMixin(TableBaseMixin):
|
||||
"""
|
||||
一个使用 UUID 作为主键的异步 CRUD 操作基础模型类 Mixin.
|
||||
|
||||
此类继承自 `TableBaseMixin`, 将主键 `id` 的类型覆盖为 `uuid.UUID`,
|
||||
并为新记录自动生成 UUID. 它继承了 `TableBaseMixin` 的所有 CRUD 方法.
|
||||
|
||||
Attributes:
|
||||
id (uuid.UUID): UUID 类型的主键, 在创建时自动生成.
|
||||
"""
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
"""覆盖 `TableBaseMixin` 的 id 字段,使用 UUID 作为主键."""
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
|
||||
"""
|
||||
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
此方法覆盖了父类的同名方法,以确保 `id` 参数的类型注解为 `uuid.UUID`,
|
||||
从而提供更好的类型安全和代码提示.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (uuid.UUID): 要查找的记录的 UUID 主键.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果 UUID 对应的记录不存在,则抛出状态码为 404 的异常.
|
||||
"""
|
||||
return await super().get_exist_one(session, id, load) # type: ignore
|
||||
@@ -1,28 +0,0 @@
|
||||
from typing import Literal
|
||||
from sqlmodel import Field, Column, String, DateTime
|
||||
from .base import TableBase, IdMixin
|
||||
from datetime import datetime
|
||||
|
||||
class Object(IdMixin, TableBase, table=True):
|
||||
|
||||
key: str = Field(index=True, nullable=False, description="物品外部ID")
|
||||
type: Literal['normal', 'car'] = Field(
|
||||
default='normal',
|
||||
description="物品类型",
|
||||
sa_column=Column(String, default='normal', nullable=False)
|
||||
)
|
||||
name: str = Field(nullable=False, description="物品名称")
|
||||
icon: str | None = Field(default=None, description="物品图标")
|
||||
status: Literal['ok', 'lost'] = Field(
|
||||
default='ok',
|
||||
description="物品状态",
|
||||
sa_column=Column(String, default='ok', nullable=False)
|
||||
)
|
||||
phone: str | None = Field(default=None, description="联系电话")
|
||||
context: str | None = Field(default=None, description="物品描述")
|
||||
find_ip: str | None = Field(default=None, description="最后一次发现的IP地址")
|
||||
lost_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="物品标记为丢失的时间",
|
||||
sa_column=Column(DateTime, nullable=True)
|
||||
)
|
||||
@@ -1,20 +1,25 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
|
||||
from model.base import SQLModelBase
|
||||
|
||||
"""
|
||||
[TODO] 弃用,改成 ResponseBase:
|
||||
|
||||
class ResponseBase(BaseModel):
|
||||
code: int = 0
|
||||
msg: str = ""
|
||||
request_id: UUID
|
||||
|
||||
再根据需要继承
|
||||
"""
|
||||
class DefaultResponse(BaseModel):
|
||||
code: int = 0
|
||||
data: dict | list | bool | None = None
|
||||
data: str | dict | list | bool | SQLModelBase | None = None
|
||||
msg: str = ""
|
||||
|
||||
class ObjectData(BaseModel):
|
||||
id: int
|
||||
type: Literal['normal', 'car']
|
||||
key: str
|
||||
name: str
|
||||
icon: str
|
||||
status: Literal['ok', 'lost']
|
||||
phone: str
|
||||
context: str | None = None
|
||||
lost_description: str | None = None
|
||||
create_time: str
|
||||
lost_time: str | None = None
|
||||
# FastAPI 鉴权返回模型
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
from sqlmodel import Field
|
||||
from .base import TableBase
|
||||
from .base import TableBase, SQLModelBase
|
||||
|
||||
class Setting(TableBase, table=True):
|
||||
|
||||
type: str = Field(index=True, nullable=False, description="设置类型")
|
||||
name: str = Field(primary_key=True, nullable=False, description="设置名称") # name 为唯一主键
|
||||
value: str | None = Field(description="设置值")
|
||||
class SettingBase(SQLModelBase):
|
||||
type: str = Field(index=True)
|
||||
"""设置类型"""
|
||||
|
||||
name: str = Field(index=True, unique=True) # name 为唯一主键
|
||||
"""设置名称"""
|
||||
|
||||
value: str | None
|
||||
"""设置值"""
|
||||
|
||||
class Setting(SettingBase, TableBase, table=True):
|
||||
pass
|
||||
|
||||
class SettingResponse(SettingBase):
|
||||
pass
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
# FastAPI 鉴权模型
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
87
model/user.py
Normal file
87
model/user.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import EmailStr
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.orm.session import Session as SessionClass
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
from .item import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .firmware import Firmware
|
||||
|
||||
|
||||
class UserTypeEnum(StrEnum):
|
||||
normal_user = 'normal_user'
|
||||
admin = 'admin'
|
||||
super_admin = 'super_admin'
|
||||
|
||||
class UserBase(SQLModelBase):
|
||||
pass
|
||||
|
||||
class User(UserBase, UUIDTableBase, table=True):
|
||||
email: EmailStr = Field(index=True, unique=True)
|
||||
"""邮箱"""
|
||||
|
||||
nickname: str
|
||||
"""昵称"""
|
||||
|
||||
password: str
|
||||
"""Argon2算法哈希后的密码"""
|
||||
|
||||
two_factor_secret: str | None = None
|
||||
"""两步验证的密钥"""
|
||||
|
||||
role: UserTypeEnum = Field(default=UserTypeEnum.normal_user, index=True)
|
||||
"""用户的权限等级"""
|
||||
|
||||
items: list[Item] = Relationship(back_populates='user', cascade_delete=True)
|
||||
"""物品关系"""
|
||||
|
||||
firmwares: list['Firmware'] = Relationship(back_populates='uploaded_by', cascade_delete=True)
|
||||
"""上传的固件关系"""
|
||||
|
||||
_initializing: ClassVar[bool] = False
|
||||
"""标记当前是否处于初始化阶段,初始化阶段允许创建 super_admin"""
|
||||
|
||||
@event.listens_for(SessionClass, "before_flush")
|
||||
def check_super_admin_immutability(session, flush_context, instances):
|
||||
"""
|
||||
在事务刷新到数据库前,集中检查所有关于 super_admin 的不合法操作。
|
||||
此监听器确保超级管理员的角色和存在性是不可变的。
|
||||
"""
|
||||
# 检查1: 禁止创建新的 super_admin
|
||||
for obj in session.new:
|
||||
if isinstance(obj, User) and obj.role == UserTypeEnum.super_admin and not User._initializing:
|
||||
raise ValueError("业务规则:不允许创建新的超级管理员。")
|
||||
|
||||
# 检查2: 禁止删除已存在的 super_admin
|
||||
for obj in session.deleted:
|
||||
if isinstance(obj, User):
|
||||
state = sa.inspect(obj)
|
||||
# 直接从对象被删除前的状态获取角色,避免不必要的 lazy load
|
||||
original_role = state.committed_state.get('role')
|
||||
if original_role == UserTypeEnum.super_admin:
|
||||
username = state.committed_state.get('username', f'(ID: {obj.id})')
|
||||
raise ValueError(f"业务规则:不允许删除超级管理员 '{username}'。")
|
||||
|
||||
# 检查3: 禁止与 super_admin 相关的角色变更
|
||||
for obj in session.dirty:
|
||||
if isinstance(obj, User):
|
||||
state = sa.inspect(obj)
|
||||
# 仅在 'role' 字段确实被修改时才进行检查
|
||||
if "role" in state.committed_state:
|
||||
history = state.attrs.role.history
|
||||
original_role = history.deleted[0]
|
||||
new_role = history.added[0]
|
||||
|
||||
# 场景 a: 禁止将 super_admin 降级
|
||||
if original_role == UserTypeEnum.super_admin:
|
||||
raise ValueError(f"业务规则:不允许将超级管理员 '{obj.username}' 的角色降级。")
|
||||
|
||||
# 场景 b: 禁止将任何用户提升为 super_admin
|
||||
if new_role == UserTypeEnum.super_admin:
|
||||
raise ValueError(f"业务规则:不允许将用户 '{obj.username}' 提升为超级管理员。")
|
||||
0
model/version.py
Normal file
0
model/version.py
Normal file
@@ -1 +1,2 @@
|
||||
from .password import Password
|
||||
|
||||
|
||||
@@ -2,10 +2,24 @@ import secrets
|
||||
from loguru import logger
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
from enum import StrEnum
|
||||
|
||||
_ph = PasswordHasher()
|
||||
|
||||
class Password():
|
||||
class PasswordStatus(StrEnum):
|
||||
"""密码校验状态枚举"""
|
||||
|
||||
VALID = "valid"
|
||||
"""密码校验通过"""
|
||||
|
||||
INVALID = "invalid"
|
||||
"""密码校验失败"""
|
||||
|
||||
EXPIRED = "expired"
|
||||
"""密码哈希已过时,建议重新哈希"""
|
||||
|
||||
class Password:
|
||||
"""密码处理工具类,包含密码生成、哈希和验证功能"""
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
@@ -21,6 +35,7 @@ class Password():
|
||||
"""
|
||||
return secrets.token_hex(length)
|
||||
|
||||
@staticmethod
|
||||
def hash(
|
||||
password: str
|
||||
) -> str:
|
||||
@@ -34,39 +49,31 @@ class Password():
|
||||
"""
|
||||
return _ph.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
stored_password: str,
|
||||
provided_password: str,
|
||||
debug: bool = False
|
||||
) -> bool:
|
||||
hash: str,
|
||||
password: str
|
||||
) -> PasswordStatus:
|
||||
"""
|
||||
验证存储的 Argon2 哈希值与用户提供的密码是否匹配。
|
||||
|
||||
:param stored_password: 数据库中存储的 Argon2 哈希字符串
|
||||
:param provided_password: 用户本次提供的密码
|
||||
:param debug: 是否输出调试信息
|
||||
:param hash: 数据库中存储的 Argon2 哈希字符串
|
||||
:param password: 用户本次提供的密码
|
||||
:return: 如果密码匹配返回 True, 否则返回 False
|
||||
"""
|
||||
if debug:
|
||||
logger.info(f"验证密码: (哈希) {stored_password}")
|
||||
|
||||
try:
|
||||
# verify 函数会自动解析 stored_password 中的盐和参数
|
||||
_ph.verify(stored_password, provided_password)
|
||||
_ph.verify(hash, password)
|
||||
|
||||
# 检查哈希参数是否已过时。如果返回True,
|
||||
# 意味着你应该使用新的参数重新哈希密码并更新存储。
|
||||
# 这是一个很好的实践,可以随着时间推移增强安全性。
|
||||
if _ph.check_needs_rehash(stored_password):
|
||||
if _ph.check_needs_rehash(hash):
|
||||
logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
|
||||
return PasswordStatus.EXPIRED
|
||||
|
||||
return True
|
||||
return PasswordStatus.VALID
|
||||
except VerifyMismatchError:
|
||||
# 这是预期的异常,当密码不匹配时触发。
|
||||
if debug:
|
||||
logger.info("密码不匹配")
|
||||
return False
|
||||
except Exception as e:
|
||||
# 捕获其他可能的错误
|
||||
logger.error(f"密码验证过程中发生未知错误: {e}")
|
||||
return False
|
||||
return PasswordStatus.INVALID
|
||||
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
|
||||
|
||||
2
pkg/sender/__init__.py
Normal file
2
pkg/sender/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .wechat_bot import WeChatBot
|
||||
from .server_chan import ServerChatBot
|
||||
42
pkg/sender/server_chan.py
Normal file
42
pkg/sender/server_chan.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Literal
|
||||
from loguru import logger
|
||||
from model import Setting
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from pkg.utils import raise_internal_error, raise_service_unavailable
|
||||
import aiohttp
|
||||
|
||||
class ServerChatBot:
|
||||
async def get_url(session: AsyncSession):
|
||||
server_chan_key = await Setting.get(session, Setting.name == "server_chan_key")
|
||||
|
||||
if not server_chan_key.value:
|
||||
raise_internal_error("Server酱未配置,请联系管理员")
|
||||
|
||||
url = f"https://sctapi.ftqq.com/{server_chan_key.value}.send"
|
||||
return url
|
||||
|
||||
async def send_text(
|
||||
session: AsyncSession,
|
||||
title: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
"""发送的 Markdown 消息。
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 数据库会话
|
||||
title (str): 需要发送的标题
|
||||
description (str): 需要发送的文本消息
|
||||
"""
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
url=await ServerChatBot.get_url(session),
|
||||
data={
|
||||
"title": title,
|
||||
"desp": description
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"Failed to send to Server Chan: {response.status}")
|
||||
raise_internal_error("Server酱服务不可用,请稍后再试")
|
||||
else:
|
||||
logger.info("Server Chan message sent successfully")
|
||||
102
pkg/sender/wechat_bot.py
Normal file
102
pkg/sender/wechat_bot.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Literal
|
||||
from loguru import logger
|
||||
from model import Setting
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from pkg.utils import raise_internal_error, raise_service_unavailable
|
||||
import aiohttp
|
||||
|
||||
webhook_url = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send"
|
||||
|
||||
class WeChatBot:
|
||||
async def get_key(session: AsyncSession):
|
||||
key = await Setting.get(session, Setting.name == "wechat_bot_key")
|
||||
|
||||
if not key.value:
|
||||
raise_internal_error("企业微信机器人未配置,请联系管理员")
|
||||
return key.value
|
||||
|
||||
async def send_text(
|
||||
session: AsyncSession,
|
||||
text: str,
|
||||
mentioned_all: bool = False,
|
||||
mentioned_list: list[str] = [],
|
||||
mentioned_mobile_list: list[str] = []
|
||||
) -> None:
|
||||
"""发送文本类型的消息。
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 数据库会话
|
||||
text (str): 需要发送的文本消息
|
||||
mentioned_all (bool, optional): 是否提及所有人 Defaults to False.
|
||||
mentioned_list (list[str], optional): 提及的用户列表 Defaults to [].
|
||||
mentioned_mobile_list (list[str], optional): 提及的手机号码列表 Defaults to [].
|
||||
"""
|
||||
key = await WeChatBot.get_key(session)
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
url=f"{webhook_url}?key={key}",
|
||||
json={
|
||||
"msgtype": "text",
|
||||
"text": {
|
||||
"content": text
|
||||
},
|
||||
"mentioned_list": ["@all"] if mentioned_all else mentioned_list,
|
||||
"mentioned_mobile_list": ["@all"] if mentioned_all else mentioned_mobile_list
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"Failed to send WeChat message: {response.status}")
|
||||
raise_internal_error("企业微信机器人服务不可用,请稍后再试")
|
||||
else:
|
||||
resp_json = await response.json()
|
||||
if resp_json.get("errcode") != 0:
|
||||
logger.error(f"WeChat API error: {resp_json.get('errmsg')}")
|
||||
raise_service_unavailable("发送企业微信消息失败,请稍后再试或联系管理员")
|
||||
else:
|
||||
logger.info("WeChat message sent successfully")
|
||||
|
||||
async def send_markdown(
|
||||
session: AsyncSession,
|
||||
markdown: str,
|
||||
version: Literal['v1', 'v2'],
|
||||
mentioned_all: bool = False,
|
||||
mentioned_list: list[str] = [],
|
||||
mentioned_mobile_list: list[str] = []
|
||||
) -> None:
|
||||
key = await WeChatBot.get_key(session)
|
||||
|
||||
if version == 'v1':
|
||||
payload = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"content": markdown,
|
||||
"mentioned_list": ["@all"] if mentioned_all else mentioned_list,
|
||||
"mentioned_mobile_list": ["@all"] if mentioned_all else mentioned_mobile_list
|
||||
}
|
||||
}
|
||||
elif version == 'v2':
|
||||
payload = {
|
||||
"msgtype": "markdown_v2",
|
||||
"markdown_v2": {
|
||||
"content": markdown,
|
||||
"mentioned_list": ["@all"] if mentioned_all else mentioned_list,
|
||||
"mentioned_mobile_list": ["@all"] if mentioned_all else mentioned_mobile_list
|
||||
}
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
url=f"{webhook_url}?key={key}",
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"Failed to send WeChat message: {response.status}")
|
||||
raise_internal_error("企业微信机器人服务不可用,请稍后再试")
|
||||
else:
|
||||
resp_json = await response.json()
|
||||
if resp_json.get("errcode") != 0:
|
||||
logger.error(f"WeChat API error: {resp_json.get('errmsg')}")
|
||||
raise_service_unavailable("发送企业微信消息失败,请稍后再试或联系管理员")
|
||||
else:
|
||||
logger.info("WeChat message sent successfully")
|
||||
73
pkg/utils.py
Normal file
73
pkg/utils.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from starlette.status import (
|
||||
HTTP_400_BAD_REQUEST,
|
||||
HTTP_401_UNAUTHORIZED,
|
||||
HTTP_402_PAYMENT_REQUIRED,
|
||||
HTTP_403_FORBIDDEN,
|
||||
HTTP_404_NOT_FOUND,
|
||||
HTTP_409_CONFLICT,
|
||||
HTTP_429_TOO_MANY_REQUESTS,
|
||||
HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
HTTP_501_NOT_IMPLEMENTED,
|
||||
HTTP_503_SERVICE_UNAVAILABLE,
|
||||
HTTP_504_GATEWAY_TIMEOUT,
|
||||
)
|
||||
|
||||
# --- 400 ---
|
||||
|
||||
def ensure_request_param(to_check: Any, detail: str) -> None:
|
||||
"""
|
||||
Ensures a parameter exists. If not, raises a 400 Bad Request.
|
||||
This function returns None if the check passes.
|
||||
"""
|
||||
if not to_check:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
||||
|
||||
def raise_bad_request(detail: str = '') -> NoReturn:
|
||||
"""Raises an HTTP 400 Bad Request exception."""
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
||||
|
||||
def raise_unauthorized(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 401 Unauthorized exception."""
|
||||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
|
||||
|
||||
def raise_insufficient_quota(detail: str = "积分不足,请充值") -> NoReturn:
|
||||
"""Raises an HTTP 402 Payment Required exception."""
|
||||
raise HTTPException(status_code=HTTP_402_PAYMENT_REQUIRED, detail=detail)
|
||||
|
||||
def raise_forbidden(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 403 Forbidden exception."""
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
def raise_not_found(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 404 Not Found exception."""
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
|
||||
|
||||
def raise_conflict(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 409 Conflict exception."""
|
||||
raise HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
|
||||
|
||||
def raise_too_many_requests(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 429 Too Many Requests exception."""
|
||||
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=detail)
|
||||
|
||||
# --- 500 ---
|
||||
|
||||
def raise_internal_error(detail: str = "服务器出现故障,请稍后再试或联系管理员") -> NoReturn:
|
||||
"""Raises an HTTP 500 Internal Server Error exception."""
|
||||
raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
|
||||
|
||||
def raise_not_implemented(detail: str = "尚未支持这种方法") -> NoReturn:
|
||||
"""Raises an HTTP 501 Not Implemented exception."""
|
||||
raise HTTPException(status_code=HTTP_501_NOT_IMPLEMENTED, detail=detail)
|
||||
|
||||
def raise_service_unavailable(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 503 Service Unavailable exception."""
|
||||
raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail=detail)
|
||||
|
||||
def raise_gateway_timeout(detail: str) -> NoReturn:
|
||||
"""Raises an HTTP 504 Gateway Timeout exception."""
|
||||
raise HTTPException(status_code=HTTP_504_GATEWAY_TIMEOUT, detail=detail)
|
||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[project]
|
||||
name = "findreve"
|
||||
version = "2.0.0"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"aiohttp>=3.13.2",
|
||||
"aiosqlite>=0.22.0",
|
||||
"argon2-cffi>=25.1.0",
|
||||
"fastapi[standard]>=0.124.4",
|
||||
"loguru>=0.7.3",
|
||||
"pydantic-extra-types>=2.11.0",
|
||||
"pyjwt>=2.10.1",
|
||||
"semver>=3.0.4",
|
||||
"slowapi>=0.1.9",
|
||||
"sqlmodel>=0.0.27",
|
||||
]
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
327
routes/admin.py
327
routes/admin.py
@@ -1,45 +1,14 @@
|
||||
from fastapi import APIRouter
|
||||
from typing import Annotated, Literal
|
||||
from fastapi import Depends, Query
|
||||
from fastapi import HTTPException
|
||||
import JWT
|
||||
import jwt
|
||||
from jwt import InvalidTokenError
|
||||
from model import database
|
||||
from model.response import DefaultResponse
|
||||
from model.items import Item
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from model import Setting
|
||||
from model.object import Object
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
# 验证是否为管理员
|
||||
async def is_admin(
|
||||
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
||||
) -> Literal[True]:
|
||||
'''
|
||||
验证是否为管理员。
|
||||
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(is_admin)])
|
||||
'''
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, await JWT.get_secret_key(), algorithms=[JWT.ALGORITHM])
|
||||
username = payload.get("sub")
|
||||
stored_account = await Setting.get(session, Setting.name == 'account')
|
||||
if username is None or not stored_account.value == username:
|
||||
raise credentials_exception
|
||||
else:
|
||||
return True
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
from middleware.admin import is_admin
|
||||
from middleware.dependencies import SessionDep
|
||||
from model import User, DefaultResponse
|
||||
from model.firmware import ChipTypeEnum
|
||||
from services import admin as admin_service
|
||||
|
||||
Router = APIRouter(
|
||||
prefix='/api/admin',
|
||||
@@ -64,182 +33,132 @@ async def verity_admin() -> DefaultResponse:
|
||||
return DefaultResponse(data=True)
|
||||
|
||||
@Router.get(
|
||||
path='/items',
|
||||
summary='获取物品信息',
|
||||
description='返回物品信息列表',
|
||||
path='api/admin/settings',
|
||||
summary='获取设置项',
|
||||
description='获取设置项, 留空则获取所有',
|
||||
response_model=DefaultResponse,
|
||||
response_description='物品信息列表'
|
||||
response_description='设置项列表'
|
||||
)
|
||||
async def get_items(
|
||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
||||
id: int | None = Query(default=None, ge=1, description='物品ID'),
|
||||
key: str | None = Query(default=None, description='物品序列号')):
|
||||
'''
|
||||
获得物品信息。
|
||||
async def get_settings(
|
||||
session: SessionDep,
|
||||
name: str | None = None
|
||||
) -> DefaultResponse:
|
||||
data = await admin_service.fetch_settings(session=session, name=name)
|
||||
return DefaultResponse(data=data)
|
||||
|
||||
不传参数返回所有信息,否则可传入 `id` 或 `key` 进行筛选。
|
||||
'''
|
||||
# 根据条件查询物品
|
||||
if id is not None:
|
||||
results = await Object.get(session, Object.id == id)
|
||||
results = [results] if results else []
|
||||
elif key is not None:
|
||||
results = await Object.get(session, Object.key == key)
|
||||
results = [results] if results else []
|
||||
else:
|
||||
results = await Object.get(session, None, fetch_mode="all")
|
||||
|
||||
if results:
|
||||
items = []
|
||||
for obj in results:
|
||||
items.append(Item(
|
||||
id=obj.id,
|
||||
type=obj.type,
|
||||
key=obj.key,
|
||||
name=obj.name,
|
||||
icon=obj.icon or "",
|
||||
status=obj.status or "",
|
||||
phone=int(obj.phone) if obj.phone and obj.phone.isdigit() else 0,
|
||||
lost_description=obj.context,
|
||||
find_ip=obj.find_ip,
|
||||
create_time=obj.created_at.isoformat(),
|
||||
lost_time=obj.lost_at.isoformat() if obj.lost_at else None
|
||||
))
|
||||
return DefaultResponse(data=items)
|
||||
else:
|
||||
return DefaultResponse(data=[])
|
||||
@Router.put(
|
||||
path='api/admin/settings',
|
||||
summary='更新设置项',
|
||||
description='更新设置项',
|
||||
response_model=DefaultResponse,
|
||||
response_description='更新结果'
|
||||
)
|
||||
async def update_settings(
|
||||
session: SessionDep,
|
||||
name: str,
|
||||
value: str
|
||||
) -> DefaultResponse:
|
||||
result = await admin_service.update_setting_value(session=session, name=name, value=value)
|
||||
return DefaultResponse(data=result)
|
||||
|
||||
|
||||
# 固件管理接口
|
||||
|
||||
@Router.post(
|
||||
path='/items',
|
||||
summary='添加物品信息',
|
||||
description='添加新的物品信息',
|
||||
path='/firmware',
|
||||
summary='上传固件包',
|
||||
description='管理员上传新的固件更新包',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='上传成功'
|
||||
)
|
||||
async def upload_firmware(
|
||||
session: SessionDep,
|
||||
admin: Annotated[User, Depends(is_admin)],
|
||||
chip_type: ChipTypeEnum = Form(..., description='芯片类型'),
|
||||
version: str = Form(..., description='版本号'),
|
||||
description: str | None = Form(None, description='更新说明'),
|
||||
file: UploadFile = File(..., description='固件文件'),
|
||||
):
|
||||
"""
|
||||
上传固件包。
|
||||
|
||||
支持的文件格式:.bin
|
||||
文件大小限制:4MB
|
||||
"""
|
||||
await admin_service.upload_firmware(
|
||||
session=session,
|
||||
admin=admin,
|
||||
chip_type=chip_type,
|
||||
version=version,
|
||||
description=description,
|
||||
file=file,
|
||||
)
|
||||
|
||||
|
||||
@Router.get(
|
||||
path='/firmwares',
|
||||
summary='获取固件列表',
|
||||
description='获取已上传的固件列表',
|
||||
response_model=DefaultResponse,
|
||||
response_description='添加物品成功'
|
||||
response_description='固件列表'
|
||||
)
|
||||
async def add_items(
|
||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
||||
key: str,
|
||||
type: Literal['normal', 'car'],
|
||||
name: str,
|
||||
icon: str,
|
||||
phone: str
|
||||
async def list_firmwares(
|
||||
session: SessionDep,
|
||||
admin: Annotated[User, Depends(is_admin)],
|
||||
chip_type: ChipTypeEnum | None = Query(None, description='筛选芯片类型'),
|
||||
is_active: bool | None = Query(None, description='筛选启用状态'),
|
||||
) -> DefaultResponse:
|
||||
'''
|
||||
添加物品信息。
|
||||
|
||||
- **key**: 物品的关键字
|
||||
- **type**: 物品的类型
|
||||
- **name**: 物品的名称
|
||||
- **icon**: 物品的图标
|
||||
- **phone**: 联系电话
|
||||
'''
|
||||
|
||||
try:
|
||||
# 创建新物品对象
|
||||
new_object = Object(
|
||||
key=key,
|
||||
type=type,
|
||||
name=name,
|
||||
icon=icon,
|
||||
phone=phone
|
||||
"""
|
||||
获取固件列表。
|
||||
"""
|
||||
result = await admin_service.list_firmwares(
|
||||
session=session,
|
||||
chip_type=chip_type,
|
||||
is_active=is_active,
|
||||
)
|
||||
# 使用 base.py 中的 add 方法
|
||||
await Object.add(session, new_object)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
else:
|
||||
return DefaultResponse(data=True)
|
||||
return DefaultResponse(data=result)
|
||||
|
||||
@Router.patch(
|
||||
path='/items',
|
||||
summary='更新物品信息',
|
||||
description='更新现有物品的信息',
|
||||
response_model=DefaultResponse,
|
||||
response_description='更新物品成功'
|
||||
)
|
||||
async def update_items(
|
||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
||||
id: int = Query(ge=1),
|
||||
key: str | None = None,
|
||||
name: str | None = None,
|
||||
icon: str | None = None,
|
||||
status: str | None = None,
|
||||
phone: int | None = None,
|
||||
lost_description: str | None = None,
|
||||
find_ip: str | None = None,
|
||||
lost_time: str | None = None
|
||||
) -> DefaultResponse:
|
||||
'''
|
||||
更新物品信息。
|
||||
|
||||
只有 `id` 是必填参数,其余参数都是可选的,在不传入任何值的时候将不做任何更改。
|
||||
|
||||
- **id**: 物品的ID
|
||||
- **key**: 物品的序列号 **不建议修改此项,这样会导致生成的物品二维码直接失效**
|
||||
- **name**: 物品的名称
|
||||
- **icon**: 物品的图标
|
||||
- **status**: 物品的状态
|
||||
- **phone**: 联系电话
|
||||
- **lost_description**: 物品丢失描述
|
||||
- **find_ip**: 找到物品的IP
|
||||
- **lost_time**: 物品丢失时间
|
||||
|
||||
'''
|
||||
try:
|
||||
# 获取现有物品
|
||||
obj = await Object.get_exist_one(session, id)
|
||||
|
||||
# 更新字段
|
||||
if key is not None:
|
||||
obj.key = key
|
||||
if name is not None:
|
||||
obj.name = name
|
||||
if icon is not None:
|
||||
obj.icon = icon
|
||||
if status is not None:
|
||||
obj.status = status
|
||||
if phone is not None:
|
||||
obj.phone = str(phone)
|
||||
if lost_description is not None:
|
||||
obj.context = lost_description
|
||||
if find_ip is not None:
|
||||
obj.find_ip = find_ip
|
||||
if lost_time is not None:
|
||||
from datetime import datetime
|
||||
obj.lost_at = datetime.fromisoformat(lost_time)
|
||||
|
||||
# 保存更新
|
||||
await obj.save(session)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
else:
|
||||
return DefaultResponse(data=True)
|
||||
|
||||
@Router.delete(
|
||||
path='/items',
|
||||
summary='删除物品信息',
|
||||
description='删除指定的物品信息',
|
||||
response_model=DefaultResponse,
|
||||
response_description='删除物品成功'
|
||||
path='/firmware/{firmware_id}',
|
||||
summary='删除固件',
|
||||
description='删除指定的固件包',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='删除成功'
|
||||
)
|
||||
async def delete_firmware(
|
||||
session: SessionDep,
|
||||
admin: Annotated[User, Depends(is_admin)],
|
||||
firmware_id: UUID,
|
||||
):
|
||||
"""
|
||||
删除固件包。
|
||||
"""
|
||||
await admin_service.delete_firmware(
|
||||
session=session,
|
||||
firmware_id=firmware_id,
|
||||
)
|
||||
async def delete_items(
|
||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
||||
id: int) -> DefaultResponse:
|
||||
'''
|
||||
删除物品信息。
|
||||
|
||||
- **id**: 物品的ID
|
||||
'''
|
||||
try:
|
||||
# 获取现有物品
|
||||
obj = await Object.get_exist_one(session, id)
|
||||
# 使用 base.py 中的 delete 方法
|
||||
await Object.delete(session, obj)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
else:
|
||||
return DefaultResponse(data=True)
|
||||
|
||||
@Router.patch(
|
||||
path='/firmware/{firmware_id}/status',
|
||||
summary='切换固件状态',
|
||||
description='启用或禁用固件',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='操作成功'
|
||||
)
|
||||
async def toggle_firmware_status(
|
||||
session: SessionDep,
|
||||
admin: Annotated[User, Depends(is_admin)],
|
||||
firmware_id: UUID,
|
||||
is_active: bool = Query(..., description='目标状态'),
|
||||
):
|
||||
"""
|
||||
切换固件启用状态。
|
||||
"""
|
||||
await admin_service.toggle_firmware_status(
|
||||
session=session,
|
||||
firmware_id=firmware_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
259
routes/object.py
259
routes/object.py
@@ -1,141 +1,170 @@
|
||||
import random
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from model.database import Database
|
||||
from model.response import DefaultResponse, ObjectData
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from dependencies import SessionDep
|
||||
from middleware.user import get_current_user
|
||||
from model import DefaultResponse, User
|
||||
from model.item import ItemDataUpdateRequest
|
||||
from services import object as object_service
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
Router = APIRouter(prefix='/api/object', tags=['物品 Object'])
|
||||
|
||||
@Router.get(
|
||||
path='/{item_key}',
|
||||
path='/items',
|
||||
summary='获取物品信息',
|
||||
description='返回物品信息列表',
|
||||
response_model=DefaultResponse,
|
||||
response_description='物品信息列表'
|
||||
)
|
||||
async def get_items(
|
||||
session: SessionDep,
|
||||
token: Annotated[User, Depends(get_current_user)],
|
||||
id: int | None = Query(default=None, ge=1, description='物品ID'),
|
||||
key: str | None = Query(default=None, description='物品序列号')):
|
||||
"""
|
||||
获得物品信息。
|
||||
|
||||
不传参数返回所有信息,否则可传入 `id` 或 `key` 进行筛选。
|
||||
"""
|
||||
items = await object_service.list_items(
|
||||
session=session,
|
||||
user=token,
|
||||
item_id=id,
|
||||
key=key,
|
||||
)
|
||||
return DefaultResponse(data=items)
|
||||
|
||||
@Router.post(
|
||||
path='/items',
|
||||
summary='添加物品信息',
|
||||
description='添加新的物品信息',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='添加物品成功'
|
||||
)
|
||||
async def add_items(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
request: ItemDataUpdateRequest
|
||||
):
|
||||
"""
|
||||
添加物品信息。
|
||||
"""
|
||||
await object_service.create_item(
|
||||
session=session,
|
||||
user=user,
|
||||
request=request,
|
||||
)
|
||||
|
||||
@Router.patch(
|
||||
path='/items/{item_id}',
|
||||
summary='更新物品信息',
|
||||
description='更新现有物品的信息',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='更新物品成功'
|
||||
)
|
||||
async def update_items(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
item_id: UUID,
|
||||
request: ItemDataUpdateRequest,
|
||||
):
|
||||
"""
|
||||
更新物品信息。
|
||||
|
||||
只有 `id` 是必填参数,其余参数都是可选的,在不传入任何值的时候将不做任何更改。
|
||||
|
||||
- **id**: 物品的ID
|
||||
- **key**: 物品的序列号
|
||||
- **name**: 物品的名称
|
||||
- **icon**: 物品的图标
|
||||
- **status**: 物品的状态
|
||||
- **phone**: 联系电话
|
||||
- **lost_description**: 物品丢失描述
|
||||
- **find_ip**: 找到物品的IP
|
||||
- **lost_time**: 物品丢失时间
|
||||
"""
|
||||
|
||||
await object_service.update_item(
|
||||
session=session,
|
||||
user=user,
|
||||
item_id=item_id,
|
||||
request=request,
|
||||
)
|
||||
|
||||
@Router.delete(
|
||||
path='/items/{item_id}',
|
||||
summary='删除物品信息',
|
||||
description='删除指定的物品信息',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='删除物品成功'
|
||||
)
|
||||
async def delete_items(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
item_id: UUID
|
||||
):
|
||||
"""
|
||||
删除物品信息。
|
||||
- **id**: 物品的ID
|
||||
"""
|
||||
await object_service.delete_item(
|
||||
session=session,
|
||||
user=user,
|
||||
item_id=item_id,
|
||||
)
|
||||
|
||||
@Router.get(
|
||||
path='/{item_id}',
|
||||
summary="获取物品信息",
|
||||
description="根据物品键获取物品信息",
|
||||
response_model=DefaultResponse,
|
||||
response_description="物品信息"
|
||||
)
|
||||
async def get_object(item_key: str, request: Request):
|
||||
async def get_object(
|
||||
session: SessionDep,
|
||||
item_id: UUID,
|
||||
request: Request
|
||||
) -> DefaultResponse:
|
||||
"""
|
||||
获取物品信息 / Get object information
|
||||
"""
|
||||
|
||||
db = Database()
|
||||
await db.init_db()
|
||||
object_data = await db.get_object(key=item_key)
|
||||
|
||||
if object_data:
|
||||
if object_data[5] == 'lost':
|
||||
# 物品已标记为丢失,更新IP地址
|
||||
await db.update_object(id=object_data[0], find_ip=str(request.client.host))
|
||||
|
||||
# 添加一些随机延迟,类似JWT身份验证时根据延迟爆破引发的问题
|
||||
await asyncio.sleep(random.uniform(0.10, 0.30))
|
||||
|
||||
print(object_data)
|
||||
return DefaultResponse(data=ObjectData(
|
||||
id=object_data[0],
|
||||
type=object_data[1],
|
||||
key=object_data[2],
|
||||
name=object_data[3],
|
||||
icon=object_data[4],
|
||||
status=object_data[5],
|
||||
phone=object_data[6],
|
||||
lost_description=object_data[7],
|
||||
create_time=object_data[9],
|
||||
lost_time=object_data[10]
|
||||
).model_dump())
|
||||
else: return JSONResponse(
|
||||
status_code=404,
|
||||
content=DefaultResponse(
|
||||
code=404,
|
||||
msg='物品不存在或出现异常'
|
||||
).model_dump()
|
||||
data = await object_service.retrieve_object(
|
||||
session=session,
|
||||
item_id=item_id,
|
||||
client_host=str(request.client.host),
|
||||
)
|
||||
return DefaultResponse(data=data.model_dump())
|
||||
|
||||
@Router.put(
|
||||
path='/{item_id}',
|
||||
@Router.post(
|
||||
path='/{item_id}/notify_move_car',
|
||||
summary="通知车主进行挪车",
|
||||
description="向车主发送挪车通知",
|
||||
response_model=DefaultResponse,
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description="挪车通知结果"
|
||||
)
|
||||
async def notify_move_car(
|
||||
item_id: int,
|
||||
phone: str = None,
|
||||
session: SessionDep,
|
||||
item_id: UUID,
|
||||
phone: str | None = None,
|
||||
):
|
||||
"""通知车主进行挪车 / Notify car owner to move the car
|
||||
"""
|
||||
通知车主进行挪车 / Notify car owner to move the car
|
||||
|
||||
Args:
|
||||
_request (Request): ...
|
||||
session (AsyncSession): 数据库会话 / Database session
|
||||
item_id (int): 物品ID / Item ID
|
||||
phone (str): 挪车发起者电话 / Phone number of the person initiating the move. Defaults to None.
|
||||
"""
|
||||
db = Database()
|
||||
await db.init_db()
|
||||
|
||||
# 检查是否存在该物品
|
||||
object_data = await db.get_object(id=item_id)
|
||||
if not object_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=DefaultResponse(
|
||||
code=404,
|
||||
msg='物品不存在或出现异常'
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# 检查物品类型是否为车辆
|
||||
if object_data[1] != 'car':
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=DefaultResponse(
|
||||
code=400,
|
||||
msg='该物品不是车辆,无法发送挪车通知'
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# 发起挪车通知(目前仅适配Server酱)
|
||||
server_chan_key = await db.get_setting('server_chan_key')
|
||||
if not server_chan_key:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=DefaultResponse(
|
||||
code=500,
|
||||
msg='未配置Server酱,无法发送挪车通知'
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
title = "挪车通知 - Findreve"
|
||||
description = f"您的车辆“{object_data[3]}”被请求挪车。\n\n"
|
||||
if phone:
|
||||
description += f"请求挪车者电话:[{phone}](tel:{phone})\n\n"
|
||||
description += "请尽快联系请求者并挪车。"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=f"https://sctapi.ftqq.com/{server_chan_key}.send",
|
||||
data={
|
||||
"title": title,
|
||||
"desp": description
|
||||
}
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
resp_json = await resp.json()
|
||||
if resp_json.get('code') == 0:
|
||||
return DefaultResponse(msg='挪车通知发送成功')
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=DefaultResponse(
|
||||
code=500,
|
||||
msg=f"挪车通知发送失败,Server酱返回错误:{resp_json.get('message')}"
|
||||
).model_dump()
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=DefaultResponse(
|
||||
code=500,
|
||||
msg=f"挪车通知发送失败,HTTP状态码:{resp.status}"
|
||||
).model_dump()
|
||||
await object_service.notify_move_car(
|
||||
session=session,
|
||||
item_id=item_id,
|
||||
phone=phone,
|
||||
)
|
||||
98
routes/ota.py
Normal file
98
routes/ota.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""OTA API 路由,处理 ESP32/8266 设备的在线升级请求。"""
|
||||
|
||||
from fastapi import APIRouter, Query, status
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
from middleware.dependencies import SessionDep, DeviceDep
|
||||
from model import DefaultResponse
|
||||
from model.firmware import FirmwareCheckUpdateRequest, FirmwareCheckUpdateResponse
|
||||
from services import ota as ota_service
|
||||
|
||||
Router = APIRouter(prefix='/api/ota', tags=['OTA升级'])
|
||||
|
||||
|
||||
@Router.post(
|
||||
path='/check-update',
|
||||
summary='检查固件更新',
|
||||
description='设备通过 mTLS 认证后查询是否有新版本固件',
|
||||
response_model=DefaultResponse,
|
||||
response_description='更新检查结果'
|
||||
)
|
||||
async def check_update(
|
||||
session: SessionDep,
|
||||
device: DeviceDep,
|
||||
request_data: FirmwareCheckUpdateRequest,
|
||||
) -> DefaultResponse:
|
||||
"""
|
||||
检查固件更新。
|
||||
|
||||
设备需要提供有效的 mTLS 客户端证书,证书 CN 字段为设备序列号。
|
||||
"""
|
||||
result = await ota_service.check_firmware_update(
|
||||
session=session,
|
||||
device=device,
|
||||
chip_type=request_data.chip_type,
|
||||
current_version=request_data.current_version,
|
||||
)
|
||||
return DefaultResponse(data=result)
|
||||
|
||||
|
||||
@Router.get(
|
||||
path='/download/{firmware_id}',
|
||||
summary='下载固件包',
|
||||
description='下载指定的固件更新包',
|
||||
)
|
||||
async def download_firmware(
|
||||
session: SessionDep,
|
||||
device: DeviceDep,
|
||||
firmware_id: str,
|
||||
):
|
||||
"""
|
||||
下载固件包。
|
||||
|
||||
需要有效的设备证书,且下载会记录统计信息。
|
||||
"""
|
||||
return await ota_service.get_firmware_file(
|
||||
session=session,
|
||||
firmware_id=firmware_id,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@Router.post(
|
||||
path='/report-version',
|
||||
summary='上报设备版本',
|
||||
description='设备上报当前运行的固件版本',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='上报成功'
|
||||
)
|
||||
async def report_version(
|
||||
session: SessionDep,
|
||||
device: DeviceDep,
|
||||
version: str = Query(..., description='当前版本号'),
|
||||
):
|
||||
"""
|
||||
上报设备当前运行的固件版本。
|
||||
"""
|
||||
await ota_service.update_device_version(
|
||||
session=session,
|
||||
device=device,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@Router.post(
|
||||
path='/report-lost',
|
||||
summary='上报设备丢失',
|
||||
description='设备上报丢失状态',
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
response_description='上报成功'
|
||||
)
|
||||
async def report_lost(
|
||||
session: SessionDep,
|
||||
device: DeviceDep,
|
||||
):
|
||||
"""
|
||||
设备上报丢失状态(复用现有丢失处理逻辑)。
|
||||
"""
|
||||
await ota_service.report_device_lost(session=session, device=device)
|
||||
@@ -1,71 +1,34 @@
|
||||
# 导入库
|
||||
from typing import Annotated
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi import APIRouter
|
||||
import jwt, JWT
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from pkg import Password
|
||||
from loguru import logger
|
||||
|
||||
from model.token import Token
|
||||
from model import Setting, database
|
||||
from model import database
|
||||
from model.response import TokenResponse
|
||||
from services import session as session_service
|
||||
from pkg import utils
|
||||
|
||||
Router = APIRouter(tags=["令牌 session"])
|
||||
|
||||
# 创建令牌
|
||||
async def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, key=await JWT.get_secret_key(), algorithm='HS256')
|
||||
return encoded_jwt
|
||||
|
||||
# 验证账号密码
|
||||
async def authenticate_user(session: AsyncSession, username: str, password: str):
|
||||
# 验证账号和密码
|
||||
account = await Setting.get(session, Setting.name == 'account')
|
||||
stored_password = await Setting.get(session, Setting.name == 'password')
|
||||
|
||||
if not account or not stored_password:
|
||||
logger.error("Account or password not set in settings.")
|
||||
return False
|
||||
|
||||
if account.value != username or not Password.verify(stored_password.value, password):
|
||||
logger.error("Invalid username or password.")
|
||||
return False
|
||||
|
||||
return {'is_authenticated': True}
|
||||
|
||||
# FastAPI 登录路由 / FastAPI login route
|
||||
@Router.post(
|
||||
path="/api/token",
|
||||
summary="获取访问令牌",
|
||||
description="使用用户名和密码获取访问令牌",
|
||||
response_model=Token,
|
||||
response_model=TokenResponse,
|
||||
response_description="访问令牌"
|
||||
)
|
||||
async def login_for_access_token(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
||||
) -> Token:
|
||||
user = await authenticate_user(
|
||||
) -> TokenResponse:
|
||||
token_response = await session_service.login_for_access_token(
|
||||
session=session,
|
||||
username=form_data.username,
|
||||
password=form_data.password
|
||||
password=form_data.password,
|
||||
)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(hours=1)
|
||||
access_token = await create_access_token(
|
||||
data={"sub": form_data.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
if not token_response:
|
||||
utils.raise_unauthorized("Incorrect username or password")
|
||||
|
||||
return token_response
|
||||
|
||||
19
routes/site.py
Normal file
19
routes/site.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter
|
||||
from model.response import DefaultResponse
|
||||
|
||||
Router = APIRouter(prefix='/api/site', tags=['站点 Site'])
|
||||
|
||||
@Router.get(
|
||||
path='/ping',
|
||||
summary='站点健康检查',
|
||||
description='检查站点是否在线',
|
||||
response_model=DefaultResponse,
|
||||
response_description='站点在线'
|
||||
)
|
||||
async def ping():
|
||||
"""
|
||||
站点健康检查接口。
|
||||
|
||||
:return: Findreve 版本号
|
||||
"""
|
||||
return DefaultResponse(data='pong')
|
||||
13
services/__init__.py
Normal file
13
services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
服务层模块聚合。
|
||||
"""
|
||||
|
||||
from . import admin, object, session, site # noqa: F401
|
||||
|
||||
|
||||
__all__ = [
|
||||
"admin",
|
||||
"object",
|
||||
"session",
|
||||
"site",
|
||||
]
|
||||
233
services/admin.py
Normal file
233
services/admin.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
管理员相关业务逻辑。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import UploadFile
|
||||
from loguru import logger
|
||||
from pydantic_extra_types.semantic_version import SemanticVersion
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from model import Firmware, User, Setting, SettingResponse
|
||||
from model.firmware import ChipTypeEnum, FirmwareDataResponseAdmin
|
||||
from pkg import utils
|
||||
|
||||
# 固件存储目录
|
||||
FIRMWARE_STORAGE_PATH = Path("data/firmware")
|
||||
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 文件大小限制 4MB
|
||||
MAX_FIRMWARE_SIZE = 4 * 1024 * 1024
|
||||
|
||||
|
||||
async def fetch_settings(
|
||||
session: SessionDep,
|
||||
name: str | None = None,
|
||||
) -> List[SettingResponse]:
|
||||
"""
|
||||
按名称获取设置项,默认返回全部。
|
||||
"""
|
||||
data: list[SettingResponse] = []
|
||||
|
||||
if name:
|
||||
setting = await Setting.get(session, Setting.name == name)
|
||||
if setting:
|
||||
data.append(SettingResponse.model_validate(setting))
|
||||
else:
|
||||
utils.raise_not_found("Setting not found")
|
||||
else:
|
||||
settings: Iterable[Setting] | None = await Setting.get(session, fetch_mode="all")
|
||||
if settings:
|
||||
data = [SettingResponse.model_validate(s) for s in settings]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def update_setting_value(
|
||||
session: SessionDep,
|
||||
name: str,
|
||||
value: str,
|
||||
) -> bool:
|
||||
"""
|
||||
更新设置项的值。
|
||||
"""
|
||||
setting = await Setting.get(session, Setting.name == name)
|
||||
if not setting:
|
||||
utils.raise_not_found("Setting not found")
|
||||
|
||||
setting.value = value
|
||||
await Setting.save(session)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _calculate_md5(file_path: Path) -> str:
|
||||
"""计算文件的 MD5 值"""
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
async def upload_firmware(
|
||||
session: SessionDep,
|
||||
admin: User,
|
||||
chip_type: ChipTypeEnum,
|
||||
version: str,
|
||||
description: str | None,
|
||||
file: UploadFile,
|
||||
) -> None:
|
||||
"""
|
||||
上传固件包。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
admin: 管理员用户
|
||||
chip_type: 芯片类型
|
||||
version: 版本号
|
||||
description: 更新说明
|
||||
file: 上传的文件
|
||||
"""
|
||||
# 验证版本号格式
|
||||
try:
|
||||
version_obj = SemanticVersion(version)
|
||||
except ValueError:
|
||||
utils.raise_bad_request("Invalid semantic version format")
|
||||
|
||||
# 验证文件扩展名
|
||||
if not file.filename or not file.filename.endswith('.bin'):
|
||||
utils.raise_bad_request("Only .bin files are supported")
|
||||
|
||||
# 检查是否已存在相同芯片类型和版本的固件
|
||||
from sqlalchemy import and_
|
||||
existing = await Firmware.get(
|
||||
session,
|
||||
and_(
|
||||
Firmware.chip_type == chip_type,
|
||||
Firmware.version == str(version_obj)
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
utils.raise_conflict(f"Firmware {chip_type} v{version} already exists")
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# 验证文件大小
|
||||
if file_size > MAX_FIRMWARE_SIZE:
|
||||
utils.raise_bad_request(f"File size exceeds {MAX_FIRMWARE_SIZE} bytes")
|
||||
|
||||
if file_size == 0:
|
||||
utils.raise_bad_request("Empty file")
|
||||
|
||||
# 生成文件名
|
||||
safe_filename = f"{chip_type}_{version}_{file.filename}"
|
||||
file_path = FIRMWARE_STORAGE_PATH / safe_filename
|
||||
|
||||
# 写入文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 MD5
|
||||
file_md5 = _calculate_md5(file_path)
|
||||
|
||||
# 创建数据库记录
|
||||
firmware = Firmware(
|
||||
chip_type=chip_type,
|
||||
version=str(version_obj),
|
||||
file_path=str(file_path),
|
||||
file_size=file_size,
|
||||
file_md5=file_md5,
|
||||
description=description,
|
||||
uploaded_by_id=admin.id,
|
||||
)
|
||||
|
||||
await Firmware.add(session, firmware)
|
||||
logger.info(f"Admin {admin.email} uploaded firmware {chip_type} v{version}")
|
||||
|
||||
|
||||
async def list_firmwares(
|
||||
session: SessionDep,
|
||||
chip_type: ChipTypeEnum | None,
|
||||
is_active: bool | None,
|
||||
) -> List[FirmwareDataResponseAdmin]:
|
||||
"""
|
||||
获取固件列表。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
chip_type: 筛选芯片类型
|
||||
is_active: 筛选启用状态
|
||||
|
||||
Returns:
|
||||
固件列表
|
||||
"""
|
||||
from sqlalchemy import and_
|
||||
|
||||
conditions = []
|
||||
|
||||
if chip_type:
|
||||
conditions.append(Firmware.chip_type == chip_type)
|
||||
if is_active is not None:
|
||||
conditions.append(Firmware.is_active == is_active)
|
||||
|
||||
if conditions:
|
||||
results = await Firmware.get(session, and_(*conditions), fetch_mode="all")
|
||||
else:
|
||||
results = await Firmware.get(session, None, fetch_mode="all")
|
||||
|
||||
if not results:
|
||||
return []
|
||||
|
||||
return [FirmwareDataResponseAdmin.model_validate(fw) for fw in results]
|
||||
|
||||
|
||||
async def delete_firmware(
|
||||
session: SessionDep,
|
||||
firmware_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除固件包。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
firmware_id: 固件ID
|
||||
"""
|
||||
firmware = await Firmware.get(session, Firmware.id == firmware_id)
|
||||
if not firmware:
|
||||
utils.raise_not_found("Firmware not found")
|
||||
|
||||
# 删除文件
|
||||
file_path = Path(firmware.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
# 删除数据库记录
|
||||
await Firmware.delete(session, firmware)
|
||||
|
||||
|
||||
async def toggle_firmware_status(
|
||||
session: SessionDep,
|
||||
firmware_id: UUID,
|
||||
is_active: bool,
|
||||
) -> None:
|
||||
"""
|
||||
切换固件启用状态。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
firmware_id: 固件ID
|
||||
is_active: 目标状态
|
||||
"""
|
||||
firmware = await Firmware.get(session, Firmware.id == firmware_id)
|
||||
if not firmware:
|
||||
utils.raise_not_found("Firmware not found")
|
||||
|
||||
firmware.is_active = is_active
|
||||
await firmware.save(session)
|
||||
163
services/object.py
Normal file
163
services/object.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
物品相关业务逻辑。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import status
|
||||
from loguru import logger
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from model import Item, ItemDataResponse, Setting, User
|
||||
from model.item import ItemDataUpdateRequest, ItemTypeEnum
|
||||
from pkg.sender import ServerChatBot, WeChatBot
|
||||
from pkg import utils
|
||||
|
||||
|
||||
async def list_items(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
item_id: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> List[Item]:
|
||||
"""
|
||||
根据条件获取当前用户的物品列表。
|
||||
"""
|
||||
if item_id is not None:
|
||||
results = await Item.get(session, (Item.id == item_id) & (Item.user_id == user.id))
|
||||
results = [results] if results else []
|
||||
elif key is not None:
|
||||
results = await Item.get(session, (Item.key == key) & (Item.user_id == user.id))
|
||||
results = [results] if results else []
|
||||
else:
|
||||
results = await Item.get(session, Item.user_id == user.id, fetch_mode="all")
|
||||
|
||||
if not results:
|
||||
return []
|
||||
|
||||
items: list[Item] = []
|
||||
for obj in results:
|
||||
items.append(
|
||||
Item(
|
||||
id=obj.id,
|
||||
type=obj.type,
|
||||
key=obj.id,
|
||||
name=obj.name,
|
||||
icon=obj.icon or "",
|
||||
status=obj.status or "",
|
||||
phone=obj.phone if obj.phone and obj.phone.isdigit() else None,
|
||||
lost_description=obj.description,
|
||||
find_ip=obj.find_ip,
|
||||
create_time=obj.created_at.isoformat(),
|
||||
lost_time=obj.lost_at.isoformat() if obj.lost_at else None,
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
async def create_item(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
request: ItemDataUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
创建新的物品信息。
|
||||
"""
|
||||
try:
|
||||
request_dict = request.model_dump()
|
||||
request_dict["user"] = user
|
||||
request_dict["user_id"] = user.id
|
||||
await Item.add(session, Item.model_validate(request_dict))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"Failed to add item: {exc}")
|
||||
utils.raise_internal_error(str(exc))
|
||||
|
||||
|
||||
async def update_item(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
item_id: UUID,
|
||||
request: ItemDataUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新物品信息。
|
||||
"""
|
||||
obj = await Item.get(session, (Item.id == item_id) & (Item.user_id == user.id))
|
||||
if not obj:
|
||||
utils.raise_not_found("Item not found or access denied")
|
||||
|
||||
await obj.update(session, request, exclude_unset=True)
|
||||
|
||||
|
||||
async def delete_item(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
item_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除指定物品。
|
||||
"""
|
||||
obj = await Item.get(session, (Item.id == item_id) & (Item.user_id == user.id))
|
||||
if not obj:
|
||||
utils.raise_not_found("Item not found or access denied")
|
||||
await Item.delete(session, obj)
|
||||
|
||||
|
||||
async def retrieve_object(
|
||||
session: AsyncSession,
|
||||
item_id: UUID,
|
||||
client_host: str,
|
||||
) -> ItemDataResponse:
|
||||
"""
|
||||
根据物品 ID 获取物品信息并视情况更新寻找者 IP。
|
||||
"""
|
||||
object_data = await Item.get(session, Item.id == item_id)
|
||||
|
||||
if not object_data:
|
||||
utils.raise_not_found("物品不存在或出现异常")
|
||||
|
||||
if object_data.status == "lost":
|
||||
object_data.find_ip = client_host
|
||||
object_data = await object_data.save(session)
|
||||
|
||||
return ItemDataResponse.model_validate(object_data)
|
||||
|
||||
|
||||
async def notify_move_car(
|
||||
session: AsyncSession,
|
||||
item_id: UUID,
|
||||
phone: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
向车主发送挪车通知。
|
||||
"""
|
||||
item_data = await Item.get_exist_one(session=session, id=item_id)
|
||||
|
||||
if item_data.type != ItemTypeEnum.car:
|
||||
utils.raise_bad_request("Item is not car")
|
||||
|
||||
server_chan_key = await Setting.get(session, Setting.name == "server_chan_key")
|
||||
wechat_bot_key = await Setting.get(session, Setting.name == "wechat_bot_key")
|
||||
if not (server_chan_key.value or wechat_bot_key.value):
|
||||
utils.raise_internal_error("未配置Server酱,无法发送挪车通知")
|
||||
|
||||
title = "挪车通知 - Findreve"
|
||||
description = (
|
||||
f"您的车辆“{item_data.name}”被请求挪车。\n"
|
||||
f"{f'请求挪车者电话:[{phone}](tel:{phone})' if phone else ''}\n"
|
||||
"请尽快联系请求者并挪车。"
|
||||
)
|
||||
|
||||
mentioned_channel = (await Setting.get(session, Setting.name == "mentioned_channel")).value
|
||||
|
||||
if mentioned_channel == "server_chan":
|
||||
await ServerChatBot.send_text(session=session, title=title, description=description)
|
||||
elif mentioned_channel == "wechat_bot":
|
||||
await WeChatBot.send_markdown(
|
||||
session=session,
|
||||
markdown=f"# {title}\n\n{description}",
|
||||
version="v1",
|
||||
)
|
||||
|
||||
return status.HTTP_204_NO_CONTENT
|
||||
168
services/ota.py
Normal file
168
services/ota.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""OTA 服务层,处理 ESP32/8266 设备的在线升级业务逻辑。"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from pydantic_extra_types.semantic_version import SemanticVersion
|
||||
|
||||
from model import Firmware, Item
|
||||
from model.firmware import ChipTypeEnum, FirmwareCheckUpdateResponse
|
||||
from middleware.dependencies import SessionDep
|
||||
from model.item import ItemStatusEnum
|
||||
from pkg import utils
|
||||
|
||||
# 固件存储目录
|
||||
FIRMWARE_STORAGE_PATH = Path("data/firmware")
|
||||
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
async def check_firmware_update(
|
||||
session: SessionDep,
|
||||
device: Item,
|
||||
chip_type: ChipTypeEnum,
|
||||
current_version: str,
|
||||
) -> FirmwareCheckUpdateResponse:
|
||||
"""
|
||||
检查设备是否有可用的固件更新。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
device: 设备对象
|
||||
chip_type: 芯片类型
|
||||
current_version: 当前版本号
|
||||
|
||||
Returns:
|
||||
FirmwareCheckUpdateResponse: 更新检查结果
|
||||
"""
|
||||
# 验证当前版本格式
|
||||
try:
|
||||
current = SemanticVersion(current_version)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid version format from device {device.id}: {current_version}")
|
||||
utils.raise_bad_request("Invalid version format")
|
||||
|
||||
# 查找该芯片类型的最新启用固件
|
||||
all_firmwares = await Firmware.get(
|
||||
session,
|
||||
(Firmware.chip_type == chip_type) & (Firmware.is_active == True),
|
||||
fetch_mode="all"
|
||||
)
|
||||
|
||||
if not all_firmwares:
|
||||
return FirmwareCheckUpdateResponse(
|
||||
has_update=False,
|
||||
)
|
||||
|
||||
# 过滤出比当前版本新的固件
|
||||
newer_firmwares = []
|
||||
for fw in all_firmwares:
|
||||
try:
|
||||
fw_version = SemanticVersion(str(fw.version))
|
||||
if fw_version > current:
|
||||
newer_firmwares.append(fw)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid firmware version in database: {fw.version}")
|
||||
continue
|
||||
|
||||
if not newer_firmwares:
|
||||
return FirmwareCheckUpdateResponse(
|
||||
has_update=False,
|
||||
)
|
||||
|
||||
# 取最新版本
|
||||
latest = max(newer_firmwares, key=lambda fw: SemanticVersion(str(fw.version)))
|
||||
|
||||
return FirmwareCheckUpdateResponse(
|
||||
has_update=True,
|
||||
latest_version=str(latest.version),
|
||||
download_url=f"/api/ota/download/{latest.id}",
|
||||
file_size=latest.file_size,
|
||||
file_md5=latest.file_md5,
|
||||
description=latest.description,
|
||||
)
|
||||
|
||||
|
||||
async def get_firmware_file(
|
||||
session: SessionDep,
|
||||
firmware_id: str,
|
||||
device: Item,
|
||||
) -> FileResponse:
|
||||
"""
|
||||
获取固件文件并更新下载统计。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
firmware_id: 固件ID
|
||||
device: 设备对象
|
||||
|
||||
Returns:
|
||||
FileResponse: 固件文件响应
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
firmware = await Firmware.get(session, Firmware.id == UUID(firmware_id))
|
||||
|
||||
if not firmware:
|
||||
utils.raise_not_found("Firmware not found")
|
||||
|
||||
if not firmware.is_active:
|
||||
utils.raise_forbidden("Firmware is not available")
|
||||
|
||||
# 验证芯片类型匹配
|
||||
if device.chip_type != firmware.chip_type:
|
||||
utils.raise_forbidden("Firmware chip type mismatch")
|
||||
|
||||
# 更新下载计数
|
||||
firmware.downloaded_count += 1
|
||||
await firmware.save(session)
|
||||
|
||||
file_path = Path(firmware.file_path)
|
||||
if not file_path.exists():
|
||||
logger.error(f"Firmware file not found: {file_path}")
|
||||
utils.raise_internal_error("Firmware file not available")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=file_path.name,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
|
||||
|
||||
async def update_device_version(
|
||||
session: SessionDep,
|
||||
device: Item,
|
||||
version: str,
|
||||
) -> None:
|
||||
"""
|
||||
更新设备上报的固件版本。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
device: 设备对象
|
||||
version: 版本号字符串
|
||||
"""
|
||||
try:
|
||||
SemanticVersion(version)
|
||||
except ValueError:
|
||||
utils.raise_bad_request("Invalid version format")
|
||||
|
||||
device.version = version
|
||||
await device.save(session)
|
||||
logger.info(f"Device {device.id} reported version: {version}")
|
||||
|
||||
|
||||
async def report_device_lost(
|
||||
session: SessionDep,
|
||||
device: Item,
|
||||
) -> None:
|
||||
"""
|
||||
设备上报丢失状态。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
device: 设备对象
|
||||
"""
|
||||
device.status = ItemStatusEnum.lost
|
||||
await device.save(session)
|
||||
logger.info(f"Device {device.id} reported as lost")
|
||||
60
services/session.py
Normal file
60
services/session.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
会话服务,负责处理登录与令牌生成逻辑。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from typing import Any
|
||||
import jwt
|
||||
|
||||
from model import Setting, User
|
||||
from model.response import TokenResponse
|
||||
from pkg import Password, utils
|
||||
import JWT
|
||||
|
||||
async def create_access_token(
|
||||
session: AsyncSession,
|
||||
data: dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
创建访问令牌。
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
jwt_exp_setting = await Setting.get(session, Setting.name == "jwt_token_exp")
|
||||
expire = datetime.now(timezone.utc) + timedelta(int(jwt_exp_setting.value))
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, key=await JWT.get_secret_key(), algorithm="HS256")
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> User:
|
||||
"""
|
||||
验证用户名和密码,返回认证后的用户。
|
||||
"""
|
||||
account = await User.get(session, User.email == username)
|
||||
|
||||
if not account or account.email != username or not Password.verify(account.password, password):
|
||||
utils.raise_unauthorized("Account or password is incorrect")
|
||||
|
||||
return account
|
||||
|
||||
|
||||
async def login_for_access_token(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
登录并生成访问令牌。
|
||||
"""
|
||||
user = await authenticate_user(session=session, username=username, password=password)
|
||||
|
||||
access_token = await create_access_token(
|
||||
session=session,
|
||||
data={"sub": user.email},
|
||||
)
|
||||
return TokenResponse(access_token=access_token)
|
||||
3
services/site.py
Normal file
3
services/site.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
站点信息服务。
|
||||
"""
|
||||
Reference in New Issue
Block a user