添加ESP系列设备的OTA
This commit is contained in:
4
app.py
4
app.py
@@ -7,7 +7,7 @@ from slowapi.util import get_remote_address
|
|||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
from pkg.utils import raise_internal_error
|
from pkg.utils import raise_internal_error
|
||||||
from routes import (session, admin, object)
|
from routes import (session, admin, object, ota)
|
||||||
from model.database import Database
|
from model.database import Database
|
||||||
import os
|
import os
|
||||||
import pkg.conf
|
import pkg.conf
|
||||||
@@ -15,7 +15,7 @@ from pkg import utils
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
Router = [admin, session, object]
|
Router = [admin, session, object, ota]
|
||||||
|
|
||||||
# Findreve 的生命周期
|
# Findreve 的生命周期
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from typing import Annotated, TypeAlias
|
from typing import Annotated, TypeAlias
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends, Request
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from model.database import Database
|
from model.database import Database
|
||||||
from model.mixin.table import TableViewRequest
|
from model.mixin.table import TableViewRequest
|
||||||
|
from model import Item
|
||||||
|
from model.item import ItemTypeEnum
|
||||||
|
from pkg import utils
|
||||||
|
|
||||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)]
|
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)]
|
||||||
"""数据库会话依赖,用于路由函数中获取数据库会话"""
|
"""数据库会话依赖,用于路由函数中获取数据库会话"""
|
||||||
@@ -12,3 +15,48 @@ SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)]
|
|||||||
# 新增:表格视图请求依赖(用于分页排序)
|
# 新增:表格视图请求依赖(用于分页排序)
|
||||||
TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends()]
|
TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends()]
|
||||||
"""分页排序请求依赖,用于 LIST 端点"""
|
"""分页排序请求依赖,用于 LIST 端点"""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_device_from_cert(
|
||||||
|
request: Request,
|
||||||
|
session: SessionDep,
|
||||||
|
) -> Item:
|
||||||
|
"""
|
||||||
|
从 mTLS 客户端证书中提取设备序列号并验证设备。
|
||||||
|
|
||||||
|
客户端证书的 CN (Common Name) 字段应存储设备序列号 (UUID)。
|
||||||
|
反向代理(Nginx/Apache)验证证书后,通过 HTTP Header 将 CN 传递给 FastAPI。
|
||||||
|
|
||||||
|
Nginx 配置示例:
|
||||||
|
proxy_set_header X-Client-CN $ssl_client_s_dn_cn;
|
||||||
|
|
||||||
|
Apache 配置示例:
|
||||||
|
RequestHeader set X-Client-CN "%{SSL_CLIENT_S_DN_CN}s"
|
||||||
|
"""
|
||||||
|
# 从 Header 获取设备序列号(由反向代理注入)
|
||||||
|
serial_number = request.headers.get("X-Client-CN")
|
||||||
|
|
||||||
|
if not serial_number:
|
||||||
|
utils.raise_unauthorized("Device certificate required")
|
||||||
|
|
||||||
|
# 验证 UUID 格式
|
||||||
|
try:
|
||||||
|
from uuid import UUID
|
||||||
|
serial_uuid = UUID(serial_number)
|
||||||
|
except ValueError:
|
||||||
|
utils.raise_unauthorized("Invalid device serial number format")
|
||||||
|
|
||||||
|
# 查找设备
|
||||||
|
device = await Item.get(session, Item.id == serial_uuid)
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
utils.raise_not_found("Device not found")
|
||||||
|
|
||||||
|
if device.type != ItemTypeEnum.esp32:
|
||||||
|
utils.raise_forbidden("Not an ESP device")
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
DeviceDep: TypeAlias = Annotated[Item, Depends(get_device_from_cert)]
|
||||||
|
"""设备认证依赖,通过 mTLS 证书验证 ESP 设备"""
|
||||||
|
|||||||
@@ -3,6 +3,15 @@ from .setting import Setting, SettingResponse
|
|||||||
from .item import Item, ItemDataResponse, ItemTypeEnum, ItemStatusEnum
|
from .item import Item, ItemDataResponse, ItemTypeEnum, ItemStatusEnum
|
||||||
from .user import User, UserTypeEnum
|
from .user import User, UserTypeEnum
|
||||||
from .database import Database
|
from .database import Database
|
||||||
|
from .firmware import (
|
||||||
|
Firmware,
|
||||||
|
FirmwareDataResponse,
|
||||||
|
FirmwareDataResponseAdmin,
|
||||||
|
FirmwareUploadRequest,
|
||||||
|
FirmwareCheckUpdateRequest,
|
||||||
|
FirmwareCheckUpdateResponse,
|
||||||
|
ChipTypeEnum,
|
||||||
|
)
|
||||||
|
|
||||||
# 新增:从 foxline 项目移植的 Mixin 组件
|
# 新增:从 foxline 项目移植的 Mixin 组件
|
||||||
from .mixin.table import (
|
from .mixin.table import (
|
||||||
@@ -27,6 +36,14 @@ __all__ = [
|
|||||||
"User",
|
"User",
|
||||||
"UserTypeEnum",
|
"UserTypeEnum",
|
||||||
"Database",
|
"Database",
|
||||||
|
# 固件相关
|
||||||
|
"Firmware",
|
||||||
|
"FirmwareDataResponse",
|
||||||
|
"FirmwareDataResponseAdmin",
|
||||||
|
"FirmwareUploadRequest",
|
||||||
|
"FirmwareCheckUpdateRequest",
|
||||||
|
"FirmwareCheckUpdateResponse",
|
||||||
|
"ChipTypeEnum",
|
||||||
# 新增的 Mixin 组件
|
# 新增的 Mixin 组件
|
||||||
"TableBaseMixin",
|
"TableBaseMixin",
|
||||||
"UUIDTableBaseMixin",
|
"UUIDTableBaseMixin",
|
||||||
|
|||||||
125
model/firmware.py
Normal file
125
model/firmware.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""固件包数据模型,用于 ESP32/8266 OTA 在线升级功能。"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlmodel import Field, Relationship, String, Text
|
||||||
|
|
||||||
|
from .base import SQLModelBase, UUIDTableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
class ChipTypeEnum(StrEnum):
|
||||||
|
"""ESP 芯片类型枚举"""
|
||||||
|
esp32 = 'esp32'
|
||||||
|
esp8266 = 'esp8266'
|
||||||
|
esp32s2 = 'esp32s2'
|
||||||
|
esp32s3 = 'esp32s3'
|
||||||
|
esp32c3 = 'esp32c3'
|
||||||
|
|
||||||
|
|
||||||
|
class FirmwareBase(SQLModelBase):
|
||||||
|
chip_type: ChipTypeEnum = Field(index=True)
|
||||||
|
"""芯片类型"""
|
||||||
|
|
||||||
|
version: str = Field(sa_type=String(64), index=True)
|
||||||
|
"""固件版本号,遵循语义化版本规范"""
|
||||||
|
|
||||||
|
file_path: str
|
||||||
|
"""固件文件存储路径"""
|
||||||
|
|
||||||
|
file_size: int
|
||||||
|
"""固件文件大小(字节)"""
|
||||||
|
|
||||||
|
file_md5: str = Field(max_length=32)
|
||||||
|
"""固件文件 MD5 校验值"""
|
||||||
|
|
||||||
|
description: str | None = Field(default=None, sa_type=Text)
|
||||||
|
"""固件更新说明"""
|
||||||
|
|
||||||
|
is_active: bool = Field(default=True, index=True)
|
||||||
|
"""是否启用该固件版本"""
|
||||||
|
|
||||||
|
|
||||||
|
class Firmware(FirmwareBase, UUIDTableBase, table=True):
|
||||||
|
"""固件包表"""
|
||||||
|
|
||||||
|
uploaded_by_id: UUID = Field(foreign_key='user.id', ondelete='RESTRICT')
|
||||||
|
"""上传者用户ID"""
|
||||||
|
|
||||||
|
downloaded_count: int = Field(default=0)
|
||||||
|
"""下载次数统计"""
|
||||||
|
|
||||||
|
uploaded_at: datetime = Field(default_factory=datetime.now)
|
||||||
|
"""上传时间"""
|
||||||
|
|
||||||
|
uploaded_by: 'User' = Relationship(back_populates='firmwares')
|
||||||
|
|
||||||
|
|
||||||
|
# DTO 定义
|
||||||
|
|
||||||
|
class FirmwareDataResponse(FirmwareBase):
|
||||||
|
"""固件信息响应"""
|
||||||
|
id: UUID
|
||||||
|
"""固件ID"""
|
||||||
|
|
||||||
|
downloaded_count: int
|
||||||
|
"""下载次数"""
|
||||||
|
|
||||||
|
uploaded_at: datetime
|
||||||
|
"""上传时间"""
|
||||||
|
|
||||||
|
download_url: str | None = None
|
||||||
|
"""下载地址"""
|
||||||
|
|
||||||
|
|
||||||
|
class FirmwareDataResponseAdmin(FirmwareDataResponse):
|
||||||
|
"""固件信息响应(管理员)"""
|
||||||
|
uploaded_by_id: UUID
|
||||||
|
"""上传者ID"""
|
||||||
|
|
||||||
|
|
||||||
|
class FirmwareUploadRequest(SQLModelBase):
|
||||||
|
"""固件上传请求"""
|
||||||
|
chip_type: ChipTypeEnum
|
||||||
|
"""芯片类型"""
|
||||||
|
|
||||||
|
version: str
|
||||||
|
"""版本号字符串"""
|
||||||
|
|
||||||
|
description: str | None = None
|
||||||
|
"""更新说明"""
|
||||||
|
|
||||||
|
|
||||||
|
class FirmwareCheckUpdateRequest(SQLModelBase):
|
||||||
|
"""设备检查更新请求"""
|
||||||
|
chip_type: ChipTypeEnum
|
||||||
|
"""芯片类型"""
|
||||||
|
|
||||||
|
current_version: str
|
||||||
|
"""当前版本号"""
|
||||||
|
|
||||||
|
|
||||||
|
class FirmwareCheckUpdateResponse(SQLModelBase):
|
||||||
|
"""检查更新响应"""
|
||||||
|
has_update: bool
|
||||||
|
"""是否有可用更新"""
|
||||||
|
|
||||||
|
latest_version: str | None = None
|
||||||
|
"""最新版本号"""
|
||||||
|
|
||||||
|
download_url: str | None = None
|
||||||
|
"""下载地址"""
|
||||||
|
|
||||||
|
file_size: int | None = None
|
||||||
|
"""文件大小"""
|
||||||
|
|
||||||
|
file_md5: str | None = None
|
||||||
|
"""文件MD5"""
|
||||||
|
|
||||||
|
description: str | None = None
|
||||||
|
"""更新说明"""
|
||||||
@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship, String
|
|||||||
from pydantic_extra_types.semantic_version import SemanticVersion
|
from pydantic_extra_types.semantic_version import SemanticVersion
|
||||||
|
|
||||||
from .base import SQLModelBase, UUIDTableBase
|
from .base import SQLModelBase, UUIDTableBase
|
||||||
|
from .firmware import ChipTypeEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -41,6 +42,9 @@ class ItemBase(SQLModelBase):
|
|||||||
version: SemanticVersion = Field(sa_type=String(64))
|
version: SemanticVersion = Field(sa_type=String(64))
|
||||||
"""版本号"""
|
"""版本号"""
|
||||||
|
|
||||||
|
chip_type: ChipTypeEnum | None = Field(default=None, index=True)
|
||||||
|
"""ESP设备芯片类型,仅当type=esp32时有值"""
|
||||||
|
|
||||||
class Item(ItemBase, UUIDTableBase, table=True):
|
class Item(ItemBase, UUIDTableBase, table=True):
|
||||||
expires_at: datetime | None = None
|
expires_at: datetime | None = None
|
||||||
"""物品过期时间"""
|
"""物品过期时间"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import ClassVar
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import EmailStr
|
from pydantic import EmailStr
|
||||||
@@ -10,6 +10,9 @@ from sqlmodel import Field, Relationship
|
|||||||
from .base import SQLModelBase, UUIDTableBase
|
from .base import SQLModelBase, UUIDTableBase
|
||||||
from .item import Item
|
from .item import Item
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .firmware import Firmware
|
||||||
|
|
||||||
|
|
||||||
class UserTypeEnum(StrEnum):
|
class UserTypeEnum(StrEnum):
|
||||||
normal_user = 'normal_user'
|
normal_user = 'normal_user'
|
||||||
@@ -38,6 +41,9 @@ class User(UserBase, UUIDTableBase, table=True):
|
|||||||
items: list[Item] = Relationship(back_populates='user', cascade_delete=True)
|
items: list[Item] = Relationship(back_populates='user', cascade_delete=True)
|
||||||
"""物品关系"""
|
"""物品关系"""
|
||||||
|
|
||||||
|
firmwares: list['Firmware'] = Relationship(back_populates='uploaded_by', cascade_delete=True)
|
||||||
|
"""上传的固件关系"""
|
||||||
|
|
||||||
_initializing: ClassVar[bool] = False
|
_initializing: ClassVar[bool] = False
|
||||||
"""标记当前是否处于初始化阶段,初始化阶段允许创建 super_admin"""
|
"""标记当前是否处于初始化阶段,初始化阶段允许创建 super_admin"""
|
||||||
|
|
||||||
|
|||||||
115
routes/admin.py
115
routes/admin.py
@@ -1,11 +1,13 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from starlette.status import HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
from middleware.admin import is_admin
|
from middleware.admin import is_admin
|
||||||
from model import database
|
from middleware.dependencies import SessionDep
|
||||||
from model.response import DefaultResponse
|
from model import User, DefaultResponse
|
||||||
|
from model.firmware import ChipTypeEnum
|
||||||
from services import admin as admin_service
|
from services import admin as admin_service
|
||||||
|
|
||||||
Router = APIRouter(
|
Router = APIRouter(
|
||||||
@@ -38,7 +40,7 @@ async def verity_admin() -> DefaultResponse:
|
|||||||
response_description='设置项列表'
|
response_description='设置项列表'
|
||||||
)
|
)
|
||||||
async def get_settings(
|
async def get_settings(
|
||||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
session: SessionDep,
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
) -> DefaultResponse:
|
) -> DefaultResponse:
|
||||||
data = await admin_service.fetch_settings(session=session, name=name)
|
data = await admin_service.fetch_settings(session=session, name=name)
|
||||||
@@ -53,9 +55,110 @@ async def get_settings(
|
|||||||
response_description='更新结果'
|
response_description='更新结果'
|
||||||
)
|
)
|
||||||
async def update_settings(
|
async def update_settings(
|
||||||
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
|
session: SessionDep,
|
||||||
name: str,
|
name: str,
|
||||||
value: str
|
value: str
|
||||||
) -> DefaultResponse:
|
) -> DefaultResponse:
|
||||||
result = await admin_service.update_setting_value(session=session, name=name, value=value)
|
result = await admin_service.update_setting_value(session=session, name=name, value=value)
|
||||||
return DefaultResponse(data=result)
|
return DefaultResponse(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
# 固件管理接口
|
||||||
|
|
||||||
|
@Router.post(
|
||||||
|
path='/firmware',
|
||||||
|
summary='上传固件包',
|
||||||
|
description='管理员上传新的固件更新包',
|
||||||
|
status_code=HTTP_204_NO_CONTENT,
|
||||||
|
response_description='上传成功'
|
||||||
|
)
|
||||||
|
async def upload_firmware(
|
||||||
|
session: SessionDep,
|
||||||
|
admin: Annotated[User, Depends(is_admin)],
|
||||||
|
chip_type: ChipTypeEnum = Form(..., description='芯片类型'),
|
||||||
|
version: str = Form(..., description='版本号'),
|
||||||
|
description: str | None = Form(None, description='更新说明'),
|
||||||
|
file: UploadFile = File(..., description='固件文件'),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传固件包。
|
||||||
|
|
||||||
|
支持的文件格式:.bin
|
||||||
|
文件大小限制:4MB
|
||||||
|
"""
|
||||||
|
await admin_service.upload_firmware(
|
||||||
|
session=session,
|
||||||
|
admin=admin,
|
||||||
|
chip_type=chip_type,
|
||||||
|
version=version,
|
||||||
|
description=description,
|
||||||
|
file=file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@Router.get(
|
||||||
|
path='/firmwares',
|
||||||
|
summary='获取固件列表',
|
||||||
|
description='获取已上传的固件列表',
|
||||||
|
response_model=DefaultResponse,
|
||||||
|
response_description='固件列表'
|
||||||
|
)
|
||||||
|
async def list_firmwares(
|
||||||
|
session: SessionDep,
|
||||||
|
admin: Annotated[User, Depends(is_admin)],
|
||||||
|
chip_type: ChipTypeEnum | None = Query(None, description='筛选芯片类型'),
|
||||||
|
is_active: bool | None = Query(None, description='筛选启用状态'),
|
||||||
|
) -> DefaultResponse:
|
||||||
|
"""
|
||||||
|
获取固件列表。
|
||||||
|
"""
|
||||||
|
result = await admin_service.list_firmwares(
|
||||||
|
session=session,
|
||||||
|
chip_type=chip_type,
|
||||||
|
is_active=is_active,
|
||||||
|
)
|
||||||
|
return DefaultResponse(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@Router.delete(
|
||||||
|
path='/firmware/{firmware_id}',
|
||||||
|
summary='删除固件',
|
||||||
|
description='删除指定的固件包',
|
||||||
|
status_code=HTTP_204_NO_CONTENT,
|
||||||
|
response_description='删除成功'
|
||||||
|
)
|
||||||
|
async def delete_firmware(
|
||||||
|
session: SessionDep,
|
||||||
|
admin: Annotated[User, Depends(is_admin)],
|
||||||
|
firmware_id: UUID,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除固件包。
|
||||||
|
"""
|
||||||
|
await admin_service.delete_firmware(
|
||||||
|
session=session,
|
||||||
|
firmware_id=firmware_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@Router.patch(
|
||||||
|
path='/firmware/{firmware_id}/status',
|
||||||
|
summary='切换固件状态',
|
||||||
|
description='启用或禁用固件',
|
||||||
|
status_code=HTTP_204_NO_CONTENT,
|
||||||
|
response_description='操作成功'
|
||||||
|
)
|
||||||
|
async def toggle_firmware_status(
|
||||||
|
session: SessionDep,
|
||||||
|
admin: Annotated[User, Depends(is_admin)],
|
||||||
|
firmware_id: UUID,
|
||||||
|
is_active: bool = Query(..., description='目标状态'),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
切换固件启用状态。
|
||||||
|
"""
|
||||||
|
await admin_service.toggle_firmware_status(
|
||||||
|
session=session,
|
||||||
|
firmware_id=firmware_id,
|
||||||
|
is_active=is_active,
|
||||||
|
)
|
||||||
|
|||||||
98
routes/ota.py
Normal file
98
routes/ota.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""OTA API 路由,处理 ESP32/8266 设备的在线升级请求。"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, status
|
||||||
|
from starlette.status import HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
from middleware.dependencies import SessionDep, DeviceDep
|
||||||
|
from model import DefaultResponse
|
||||||
|
from model.firmware import FirmwareCheckUpdateRequest, FirmwareCheckUpdateResponse
|
||||||
|
from services import ota as ota_service
|
||||||
|
|
||||||
|
Router = APIRouter(prefix='/api/ota', tags=['OTA升级'])
|
||||||
|
|
||||||
|
|
||||||
|
@Router.post(
|
||||||
|
path='/check-update',
|
||||||
|
summary='检查固件更新',
|
||||||
|
description='设备通过 mTLS 认证后查询是否有新版本固件',
|
||||||
|
response_model=DefaultResponse,
|
||||||
|
response_description='更新检查结果'
|
||||||
|
)
|
||||||
|
async def check_update(
|
||||||
|
session: SessionDep,
|
||||||
|
device: DeviceDep,
|
||||||
|
request_data: FirmwareCheckUpdateRequest,
|
||||||
|
) -> DefaultResponse:
|
||||||
|
"""
|
||||||
|
检查固件更新。
|
||||||
|
|
||||||
|
设备需要提供有效的 mTLS 客户端证书,证书 CN 字段为设备序列号。
|
||||||
|
"""
|
||||||
|
result = await ota_service.check_firmware_update(
|
||||||
|
session=session,
|
||||||
|
device=device,
|
||||||
|
chip_type=request_data.chip_type,
|
||||||
|
current_version=request_data.current_version,
|
||||||
|
)
|
||||||
|
return DefaultResponse(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@Router.get(
|
||||||
|
path='/download/{firmware_id}',
|
||||||
|
summary='下载固件包',
|
||||||
|
description='下载指定的固件更新包',
|
||||||
|
)
|
||||||
|
async def download_firmware(
|
||||||
|
session: SessionDep,
|
||||||
|
device: DeviceDep,
|
||||||
|
firmware_id: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
下载固件包。
|
||||||
|
|
||||||
|
需要有效的设备证书,且下载会记录统计信息。
|
||||||
|
"""
|
||||||
|
return await ota_service.get_firmware_file(
|
||||||
|
session=session,
|
||||||
|
firmware_id=firmware_id,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@Router.post(
|
||||||
|
path='/report-version',
|
||||||
|
summary='上报设备版本',
|
||||||
|
description='设备上报当前运行的固件版本',
|
||||||
|
status_code=HTTP_204_NO_CONTENT,
|
||||||
|
response_description='上报成功'
|
||||||
|
)
|
||||||
|
async def report_version(
|
||||||
|
session: SessionDep,
|
||||||
|
device: DeviceDep,
|
||||||
|
version: str = Query(..., description='当前版本号'),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上报设备当前运行的固件版本。
|
||||||
|
"""
|
||||||
|
await ota_service.update_device_version(
|
||||||
|
session=session,
|
||||||
|
device=device,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@Router.post(
|
||||||
|
path='/report-lost',
|
||||||
|
summary='上报设备丢失',
|
||||||
|
description='设备上报丢失状态',
|
||||||
|
status_code=HTTP_204_NO_CONTENT,
|
||||||
|
response_description='上报成功'
|
||||||
|
)
|
||||||
|
async def report_lost(
|
||||||
|
session: SessionDep,
|
||||||
|
device: DeviceDep,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
设备上报丢失状态(复用现有丢失处理逻辑)。
|
||||||
|
"""
|
||||||
|
await ota_service.report_device_lost(session=session, device=device)
|
||||||
@@ -2,17 +2,30 @@
|
|||||||
管理员相关业务逻辑。
|
管理员相关业务逻辑。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
from typing import Iterable, List
|
from typing import Iterable, List
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from fastapi import UploadFile
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic_extra_types.semantic_version import SemanticVersion
|
||||||
|
|
||||||
from model import Setting
|
from middleware.dependencies import SessionDep
|
||||||
from model import SettingResponse
|
from model import Firmware, User, Setting, SettingResponse
|
||||||
|
from model.firmware import ChipTypeEnum, FirmwareDataResponseAdmin
|
||||||
from pkg import utils
|
from pkg import utils
|
||||||
|
|
||||||
|
# 固件存储目录
|
||||||
|
FIRMWARE_STORAGE_PATH = Path("data/firmware")
|
||||||
|
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 文件大小限制 4MB
|
||||||
|
MAX_FIRMWARE_SIZE = 4 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
async def fetch_settings(
|
async def fetch_settings(
|
||||||
session: AsyncSession,
|
session: SessionDep,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
) -> List[SettingResponse]:
|
) -> List[SettingResponse]:
|
||||||
"""
|
"""
|
||||||
@@ -35,7 +48,7 @@ async def fetch_settings(
|
|||||||
|
|
||||||
|
|
||||||
async def update_setting_value(
|
async def update_setting_value(
|
||||||
session: AsyncSession,
|
session: SessionDep,
|
||||||
name: str,
|
name: str,
|
||||||
value: str,
|
value: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -50,3 +63,171 @@ async def update_setting_value(
|
|||||||
await Setting.save(session)
|
await Setting.save(session)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_md5(file_path: Path) -> str:
|
||||||
|
"""计算文件的 MD5 值"""
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_firmware(
|
||||||
|
session: SessionDep,
|
||||||
|
admin: User,
|
||||||
|
chip_type: ChipTypeEnum,
|
||||||
|
version: str,
|
||||||
|
description: str | None,
|
||||||
|
file: UploadFile,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
上传固件包。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
admin: 管理员用户
|
||||||
|
chip_type: 芯片类型
|
||||||
|
version: 版本号
|
||||||
|
description: 更新说明
|
||||||
|
file: 上传的文件
|
||||||
|
"""
|
||||||
|
# 验证版本号格式
|
||||||
|
try:
|
||||||
|
version_obj = SemanticVersion(version)
|
||||||
|
except ValueError:
|
||||||
|
utils.raise_bad_request("Invalid semantic version format")
|
||||||
|
|
||||||
|
# 验证文件扩展名
|
||||||
|
if not file.filename or not file.filename.endswith('.bin'):
|
||||||
|
utils.raise_bad_request("Only .bin files are supported")
|
||||||
|
|
||||||
|
# 检查是否已存在相同芯片类型和版本的固件
|
||||||
|
from sqlalchemy import and_
|
||||||
|
existing = await Firmware.get(
|
||||||
|
session,
|
||||||
|
and_(
|
||||||
|
Firmware.chip_type == chip_type,
|
||||||
|
Firmware.version == str(version_obj)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
utils.raise_conflict(f"Firmware {chip_type} v{version} already exists")
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
# 验证文件大小
|
||||||
|
if file_size > MAX_FIRMWARE_SIZE:
|
||||||
|
utils.raise_bad_request(f"File size exceeds {MAX_FIRMWARE_SIZE} bytes")
|
||||||
|
|
||||||
|
if file_size == 0:
|
||||||
|
utils.raise_bad_request("Empty file")
|
||||||
|
|
||||||
|
# 生成文件名
|
||||||
|
safe_filename = f"{chip_type}_{version}_{file.filename}"
|
||||||
|
file_path = FIRMWARE_STORAGE_PATH / safe_filename
|
||||||
|
|
||||||
|
# 写入文件
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# 计算 MD5
|
||||||
|
file_md5 = _calculate_md5(file_path)
|
||||||
|
|
||||||
|
# 创建数据库记录
|
||||||
|
firmware = Firmware(
|
||||||
|
chip_type=chip_type,
|
||||||
|
version=str(version_obj),
|
||||||
|
file_path=str(file_path),
|
||||||
|
file_size=file_size,
|
||||||
|
file_md5=file_md5,
|
||||||
|
description=description,
|
||||||
|
uploaded_by_id=admin.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await Firmware.add(session, firmware)
|
||||||
|
logger.info(f"Admin {admin.email} uploaded firmware {chip_type} v{version}")
|
||||||
|
|
||||||
|
|
||||||
|
async def list_firmwares(
|
||||||
|
session: SessionDep,
|
||||||
|
chip_type: ChipTypeEnum | None,
|
||||||
|
is_active: bool | None,
|
||||||
|
) -> List[FirmwareDataResponseAdmin]:
|
||||||
|
"""
|
||||||
|
获取固件列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
chip_type: 筛选芯片类型
|
||||||
|
is_active: 筛选启用状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
固件列表
|
||||||
|
"""
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
if chip_type:
|
||||||
|
conditions.append(Firmware.chip_type == chip_type)
|
||||||
|
if is_active is not None:
|
||||||
|
conditions.append(Firmware.is_active == is_active)
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
results = await Firmware.get(session, and_(*conditions), fetch_mode="all")
|
||||||
|
else:
|
||||||
|
results = await Firmware.get(session, fetch_mode="all")
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [FirmwareDataResponseAdmin.model_validate(fw) for fw in results]
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_firmware(
|
||||||
|
session: SessionDep,
|
||||||
|
firmware_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
删除固件包。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
firmware_id: 固件ID
|
||||||
|
"""
|
||||||
|
firmware = await Firmware.get(session, Firmware.id == firmware_id)
|
||||||
|
if not firmware:
|
||||||
|
utils.raise_not_found("Firmware not found")
|
||||||
|
|
||||||
|
# 删除文件
|
||||||
|
file_path = Path(firmware.file_path)
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
# 删除数据库记录
|
||||||
|
await Firmware.delete(session, firmware)
|
||||||
|
|
||||||
|
|
||||||
|
async def toggle_firmware_status(
|
||||||
|
session: SessionDep,
|
||||||
|
firmware_id: UUID,
|
||||||
|
is_active: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
切换固件启用状态。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
firmware_id: 固件ID
|
||||||
|
is_active: 目标状态
|
||||||
|
"""
|
||||||
|
firmware = await Firmware.get(session, Firmware.id == firmware_id)
|
||||||
|
if not firmware:
|
||||||
|
utils.raise_not_found("Firmware not found")
|
||||||
|
|
||||||
|
firmware.is_active = is_active
|
||||||
|
await firmware.save(session)
|
||||||
|
|||||||
168
services/ota.py
Normal file
168
services/ota.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""OTA 服务层,处理 ESP32/8266 设备的在线升级业务逻辑。"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic_extra_types.semantic_version import SemanticVersion
|
||||||
|
|
||||||
|
from model import Firmware, Item
|
||||||
|
from model.firmware import ChipTypeEnum, FirmwareCheckUpdateResponse
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from model.item import ItemStatusEnum
|
||||||
|
from pkg import utils
|
||||||
|
|
||||||
|
# 固件存储目录
|
||||||
|
FIRMWARE_STORAGE_PATH = Path("data/firmware")
|
||||||
|
FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_firmware_update(
|
||||||
|
session: SessionDep,
|
||||||
|
device: Item,
|
||||||
|
chip_type: ChipTypeEnum,
|
||||||
|
current_version: str,
|
||||||
|
) -> FirmwareCheckUpdateResponse:
|
||||||
|
"""
|
||||||
|
检查设备是否有可用的固件更新。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
device: 设备对象
|
||||||
|
chip_type: 芯片类型
|
||||||
|
current_version: 当前版本号
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FirmwareCheckUpdateResponse: 更新检查结果
|
||||||
|
"""
|
||||||
|
# 验证当前版本格式
|
||||||
|
try:
|
||||||
|
current = SemanticVersion(current_version)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid version format from device {device.id}: {current_version}")
|
||||||
|
utils.raise_bad_request("Invalid version format")
|
||||||
|
|
||||||
|
# 查找该芯片类型的最新启用固件
|
||||||
|
all_firmwares = await Firmware.get(
|
||||||
|
session,
|
||||||
|
(Firmware.chip_type == chip_type) & (Firmware.is_active == True),
|
||||||
|
fetch_mode="all"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not all_firmwares:
|
||||||
|
return FirmwareCheckUpdateResponse(
|
||||||
|
has_update=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤出比当前版本新的固件
|
||||||
|
newer_firmwares = []
|
||||||
|
for fw in all_firmwares:
|
||||||
|
try:
|
||||||
|
fw_version = SemanticVersion(str(fw.version))
|
||||||
|
if fw_version > current:
|
||||||
|
newer_firmwares.append(fw)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid firmware version in database: {fw.version}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not newer_firmwares:
|
||||||
|
return FirmwareCheckUpdateResponse(
|
||||||
|
has_update=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 取最新版本
|
||||||
|
latest = max(newer_firmwares, key=lambda fw: SemanticVersion(str(fw.version)))
|
||||||
|
|
||||||
|
return FirmwareCheckUpdateResponse(
|
||||||
|
has_update=True,
|
||||||
|
latest_version=str(latest.version),
|
||||||
|
download_url=f"/api/ota/download/{latest.id}",
|
||||||
|
file_size=latest.file_size,
|
||||||
|
file_md5=latest.file_md5,
|
||||||
|
description=latest.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_firmware_file(
|
||||||
|
session: SessionDep,
|
||||||
|
firmware_id: str,
|
||||||
|
device: Item,
|
||||||
|
) -> FileResponse:
|
||||||
|
"""
|
||||||
|
获取固件文件并更新下载统计。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
firmware_id: 固件ID
|
||||||
|
device: 设备对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FileResponse: 固件文件响应
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
firmware = await Firmware.get(session, Firmware.id == UUID(firmware_id))
|
||||||
|
|
||||||
|
if not firmware:
|
||||||
|
utils.raise_not_found("Firmware not found")
|
||||||
|
|
||||||
|
if not firmware.is_active:
|
||||||
|
utils.raise_forbidden("Firmware is not available")
|
||||||
|
|
||||||
|
# 验证芯片类型匹配
|
||||||
|
if device.chip_type != firmware.chip_type:
|
||||||
|
utils.raise_forbidden("Firmware chip type mismatch")
|
||||||
|
|
||||||
|
# 更新下载计数
|
||||||
|
firmware.downloaded_count += 1
|
||||||
|
await firmware.save(session)
|
||||||
|
|
||||||
|
file_path = Path(firmware.file_path)
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.error(f"Firmware file not found: {file_path}")
|
||||||
|
utils.raise_internal_error("Firmware file not available")
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=str(file_path),
|
||||||
|
filename=file_path.name,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_device_version(
|
||||||
|
session: SessionDep,
|
||||||
|
device: Item,
|
||||||
|
version: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
更新设备上报的固件版本。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
device: 设备对象
|
||||||
|
version: 版本号字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
SemanticVersion(version)
|
||||||
|
except ValueError:
|
||||||
|
utils.raise_bad_request("Invalid version format")
|
||||||
|
|
||||||
|
device.version = version
|
||||||
|
await device.save(session)
|
||||||
|
logger.info(f"Device {device.id} reported version: {version}")
|
||||||
|
|
||||||
|
|
||||||
|
async def report_device_lost(
|
||||||
|
session: SessionDep,
|
||||||
|
device: Item,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
设备上报丢失状态。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
device: 设备对象
|
||||||
|
"""
|
||||||
|
device.status = ItemStatusEnum.lost
|
||||||
|
await device.save(session)
|
||||||
|
logger.info(f"Device {device.id} reported as lost")
|
||||||
Reference in New Issue
Block a user