单元测试:新建用户与用户组

This commit is contained in:
2025-07-14 15:13:05 +08:00
parent e011a1ea0e
commit 557a50f539
26 changed files with 396 additions and 242 deletions

View File

@@ -49,43 +49,58 @@ class Group(BaseModel, table=True):
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
)
async def add_group(self, name: str, policies: Optional[str] = None, max_storage: int = 0,
share_enabled: bool = False, web_dav_enabled: bool = False,
speed_limit: int = 0, options: Optional[str] = None) -> "Group":
@staticmethod
async def create(
group: "Group"
) -> "Group":
"""
向数据库内添加用户组。
:param name: 用户组
:type name: str
:param policies: 允许的策略ID列表逗号分隔默认为 None
:type policies: Optional[str]
:param max_storage: 最大存储空间(字节),默认为 0
:type max_storage: int
:param share_enabled: 是否允许创建分享,默认为 False
:type share_enabled: bool
:param web_dav_enabled: 是否允许使用WebDAV默认为 False
:type web_dav_enabled: bool
:param speed_limit: 速度限制 (KB/s), 0为不限制默认为 0
:type speed_limit: int
:param options: 其他选项 (JSON格式),默认为 None
:type options: Optional[str]
:param group: 用户组对象
:type group: Group
:return: 新创建的用户组对象
:rtype: Group
"""
from .database import get_session
async for session in get_session():
try:
session.add(group)
await session.commit()
await session.refresh(group)
except Exception as e:
await session.rollback()
raise e
return group
@staticmethod
async def get(
id: int = None
) -> Optional["Group"]:
"""
获取用户组信息。
:param id: 用户组ID默认为 None
:type id: int
:return: 用户组对象或 None
:rtype: Optional[Group]
"""
from .database import get_session
from sqlmodel import select
session = get_session()
new_group = Group(
name=name,
policies=policies,
max_storage=max_storage,
share_enabled=share_enabled,
web_dav_enabled=web_dav_enabled,
speed_limit=speed_limit,
options=options
)
session.add(new_group)
if id is None:
return None
session.commit()
session.refresh(new_group)
async for session in get_session():
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if group:
return group
else:
return None