From e84b3a7dee7cf5d85c13e6aab60484152eef9a39 Mon Sep 17 00:00:00 2001 From: Yuerchu Date: Tue, 1 Jul 2025 23:50:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E8=AE=BE=E7=BD=AE=E8=A1=A8=E7=9A=84=E5=A2=9E=E5=88=A0=E6=94=B9?= =?UTF-8?q?=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/database.py | 10 +-- models/group.py | 51 +++++++++++++++- models/migration.py | 12 +++- models/setting.py | 134 +++++++++++++++++++++++++++++++++++++---- models/user.py | 10 ++- tests/test_database.py | 76 +++++++++++++++++++++-- 6 files changed, 269 insertions(+), 24 deletions(-) diff --git a/models/database.py b/models/database.py index 1764bb2..e91dffd 100644 --- a/models/database.py +++ b/models/database.py @@ -4,6 +4,7 @@ from sqlmodel import SQLModel from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.orm import sessionmaker +from typing import AsyncGenerator ASYNC_DATABASE_URL = "sqlite+aiosqlite:///database.db" @@ -20,12 +21,13 @@ engine = create_async_engine( _async_session_factory = sessionmaker(engine, class_=AsyncSession) -async def get_session(): +async def get_session() -> AsyncGenerator[AsyncSession, None]: async with _async_session_factory() as session: yield session -async def init_db(): - """初始化数据库""" - # 创建所有表 +async def init_db( + url: str = ASYNC_DATABASE_URL +): + """创建数据库结构""" async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) \ No newline at end of file diff --git a/models/group.py b/models/group.py index c8d31b3..35fd11d 100644 --- a/models/group.py +++ b/models/group.py @@ -40,5 +40,52 @@ class Group(BaseModel, table=True): ) # 关系:一个组可以有多个用户 - users: List["User"] = Relationship(back_populates="group") - previous_users: List["User"] = Relationship(back_populates="previous_group") \ No newline at end of file + users: List["User"] = Relationship( + back_populates="group", + sa_relationship_kwargs={"foreign_keys": "User.group_id"} + ) + previous_users: List["User"] = Relationship( + back_populates="previous_group", + sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"} + ) + + async def add_group(self, name: str, policies: Optional[str] = None, max_storage: int = 0, + share_enabled: bool = False, web_dav_enabled: bool = False, + speed_limit: int = 0, options: Optional[str] = None) -> "Group": + """ + 向数据库内添加用户组。 + + :param name: 用户组名 + :type name: str + :param policies: 允许的策略ID列表,逗号分隔,默认为 None + :type policies: Optional[str] + :param max_storage: 最大存储空间(字节),默认为 0 + :type max_storage: int + :param share_enabled: 是否允许创建分享,默认为 False + :type share_enabled: bool + :param web_dav_enabled: 是否允许使用WebDAV,默认为 False + :type web_dav_enabled: bool + :param speed_limit: 速度限制 (KB/s), 0为不限制,默认为 0 + :type speed_limit: int + :param options: 其他选项 (JSON格式),默认为 None + :type options: Optional[str] + :return: 新创建的用户组对象 + :rtype: Group + """ + + from .database import get_session + session = get_session() + + new_group = Group( + name=name, + policies=policies, + max_storage=max_storage, + share_enabled=share_enabled, + web_dav_enabled=web_dav_enabled, + speed_limit=speed_limit, + options=options + ) + session.add(new_group) + + session.commit() + session.refresh(new_group) \ No newline at end of file diff --git a/models/migration.py b/models/migration.py index 5204e6e..42edf2d 100644 --- a/models/migration.py +++ b/models/migration.py @@ -96,4 +96,14 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti Setting(name="pwa_display", value="standalone", type="pwa"), Setting(name="pwa_theme_color", value="#000000", type="pwa"), Setting(name="pwa_background_color", value="#ffffff", type="pwa"), -] \ No newline at end of file +] + +async def init_default_settings() -> None: + from .setting import Setting + + for setting in default_settings: + await Setting.add( + type=setting.type, + name=setting.name, + value=setting.value + ) \ No newline at end of file diff --git a/models/setting.py b/models/setting.py index bbeeafe..3885564 100644 --- a/models/setting.py +++ b/models/setting.py @@ -71,15 +71,127 @@ class Setting(BaseModel, table=True): ), ) -async def add( - type: SETTINGS_TYPE, - name: str, - value: Optional[str] = None -): - pass + async def add( + type: SETTINGS_TYPE = None, + name: str = None, + value: Optional[str] = None + ) -> None: + """ + 向数据库内添加设置项目。 + + :param type: 设置类型/分组 + :type type: SETTINGS_TYPE + :param name: 设置项名称 + :type name: str + :param value: 设置值,默认为 None + :type value: Optional[str] + """ + from .database import get_session + + if isinstance(value, (dict, list)): + value = str(value) + + async for session in get_session(): + new_setting = Setting(type=type, name=name, value=value) + session.add(new_setting) + + await session.commit() -async def get( - type: SETTINGS_TYPE, - name: str -): - pass \ No newline at end of file + async def get( + type: SETTINGS_TYPE, + name: str + ) -> Optional['Setting']: + """ + 从数据库中获取指定类型和名称的设置项。 + + :param type: 设置类型/分组 + :type type: SETTINGS_TYPE + :param name: 设置项名称 + :type name: str + + :return: 返回设置项对象,如果不存在则返回 None + :rtype: Optional[Setting] + """ + from .database import get_session + from sqlmodel import select + + async for session in get_session(): + statment = select(Setting).where( + Setting.type == type, + Setting.name == name + ) + + result = await session.exec(statment) + return result.one_or_none() + + async def set( + type: SETTINGS_TYPE, + name: str, + value: Optional[str] = None + ) -> None: + """ + 更新指定类型和名称的设置项的值。 + + :param type: 设置类型/分组 + :type type: SETTINGS_TYPE + :param name: 设置项名称 + :type name: str + :param value: 新的设置值,默认为 None + :type value: Optional[str] + + :raises ValueError: 如果设置项不存在,则抛出异常 + """ + from .database import get_session + from sqlmodel import select + + if isinstance(value, (dict, list)): + value = str(value) + + async for session in get_session(): + statment = select(Setting).where( + Setting.type == type, + Setting.name == name + ) + + result = await session.exec(statment) + setting = result.one_or_none() + + if not setting: + raise ValueError(f"Setting {type}.{name} does not exist.") + + # 设置项存在,更新数据 + setting.value = value + await session.commit() + + async def delete( + type: SETTINGS_TYPE, + name: str + ) -> None: + """ + 删除指定类型和名称的设置项。 + + :param type: 设置类型/分组 + :type type: SETTINGS_TYPE + :param name: 设置项名称 + :type name: str + + :raises ValueError: 如果设置项不存在,则抛出异常 + """ + from .database import get_session + from sqlmodel import select, delete + + async for session in get_session(): + statment = select(Setting).where( + Setting.type == type, + Setting.name == name + ) + + result = await session.exec(statment) + setting = result.one_or_none() + + if not setting: + raise ValueError(f"Setting {type}.{name} does not exist.") + + # 设置项存在,删除数据 + await session.delete(setting) + await session.commit() diff --git a/models/user.py b/models/user.py index 1f8aa70..e898029 100644 --- a/models/user.py +++ b/models/user.py @@ -69,8 +69,14 @@ class User(BaseModel, table=True): previous_group_id: Optional[int] = Field(default=None, foreign_key="groups.id", description="之前的用户组ID(用于过期后恢复)") # 关系 - group: "Group" = Relationship(back_populates="users") - previous_group: Optional["Group"] = Relationship(back_populates="previous_users") + group: "Group" = Relationship( + back_populates="users", + sa_relationship_kwargs={"foreign_keys": "User.group_id"} + ) + previous_group: Optional["Group"] = Relationship( + back_populates="previous_users", + sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"} + ) downloads: list["Download"] = Relationship(back_populates="user") files: list["File"] = Relationship(back_populates="user") diff --git a/tests/test_database.py b/tests/test_database.py index cb3fe33..463abd7 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,8 +1,76 @@ -from models import database - import pytest @pytest.mark.asyncio async def test_initialize_db(): - """Fixture to initialize the database before tests.""" - await database.init_db() \ No newline at end of file + """测试创建数据库结构""" + from models import database + + await database.init_db(url='sqlite:///:memory:') + +@pytest.fixture +async def db_session(): + """测试获取数据库连接Session""" + from models import database + + await database.init_db(url='sqlite:///:memory:') + + async for session in database.get_session(): + yield session + +@pytest.mark.asyncio +async def test_initialize_db(): + """测试数据库创建并初始化配置""" + from models import migration + from models import database + + await database.init_db(url='sqlite:///:memory:') + + await migration.init_default_settings() + +@pytest.mark.asyncio +async def test_add_settings(): + """测试数据库的增删改查""" + from models import database + from models.setting import Setting + + await database.init_db(url='sqlite:///:memory:') + + # 测试增 Create + await Setting.add( + type='example_type', + name='example_name', + value='example_value') + + # 测试查 Read + setting = await Setting.get( + type='example_type', + name='example_name') + + assert setting is not None, "设置项应该存在" + assert setting.value == 'example_value', "设置值不匹配" + + # 测试改 Update + await Setting.set( + type='example_type', + name='example_name', + value='updated_value') + + after_update_setting = await Setting.get( + type='example_type', + name='example_name' + ) + + assert after_update_setting is not None, "设置项应该存在" + assert after_update_setting.value == 'updated_value', "更新后的设置值不匹配" + + # 测试删 Delete + await Setting.delete( + type='example_type', + name='example_name') + + after_delete_setting = await Setting.get( + type='example_type', + name='example_name' + ) + + assert after_delete_setting is None, "设置项应该被删除" \ No newline at end of file