添加ESP系列设备的OTA

This commit is contained in:
2026-01-12 11:44:55 +08:00
parent 3580717087
commit 3f1bd0731b
10 changed files with 765 additions and 15 deletions

4
app.py
View File

@@ -7,7 +7,7 @@ from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from pkg.utils import raise_internal_error
from routes import (session, admin, object)
from routes import (session, admin, object, ota)
from model.database import Database
import os
import pkg.conf
@@ -15,7 +15,7 @@ from pkg import utils
from loguru import logger
Router = [admin, session, object]
Router = [admin, session, object, ota]
# Findreve 的生命周期
@asynccontextmanager

View File

@@ -1,10 +1,13 @@
from typing import Annotated, TypeAlias
from fastapi import Depends
from fastapi import Depends, Request
from sqlmodel.ext.asyncio.session import AsyncSession
from model.database import Database
from model.mixin.table import TableViewRequest
from model import Item
from model.item import ItemTypeEnum
from pkg import utils
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)]
"""数据库会话依赖,用于路由函数中获取数据库会话"""
@@ -12,3 +15,48 @@ SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)]
# 新增:表格视图请求依赖(用于分页排序)
TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends()]
"""分页排序请求依赖,用于 LIST 端点"""
async def get_device_from_cert(
request: Request,
session: SessionDep,
) -> Item:
"""
从 mTLS 客户端证书中提取设备序列号并验证设备。
客户端证书的 CN (Common Name) 字段应存储设备序列号 (UUID)。
反向代理Nginx/Apache验证证书后通过 HTTP Header 将 CN 传递给 FastAPI。
Nginx 配置示例:
proxy_set_header X-Client-CN $ssl_client_s_dn_cn;
Apache 配置示例:
RequestHeader set X-Client-CN "%{SSL_CLIENT_S_DN_CN}s"
"""
# 从 Header 获取设备序列号(由反向代理注入)
serial_number = request.headers.get("X-Client-CN")
if not serial_number:
utils.raise_unauthorized("Device certificate required")
# 验证 UUID 格式
try:
from uuid import UUID
serial_uuid = UUID(serial_number)
except ValueError:
utils.raise_unauthorized("Invalid device serial number format")
# 查找设备
device = await Item.get(session, Item.id == serial_uuid)
if not device:
utils.raise_not_found("Device not found")
if device.type != ItemTypeEnum.esp32:
utils.raise_forbidden("Not an ESP device")
return device
DeviceDep: TypeAlias = Annotated[Item, Depends(get_device_from_cert)]
"""设备认证依赖,通过 mTLS 证书验证 ESP 设备"""

View File

