新增读取用户与其单元测试
This commit is contained in:
@@ -108,3 +108,38 @@ class User(BaseModel, table=True):
|
||||
await session.rollback()
|
||||
raise e
|
||||
return user
|
||||
|
||||
async def get(
|
||||
id: int = None,
|
||||
email: str = None
|
||||
) -> Optional["User"]:
|
||||
"""
|
||||
获取用户信息。
|
||||
|
||||
:param id: 用户ID,默认为 None
|
||||
:type id: int
|
||||
:param email: 用户邮箱,默认为 None
|
||||
:type email: str
|
||||
:return: 用户对象或 None
|
||||
:rtype: Optional[User]
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
session = get_session()
|
||||
|
||||
if id is None and email is None:
|
||||
return None
|
||||
|
||||
async for session in get_session():
|
||||
query = select(User)
|
||||
if id is not None:
|
||||
query = query.where(User.id == id)
|
||||
if email is not None:
|
||||
query = query.where(User.email == email)
|
||||
|
||||
result = await session.exec(query)
|
||||
user = result.one_or_none()
|
||||
|
||||
return user
|
||||
@@ -27,3 +27,13 @@ async def test_user_curd():
|
||||
assert created_user.email == 'test_user'
|
||||
assert created_user.password == 'test_password'
|
||||
assert created_user.group_id == created_group.id
|
||||
|
||||
# 测试查 Read
|
||||
fetched_user = await User.get(id=created_user.id)
|
||||
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.email == 'test_user'
|
||||
assert fetched_user.password == 'test_password'
|
||||
assert fetched_user.group_id == created_group.id
|
||||
|
||||
# 测试改 Update
|
||||
Reference in New Issue
Block a user