新增读取用户与其单元测试
This commit is contained in:
@@ -108,3 +108,38 @@ class User(BaseModel, table=True):
|
|||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise e
|
raise e
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
id: int = None,
|
||||||
|
email: str = None
|
||||||
|
) -> Optional["User"]:
|
||||||
|
"""
|
||||||
|
获取用户信息。
|
||||||
|
|
||||||
|
:param id: 用户ID,默认为 None
|
||||||
|
:type id: int
|
||||||
|
:param email: 用户邮箱,默认为 None
|
||||||
|
:type email: str
|
||||||
|
:return: 用户对象或 None
|
||||||
|
:rtype: Optional[User]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .database import get_session
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
session = get_session()
|
||||||
|
|
||||||
|
if id is None and email is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async for session in get_session():
|
||||||
|
query = select(User)
|
||||||
|
if id is not None:
|
||||||
|
query = query.where(User.id == id)
|
||||||
|
if email is not None:
|
||||||
|
query = query.where(User.email == email)
|
||||||
|
|
||||||
|
result = await session.exec(query)
|
||||||
|
user = result.one_or_none()
|
||||||
|
|
||||||
|
return user
|
||||||
@@ -27,3 +27,13 @@ async def test_user_curd():
|
|||||||
assert created_user.email == 'test_user'
|
assert created_user.email == 'test_user'
|
||||||
assert created_user.password == 'test_password'
|
assert created_user.password == 'test_password'
|
||||||
assert created_user.group_id == created_group.id
|
assert created_user.group_id == created_group.id
|
||||||
|
|
||||||
|
# 测试查 Read
|
||||||
|
fetched_user = await User.get(id=created_user.id)
|
||||||
|
|
||||||
|
assert fetched_user is not None
|
||||||
|
assert fetched_user.email == 'test_user'
|
||||||
|
assert fetched_user.password == 'test_password'
|
||||||
|
assert fetched_user.group_id == created_group.id
|
||||||
|
|
||||||
|
# 测试改 Update
|
||||||
Reference in New Issue
Block a user