diff --git a/models/user.py b/models/user.py index ec928f9..4a571f1 100644 --- a/models/user.py +++ b/models/user.py @@ -107,4 +107,39 @@ class User(BaseModel, table=True): except Exception as e: await session.rollback() raise e + return user + + async def get( + id: int = None, + email: str = None + ) -> Optional["User"]: + """ + 获取用户信息。 + + :param id: 用户ID,默认为 None + :type id: int + :param email: 用户邮箱,默认为 None + :type email: str + :return: 用户对象或 None + :rtype: Optional[User] + """ + + from .database import get_session + from sqlmodel import select + + session = get_session() + + if id is None and email is None: + return None + + async for session in get_session(): + query = select(User) + if id is not None: + query = query.where(User.id == id) + if email is not None: + query = query.where(User.email == email) + + result = await session.exec(query) + user = result.one_or_none() + return user \ No newline at end of file diff --git a/tests/test_db_user.py b/tests/test_db_user.py index 698d32d..6455016 100644 --- a/tests/test_db_user.py +++ b/tests/test_db_user.py @@ -26,4 +26,14 @@ async def test_user_curd(): assert created_user.id is not None assert created_user.email == 'test_user' assert created_user.password == 'test_password' - assert created_user.group_id == created_group.id \ No newline at end of file + assert created_user.group_id == created_group.id + + # 测试查 Read + fetched_user = await User.get(id=created_user.id) + + assert fetched_user is not None + assert fetched_user.email == 'test_user' + assert fetched_user.password == 'test_password' + assert fetched_user.group_id == created_group.id + + # 测试改 Update \ No newline at end of file