添加ESP系列设备的OTA

This commit is contained in:
2026-01-12 11:44:55 +08:00
parent 3580717087
commit 3f1bd0731b
10 changed files with 765 additions and 15 deletions

View File

@@ -2,17 +2,30 @@
管理员相关业务逻辑。
"""
import hashlib
from pathlib import Path
from typing import Iterable, List
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import UploadFile
from loguru import logger
from pydantic_extra_types.semantic_version import SemanticVersion
from model import Setting
from model import SettingResponse
from middleware.dependencies import SessionDep
from model import Firmware, User, Setting, SettingResponse
from model.firmware import ChipTypeEnum, FirmwareDataResponseAdmin
from pkg import utils
# 固件存储目录
FIRMWARE_STORAGE_PATH = Path("data/firmware")
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
# 文件大小限制 4MB
MAX_FIRMWARE_SIZE = 4 * 1024 * 1024
async def fetch_settings(
session: AsyncSession,
session: SessionDep,
name: str | None = None,
) -> List[SettingResponse]:
"""
@@ -35,7 +48,7 @@ async def fetch_settings(
async def update_setting_value(
session: AsyncSession,
session: SessionDep,
name: str,
value: str,
) -> bool:
@@ -50,3 +63,171 @@ async def update_setting_value(
await Setting.save(session)
return True
def _calculate_md5(file_path: Path) -> str:
"""计算文件的 MD5 值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
async def upload_firmware(
session: SessionDep,
admin: User,
chip_type: ChipTypeEnum,
version: str,
description: str | None,
file: UploadFile,
) -> None:
"""
上传固件包。
Args:
session: 数据库会话
admin: 管理员用户
chip_type: 芯片类型
version: 版本号
description: 更新说明
file: 上传的文件
"""
# 验证版本号格式
try:
version_obj = SemanticVersion(version)
except ValueError:
utils.raise_bad_request("Invalid semantic version format")
# 验证文件扩展名
if not file.filename or not file.filename.endswith('.bin'):
utils.raise_bad_request("Only .bin files are supported")
# 检查是否已存在相同芯片类型和版本的固件
from sqlalchemy import and_
existing = await Firmware.get(
session,
and_(
Firmware.chip_type == chip_type,
Firmware.version == str(version_obj)
)
)
if existing:
utils.raise_conflict(f"Firmware {chip_type} v{version} already exists")
# 读取文件内容
content = await file.read()
file_size = len(content)
# 验证文件大小
if file_size > MAX_FIRMWARE_SIZE:
utils.raise_bad_request(f"File size exceeds {MAX_FIRMWARE_SIZE} bytes")
if file_size == 0:
utils.raise_bad_request("Empty file")
# 生成文件名
safe_filename = f"{chip_type}_{version}_{file.filename}"
file_path = FIRMWARE_STORAGE_PATH / safe_filename
# 写入文件
with open(file_path, "wb") as f:
f.write(content)
# 计算 MD5
file_md5 = _calculate_md5(file_path)
# 创建数据库记录
firmware = Firmware(
chip_type=chip_type,
version=str(version_obj),
file_path=str(file_path),
file_size=file_size,
file_md5=file_md5,
description=description,
uploaded_by_id=admin.id,
)
await Firmware.add(session, firmware)
logger.info(f"Admin {admin.email} uploaded firmware {chip_type} v{version}")
async def list_firmwares(
session: SessionDep,
chip_type: ChipTypeEnum | None,
is_active: bool | None,
) -> List[FirmwareDataResponseAdmin]:
"""
获取固件列表。
Args:
session: 数据库会话
chip_type: 筛选芯片类型
is_active: 筛选启用状态
Returns:
固件列表
"""
from sqlalchemy import and_
conditions = []
if chip_type:
conditions.append(Firmware.chip_type == chip_type)
if is_active is not None:
conditions.append(Firmware.is_active == is_active)
if conditions:
results = await Firmware.get(session, and_(*conditions), fetch_mode="all")
else:
results = await Firmware.get(session, fetch_mode="all")
if not results:
return []
return [FirmwareDataResponseAdmin.model_validate(fw) for fw in results]
async def delete_firmware(
session: SessionDep,
firmware_id: UUID,
) -> None:
"""
删除固件包。
Args:
session: 数据库会话
firmware_id: 固件ID
"""
firmware = await Firmware.get(session, Firmware.id == firmware_id)
if not firmware:
utils.raise_not_found("Firmware not found")
# 删除文件
file_path = Path(firmware.file_path)
if file_path.exists():
file_path.unlink()
# 删除数据库记录
await Firmware.delete(session, firmware)
async def toggle_firmware_status(
session: SessionDep,
firmware_id: UUID,
is_active: bool,
) -> None:
"""
切换固件启用状态。
Args:
session: 数据库会话
firmware_id: 固件ID
is_active: 目标状态
"""
firmware = await Firmware.get(session, Firmware.id == firmware_id)
if not firmware:
utils.raise_not_found("Firmware not found")
firmware.is_active = is_active
await firmware.save(session)

168
services/ota.py Normal file
View File

@@ -0,0 +1,168 @@
"""OTA 服务层,处理 ESP32/8266 设备的在线升级业务逻辑。"""
from pathlib import Path
from fastapi.responses import FileResponse
from loguru import logger
from pydantic_extra_types.semantic_version import SemanticVersion
from model import Firmware, Item
from model.firmware import ChipTypeEnum, FirmwareCheckUpdateResponse
from middleware.dependencies import SessionDep
from model.item import ItemStatusEnum
from pkg import utils
# 固件存储目录
FIRMWARE_STORAGE_PATH = Path("data/firmware")
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
async def check_firmware_update(
session: SessionDep,
device: Item,
chip_type: ChipTypeEnum,
current_version: str,
) -> FirmwareCheckUpdateResponse:
"""
检查设备是否有可用的固件更新。
Args:
session: 数据库会话
device: 设备对象
chip_type: 芯片类型
current_version: 当前版本号
Returns:
FirmwareCheckUpdateResponse: 更新检查结果
"""
# 验证当前版本格式
try:
current = SemanticVersion(current_version)
except ValueError:
logger.warning(f"Invalid version format from device {device.id}: {current_version}")
utils.raise_bad_request("Invalid version format")
# 查找该芯片类型的最新启用固件
all_firmwares = await Firmware.get(
session,
(Firmware.chip_type == chip_type) & (Firmware.is_active == True),
fetch_mode="all"
)
if not all_firmwares:
return FirmwareCheckUpdateResponse(
has_update=False,
)
# 过滤出比当前版本新的固件
newer_firmwares = []
for fw in all_firmwares:
try:
fw_version = SemanticVersion(str(fw.version))
if fw_version > current:
newer_firmwares.append(fw)
except ValueError:
logger.warning(f"Invalid firmware version in database: {fw.version}")
continue
if not newer_firmwares:
return FirmwareCheckUpdateResponse(
has_update=False,
)
# 取最新版本
latest = max(newer_firmwares, key=lambda fw: SemanticVersion(str(fw.version)))
return FirmwareCheckUpdateResponse(
has_update=True,
latest_version=str(latest.version),
download_url=f"/api/ota/download/{latest.id}",
file_size=latest.file_size,
file_md5=latest.file_md5,
description=latest.description,
)
async def get_firmware_file(
session: SessionDep,
firmware_id: str,
device: Item,
) -> FileResponse:
"""
获取固件文件并更新下载统计。
Args:
session: 数据库会话
firmware_id: 固件ID
device: 设备对象
Returns:
FileResponse: 固件文件响应
"""
from uuid import UUID
firmware = await Firmware.get(session, Firmware.id == UUID(firmware_id))
if not firmware:
utils.raise_not_found("Firmware not found")
if not firmware.is_active:
utils.raise_forbidden("Firmware is not available")
# 验证芯片类型匹配
if device.chip_type != firmware.chip_type:
utils.raise_forbidden("Firmware chip type mismatch")
# 更新下载计数
firmware.downloaded_count += 1
await firmware.save(session)
file_path = Path(firmware.file_path)
if not file_path.exists():
logger.error(f"Firmware file not found: {file_path}")
utils.raise_internal_error("Firmware file not available")
return FileResponse(
path=str(file_path),
filename=file_path.name,
media_type="application/octet-stream",
)
async def update_device_version(
session: SessionDep,
device: Item,
version: str,
) -> None:
"""
更新设备上报的固件版本。
Args:
session: 数据库会话
device: 设备对象
version: 版本号字符串
"""
try:
SemanticVersion(version)
except ValueError:
utils.raise_bad_request("Invalid version format")
device.version = version
await device.save(session)
logger.info(f"Device {device.id} reported version: {version}")
async def report_device_lost(
session: SessionDep,
device: Item,
) -> None:
"""
设备上报丢失状态。
Args:
session: 数据库会话
device: 设备对象
"""
device.status = ItemStatusEnum.lost
await device.save(session)
logger.info(f"Device {device.id} reported as lost")