diff --git a/app.py b/app.py index 614920a..3c949bc 100644 --- a/app.py +++ b/app.py @@ -7,7 +7,7 @@ from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from pkg.utils import raise_internal_error -from routes import (session, admin, object) +from routes import (session, admin, object, ota) from model.database import Database import os import pkg.conf @@ -15,7 +15,7 @@ from pkg import utils from loguru import logger -Router = [admin, session, object] +Router = [admin, session, object, ota] # Findreve 的生命周期 @asynccontextmanager diff --git a/middleware/dependencies.py b/middleware/dependencies.py index e34d161..e54ed29 100644 --- a/middleware/dependencies.py +++ b/middleware/dependencies.py @@ -1,10 +1,13 @@ from typing import Annotated, TypeAlias -from fastapi import Depends +from fastapi import Depends, Request from sqlmodel.ext.asyncio.session import AsyncSession from model.database import Database from model.mixin.table import TableViewRequest +from model import Item +from model.item import ItemTypeEnum +from pkg import utils SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)] """数据库会话依赖,用于路由函数中获取数据库会话""" @@ -12,3 +15,48 @@ SessionDep: TypeAlias = Annotated[AsyncSession, Depends(Database.get_session)] # 新增:表格视图请求依赖(用于分页排序) TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends()] """分页排序请求依赖,用于 LIST 端点""" + + +async def get_device_from_cert( + request: Request, + session: SessionDep, +) -> Item: + """ + 从 mTLS 客户端证书中提取设备序列号并验证设备。 + + 客户端证书的 CN (Common Name) 字段应存储设备序列号 (UUID)。 + 反向代理(Nginx/Apache)验证证书后,通过 HTTP Header 将 CN 传递给 FastAPI。 + + Nginx 配置示例: + proxy_set_header X-Client-CN $ssl_client_s_dn_cn; + + Apache 配置示例: + RequestHeader set X-Client-CN "%{SSL_CLIENT_S_DN_CN}s" + """ + # 从 Header 获取设备序列号(由反向代理注入) + serial_number = request.headers.get("X-Client-CN") + + if not serial_number: + utils.raise_unauthorized("Device certificate required") + + # 验证 UUID 格式 + try: + from uuid import UUID + serial_uuid = UUID(serial_number) + except ValueError: + utils.raise_unauthorized("Invalid device serial number format") + + # 查找设备 + device = await Item.get(session, Item.id == serial_uuid) + + if not device: + utils.raise_not_found("Device not found") + + if device.type != ItemTypeEnum.esp32: + utils.raise_forbidden("Not an ESP device") + + return device + + +DeviceDep: TypeAlias = Annotated[Item, Depends(get_device_from_cert)] +"""设备认证依赖,通过 mTLS 证书验证 ESP 设备""" diff --git a/model/__init__.py b/model/__init__.py index a37490a..e6098f2 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -3,6 +3,15 @@ from .setting import Setting, SettingResponse from .item import Item, ItemDataResponse, ItemTypeEnum, ItemStatusEnum from .user import User, UserTypeEnum from .database import Database +from .firmware import ( + Firmware, + FirmwareDataResponse, + FirmwareDataResponseAdmin, + FirmwareUploadRequest, + FirmwareCheckUpdateRequest, + FirmwareCheckUpdateResponse, + ChipTypeEnum, +) # 新增:从 foxline 项目移植的 Mixin 组件 from .mixin.table import ( @@ -27,6 +36,14 @@ __all__ = [ "User", "UserTypeEnum", "Database", + # 固件相关 + "Firmware", + "FirmwareDataResponse", + "FirmwareDataResponseAdmin", + "FirmwareUploadRequest", + "FirmwareCheckUpdateRequest", + "FirmwareCheckUpdateResponse", + "ChipTypeEnum", # 新增的 Mixin 组件 "TableBaseMixin", "UUIDTableBaseMixin", diff --git a/model/firmware.py b/model/firmware.py new file mode 100644 index 0000000..8251880 --- /dev/null +++ b/model/firmware.py @@ -0,0 +1,125 @@ +"""固件包数据模型,用于 ESP32/8266 OTA 在线升级功能。""" + +from datetime import datetime +from enum import StrEnum +from typing import TYPE_CHECKING +from uuid import UUID + +from sqlmodel import Field, Relationship, String, Text + +from .base import SQLModelBase, UUIDTableBase + +if TYPE_CHECKING: + from .user import User + + +class ChipTypeEnum(StrEnum): + """ESP 芯片类型枚举""" + esp32 = 'esp32' + esp8266 = 'esp8266' + esp32s2 = 'esp32s2' + esp32s3 = 'esp32s3' + esp32c3 = 'esp32c3' + + +class FirmwareBase(SQLModelBase): + chip_type: ChipTypeEnum = Field(index=True) + """芯片类型""" + + version: str = Field(sa_type=String(64), index=True) + """固件版本号,遵循语义化版本规范""" + + file_path: str + """固件文件存储路径""" + + file_size: int + """固件文件大小(字节)""" + + file_md5: str = Field(max_length=32) + """固件文件 MD5 校验值""" + + description: str | None = Field(default=None, sa_type=Text) + """固件更新说明""" + + is_active: bool = Field(default=True, index=True) + """是否启用该固件版本""" + + +class Firmware(FirmwareBase, UUIDTableBase, table=True): + """固件包表""" + + uploaded_by_id: UUID = Field(foreign_key='user.id', ondelete='RESTRICT') + """上传者用户ID""" + + downloaded_count: int = Field(default=0) + """下载次数统计""" + + uploaded_at: datetime = Field(default_factory=datetime.now) + """上传时间""" + + uploaded_by: 'User' = Relationship(back_populates='firmwares') + + +# DTO 定义 + +class FirmwareDataResponse(FirmwareBase): + """固件信息响应""" + id: UUID + """固件ID""" + + downloaded_count: int + """下载次数""" + + uploaded_at: datetime + """上传时间""" + + download_url: str | None = None + """下载地址""" + + +class FirmwareDataResponseAdmin(FirmwareDataResponse): + """固件信息响应(管理员)""" + uploaded_by_id: UUID + """上传者ID""" + + +class FirmwareUploadRequest(SQLModelBase): + """固件上传请求""" + chip_type: ChipTypeEnum + """芯片类型""" + + version: str + """版本号字符串""" + + description: str | None = None + """更新说明""" + + +class FirmwareCheckUpdateRequest(SQLModelBase): + """设备检查更新请求""" + chip_type: ChipTypeEnum + """芯片类型""" + + current_version: str + """当前版本号""" + + +class FirmwareCheckUpdateResponse(SQLModelBase): + """检查更新响应""" + has_update: bool + """是否有可用更新""" + + latest_version: str | None = None + """最新版本号""" + + download_url: str | None = None + """下载地址""" + + file_size: int | None = None + """文件大小""" + + file_md5: str | None = None + """文件MD5""" + + description: str | None = None + """更新说明""" diff --git a/model/item.py b/model/item.py index 5ad7184..5bd4bdf 100644 --- a/model/item.py +++ b/model/item.py @@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship, String from pydantic_extra_types.semantic_version import SemanticVersion from .base import SQLModelBase, UUIDTableBase +from .firmware import ChipTypeEnum if TYPE_CHECKING: from .user import User @@ -41,6 +42,9 @@ class ItemBase(SQLModelBase): version: SemanticVersion = Field(sa_type=String(64)) """版本号""" + chip_type: ChipTypeEnum | None = Field(default=None, index=True) + """ESP设备芯片类型,仅当type=esp32时有值""" + class Item(ItemBase, UUIDTableBase, table=True): expires_at: datetime | None = None """物品过期时间""" diff --git a/model/user.py b/model/user.py index 95beb17..373b34c 100644 --- a/model/user.py +++ b/model/user.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar import sqlalchemy as sa from pydantic import EmailStr @@ -10,6 +10,9 @@ from sqlmodel import Field, Relationship from .base import SQLModelBase, UUIDTableBase from .item import Item +if TYPE_CHECKING: + from .firmware import Firmware + class UserTypeEnum(StrEnum): normal_user = 'normal_user' @@ -38,6 +41,9 @@ class User(UserBase, UUIDTableBase, table=True): items: list[Item] = Relationship(back_populates='user', cascade_delete=True) """物品关系""" + firmwares: list['Firmware'] = Relationship(back_populates='uploaded_by', cascade_delete=True) + """上传的固件关系""" + _initializing: ClassVar[bool] = False """标记当前是否处于初始化阶段,初始化阶段允许创建 super_admin""" diff --git a/routes/admin.py b/routes/admin.py index 1caa81b..d2e6398 100644 --- a/routes/admin.py +++ b/routes/admin.py @@ -1,11 +1,13 @@ from typing import Annotated +from uuid import UUID -from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import APIRouter, Depends, File, Form, Query, UploadFile +from starlette.status import HTTP_204_NO_CONTENT from middleware.admin import is_admin -from model import database -from model.response import DefaultResponse +from middleware.dependencies import SessionDep +from model import User, DefaultResponse +from model.firmware import ChipTypeEnum from services import admin as admin_service Router = APIRouter( @@ -38,7 +40,7 @@ async def verity_admin() -> DefaultResponse: response_description='设置项列表' ) async def get_settings( - session: Annotated[AsyncSession, Depends(database.Database.get_session)], + session: SessionDep, name: str | None = None ) -> DefaultResponse: data = await admin_service.fetch_settings(session=session, name=name) @@ -53,9 +55,110 @@ async def get_settings( response_description='更新结果' ) async def update_settings( - session: Annotated[AsyncSession, Depends(database.Database.get_session)], + session: SessionDep, name: str, value: str ) -> DefaultResponse: result = await admin_service.update_setting_value(session=session, name=name, value=value) return DefaultResponse(data=result) + + +# 固件管理接口 + +@Router.post( + path='/firmware', + summary='上传固件包', + description='管理员上传新的固件更新包', + status_code=HTTP_204_NO_CONTENT, + response_description='上传成功' +) +async def upload_firmware( + session: SessionDep, + admin: Annotated[User, Depends(is_admin)], + chip_type: ChipTypeEnum = Form(..., description='芯片类型'), + version: str = Form(..., description='版本号'), + description: str | None = Form(None, description='更新说明'), + file: UploadFile = File(..., description='固件文件'), +): + """ + 上传固件包。 + + 支持的文件格式:.bin + 文件大小限制:4MB + """ + await admin_service.upload_firmware( + session=session, + admin=admin, + chip_type=chip_type, + version=version, + description=description, + file=file, + ) + + +@Router.get( + path='/firmwares', + summary='获取固件列表', + description='获取已上传的固件列表', + response_model=DefaultResponse, + response_description='固件列表' +) +async def list_firmwares( + session: SessionDep, + admin: Annotated[User, Depends(is_admin)], + chip_type: ChipTypeEnum | None = Query(None, description='筛选芯片类型'), + is_active: bool | None = Query(None, description='筛选启用状态'), +) -> DefaultResponse: + """ + 获取固件列表。 + """ + result = await admin_service.list_firmwares( + session=session, + chip_type=chip_type, + is_active=is_active, + ) + return DefaultResponse(data=result) + + +@Router.delete( + path='/firmware/{firmware_id}', + summary='删除固件', + description='删除指定的固件包', + status_code=HTTP_204_NO_CONTENT, + response_description='删除成功' +) +async def delete_firmware( + session: SessionDep, + admin: Annotated[User, Depends(is_admin)], + firmware_id: UUID, +): + """ + 删除固件包。 + """ + await admin_service.delete_firmware( + session=session, + firmware_id=firmware_id, + ) + + +@Router.patch( + path='/firmware/{firmware_id}/status', + summary='切换固件状态', + description='启用或禁用固件', + status_code=HTTP_204_NO_CONTENT, + response_description='操作成功' +) +async def toggle_firmware_status( + session: SessionDep, + admin: Annotated[User, Depends(is_admin)], + firmware_id: UUID, + is_active: bool = Query(..., description='目标状态'), +): + """ + 切换固件启用状态。 + """ + await admin_service.toggle_firmware_status( + session=session, + firmware_id=firmware_id, + is_active=is_active, + ) diff --git a/routes/ota.py b/routes/ota.py new file mode 100644 index 0000000..a82b96d --- /dev/null +++ b/routes/ota.py @@ -0,0 +1,98 @@ +"""OTA API 路由,处理 ESP32/8266 设备的在线升级请求。""" + +from fastapi import APIRouter, Query, status +from starlette.status import HTTP_204_NO_CONTENT + +from middleware.dependencies import SessionDep, DeviceDep +from model import DefaultResponse +from model.firmware import FirmwareCheckUpdateRequest, FirmwareCheckUpdateResponse +from services import ota as ota_service + +Router = APIRouter(prefix='/api/ota', tags=['OTA升级']) + + +@Router.post( + path='/check-update', + summary='检查固件更新', + description='设备通过 mTLS 认证后查询是否有新版本固件', + response_model=DefaultResponse, + response_description='更新检查结果' +) +async def check_update( + session: SessionDep, + device: DeviceDep, + request_data: FirmwareCheckUpdateRequest, +) -> DefaultResponse: + """ + 检查固件更新。 + + 设备需要提供有效的 mTLS 客户端证书,证书 CN 字段为设备序列号。 + """ + result = await ota_service.check_firmware_update( + session=session, + device=device, + chip_type=request_data.chip_type, + current_version=request_data.current_version, + ) + return DefaultResponse(data=result) + + +@Router.get( + path='/download/{firmware_id}', + summary='下载固件包', + description='下载指定的固件更新包', +) +async def download_firmware( + session: SessionDep, + device: DeviceDep, + firmware_id: str, +): + """ + 下载固件包。 + + 需要有效的设备证书,且下载会记录统计信息。 + """ + return await ota_service.get_firmware_file( + session=session, + firmware_id=firmware_id, + device=device, + ) + + +@Router.post( + path='/report-version', + summary='上报设备版本', + description='设备上报当前运行的固件版本', + status_code=HTTP_204_NO_CONTENT, + response_description='上报成功' +) +async def report_version( + session: SessionDep, + device: DeviceDep, + version: str = Query(..., description='当前版本号'), +): + """ + 上报设备当前运行的固件版本。 + """ + await ota_service.update_device_version( + session=session, + device=device, + version=version, + ) + + +@Router.post( + path='/report-lost', + summary='上报设备丢失', + description='设备上报丢失状态', + status_code=HTTP_204_NO_CONTENT, + response_description='上报成功' +) +async def report_lost( + session: SessionDep, + device: DeviceDep, +): + """ + 设备上报丢失状态(复用现有丢失处理逻辑)。 + """ + await ota_service.report_device_lost(session=session, device=device) diff --git a/services/admin.py b/services/admin.py index 4115964..0618377 100644 --- a/services/admin.py +++ b/services/admin.py @@ -2,17 +2,30 @@ 管理员相关业务逻辑。 """ +import hashlib +from pathlib import Path from typing import Iterable, List +from uuid import UUID -from sqlmodel.ext.asyncio.session import AsyncSession +from fastapi import UploadFile +from loguru import logger +from pydantic_extra_types.semantic_version import SemanticVersion -from model import Setting -from model import SettingResponse +from middleware.dependencies import SessionDep +from model import Firmware, User, Setting, SettingResponse +from model.firmware import ChipTypeEnum, FirmwareDataResponseAdmin from pkg import utils +# 固件存储目录 +FIRMWARE_STORAGE_PATH = Path("data/firmware") +FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True) + +# 文件大小限制 4MB +MAX_FIRMWARE_SIZE = 4 * 1024 * 1024 + async def fetch_settings( - session: AsyncSession, + session: SessionDep, name: str | None = None, ) -> List[SettingResponse]: """ @@ -35,7 +48,7 @@ async def fetch_settings( async def update_setting_value( - session: AsyncSession, + session: SessionDep, name: str, value: str, ) -> bool: @@ -50,3 +63,171 @@ async def update_setting_value( await Setting.save(session) return True + + +def _calculate_md5(file_path: Path) -> str: + """计算文件的 MD5 值""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +async def upload_firmware( + session: SessionDep, + admin: User, + chip_type: ChipTypeEnum, + version: str, + description: str | None, + file: UploadFile, +) -> None: + """ + 上传固件包。 + + Args: + session: 数据库会话 + admin: 管理员用户 + chip_type: 芯片类型 + version: 版本号 + description: 更新说明 + file: 上传的文件 + """ + # 验证版本号格式 + try: + version_obj = SemanticVersion(version) + except ValueError: + utils.raise_bad_request("Invalid semantic version format") + + # 验证文件扩展名 + if not file.filename or not file.filename.endswith('.bin'): + utils.raise_bad_request("Only .bin files are supported") + + # 检查是否已存在相同芯片类型和版本的固件 + from sqlalchemy import and_ + existing = await Firmware.get( + session, + and_( + Firmware.chip_type == chip_type, + Firmware.version == str(version_obj) + ) + ) + if existing: + utils.raise_conflict(f"Firmware {chip_type} v{version} already exists") + + # 读取文件内容 + content = await file.read() + file_size = len(content) + + # 验证文件大小 + if file_size > MAX_FIRMWARE_SIZE: + utils.raise_bad_request(f"File size exceeds {MAX_FIRMWARE_SIZE} bytes") + + if file_size == 0: + utils.raise_bad_request("Empty file") + + # 生成文件名 + safe_filename = f"{chip_type}_{version}_{file.filename}" + file_path = FIRMWARE_STORAGE_PATH / safe_filename + + # 写入文件 + with open(file_path, "wb") as f: + f.write(content) + + # 计算 MD5 + file_md5 = _calculate_md5(file_path) + + # 创建数据库记录 + firmware = Firmware( + chip_type=chip_type, + version=str(version_obj), + file_path=str(file_path), + file_size=file_size, + file_md5=file_md5, + description=description, + uploaded_by_id=admin.id, + ) + + await Firmware.add(session, firmware) + logger.info(f"Admin {admin.email} uploaded firmware {chip_type} v{version}") + + +async def list_firmwares( + session: SessionDep, + chip_type: ChipTypeEnum | None, + is_active: bool | None, +) -> List[FirmwareDataResponseAdmin]: + """ + 获取固件列表。 + + Args: + session: 数据库会话 + chip_type: 筛选芯片类型 + is_active: 筛选启用状态 + + Returns: + 固件列表 + """ + from sqlalchemy import and_ + + conditions = [] + + if chip_type: + conditions.append(Firmware.chip_type == chip_type) + if is_active is not None: + conditions.append(Firmware.is_active == is_active) + + if conditions: + results = await Firmware.get(session, and_(*conditions), fetch_mode="all") + else: + results = await Firmware.get(session, fetch_mode="all") + + if not results: + return [] + + return [FirmwareDataResponseAdmin.model_validate(fw) for fw in results] + + +async def delete_firmware( + session: SessionDep, + firmware_id: UUID, +) -> None: + """ + 删除固件包。 + + Args: + session: 数据库会话 + firmware_id: 固件ID + """ + firmware = await Firmware.get(session, Firmware.id == firmware_id) + if not firmware: + utils.raise_not_found("Firmware not found") + + # 删除文件 + file_path = Path(firmware.file_path) + if file_path.exists(): + file_path.unlink() + + # 删除数据库记录 + await Firmware.delete(session, firmware) + + +async def toggle_firmware_status( + session: SessionDep, + firmware_id: UUID, + is_active: bool, +) -> None: + """ + 切换固件启用状态。 + + Args: + session: 数据库会话 + firmware_id: 固件ID + is_active: 目标状态 + """ + firmware = await Firmware.get(session, Firmware.id == firmware_id) + if not firmware: + utils.raise_not_found("Firmware not found") + + firmware.is_active = is_active + await firmware.save(session) diff --git a/services/ota.py b/services/ota.py new file mode 100644 index 0000000..220b62c --- /dev/null +++ b/services/ota.py @@ -0,0 +1,168 @@ +"""OTA 服务层,处理 ESP32/8266 设备的在线升级业务逻辑。""" + +from pathlib import Path + +from fastapi.responses import FileResponse +from loguru import logger +from pydantic_extra_types.semantic_version import SemanticVersion + +from model import Firmware, Item +from model.firmware import ChipTypeEnum, FirmwareCheckUpdateResponse +from middleware.dependencies import SessionDep +from model.item import ItemStatusEnum +from pkg import utils + +# 固件存储目录 +FIRMWARE_STORAGE_PATH = Path("data/firmware") +FIRMWARE_STORAGE_PATH.mkdir(parents=True, exist_ok=True) + + +async def check_firmware_update( + session: SessionDep, + device: Item, + chip_type: ChipTypeEnum, + current_version: str, +) -> FirmwareCheckUpdateResponse: + """ + 检查设备是否有可用的固件更新。 + + Args: + session: 数据库会话 + device: 设备对象 + chip_type: 芯片类型 + current_version: 当前版本号 + + Returns: + FirmwareCheckUpdateResponse: 更新检查结果 + """ + # 验证当前版本格式 + try: + current = SemanticVersion(current_version) + except ValueError: + logger.warning(f"Invalid version format from device {device.id}: {current_version}") + utils.raise_bad_request("Invalid version format") + + # 查找该芯片类型的最新启用固件 + all_firmwares = await Firmware.get( + session, + (Firmware.chip_type == chip_type) & (Firmware.is_active == True), + fetch_mode="all" + ) + + if not all_firmwares: + return FirmwareCheckUpdateResponse( + has_update=False, + ) + + # 过滤出比当前版本新的固件 + newer_firmwares = [] + for fw in all_firmwares: + try: + fw_version = SemanticVersion(str(fw.version)) + if fw_version > current: + newer_firmwares.append(fw) + except ValueError: + logger.warning(f"Invalid firmware version in database: {fw.version}") + continue + + if not newer_firmwares: + return FirmwareCheckUpdateResponse( + has_update=False, + ) + + # 取最新版本 + latest = max(newer_firmwares, key=lambda fw: SemanticVersion(str(fw.version))) + + return FirmwareCheckUpdateResponse( + has_update=True, + latest_version=str(latest.version), + download_url=f"/api/ota/download/{latest.id}", + file_size=latest.file_size, + file_md5=latest.file_md5, + description=latest.description, + ) + + +async def get_firmware_file( + session: SessionDep, + firmware_id: str, + device: Item, +) -> FileResponse: + """ + 获取固件文件并更新下载统计。 + + Args: + session: 数据库会话 + firmware_id: 固件ID + device: 设备对象 + + Returns: + FileResponse: 固件文件响应 + """ + from uuid import UUID + + firmware = await Firmware.get(session, Firmware.id == UUID(firmware_id)) + + if not firmware: + utils.raise_not_found("Firmware not found") + + if not firmware.is_active: + utils.raise_forbidden("Firmware is not available") + + # 验证芯片类型匹配 + if device.chip_type != firmware.chip_type: + utils.raise_forbidden("Firmware chip type mismatch") + + # 更新下载计数 + firmware.downloaded_count += 1 + await firmware.save(session) + + file_path = Path(firmware.file_path) + if not file_path.exists(): + logger.error(f"Firmware file not found: {file_path}") + utils.raise_internal_error("Firmware file not available") + + return FileResponse( + path=str(file_path), + filename=file_path.name, + media_type="application/octet-stream", + ) + + +async def update_device_version( + session: SessionDep, + device: Item, + version: str, +) -> None: + """ + 更新设备上报的固件版本。 + + Args: + session: 数据库会话 + device: 设备对象 + version: 版本号字符串 + """ + try: + SemanticVersion(version) + except ValueError: + utils.raise_bad_request("Invalid version format") + + device.version = version + await device.save(session) + logger.info(f"Device {device.id} reported version: {version}") + + +async def report_device_lost( + session: SessionDep, + device: Item, +) -> None: + """ + 设备上报丢失状态。 + + Args: + session: 数据库会话 + device: 设备对象 + """ + device.status = ItemStatusEnum.lost + await device.save(session) + logger.info(f"Device {device.id} reported as lost")