单元测试:新建用户与用户组
This commit is contained in:
@@ -6,7 +6,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import AsyncGenerator
|
||||
|
||||
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///database.db"
|
||||
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
engine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
|
||||
@@ -49,43 +49,58 @@ class Group(BaseModel, table=True):
|
||||
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
|
||||
)
|
||||
|
||||
async def add_group(self, name: str, policies: Optional[str] = None, max_storage: int = 0,
|
||||
share_enabled: bool = False, web_dav_enabled: bool = False,
|
||||
speed_limit: int = 0, options: Optional[str] = None) -> "Group":
|
||||
@staticmethod
|
||||
async def create(
|
||||
group: "Group"
|
||||
) -> "Group":
|
||||
"""
|
||||
向数据库内添加用户组。
|
||||
|
||||
:param name: 用户组名
|
||||
:type name: str
|
||||
:param policies: 允许的策略ID列表,逗号分隔,默认为 None
|
||||
:type policies: Optional[str]
|
||||
:param max_storage: 最大存储空间(字节),默认为 0
|
||||
:type max_storage: int
|
||||
:param share_enabled: 是否允许创建分享,默认为 False
|
||||
:type share_enabled: bool
|
||||
:param web_dav_enabled: 是否允许使用WebDAV,默认为 False
|
||||
:type web_dav_enabled: bool
|
||||
:param speed_limit: 速度限制 (KB/s), 0为不限制,默认为 0
|
||||
:type speed_limit: int
|
||||
:param options: 其他选项 (JSON格式),默认为 None
|
||||
:type options: Optional[str]
|
||||
:param group: 用户组对象
|
||||
:type group: Group
|
||||
:return: 新创建的用户组对象
|
||||
:rtype: Group
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
session.add(group)
|
||||
await session.commit()
|
||||
await session.refresh(group)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
return group
|
||||
|
||||
@staticmethod
|
||||
async def get(
|
||||
id: int = None
|
||||
) -> Optional["Group"]:
|
||||
"""
|
||||
获取用户组信息。
|
||||
|
||||
:param id: 用户组ID,默认为 None
|
||||
:type id: int
|
||||
|
||||
:return: 用户组对象或 None
|
||||
:rtype: Optional[Group]
|
||||
"""
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
session = get_session()
|
||||
|
||||
new_group = Group(
|
||||
name=name,
|
||||
policies=policies,
|
||||
max_storage=max_storage,
|
||||
share_enabled=share_enabled,
|
||||
web_dav_enabled=web_dav_enabled,
|
||||
speed_limit=speed_limit,
|
||||
options=options
|
||||
)
|
||||
session.add(new_group)
|
||||
if id is None:
|
||||
return None
|
||||
|
||||
session.commit()
|
||||
session.refresh(new_group)
|
||||
async for session in get_session():
|
||||
statement = select(Group).where(Group.id == id)
|
||||
result = await session.exec(statement)
|
||||
group = result.one_or_none()
|
||||
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
return None
|
||||
@@ -113,4 +113,55 @@ async def init_default_settings() -> None:
|
||||
type=setting.type,
|
||||
name=setting.name,
|
||||
value=setting.value
|
||||
)
|
||||
)
|
||||
|
||||
async def init_default_group() -> None:
|
||||
from .group import Group
|
||||
|
||||
try:
|
||||
# 未找到初始管理组时,则创建
|
||||
if not Group.get(id=1):
|
||||
Group.add(
|
||||
name="管理员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
options={
|
||||
"ArchiveDownload": True,
|
||||
"ArchiveTask": True,
|
||||
"ShareDownload": True,
|
||||
"Aria2": True,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建管理员用户组: {e}") from e
|
||||
|
||||
try:
|
||||
# 未找到初始注册会员时,则创建
|
||||
if not Group.get(id=2):
|
||||
Group.add(
|
||||
name="注册会员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
options={
|
||||
"ShareDownload": True,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建初始注册会员用户组: {e}") from e
|
||||
|
||||
try:
|
||||
# 未找到初始游客组时,则创建
|
||||
if not Group.get(id=3):
|
||||
Group.add(
|
||||
name="游客",
|
||||
policies="[]",
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
options={
|
||||
"ShareDownload": True,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建初始游客用户组: {e}") from e
|
||||
|
||||
@@ -71,6 +71,7 @@ class Setting(BaseModel, table=True):
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def add(
|
||||
type: SETTINGS_TYPE = None,
|
||||
name: str = None,
|
||||
@@ -97,6 +98,7 @@ class Setting(BaseModel, table=True):
|
||||
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def get(
|
||||
type: SETTINGS_TYPE,
|
||||
name: str,
|
||||
@@ -138,6 +140,7 @@ class Setting(BaseModel, table=True):
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
@staticmethod
|
||||
async def set(
|
||||
type: SETTINGS_TYPE,
|
||||
name: str,
|
||||
@@ -177,6 +180,7 @@ class Setting(BaseModel, table=True):
|
||||
setting.value = value
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def delete(
|
||||
type: SETTINGS_TYPE,
|
||||
name: str
|
||||
|
||||
@@ -86,4 +86,25 @@ class User(BaseModel, table=True):
|
||||
storage_packs: list["StoragePack"] = Relationship(back_populates="user")
|
||||
tags: list["Tag"] = Relationship(back_populates="user")
|
||||
tasks: list["Task"] = Relationship(back_populates="user")
|
||||
webdavs: list["WebDAV"] = Relationship(back_populates="user")
|
||||
webdavs: list["WebDAV"] = Relationship(back_populates="user")
|
||||
|
||||
async def create(
|
||||
user: "User"
|
||||
):
|
||||
"""
|
||||
向数据库内添加用户。
|
||||
|
||||
:param user: User 实例
|
||||
:type user: User
|
||||
"""
|
||||
from .database import get_session
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
return user
|
||||
Reference in New Issue
Block a user