169 lines
4.4 KiB
Python
169 lines
4.4 KiB
Python
"""OTA 服务层,处理 ESP32/8266 设备的在线升级业务逻辑。"""
|
|
|
|
from pathlib import Path
|
|
|
|
from fastapi.responses import FileResponse
|
|
from loguru import logger
|
|
from pydantic_extra_types.semantic_version import SemanticVersion
|
|
|
|
from model import Firmware, Item
|
|
from model.firmware import ChipTypeEnum, FirmwareCheckUpdateResponse
|
|
from middleware.dependencies import SessionDep
|
|
from model.item import ItemStatusEnum
|
|
from pkg import utils
|
|
|
|
# 固件存储目录
|
|
FIRMWARE_STORAGE_PATH = Path("data/firmware")
|
|
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
async def check_firmware_update(
|
|
session: SessionDep,
|
|
device: Item,
|
|
chip_type: ChipTypeEnum,
|
|
current_version: str,
|
|
) -> FirmwareCheckUpdateResponse:
|
|
"""
|
|
检查设备是否有可用的固件更新。
|
|
|
|
Args:
|
|
session: 数据库会话
|
|
device: 设备对象
|
|
chip_type: 芯片类型
|
|
current_version: 当前版本号
|
|
|
|
Returns:
|
|
FirmwareCheckUpdateResponse: 更新检查结果
|
|
"""
|
|
# 验证当前版本格式
|
|
try:
|
|
current = SemanticVersion(current_version)
|
|
except ValueError:
|
|
logger.warning(f"Invalid version format from device {device.id}: {current_version}")
|
|
utils.raise_bad_request("Invalid version format")
|
|
|
|
# 查找该芯片类型的最新启用固件
|
|
all_firmwares = await Firmware.get(
|
|
session,
|
|
(Firmware.chip_type == chip_type) & (Firmware.is_active == True),
|
|
fetch_mode="all"
|
|
)
|
|
|
|
if not all_firmwares:
|
|
return FirmwareCheckUpdateResponse(
|
|
has_update=False,
|
|
)
|
|
|
|
# 过滤出比当前版本新的固件
|
|
newer_firmwares = []
|
|
for fw in all_firmwares:
|
|
try:
|
|
fw_version = SemanticVersion(str(fw.version))
|
|
if fw_version > current:
|
|
newer_firmwares.append(fw)
|
|
except ValueError:
|
|
logger.warning(f"Invalid firmware version in database: {fw.version}")
|
|
continue
|
|
|
|
if not newer_firmwares:
|
|
return FirmwareCheckUpdateResponse(
|
|
has_update=False,
|
|
)
|
|
|
|
# 取最新版本
|
|
latest = max(newer_firmwares, key=lambda fw: SemanticVersion(str(fw.version)))
|
|
|
|
return FirmwareCheckUpdateResponse(
|
|
has_update=True,
|
|
latest_version=str(latest.version),
|
|
download_url=f"/api/ota/download/{latest.id}",
|
|
file_size=latest.file_size,
|
|
file_md5=latest.file_md5,
|
|
description=latest.description,
|
|
)
|
|
|
|
|
|
async def get_firmware_file(
|
|
session: SessionDep,
|
|
firmware_id: str,
|
|
device: Item,
|
|
) -> FileResponse:
|
|
"""
|
|
获取固件文件并更新下载统计。
|
|
|
|
Args:
|
|
session: 数据库会话
|
|
firmware_id: 固件ID
|
|
device: 设备对象
|
|
|
|
Returns:
|
|
FileResponse: 固件文件响应
|
|
"""
|
|
from uuid import UUID
|
|
|
|
firmware = await Firmware.get(session, Firmware.id == UUID(firmware_id))
|
|
|
|
if not firmware:
|
|
utils.raise_not_found("Firmware not found")
|
|
|
|
if not firmware.is_active:
|
|
utils.raise_forbidden("Firmware is not available")
|
|
|
|
# 验证芯片类型匹配
|
|
if device.chip_type != firmware.chip_type:
|
|
utils.raise_forbidden("Firmware chip type mismatch")
|
|
|
|
# 更新下载计数
|
|
firmware.downloaded_count += 1
|
|
await firmware.save(session)
|
|
|
|
file_path = Path(firmware.file_path)
|
|
if not file_path.exists():
|
|
logger.error(f"Firmware file not found: {file_path}")
|
|
utils.raise_internal_error("Firmware file not available")
|
|
|
|
return FileResponse(
|
|
path=str(file_path),
|
|
filename=file_path.name,
|
|
media_type="application/octet-stream",
|
|
)
|
|
|
|
|
|
async def update_device_version(
|
|
session: SessionDep,
|
|
device: Item,
|
|
version: str,
|
|
) -> None:
|
|
"""
|
|
更新设备上报的固件版本。
|
|
|
|
Args:
|
|
session: 数据库会话
|
|
device: 设备对象
|
|
version: 版本号字符串
|
|
"""
|
|
try:
|
|
SemanticVersion(version)
|
|
except ValueError:
|
|
utils.raise_bad_request("Invalid version format")
|
|
|
|
device.version = version
|
|
await device.save(session)
|
|
logger.info(f"Device {device.id} reported version: {version}")
|
|
|
|
|
|
async def report_device_lost(
|
|
session: SessionDep,
|
|
device: Item,
|
|
) -> None:
|
|
"""
|
|
设备上报丢失状态。
|
|
|
|
Args:
|
|
session: 数据库会话
|
|
device: 设备对象
|
|
"""
|
|
device.status = ItemStatusEnum.lost
|
|
await device.save(session)
|
|
logger.info(f"Device {device.id} reported as lost")
|