完成数据库设置表的增删改查
This commit is contained in:
@@ -4,6 +4,7 @@ from sqlmodel import SQLModel
|
|||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///database.db"
|
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///database.db"
|
||||||
|
|
||||||
@@ -20,12 +21,13 @@ engine = create_async_engine(
|
|||||||
|
|
||||||
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
|
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
|
||||||
|
|
||||||
async def get_session():
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
async with _async_session_factory() as session:
|
async with _async_session_factory() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
async def init_db():
|
async def init_db(
|
||||||
"""初始化数据库"""
|
url: str = ASYNC_DATABASE_URL
|
||||||
# 创建所有表
|
):
|
||||||
|
"""创建数据库结构"""
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(SQLModel.metadata.create_all)
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
@@ -40,5 +40,52 @@ class Group(BaseModel, table=True):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 关系:一个组可以有多个用户
|
# 关系:一个组可以有多个用户
|
||||||
users: List["User"] = Relationship(back_populates="group")
|
users: List["User"] = Relationship(
|
||||||
previous_users: List["User"] = Relationship(back_populates="previous_group")
|
back_populates="group",
|
||||||
|
sa_relationship_kwargs={"foreign_keys": "User.group_id"}
|
||||||
|
)
|
||||||
|
previous_users: List["User"] = Relationship(
|
||||||
|
back_populates="previous_group",
|
||||||
|
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_group(self, name: str, policies: Optional[str] = None, max_storage: int = 0,
|
||||||
|
share_enabled: bool = False, web_dav_enabled: bool = False,
|
||||||
|
speed_limit: int = 0, options: Optional[str] = None) -> "Group":
|
||||||
|
"""
|
||||||
|
向数据库内添加用户组。
|
||||||
|
|
||||||
|
:param name: 用户组名
|
||||||
|
:type name: str
|
||||||
|
:param policies: 允许的策略ID列表,逗号分隔,默认为 None
|
||||||
|
:type policies: Optional[str]
|
||||||
|
:param max_storage: 最大存储空间(字节),默认为 0
|
||||||
|
:type max_storage: int
|
||||||
|
:param share_enabled: 是否允许创建分享,默认为 False
|
||||||
|
:type share_enabled: bool
|
||||||
|
:param web_dav_enabled: 是否允许使用WebDAV,默认为 False
|
||||||
|
:type web_dav_enabled: bool
|
||||||
|
:param speed_limit: 速度限制 (KB/s), 0为不限制,默认为 0
|
||||||
|
:type speed_limit: int
|
||||||
|
:param options: 其他选项 (JSON格式),默认为 None
|
||||||
|
:type options: Optional[str]
|
||||||
|
:return: 新创建的用户组对象
|
||||||
|
:rtype: Group
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .database import get_session
|
||||||
|
session = get_session()
|
||||||
|
|
||||||
|
new_group = Group(
|
||||||
|
name=name,
|
||||||
|
policies=policies,
|
||||||
|
max_storage=max_storage,
|
||||||
|
share_enabled=share_enabled,
|
||||||
|
web_dav_enabled=web_dav_enabled,
|
||||||
|
speed_limit=speed_limit,
|
||||||
|
options=options
|
||||||
|
)
|
||||||
|
session.add(new_group)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
session.refresh(new_group)
|
||||||
@@ -97,3 +97,13 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
|||||||
Setting(name="pwa_theme_color", value="#000000", type="pwa"),
|
Setting(name="pwa_theme_color", value="#000000", type="pwa"),
|
||||||
Setting(name="pwa_background_color", value="#ffffff", type="pwa"),
|
Setting(name="pwa_background_color", value="#ffffff", type="pwa"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def init_default_settings() -> None:
|
||||||
|
from .setting import Setting
|
||||||
|
|
||||||
|
for setting in default_settings:
|
||||||
|
await Setting.add(
|
||||||
|
type=setting.type,
|
||||||
|
name=setting.name,
|
||||||
|
value=setting.value
|
||||||
|
)
|
||||||
@@ -72,14 +72,126 @@ class Setting(BaseModel, table=True):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
type: SETTINGS_TYPE,
|
type: SETTINGS_TYPE = None,
|
||||||
name: str,
|
name: str = None,
|
||||||
value: Optional[str] = None
|
value: Optional[str] = None
|
||||||
):
|
) -> None:
|
||||||
pass
|
"""
|
||||||
|
向数据库内添加设置项目。
|
||||||
|
|
||||||
|
:param type: 设置类型/分组
|
||||||
|
:type type: SETTINGS_TYPE
|
||||||
|
:param name: 设置项名称
|
||||||
|
:type name: str
|
||||||
|
:param value: 设置值,默认为 None
|
||||||
|
:type value: Optional[str]
|
||||||
|
"""
|
||||||
|
from .database import get_session
|
||||||
|
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
value = str(value)
|
||||||
|
|
||||||
|
async for session in get_session():
|
||||||
|
new_setting = Setting(type=type, name=name, value=value)
|
||||||
|
session.add(new_setting)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
type: SETTINGS_TYPE,
|
type: SETTINGS_TYPE,
|
||||||
name: str
|
name: str
|
||||||
):
|
) -> Optional['Setting']:
|
||||||
pass
|
"""
|
||||||
|
从数据库中获取指定类型和名称的设置项。
|
||||||
|
|
||||||
|
:param type: 设置类型/分组
|
||||||
|
:type type: SETTINGS_TYPE
|
||||||
|
:param name: 设置项名称
|
||||||
|
:type name: str
|
||||||
|
|
||||||
|
:return: 返回设置项对象,如果不存在则返回 None
|
||||||
|
:rtype: Optional[Setting]
|
||||||
|
"""
|
||||||
|
from .database import get_session
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
async for session in get_session():
|
||||||
|
statment = select(Setting).where(
|
||||||
|
Setting.type == type,
|
||||||
|
Setting.name == name
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.exec(statment)
|
||||||
|
return result.one_or_none()
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
type: SETTINGS_TYPE,
|
||||||
|
name: str,
|
||||||
|
value: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
更新指定类型和名称的设置项的值。
|
||||||
|
|
||||||
|
:param type: 设置类型/分组
|
||||||
|
:type type: SETTINGS_TYPE
|
||||||
|
:param name: 设置项名称
|
||||||
|
:type name: str
|
||||||
|
:param value: 新的设置值,默认为 None
|
||||||
|
:type value: Optional[str]
|
||||||
|
|
||||||
|
:raises ValueError: 如果设置项不存在,则抛出异常
|
||||||
|
"""
|
||||||
|
from .database import get_session
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
value = str(value)
|
||||||
|
|
||||||
|
async for session in get_session():
|
||||||
|
statment = select(Setting).where(
|
||||||
|
Setting.type == type,
|
||||||
|
Setting.name == name
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.exec(statment)
|
||||||
|
setting = result.one_or_none()
|
||||||
|
|
||||||
|
if not setting:
|
||||||
|
raise ValueError(f"Setting {type}.{name} does not exist.")
|
||||||
|
|
||||||
|
# 设置项存在,更新数据
|
||||||
|
setting.value = value
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def delete(
|
||||||
|
type: SETTINGS_TYPE,
|
||||||
|
name: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
删除指定类型和名称的设置项。
|
||||||
|
|
||||||
|
:param type: 设置类型/分组
|
||||||
|
:type type: SETTINGS_TYPE
|
||||||
|
:param name: 设置项名称
|
||||||
|
:type name: str
|
||||||
|
|
||||||
|
:raises ValueError: 如果设置项不存在,则抛出异常
|
||||||
|
"""
|
||||||
|
from .database import get_session
|
||||||
|
from sqlmodel import select, delete
|
||||||
|
|
||||||
|
async for session in get_session():
|
||||||
|
statment = select(Setting).where(
|
||||||
|
Setting.type == type,
|
||||||
|
Setting.name == name
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.exec(statment)
|
||||||
|
setting = result.one_or_none()
|
||||||
|
|
||||||
|
if not setting:
|
||||||
|
raise ValueError(f"Setting {type}.{name} does not exist.")
|
||||||
|
|
||||||
|
# 设置项存在,删除数据
|
||||||
|
await session.delete(setting)
|
||||||
|
await session.commit()
|
||||||
|
|||||||
@@ -69,8 +69,14 @@ class User(BaseModel, table=True):
|
|||||||
previous_group_id: Optional[int] = Field(default=None, foreign_key="groups.id", description="之前的用户组ID(用于过期后恢复)")
|
previous_group_id: Optional[int] = Field(default=None, foreign_key="groups.id", description="之前的用户组ID(用于过期后恢复)")
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
group: "Group" = Relationship(back_populates="users")
|
group: "Group" = Relationship(
|
||||||
previous_group: Optional["Group"] = Relationship(back_populates="previous_users")
|
back_populates="users",
|
||||||
|
sa_relationship_kwargs={"foreign_keys": "User.group_id"}
|
||||||
|
)
|
||||||
|
previous_group: Optional["Group"] = Relationship(
|
||||||
|
back_populates="previous_users",
|
||||||
|
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
|
||||||
|
)
|
||||||
|
|
||||||
downloads: list["Download"] = Relationship(back_populates="user")
|
downloads: list["Download"] = Relationship(back_populates="user")
|
||||||
files: list["File"] = Relationship(back_populates="user")
|
files: list["File"] = Relationship(back_populates="user")
|
||||||
|
|||||||
@@ -1,8 +1,76 @@
|
|||||||
from models import database
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_initialize_db():
|
async def test_initialize_db():
|
||||||
"""Fixture to initialize the database before tests."""
|
"""测试创建数据库结构"""
|
||||||
await database.init_db()
|
from models import database
|
||||||
|
|
||||||
|
await database.init_db(url='sqlite:///:memory:')
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
"""测试获取数据库连接Session"""
|
||||||
|
from models import database
|
||||||
|
|
||||||
|
await database.init_db(url='sqlite:///:memory:')
|
||||||
|
|
||||||
|
async for session in database.get_session():
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_db():
|
||||||
|
"""测试数据库创建并初始化配置"""
|
||||||
|
from models import migration
|
||||||
|
from models import database
|
||||||
|
|
||||||
|
await database.init_db(url='sqlite:///:memory:')
|
||||||
|
|
||||||
|
await migration.init_default_settings()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_settings():
|
||||||
|
"""测试数据库的增删改查"""
|
||||||
|
from models import database
|
||||||
|
from models.setting import Setting
|
||||||
|
|
||||||
|
await database.init_db(url='sqlite:///:memory:')
|
||||||
|
|
||||||
|
# 测试增 Create
|
||||||
|
await Setting.add(
|
||||||
|
type='example_type',
|
||||||
|
name='example_name',
|
||||||
|
value='example_value')
|
||||||
|
|
||||||
|
# 测试查 Read
|
||||||
|
setting = await Setting.get(
|
||||||
|
type='example_type',
|
||||||
|
name='example_name')
|
||||||
|
|
||||||
|
assert setting is not None, "设置项应该存在"
|
||||||
|
assert setting.value == 'example_value', "设置值不匹配"
|
||||||
|
|
||||||
|
# 测试改 Update
|
||||||
|
await Setting.set(
|
||||||
|
type='example_type',
|
||||||
|
name='example_name',
|
||||||
|
value='updated_value')
|
||||||
|
|
||||||
|
after_update_setting = await Setting.get(
|
||||||
|
type='example_type',
|
||||||
|
name='example_name'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert after_update_setting is not None, "设置项应该存在"
|
||||||
|
assert after_update_setting.value == 'updated_value', "更新后的设置值不匹配"
|
||||||
|
|
||||||
|
# 测试删 Delete
|
||||||
|
await Setting.delete(
|
||||||
|
type='example_type',
|
||||||
|
name='example_name')
|
||||||
|
|
||||||
|
after_delete_setting = await Setting.get(
|
||||||
|
type='example_type',
|
||||||
|
name='example_name'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert after_delete_setting is None, "设置项应该被删除"
|
||||||
Reference in New Issue
Block a user