Refactor code structure for improved readability and maintainability
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
# my_project/models/__init__.py
|
||||
|
||||
from . import response
|
||||
|
||||
# 将所有模型导入到这个包的命名空间中
|
||||
from .base import TableBase
|
||||
from .user import User
|
||||
|
||||
from .download import Download
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
@@ -19,12 +17,6 @@ from .source_link import SourceLink
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag
|
||||
from .task import Task
|
||||
from .user import User
|
||||
from .webdav import WebDAV
|
||||
|
||||
# 可以定义一个 __all__ 列表来明确指定可以被 from .models import * 导入的内容
|
||||
__all__ = [
|
||||
"TableBase", "Download", "File", "Folder", "Group", "Node", "Order",
|
||||
"Policy", "Redeem", "Report", "Setting", "Share", "SourceLink",
|
||||
"StoragePack", "Tag", "Task", "User", "WebDAV"
|
||||
]
|
||||
from .database import engine, get_session
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlalchemy import DateTime
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
|
||||
utcnow = lambda: datetime.now(tz=timezone.utc)
|
||||
|
||||
class TableBase(SQLModel, AsyncAttrs):
|
||||
__abstract__ = True
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True, description="主键ID")
|
||||
created_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
default_factory=utcnow,
|
||||
description="创建时间",
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
sa_column_kwargs={"default": utcnow, "onupdate": utcnow},
|
||||
default_factory=utcnow,
|
||||
description="更新时间",
|
||||
)
|
||||
deleted_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
nullable=True,
|
||||
description="删除时间",
|
||||
)
|
||||
2
models/base/__init__.py
Normal file
2
models/base/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
from .table_base import TableBase, UUIDTableBase, now, now_date
|
||||
5
models/base/sqlmodel_base.py
Normal file
5
models/base/sqlmodel_base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import ConfigDict
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
class SQLModelBase(SQLModel):
|
||||
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)
|
||||
202
models/base/table_base.py
Normal file
202
models/base/table_base.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union, List, TypeVar, Type, Literal, override, Optional, Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import Field, select, Relationship
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
|
||||
T = TypeVar("T", bound="TableBase")
|
||||
M = TypeVar("M", bound="SQLModel")
|
||||
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
|
||||
class TableBase(SQLModelBase, AsyncAttrs):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
created_at: datetime = Field(default_factory=now)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
sa_column_kwargs={"default": now, "onupdate": now},
|
||||
default_factory=now
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: Type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | List[T]:
|
||||
"""
|
||||
新增一条记录
|
||||
:param session: 数据库会话
|
||||
:param instances:
|
||||
:param refresh:
|
||||
:return: 新增的实例对象
|
||||
|
||||
usage:
|
||||
item1 = Item(...)
|
||||
item2 = Item(...)
|
||||
|
||||
Item.add(session, [item1, item2])
|
||||
|
||||
item1_id = item1.id
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
is_list = True
|
||||
session.add_all(instances)
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
else:
|
||||
await session.refresh(instances)
|
||||
|
||||
return instances
|
||||
|
||||
async def save(self: T, session: AsyncSession, load: Optional[Relationship] = None) -> T:
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
async def update(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
other: M,
|
||||
extra_data: dict = None,
|
||||
exclude_unset: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
更新记录
|
||||
:param session: 数据库会话
|
||||
:param other:
|
||||
:param extra_data:
|
||||
:param exclude_unset:
|
||||
:return:
|
||||
"""
|
||||
self.sqlmodel_update(other.model_dump(exclude_unset=exclude_unset), update=extra_data)
|
||||
|
||||
session.add(self)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(cls: Type[T], session: AsyncSession, instances: T | list[T]) -> None:
|
||||
"""
|
||||
删除一些记录
|
||||
:param session: 数据库会话
|
||||
:param instances:
|
||||
:return: None
|
||||
|
||||
usage:
|
||||
item1 = Item.get(...)
|
||||
item2 = Item.get(...)
|
||||
|
||||
Item.delete(session, [item1, item2])
|
||||
|
||||
"""
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
else:
|
||||
await session.delete(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: Type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: Type[T] | tuple[Type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: Union[Relationship, None] = None,
|
||||
order_by: list[ClauseElement] | None = None
|
||||
) -> T | List[T] | None:
|
||||
"""
|
||||
异步获取模型实例
|
||||
|
||||
参数:
|
||||
session: 异步数据库会话
|
||||
condition: SQLAlchemy查询条件,如Model.id == 1
|
||||
offset: 结果偏移量
|
||||
limit: 结果数量限制
|
||||
options: 查询选项,如selectinload(Model.relation),异步访问关系属性必备,不然会报错
|
||||
fetch_mode: 获取模式 - "one"/"all"/"first"
|
||||
join: 要联接的模型类
|
||||
|
||||
返回:
|
||||
根据fetch_mode返回相应的查询结果
|
||||
"""
|
||||
statement = select(cls)
|
||||
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
if join is not None:
|
||||
statement = statement.join(*join)
|
||||
|
||||
if options:
|
||||
statement = statement.options(*options)
|
||||
|
||||
if load:
|
||||
statement = statement.options(selectinload(load))
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await session.exec(statement)
|
||||
|
||||
if fetch_mode == "one":
|
||||
return result.one()
|
||||
elif fetch_mode == "first":
|
||||
return result.first()
|
||||
elif fetch_mode == "all":
|
||||
return list(result.all())
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: Type[T], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> T:
|
||||
"""此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None,
|
||||
而是会抛出fastapi 404 异常"""
|
||||
instance = await cls.get(session, cls.id == id, load=load)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return instance
|
||||
|
||||
class UUIDTableBase(TableBase):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
"""override"""
|
||||
|
||||
@override
|
||||
async def get_exist_one(cls: Type[T], session: AsyncSession, id: uuid.UUID, load: Union[Relationship, None] = None) -> T:
|
||||
return super().get_exist_one(session, id, load) # type: ignore
|
||||
@@ -29,4 +29,5 @@ async def init_db(
|
||||
):
|
||||
"""创建数据库结构"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# my_project/models/download.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
from .base import TableBase
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .task import Task
|
||||
from .node import Node
|
||||
|
||||
class Download(TableBase, table=True):
|
||||
class DownloadBase(SQLModelBase):
|
||||
pass
|
||||
|
||||
class Download(DownloadBase, UUIDTableBase, table=True):
|
||||
__tablename__ = 'downloads'
|
||||
__table_args__ = (
|
||||
UniqueConstraint("node_id", "g_id", name="uq_download_node_gid"),
|
||||
@@ -36,4 +37,6 @@ class Download(TableBase, table=True):
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="downloads")
|
||||
task: Optional["Task"] = Relationship(back_populates="downloads")
|
||||
node: "Node" = Relationship(back_populates="downloads")
|
||||
node: "Node" = Relationship(back_populates="downloads")
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/file.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/folder.py
|
||||
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/group.py
|
||||
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, text, Column, JSON
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
|
||||
from .setting import Setting
|
||||
from pkg.conf.appmeta import BackendVersion
|
||||
from .response import ThemeModel
|
||||
from pkg.password.pwd import Password
|
||||
from pkg.log import log
|
||||
from loguru import logger as log
|
||||
|
||||
async def migration() -> None:
|
||||
"""
|
||||
@@ -188,37 +189,39 @@ async def init_default_group() -> None:
|
||||
raise RuntimeError(f"无法创建初始游客用户组: {e}")
|
||||
|
||||
async def init_default_user() -> None:
|
||||
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
|
||||
from .user import User
|
||||
from .group import Group
|
||||
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(id=1)
|
||||
|
||||
if not admin_user:
|
||||
# 创建初始管理员用户
|
||||
|
||||
# 获取管理员组
|
||||
admin_group = await Group.get(id=1)
|
||||
if not admin_group:
|
||||
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
|
||||
|
||||
# 生成管理员密码
|
||||
from pkg.password.pwd import Password
|
||||
admin_password = Password.generate(8)
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
email="admin@yxqi.cn",
|
||||
nick="admin",
|
||||
status=True, # 正常状态
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
from .database import get_session
|
||||
|
||||
async for session in get_session():
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(session, User.id == 1)
|
||||
|
||||
if not admin_user:
|
||||
# 创建初始管理员用户
|
||||
|
||||
# 获取管理员组
|
||||
admin_group = await Group.get(id=1)
|
||||
if not admin_group:
|
||||
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
|
||||
|
||||
# 生成管理员密码
|
||||
from pkg.password.pwd import Password
|
||||
admin_password = Password.generate(8)
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
email="admin@yxqi.cn",
|
||||
nick="admin",
|
||||
status=True, # 正常状态
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
|
||||
admin_user = await User.create(admin_user)
|
||||
|
||||
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
|
||||
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')
|
||||
|
||||
admin_user = await admin_user.save(session)
|
||||
|
||||
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
|
||||
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/node.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, text, Column, func, DateTime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/order.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/policy.py
|
||||
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, text, Column, func, DateTime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/redeem.py
|
||||
|
||||
from typing import Optional
|
||||
from sqlmodel import Field, text, Column, func, DateTime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/report.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/setting.py
|
||||
|
||||
from typing import Optional, Literal
|
||||
from sqlmodel import Field, UniqueConstraint, Column, func, DateTime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/share.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/source_link.py
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/storage_pack.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/tag.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/task.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, CheckConstraint
|
||||
@@ -15,7 +14,6 @@ class Task(TableBase, table=True):
|
||||
CheckConstraint("progress BETWEEN 0 AND 100", name="ck_task_progress_range"),
|
||||
)
|
||||
|
||||
|
||||
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误")
|
||||
type: int = Field(description="任务类型")
|
||||
progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务进度 (0-100)")
|
||||
|
||||
225
models/user.py
225
models/user.py
@@ -1,47 +1,58 @@
|
||||
# my_project/models/user.py
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
from .base import TableBase
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
# TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入
|
||||
if TYPE_CHECKING:
|
||||
from .group import Group
|
||||
from .download import Download
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
from .order import Order
|
||||
from .share import Share
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag
|
||||
from .task import Task
|
||||
from .webdav import WebDAV
|
||||
from .group import Group
|
||||
from .download import Download
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
from .order import Order
|
||||
from .share import Share
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag
|
||||
from .task import Task
|
||||
from .webdav import WebDAV
|
||||
|
||||
class User(TableBase, table=True):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email: str = Field(max_length=100, unique=True, index=True, description="用户邮箱,唯一")
|
||||
phone: str = Field(default=None, nullable=True, index=True, description="用户手机号,唯一")
|
||||
email: str = Field(max_length=100, unique=True, index=True)
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
nick: Optional[str] = Field(default=None, max_length=50, description="用户昵称")
|
||||
password: str = Field(max_length=255, description="用户密码(加密后)")
|
||||
status: Optional[bool] = Field(default=None, sa_column_kwargs={"server_default": "0"}, description="用户状态: True=正常, None=未激活, False=封禁")
|
||||
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已用存储空间(字节)")
|
||||
two_factor: Optional[str] = Field(default=None, max_length=255, description="两步验证密钥")
|
||||
avatar: Optional[str] = Field(default=None, max_length=255, description="头像地址")
|
||||
options: Optional[str] = Field(default=None, description="用户个人设置 (JSON格式)")
|
||||
authn: Optional[str] = Field(default=None, description="WebAuthn 凭证")
|
||||
open_id: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="第三方登录OpenID")
|
||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="用户积分")
|
||||
group_expires: Optional[datetime] = Field(default=None, description="当前用户组过期时间")
|
||||
phone: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="手机号")
|
||||
phone: str | None = Field(default=None, nullable=True, index=True)
|
||||
"""用户手机号,唯一"""
|
||||
|
||||
nick: str | None = Field(default=None, max_length=50)
|
||||
"""用户昵称"""
|
||||
password: str = Field(max_length=255)
|
||||
"""用户密码(加密后)"""
|
||||
status: bool | None = Field(default=None, sa_column_kwargs={"server_default": "0"})
|
||||
"""用户状态: True=正常, None=未激活, False=封禁"""
|
||||
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""已用存储空间(字节)"""
|
||||
two_factor: str | None = Field(default=None, max_length=255)
|
||||
"""两步验证密钥"""
|
||||
avatar: str | None = Field(default=None, max_length=255)
|
||||
"""头像地址"""
|
||||
options: str | None = Field(default=None)
|
||||
"""用户个人设置 (JSON格式)"""
|
||||
authn: str | None = Field(default=None)
|
||||
"""WebAuthn 凭证"""
|
||||
open_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
||||
"""第三方登录OpenID"""
|
||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""用户积分"""
|
||||
group_expires: datetime | None = Field(default=None)
|
||||
"""当前用户组过期时间"""
|
||||
phone: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
||||
"""手机号"""
|
||||
|
||||
# 外键
|
||||
group_id: int = Field(foreign_key="groups.id", index=True, description="所属用户组ID")
|
||||
previous_group_id: Optional[int] = Field(default=None, foreign_key="groups.id", description="之前的用户组ID(用于过期后恢复)")
|
||||
group_id: int = Field(foreign_key="groups.id", index=True)
|
||||
"""所属用户组ID"""
|
||||
previous_group_id: int | None = Field(default=None, foreign_key="groups.id")
|
||||
"""之前的用户组ID(用于过期后恢复)"""
|
||||
|
||||
# 关系
|
||||
group: "Group" = Relationship(
|
||||
@@ -66,152 +77,4 @@ class User(TableBase, table=True):
|
||||
tags: list["Tag"] = Relationship(back_populates="user")
|
||||
tasks: list["Task"] = Relationship(back_populates="user")
|
||||
webdavs: list["WebDAV"] = Relationship(back_populates="user")
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
user: Optional["User"] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
向数据库内添加用户。
|
||||
|
||||
:param user: User 实例
|
||||
:type user: User
|
||||
"""
|
||||
if not user:
|
||||
user = User(**kwargs)
|
||||
|
||||
from .database import get_session
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def get(
|
||||
id: int = None,
|
||||
email: str = None
|
||||
) -> Optional["User"]:
|
||||
"""
|
||||
获取用户信息。
|
||||
|
||||
:param id: 用户ID,默认为 None
|
||||
:type id: int
|
||||
:param email: 用户邮箱,默认为 None
|
||||
:type email: str
|
||||
:return: 用户对象或 None
|
||||
:rtype: Optional[User]
|
||||
"""
|
||||
session = get_session()
|
||||
|
||||
if id is None and email is None:
|
||||
return None
|
||||
|
||||
async for session in get_session():
|
||||
query = select(User)
|
||||
if id is not None:
|
||||
query = query.where(User.id == id)
|
||||
if email is not None:
|
||||
query = query.where(User.email == email)
|
||||
|
||||
result = await session.exec(query)
|
||||
user = result.one_or_none()
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
id: int,
|
||||
email: Optional[str] = None,
|
||||
nick: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
status: Optional[int] = None,
|
||||
storage: Optional[int] = None,
|
||||
two_factor: Optional[str] = None,
|
||||
avatar: Optional[str] = None,
|
||||
options: Optional[str] = None,
|
||||
authn: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
score: Optional[int] = None,
|
||||
group_id: Optional[int] = None
|
||||
) -> "User":
|
||||
"""
|
||||
更新用户信息。
|
||||
|
||||
:return: 更新后的用户对象
|
||||
:rtype: User
|
||||
"""
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(User).where(User.id == id)
|
||||
result = await session.exec(statement)
|
||||
user = result.first()
|
||||
|
||||
if user is None:
|
||||
raise ValueError(f"User with id {id} not found.")
|
||||
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if nick is not None:
|
||||
user.nick = nick
|
||||
if password is not None:
|
||||
user.password = password
|
||||
if status is not None:
|
||||
user.status = status
|
||||
if storage is not None:
|
||||
user.storage = storage
|
||||
if two_factor is not None:
|
||||
user.two_factor = two_factor
|
||||
if avatar is not None:
|
||||
user.avatar = avatar
|
||||
if options is not None:
|
||||
user.options = options
|
||||
if authn is not None:
|
||||
user.authn = authn
|
||||
if open_id is not None:
|
||||
user.open_id = open_id
|
||||
if score is not None:
|
||||
user.score = score
|
||||
if group_id is not None:
|
||||
user.group_id = group_id
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def delete(
|
||||
id: int
|
||||
) -> None:
|
||||
"""
|
||||
删除用户。
|
||||
|
||||
:param id: 用户ID
|
||||
:type id: int
|
||||
"""
|
||||
if id == 1:
|
||||
raise ValueError("Cannot delete the default admin user with id 1.")
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(User).where(User.id == id)
|
||||
result = await session.exec(statement)
|
||||
user = result.one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise ValueError(f"User with id {id} not found.")
|
||||
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# my_project/models/webdav.py
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
|
||||
|
||||
Reference in New Issue
Block a user