From 6c512805e83f3feb6fbe6e73f6fa02c29abce994 Mon Sep 17 00:00:00 2001 From: Yuerchu Date: Thu, 14 Aug 2025 22:30:40 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E5=88=9B=E5=BB=BA=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 3 + model/__init__.py | 10 +- model/base.py | 229 +++++++++++++++++++++++++++++++++++++++++++++ model/database.py | 128 +++++++------------------ model/migration.py | 51 ++++++++++ model/object.py | 48 ++++++++++ model/setting.py | 22 +++++ 7 files changed, 394 insertions(+), 97 deletions(-) create mode 100644 model/base.py create mode 100644 model/migration.py create mode 100644 model/object.py create mode 100644 model/setting.py diff --git a/README.md b/README.md index 04836d4..1fba819 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,9 @@ Upon launch, Findreve will create a SQLite database in the project's root direct display the administrator's account and password in the console. ## 构建 + +> 当前版本的 Findreve Core 无法正常工作,因为我们正在尝试[重构数据库组件以使用ORM](https://github.com/Findreve/Findreve/issues/8) + 你需要安装Python 3.8 以上的版本。然后,clone 本仓库到您的服务器并解压,然后安装下面的依赖: You need to have Python 3.8 or higher installed on your server. Then, clone this repository diff --git a/model/__init__.py b/model/__init__.py index d0d2fd8..bb1b942 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1 +1,9 @@ -from . import token \ No newline at end of file +from . import token + +from .object import Object +from .setting import Setting + +__all__ = [ + "Object", + "Setting" +] \ No newline at end of file diff --git a/model/base.py b/model/base.py new file mode 100644 index 0000000..f2a228a --- /dev/null +++ b/model/base.py @@ -0,0 +1,229 @@ +from datetime import datetime, timezone +from typing import Optional, Type, TypeVar, Union, Literal, List + +from sqlalchemy import DateTime, BinaryExpression, ClauseElement +from sqlalchemy.orm import selectinload +from sqlmodel import SQLModel, Field, select, Relationship +from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy.sql._typing import _OnClauseArgument +from sqlalchemy.ext.asyncio import AsyncAttrs + +B = TypeVar('B', bound='BaseModel') +M = TypeVar('M', bound='SQLModel') + +utcnow = lambda: datetime.now(tz=timezone.utc) + +class BaseModel(AsyncAttrs): + __abstract__ = True + + id: Optional[int] = Field(default=None, primary_key=True, description="主键ID") + created_at: datetime = Field( + default_factory=utcnow, + description="创建时间", + ) + updated_at: datetime = Field( + sa_type=DateTime, + description="更新时间", + sa_column_kwargs={"default": utcnow, "onupdate": utcnow}, + default_factory=utcnow + ) + deleted_at: Optional[datetime] = Field( + default=None, + description="删除时间", + sa_column={"nullable": True} + ) + + @classmethod + async def add( + cls: Type[B], + session: AsyncSession, + instances: B | List[B], + refresh: bool = True + ) -> B | List[B]: + """ + 新增一条记录 + + :param session: 异步会话对象 + :param instances: 实例或实例列表 + :param refresh: 是否刷新实例 + :return: 新增的实例或实例列表 + + Example: + + >>> from model.base import BaseModel + > from model.object import Object + > from database import Database + > import asyncio + > async def main(): + > async with Database.get_session() as session: + > obj = Object(key="12345", name="Test Object", icon="icon.png") + > added_obj = await BaseModel.add(session, obj) + > print(added_obj) + > asyncio.run(main()) + + """ + is_list = False + if isinstance(instances, list): + is_list = True + session.add_all(instances) + else: + session.add(instances) + + await session.commit() + + if refresh: + if is_list: + for instance in instances: + await session.refresh(instance) + else: + await session.refresh(instances) + + return instances + + async def save( + self: B, + session: AsyncSession, + load: Union[Relationship, None] + ): + """ + 保存当前实例到数据库 + + :param session: 异步会话对象 + :param load: 需要加载的关系属性 + :return: None + + """ + session.add(self) + await session.commit() + + if load is not None: + cls = type(self) + return await cls.get(session, self.id, load=load) + else: + await session.refresh(self) + return self + + async def update( + self: B, + session: AsyncSession, + other: M, + extra_data: dict = None, + exclude_unset: bool = True, + ) -> B: + """ + 更新当前实例 + + :param session: 异步会话对象 + :param other: 用于更新的实例 + :param extra_data: 额外的数据字典 + :param exclude_unset: 是否排除未设置的字段 + :return: 更新后的实例 + """ + self.sqlmodel_update( + other.model_dump( + exclude_unset=exclude_unset + ), + update=extra_data + ) + + session.add(self) + + await session.commit() + await session.refresh(self) + + return self + + @classmethod + async def delete( + cls: Type[B], + session: AsyncSession, + instance: B | list[B], + ) -> None: + """ + 删除实例 + + :param session: 异步会话对象 + :param instance: 实例或实例列表 + :return: None + """ + + if isinstance(instance, list): + for inst in instance: + await session.delete(inst) + else: + await session.delete(instance) + + await session.commit() + + @classmethod + async def get( + cls: Type[B], + session: AsyncSession, + condition: BinaryExpression | ClauseElement | None, + *, + offset: int | None = None, + limit: int | None = None, + fetch_mode: Literal["one", "first", "all"] = "first", + join: Type[B] | tuple[Type[B], _OnClauseArgument] | None = None, + options: list | None = None, + load: Union[Relationship, None] = None, + order_by: list[ClauseElement] | None = None + ) -> B | List[B] | None: + """ + 异步获取模型实例 + + 参数: + session: 异步数据库会话 + condition: SQLAlchemy查询条件,如Model.id == 1 + offset: 结果偏移量 + limit: 结果数量限制 + options: 查询选项,如selectinload(Model.relation),异步访问关系属性必备,不然会报错 + fetch_mode: 获取模式 - "one"/"all"/"first" + join: 要联接的模型类 + + 返回: + 根据fetch_mode返回相应的查询结果 + """ + statement = select(cls) + + if condition is not None: + statement = statement.where(condition) + + if join is not None: + statement = statement.join(*join) + + if options: + statement = statement.options(*options) + + if load: + statement = statement.options(selectinload(load)) + + if order_by is not None: + statement = statement.order_by(*order_by) + + if offset: + statement = statement.offset(offset) + + if limit: + statement = statement.limit(limit) + + result = await session.exec(statement) + + if fetch_mode == "one": + return result.one() + elif fetch_mode == "first": + return result.first() + elif fetch_mode == "all": + return list(result.all()) + else: + raise ValueError(f"无效的 fetch_mode: {fetch_mode}") + + @classmethod + async def get_exist_one(cls: Type[B], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> B: + """此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None, + 而是会抛出fastapi 404 异常""" + instance = await cls.get(session, cls.id == id, load=load) + if not instance: + from fastapi import HTTPException + raise HTTPException(status_code=404, detail="Not found") + return instance \ No newline at end of file diff --git a/model/database.py b/model/database.py index ddc7e62..e1804cd 100644 --- a/model/database.py +++ b/model/database.py @@ -1,20 +1,28 @@ -''' -Author: 于小丘 海枫 -Date: 2024-10-02 15:23:34 -LastEditors: Yuerchu admin@yuxiaoqiu.cn -LastEditTime: 2024-11-29 20:05:03 -FilePath: /Findreve/model.py -Description: Findreve 数据库组件 model - -Copyright (c) 2018-2024 by 于小丘Yuerchu, All Rights Reserved. -''' - import aiosqlite from datetime import datetime -import tool -import logging from typing import Optional +from sqlmodel import SQLModel +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlalchemy.orm import sessionmaker +from typing import AsyncGenerator + +ASYNC_DATABASE_URL = "sqlite+aiosqlite:///data.db" + +engine: AsyncEngine = create_async_engine( + ASYNC_DATABASE_URL, + echo=True, + connect_args={ + "check_same_thread": False + } if ASYNC_DATABASE_URL.startswith("sqlite") else None, + future=True, + # pool_size=POOL_SIZE, + # max_overflow=64, +) + +_async_session_factory = sessionmaker(engine, class_=AsyncSession) + # 数据库类 class Database: @@ -24,90 +32,18 @@ class Database: db_path: str = "data.db" # db_path 数据库文件路径,默认为 data.db ): self.db_path = db_path + + async def get_session() -> AsyncGenerator[AsyncSession, None]: + async with _async_session_factory() as session: + yield session - async def init_db(self): - """初始化数据库和表""" - logging.info("开始初始化数据库和表") - - create_objects_table = """ - CREATE TABLE IF NOT EXISTS fr_objects ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - key TEXT NOT NULL, - name TEXT NOT NULL, - icon TEXT, - status TEXT, - phone TEXT, - context TEXT, - find_ip TEXT, - create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - lost_at TIMESTAMP - ) - """ - - create_settings_table = """ - CREATE TABLE IF NOT EXISTS fr_settings ( - type TEXT, - name TEXT PRIMARY KEY, - value TEXT - ) - """ - - async with aiosqlite.connect(self.db_path) as db: - logging.info("连接到数据库") - await db.execute(create_objects_table) - logging.info("创建或验证fr_objects表") - await db.execute(create_settings_table) - logging.info("创建或验证fr_settings表") - - # 初始化设置表数据 - async with db.execute("SELECT name FROM fr_settings WHERE name = 'version'") as cursor: - if not await cursor.fetchone(): - await db.execute( - "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", - ('string', 'version', '1.0.0') - ) - logging.info("插入初始版本信息: version 1.0.0") - - async with db.execute("SELECT name FROM fr_settings WHERE name = 'ver'") as cursor: - if not await cursor.fetchone(): - await db.execute( - "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", - ('int', 'ver', '1') - ) - logging.info("插入初始版本号: ver 1") - - async with db.execute("SELECT name FROM fr_settings WHERE name = 'account'") as cursor: - if not await cursor.fetchone(): - account = 'admin@yuxiaoqiu.cn' - await db.execute( - "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", - ('string', 'account', account) - ) - logging.info(f"插入初始账号信息: {account}") - print(f"账号: {account}") - - async with db.execute("SELECT name FROM fr_settings WHERE name = 'password'") as cursor: - if not await cursor.fetchone(): - password = tool.generate_password() - hashed_password = tool.hash_password(password) - await db.execute( - "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", - ('string', 'password', hashed_password) - ) - logging.info("插入初始密码信息") - print(f"密码(请牢记,后续不再显示): {password}") - - async with db.execute("SELECT name FROM fr_settings WHERE name = 'SECRET_KEY'") as cursor: - if not await cursor.fetchone(): - secret_key = tool.generate_password(64) - await db.execute( - "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", - ('string', 'SECRET_KEY', secret_key) - ) - logging.info("插入初始密钥信息") - - await db.commit() - logging.info("数据库初始化完成并提交更改") + async def init_db( + self, + url: str = ASYNC_DATABASE_URL + ): + """创建数据库结构""" + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) async def add_object(self, key: str, name: str, icon: str = None, phone: str = None): """ diff --git a/model/migration.py b/model/migration.py new file mode 100644 index 0000000..6e92f4b --- /dev/null +++ b/model/migration.py @@ -0,0 +1,51 @@ +""" +# 初始化设置表数据 + async with db.execute("SELECT name FROM fr_settings WHERE name = 'version'") as cursor: + if not await cursor.fetchone(): + await db.execute( + "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", + ('string', 'version', '1.0.0') + ) + logging.info("插入初始版本信息: version 1.0.0") + + async with db.execute("SELECT name FROM fr_settings WHERE name = 'ver'") as cursor: + if not await cursor.fetchone(): + await db.execute( + "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", + ('int', 'ver', '1') + ) + logging.info("插入初始版本号: ver 1") + + async with db.execute("SELECT name FROM fr_settings WHERE name = 'account'") as cursor: + if not await cursor.fetchone(): + account = 'admin@yuxiaoqiu.cn' + await db.execute( + "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", + ('string', 'account', account) + ) + logging.info(f"插入初始账号信息: {account}") + print(f"账号: {account}") + + async with db.execute("SELECT name FROM fr_settings WHERE name = 'password'") as cursor: + if not await cursor.fetchone(): + password = tool.generate_password() + hashed_password = tool.hash_password(password) + await db.execute( + "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", + ('string', 'password', hashed_password) + ) + logging.info("插入初始密码信息") + print(f"密码(请牢记,后续不再显示): {password}") + + async with db.execute("SELECT name FROM fr_settings WHERE name = 'SECRET_KEY'") as cursor: + if not await cursor.fetchone(): + secret_key = tool.generate_password(64) + await db.execute( + "INSERT INTO fr_settings (type, name, value) VALUES (?, ?, ?)", + ('string', 'SECRET_KEY', secret_key) + ) + logging.info("插入初始密钥信息") + + await db.commit() + logging.info("数据库初始化完成并提交更改") +""" \ No newline at end of file diff --git a/model/object.py b/model/object.py new file mode 100644 index 0000000..220abfc --- /dev/null +++ b/model/object.py @@ -0,0 +1,48 @@ +# my_project/models/download.py + +from typing import Literal, Optional, TYPE_CHECKING +from sqlmodel import Field, Column, SQLModel, String, DateTime +from .base import BaseModel +from datetime import datetime + +from .base import BaseModel + +""" +原建表语句: + +CREATE TABLE IF NOT EXISTS fr_objects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT NOT NULL, + name TEXT NOT NULL, + icon TEXT, + status TEXT, + phone TEXT, + context TEXT, + find_ip TEXT, + create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + lost_at TIMESTAMP +""" + +if TYPE_CHECKING: + pass + +class Object(SQLModel, BaseModel, table=True): + __tablename__ = 'fr_objects' + + key: str = Field(index=True, nullable=False, description="物品外部ID") + type: Literal['object', 'box'] = Field( + default='object', + description="物品类型", + sa_column=Column(String, default='object', nullable=False) + ) + name: str = Field(nullable=False, description="物品名称") + icon: Optional[str] = Field(default=None, description="物品图标") + status: Optional[str] = Field(default=None, description="物品状态") + phone: Optional[str] = Field(default=None, description="联系电话") + context: Optional[str] = Field(default=None, description="物品描述") + find_ip: Optional[str] = Field(default=None, description="最后一次发现的IP地址") + lost_at: Optional[datetime] = Field( + default=None, + description="物品标记为丢失的时间", + sa_column=Column(DateTime, nullable=True) + ) \ No newline at end of file diff --git a/model/setting.py b/model/setting.py new file mode 100644 index 0000000..f7cd143 --- /dev/null +++ b/model/setting.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING, Optional +from sqlmodel import Field, SQLModel +from .base import BaseModel + +""" +原建表语句: + +CREATE TABLE IF NOT EXISTS fr_settings ( + type TEXT, + name TEXT PRIMARY KEY, + value TEXT +""" + +if TYPE_CHECKING: + pass + +class Setting(SQLModel, BaseModel, table=True): + __tablename__ = 'fr_settings' + + type: str = Field(index=True, nullable=False, description="设置类型") + name: str = Field(index=True, primary_key=True, nullable=False, description="设置名称") + value: Optional[str] = Field(description="设置值") \ No newline at end of file