@@ -3,6 +3,15 @@ from .setting import Setting, SettingResponse
from .item import Item, ItemDataResponse, ItemTypeEnum, ItemStatusEnum
from .user import User, UserTypeEnum
from .database import Database
from .firmware import (
Firmware,
FirmwareDataResponse,
FirmwareDataResponseAdmin,
FirmwareUploadRequest,
FirmwareCheckUpdateRequest,
FirmwareCheckUpdateResponse,
ChipTypeEnum,
)
# 新增:从 foxline 项目移植的 Mixin 组件
from .mixin.table import (
@@ -27,6 +36,14 @@ __all__ = [
"User",
"UserTypeEnum",
"Database",
# 固件相关
"Firmware",
"FirmwareDataResponse",
"FirmwareDataResponseAdmin",
"FirmwareUploadRequest",
"FirmwareCheckUpdateRequest",
"FirmwareCheckUpdateResponse",
"ChipTypeEnum",
# 新增的 Mixin 组件
"TableBaseMixin",
"UUIDTableBaseMixin",

125
model/firmware.py Normal file
View File

@@ -0,0 +1,125 @@
"""固件包数据模型,用于 ESP32/8266 OTA 在线升级功能。"""
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, String, Text
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .user import User
class ChipTypeEnum(StrEnum):
"""ESP 芯片类型枚举"""
esp32 = 'esp32'
esp8266 = 'esp8266'
esp32s2 = 'esp32s2'
esp32s3 = 'esp32s3'
esp32c3 = 'esp32c3'
class FirmwareBase(SQLModelBase):
chip_type: ChipTypeEnum = Field(index=True)
"""芯片类型"""
version: str = Field(sa_type=String(64), index=True)
"""固件版本号,遵循语义化版本规范"""
file_path: str
"""固件文件存储路径"""
file_size: int
"""固件文件大小(字节)"""
file_md5: str = Field(max_length=32)
"""固件文件 MD5 校验值"""
description: str | None = Field(default=None, sa_type=Text)
"""固件更新说明"""
is_active: bool = Field(default=True, index=True)
"""是否启用该固件版本"""
class Firmware(FirmwareBase, UUIDTableBase, table=True):
"""固件包表"""
uploaded_by_id: UUID = Field(foreign_key='user.id', ondelete='RESTRICT')
"""上传者用户ID"""
downloaded_count: int = Field(default=0)
"""下载次数统计"""
uploaded_at: datetime = Field(default_factory=datetime.now)
"""上传时间"""
uploaded_by: 'User' = Relationship(back_populates='firmwares')
# DTO 定义
class FirmwareDataResponse(FirmwareBase):
"""固件信息响应"""
id: UUID
"""固件ID"""
downloaded_count: int
"""下载次数"""
uploaded_at: datetime
"""上传时间"""
download_url: str | None = None
"""下载地址"""
class FirmwareDataResponseAdmin(FirmwareDataResponse):
"""固件信息响应(管理员)"""
uploaded_by_id: UUID
"""上传者ID"""
class FirmwareUploadRequest(SQLModelBase):
"""固件上传请求"""
chip_type: ChipTypeEnum
"""芯片类型"""
version: str
"""版本号字符串"""
description: str | None = None
"""更新说明"""
class FirmwareCheckUpdateRequest(SQLModelBase):
"""设备检查更新请求"""
chip_type: ChipTypeEnum
"""芯片类型"""
current_version: str
"""当前版本号"""
class FirmwareCheckUpdateResponse(SQLModelBase):
"""检查更新响应"""
has_update: bool
"""是否有可用更新"""
latest_version: str | None = None
"""最新版本号"""
download_url: str | None = None
"""下载地址"""
file_size: int | None = None
"""文件大小"""
file_md5: str | None = None
"""文件MD5"""
description: str | None = None
"""更新说明"""

View File

@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship, String
from pydantic_extra_types.semantic_version import SemanticVersion
from .base import SQLModelBase, UUIDTableBase
from .firmware import ChipTypeEnum
if TYPE_CHECKING:
from .user import User
@@ -41,6 +42,9 @@ class ItemBase(SQLModelBase):
version: SemanticVersion = Field(sa_type=String(64))
"""版本号"""
chip_type: ChipTypeEnum | None = Field(default=None, index=True)
"""ESP设备芯片类型仅当type=esp32时有值"""
class Item(ItemBase, UUIDTableBase, table=True):
expires_at: datetime | None = None
"""物品过期时间"""

View File

@@ -1,5 +1,5 @@
from enum import StrEnum
from typing import ClassVar
from typing import TYPE_CHECKING, ClassVar
import sqlalchemy as sa
from pydantic import EmailStr
@@ -10,6 +10,9 @@ from sqlmodel import Field, Relationship
from .base import SQLModelBase, UUIDTableBase
from .item import Item
if TYPE_CHECKING:
from .firmware import Firmware
class UserTypeEnum(StrEnum):
normal_user = 'normal_user'
@@ -38,6 +41,9 @@ class User(UserBase, UUIDTableBase, table=True):
items: list[Item] = Relationship(back_populates='user', cascade_delete=True)
"""物品关系"""
firmwares: list['Firmware'] = Relationship(back_populates='uploaded_by', cascade_delete=True)
"""上传的固件关系"""
_initializing: ClassVar[bool] = False
"""标记当前是否处于初始化阶段,初始化阶段允许创建 super_admin"""

View File

@@ -1,11 +1,13 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from starlette.status import HTTP_204_NO_CONTENT
from middleware.admin import is_admin
from model import database
from model.response import DefaultResponse
from middleware.dependencies import SessionDep
from model import User, DefaultResponse
from model.firmware import ChipTypeEnum
from services import admin as admin_service
Router = APIRouter(
@@ -38,7 +40,7 @@ async def verity_admin() -> DefaultResponse:
response_description='设置项列表'
)
async def get_settings(
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
session: SessionDep,
name: str | None = None
) -> DefaultResponse:
data = await admin_service.fetch_settings(session=session, name=name)
@@ -53,9 +55,110 @@ async def get_settings(
response_description='更新结果'
)
async def update_settings(
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
session: SessionDep,
name: str,
value: str
) -> DefaultResponse:
result = await admin_service.update_setting_value(session=session, name=name, value=value)
return DefaultResponse(data=result)
# 固件管理接口
@Router.post(
path='/firmware',
summary='上传固件包',
description='管理员上传新的固件更新包',
status_code=HTTP_204_NO_CONTENT,
response_description='上传成功'
)
async def upload_firmware(
session: SessionDep,
admin: Annotated[User, Depends(is_admin)],
chip_type: ChipTypeEnum = Form(..., description='芯片类型'),
version: str = Form(..., description='版本号'),
description: str | None = Form(None, description='更新说明'),
file: UploadFile = File(..., description='固件文件'),
):
"""
上传固件包。
支持的文件格式:.bin
文件大小限制4MB
"""
await admin_service.upload_firmware(
session=session,
admin=admin,
chip_type=chip_type,
version=version,
description=description,
file=file,
)
@Router.get(
path='/firmwares',
summary='获取固件列表',
description='获取已上传的固件列表',
response_model=DefaultResponse,
response_description='固件列表'
)
async def list_firmwares(
session: SessionDep,
admin: Annotated[User, Depends(is_admin)],
chip_type: ChipTypeEnum | None = Query(None, description='筛选芯片类型'),
is_active: bool | None = Query(None, description='筛选启用状态'),
) -> DefaultResponse:
"""
获取固件列表。
"""
result = await admin_service.list_firmwares(
session=session,
chip_type=chip_type,
is_active=is_active,
)
return DefaultResponse(data=result)
@Router.delete(
path='/firmware/{firmware_id}',
summary='删除固件',
description='删除指定的固件包',
status_code=HTTP_204_NO_CONTENT,
response_description='删除成功'
)
async def delete_firmware(
session: SessionDep,
admin: Annotated[User, Depends(is_admin)],
firmware_id: UUID,
):
"""
删除固件包。
"""
await admin_service.delete_firmware(
session=session,
firmware_id=firmware_id,
)
@Router.patch(
path='/firmware/{firmware_id}/status',
summary='切换固件状态',
description='启用或禁用固件',
status_code=HTTP_204_NO_CONTENT,
response_description='操作成功'
)
async def toggle_firmware_status(
session: SessionDep,
admin: Annotated[User, Depends(is_admin)],
firmware_id: UUID,
is_active: bool = Query(..., description='目标状态'),
):
"""
切换固件启用状态。
"""
await admin_service.toggle_firmware_status(
session=session,
firmware_id=firmware_id,
is_active=is_active,
)

98
routes/ota.py Normal file
View File

@@ -0,0 +1,98 @@
"""OTA API 路由,处理 ESP32/8266 设备的在线升级请求。"""
from fastapi import APIRouter, Query, status
from starlette.status import HTTP_204_NO_CONTENT
from middleware.dependencies import SessionDep, DeviceDep
from model import DefaultResponse
from model.firmware import FirmwareCheckUpdateRequest, FirmwareCheckUpdateResponse
from services import ota as ota_service
Router = APIRouter(prefix='/api/ota', tags=['OTA升级'])
@Router.post(
path='/check-update',
summary='检查固件更新',
description='设备通过 mTLS 认证后查询是否有新版本固件',
response_model=DefaultResponse,
response_description='更新检查结果'
)
async def check_update(
session: SessionDep,
device: DeviceDep,
request_data: FirmwareCheckUpdateRequest,
) -> DefaultResponse:
"""
检查固件更新。
设备需要提供有效的 mTLS 客户端证书,证书 CN 字段为设备序列号。
"""
result = await ota_service.check_firmware_update(
session=session,
device=device,
chip_type=request_data.chip_type,
current_version=request_data.current_version,
)
return DefaultResponse(data=result)
@Router.get(
path='/download/{firmware_id}',
summary='下载固件包',
description='下载指定的固件更新包',
)
async def download_firmware(
session: SessionDep,
device: DeviceDep,
firmware_id: str,
):
"""
下载固件包。
需要有效的设备证书,且下载会记录统计信息。
"""
return await ota_service.get_firmware_file(
session=session,
firmware_id=firmware_id,
device=device,
)
@Router.post(
path='/report-version',
summary='上报设备版本',
description='设备上报当前运行的固件版本',
status_code=HTTP_204_NO_CONTENT,
response_description='上报成功'
)
async def report_version(
session: SessionDep,
device: DeviceDep,
version: str = Query(..., description='当前版本号'),
):
"""
上报设备当前运行的固件版本。
"""
await ota_service.update_device_version(
session=session,
device=device,
version=version,
)
@Router.post(
path='/report-lost',
summary='上报设备丢失',
description='设备上报丢失状态',
status_code=HTTP_204_NO_CONTENT,
response_description='上报成功'
)
async def report_lost(
session: SessionDep,
device: DeviceDep,
):
"""
设备上报丢失状态(复用现有丢失处理逻辑)。
"""
await ota_service.report_device_lost(session=session, device=device)

View File

@@ -2,17 +2,30 @@
管理员相关业务逻辑。
"""
import hashlib
from pathlib import Path
from typing import Iterable, List
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import UploadFile
from loguru import logger
from pydantic_extra_types.semantic_version import SemanticVersion
from model import Setting
from model import SettingResponse
from middleware.dependencies import SessionDep
from model import Firmware, User, Setting, SettingResponse
from model.firmware import ChipTypeEnum, FirmwareDataResponseAdmin
from pkg import utils
# 固件存储目录
FIRMWARE_STORAGE_PATH = Path("data/firmware")
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
# 文件大小限制 4MB
MAX_FIRMWARE_SIZE = 4 * 1024 * 1024
async def fetch_settings(
session: AsyncSession,
session: SessionDep,
name: str | None = None,
) -> List[SettingResponse]:
"""
@@ -35,7 +48,7 @@ async def fetch_settings(
async def update_setting_value(
session: AsyncSession,
session: SessionDep,
name: str,
value: str,
) -> bool:
@@ -50,3 +63,171 @@ async def update_setting_value(
await Setting.save(session)
return True
def _calculate_md5(file_path: Path) -> str:
"""计算文件的 MD5 值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
async def upload_firmware(
session: SessionDep,
admin: User,
chip_type: ChipTypeEnum,
version: str,
description: str | None,
file: UploadFile,
) -> None:
"""
上传固件包。
Args:
session: 数据库会话
admin: 管理员用户
chip_type: 芯片类型
version: 版本号
description: 更新说明
file: 上传的文件
"""
# 验证版本号格式
try:
version_obj = SemanticVersion(version)
except ValueError:
utils.raise_bad_request("Invalid semantic version format")
# 验证文件扩展名
if not file.filename or not file.filename.endswith('.bin'):
utils.raise_bad_request("Only .bin files are supported")
# 检查是否已存在相同芯片类型和版本的固件
from sqlalchemy import and_
existing = await Firmware.get(
session,
and_(
Firmware.chip_type == chip_type,
Firmware.version == str(version_obj)
)
)
if existing:
utils.raise_conflict(f"Firmware {chip_type} v{version} already exists")
# 读取文件内容
content = await file.read()
file_size = len(content)
# 验证文件大小
if file_size > MAX_FIRMWARE_SIZE:
utils.raise_bad_request(f"File size exceeds {MAX_FIRMWARE_SIZE} bytes")
if file_size == 0:
utils.raise_bad_request("Empty file")
# 生成文件名
safe_filename = f"{chip_type}_{version}_{file.filename}"
file_path = FIRMWARE_STORAGE_PATH / safe_filename
# 写入文件
with open(file_path, "wb") as f:
f.write(content)
# 计算 MD5
file_md5 = _calculate_md5(file_path)
# 创建数据库记录
firmware = Firmware(
chip_type=chip_type,
version=str(version_obj),
file_path=str(file_path),
file_size=file_size,
file_md5=file_md5,
description=description,
uploaded_by_id=admin.id,
)
await Firmware.add(session, firmware)
logger.info(f"Admin {admin.email} uploaded firmware {chip_type} v{version}")
async def list_firmwares(
session: SessionDep,
chip_type: ChipTypeEnum | None,
is_active: bool | None,
) -> List[FirmwareDataResponseAdmin]:
"""
获取固件列表。
Args:
session: 数据库会话
chip_type: 筛选芯片类型
is_active: 筛选启用状态
Returns:
固件列表
"""
from sqlalchemy import and_
conditions = []
if chip_type:
conditions.append(Firmware.chip_type == chip_type)
if is_active is not None:
conditions.append(Firmware.is_active == is_active)
if conditions:
results = await Firmware.get(session, and_(*conditions), fetch_mode="all")
else:
results = await Firmware.get(session, fetch_mode="all")
if not results:
return []
return [FirmwareDataResponseAdmin.model_validate(fw) for fw in results]
async def delete_firmware(
session: SessionDep,
firmware_id: UUID,
) -> None:
"""
删除固件包。
Args:
session: 数据库会话
firmware_id: 固件ID
"""
firmware = await Firmware.get(session, Firmware.id == firmware_id)
if not firmware:
utils.raise_not_found("Firmware not found")
# 删除文件
file_path = Path(firmware.file_path)
if file_path.exists():
file_path.unlink()
# 删除数据库记录
await Firmware.delete(session, firmware)
async def toggle_firmware_status(
session: SessionDep,
firmware_id: UUID,
is_active: bool,
) -> None:
"""
切换固件启用状态。
Args:
session: 数据库会话
firmware_id: 固件ID
is_active: 目标状态
"""
firmware = await Firmware.get(session, Firmware.id == firmware_id)
if not firmware:
utils.raise_not_found("Firmware not found")
firmware.is_active = is_active
await firmware.save(session)

168
services/ota.py Normal file
View File

@@ -0,0 +1,168 @@
"""OTA 服务层,处理 ESP32/8266 设备的在线升级业务逻辑。"""
from pathlib import Path
from fastapi.responses import FileResponse
from loguru import logger
from pydantic_extra_types.semantic_version import SemanticVersion
from model import Firmware, Item
from model.firmware import ChipTypeEnum, FirmwareCheckUpdateResponse
from middleware.dependencies import SessionDep
from model.item import ItemStatusEnum
from pkg import utils
# 固件存储目录
FIRMWARE_STORAGE_PATH = Path("data/firmware")
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
async def check_firmware_update(
session: SessionDep,
device: Item,
chip_type: ChipTypeEnum,
current_version: str,
) -> FirmwareCheckUpdateResponse:
"""
检查设备是否有可用的固件更新。
Args:
session: 数据库会话
device: 设备对象
chip_type: 芯片类型
current_version: 当前版本号
Returns:
FirmwareCheckUpdateResponse: 更新检查结果
"""
# 验证当前版本格式
try:
current = SemanticVersion(current_version)
except ValueError:
logger.warning(f"Invalid version format from device {device.id}: {current_version}")
utils.raise_bad_request("Invalid version format")
# 查找该芯片类型的最新启用固件
all_firmwares = await Firmware.get(
session,
(Firmware.chip_type == chip_type) & (Firmware.is_active == True),
fetch_mode="all"
)
if not all_firmwares:
return FirmwareCheckUpdateResponse(
has_update=False,
)
# 过滤出比当前版本新的固件
newer_firmwares = []
for fw in all_firmwares:
try:
fw_version = SemanticVersion(str(fw.version))
if fw_version > current:
newer_firmwares.append(fw)
except ValueError:
logger.warning(f"Invalid firmware version in database: {fw.version}")
continue
if not newer_firmwares:
return FirmwareCheckUpdateResponse(
has_update=False,
)
# 取最新版本
latest = max(newer_firmwares, key=lambda fw: SemanticVersion(str(fw.version)))
return FirmwareCheckUpdateResponse(
has_update=True,
latest_version=str(latest.version),
download_url=f"/api/ota/download/{latest.id}",
file_size=latest.file_size,
file_md5=latest.file_md5,
description=latest.description,
)
async def get_firmware_file(
session: SessionDep,
firmware_id: str,
device: Item,
) -> FileResponse:
"""
获取固件文件并更新下载统计。
Args:
session: 数据库会话
firmware_id: 固件ID
device: 设备对象
Returns:
FileResponse: 固件文件响应
"""
from uuid import UUID
firmware = await Firmware.get(session, Firmware.id == UUID(firmware_id))
if not firmware:
utils.raise_not_found("Firmware not found")
if not firmware.is_active:
utils.raise_forbidden("Firmware is not available")
# 验证芯片类型匹配
if device.chip_type != firmware.chip_type:
utils.raise_forbidden("Firmware chip type mismatch")
# 更新下载计数
firmware.downloaded_count += 1
await firmware.save(session)
file_path = Path(firmware.file_path)
if not file_path.exists():
logger.error(f"Firmware file not found: {file_path}")
utils.raise_internal_error("Firmware file not available")
return FileResponse(
path=str(file_path),
filename=file_path.name,
media_type="application/octet-stream",
)
async def update_device_version(
session: SessionDep,
device: Item,
version: str,
) -> None:
"""
更新设备上报的固件版本。
Args:
session: 数据库会话
device: 设备对象
version: 版本号字符串
"""
try:
SemanticVersion(version)
except ValueError:
utils.raise_bad_request("Invalid version format")
device.version = version
await device.save(session)
logger.info(f"Device {device.id} reported version: {version}")
async def report_device_lost(
session: SessionDep,
device: Item,
) -> None:
"""
设备上报丢失状态。
Args:
session: 数据库会话
device: 设备对象
"""
device.status = ItemStatusEnum.lost
await device.save(session)
logger.info(f"Device {device.id} reported as lost")