Refactor code structure for improved readability and maintainability

This commit is contained in:
2025-11-27 20:56:48 +08:00
parent 83276c8b95
commit 1533d9e89c
42 changed files with 5282 additions and 330 deletions

View File

@@ -1,9 +1,7 @@
# my_project/models/__init__.py
from . import response
# 将所有模型导入到这个包的命名空间中
from .base import TableBase
from .user import User
from .download import Download
from .file import File
from .folder import Folder
@@ -19,12 +17,6 @@ from .source_link import SourceLink
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .user import User
from .webdav import WebDAV
# 可以定义一个 __all__ 列表来明确指定可以被 from .models import * 导入的内容
__all__ = [
"TableBase", "Download", "File", "Folder", "Group", "Node", "Order",
"Policy", "Redeem", "Report", "Setting", "Share", "SourceLink",
"StoragePack", "Tag", "Task", "User", "WebDAV"
]
from .database import engine, get_session

View File

@@ -1,28 +0,0 @@
from typing import Optional
from sqlmodel import SQLModel, Field
from sqlalchemy import DateTime
from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import AsyncAttrs
utcnow = lambda: datetime.now(tz=timezone.utc)
class TableBase(SQLModel, AsyncAttrs):
__abstract__ = True
id: Optional[int] = Field(default=None, primary_key=True, description="主键ID")
created_at: datetime = Field(
sa_type=DateTime,
default_factory=utcnow,
description="创建时间",
)
updated_at: datetime = Field(
sa_type=DateTime,
sa_column_kwargs={"default": utcnow, "onupdate": utcnow},
default_factory=utcnow,
description="更新时间",
)
deleted_at: Optional[datetime] = Field(
default=None,
nullable=True,
description="删除时间",
)

2
models/base/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .sqlmodel_base import SQLModelBase
from .table_base import TableBase, UUIDTableBase, now, now_date

View File

@@ -0,0 +1,5 @@
from pydantic import ConfigDict
from sqlmodel import SQLModel
class SQLModelBase(SQLModel):
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)

202
models/base/table_base.py Normal file
View File

@@ -0,0 +1,202 @@
import uuid
from datetime import datetime, timezone
from typing import Union, List, TypeVar, Type, Literal, override, Optional, Any
from fastapi import HTTPException
from sqlalchemy import DateTime, BinaryExpression, ClauseElement
from sqlalchemy.orm import selectinload
from sqlmodel import Field, select, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.sql._typing import _OnClauseArgument
from sqlalchemy.ext.asyncio import AsyncAttrs
from .sqlmodel_base import SQLModelBase
T = TypeVar("T", bound="TableBase")
M = TypeVar("M", bound="SQLModel")
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
class TableBase(SQLModelBase, AsyncAttrs):
id: int | None = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=now)
updated_at: datetime = Field(
sa_type=DateTime,
sa_column_kwargs={"default": now, "onupdate": now},
default_factory=now
)
@classmethod
async def add(cls: Type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | List[T]:
"""
新增一条记录
:param session: 数据库会话
:param instances:
:param refresh:
:return: 新增的实例对象
usage:
item1 = Item(...)
item2 = Item(...)
Item.add(session, [item1, item2])
item1_id = item1.id
"""
is_list = False
if isinstance(instances, list):
is_list = True
session.add_all(instances)
else:
session.add(instances)
await session.commit()
if refresh:
if is_list:
for instance in instances:
await session.refresh(instance)
else:
await session.refresh(instances)
return instances
async def save(self: T, session: AsyncSession, load: Optional[Relationship] = None) -> T:
session.add(self)
await session.commit()
if load is not None:
cls = type(self)
return await cls.get(session, cls.id == self.id, load=load)
else:
await session.refresh(self)
return self
async def update(
self: T,
session: AsyncSession,
other: M,
extra_data: dict = None,
exclude_unset: bool = True
) -> T:
"""
更新记录
:param session: 数据库会话
:param other:
:param extra_data:
:param exclude_unset:
:return:
"""
self.sqlmodel_update(other.model_dump(exclude_unset=exclude_unset), update=extra_data)
session.add(self)
await session.commit()
await session.refresh(self)
return self
@classmethod
async def delete(cls: Type[T], session: AsyncSession, instances: T | list[T]) -> None:
"""
删除一些记录
:param session: 数据库会话
:param instances:
:return: None
usage:
item1 = Item.get(...)
item2 = Item.get(...)
Item.delete(session, [item1, item2])
"""
if isinstance(instances, list):
for instance in instances:
await session.delete(instance)
else:
await session.delete(instances)
await session.commit()
@classmethod
async def get(
cls: Type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None,
*,
offset: int | None = None,
limit: int | None = None,
fetch_mode: Literal["one", "first", "all"] = "first",
join: Type[T] | tuple[Type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: Union[Relationship, None] = None,
order_by: list[ClauseElement] | None = None
) -> T | List[T] | None:
"""
异步获取模型实例
参数:
session: 异步数据库会话
condition: SQLAlchemy查询条件如Model.id == 1
offset: 结果偏移量
limit: 结果数量限制
options: 查询选项如selectinload(Model.relation),异步访问关系属性必备,不然会报错
fetch_mode: 获取模式 - "one"/"all"/"first"
join: 要联接的模型类
返回:
根据fetch_mode返回相应的查询结果
"""
statement = select(cls)
if condition is not None:
statement = statement.where(condition)
if join is not None:
statement = statement.join(*join)
if options:
statement = statement.options(*options)
if load:
statement = statement.options(selectinload(load))
if order_by is not None:
statement = statement.order_by(*order_by)
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await session.exec(statement)
if fetch_mode == "one":
return result.one()
elif fetch_mode == "first":
return result.first()
elif fetch_mode == "all":
return list(result.all())
else:
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
@classmethod
async def get_exist_one(cls: Type[T], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> T:
"""此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None
而是会抛出fastapi 404 异常"""
instance = await cls.get(session, cls.id == id, load=load)
if not instance:
raise HTTPException(status_code=404, detail="Not found")
return instance
class UUIDTableBase(TableBase):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
"""override"""
@override
async def get_exist_one(cls: Type[T], session: AsyncSession, id: uuid.UUID, load: Union[Relationship, None] = None) -> T:
return super().get_exist_one(session, id, load) # type: ignore

View File

@@ -29,4 +29,5 @@ async def init_db(
):
"""创建数据库结构"""
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.run_sync(SQLModel.metadata.create_all)

View File

@@ -1,15 +1,16 @@
# my_project/models/download.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import TableBase
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .user import User
from .task import Task
from .node import Node
class Download(TableBase, table=True):
class DownloadBase(SQLModelBase):
pass
class Download(DownloadBase, UUIDTableBase, table=True):
__tablename__ = 'downloads'
__table_args__ = (
UniqueConstraint("node_id", "g_id", name="uq_download_node_gid"),
@@ -36,4 +37,6 @@ class Download(TableBase, table=True):
# 关系
user: "User" = Relationship(back_populates="downloads")
task: Optional["Task"] = Relationship(back_populates="downloads")
node: "Node" = Relationship(back_populates="downloads")
node: "Node" = Relationship(back_populates="downloads")

View File

@@ -1,4 +1,3 @@
# my_project/models/file.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index

View File

@@ -1,4 +1,3 @@
# my_project/models/folder.py
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint

View File

@@ -1,4 +1,3 @@
# my_project/models/group.py
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, JSON

View File

@@ -1,8 +1,9 @@
from .setting import Setting
from pkg.conf.appmeta import BackendVersion
from .response import ThemeModel
from pkg.password.pwd import Password
from pkg.log import log
from loguru import logger as log
async def migration() -> None:
"""
@@ -188,37 +189,39 @@ async def init_default_group() -> None:
raise RuntimeError(f"无法创建初始游客用户组: {e}")
async def init_default_user() -> None:
log.info('初始化管理员用户...')
from .user import User
from .group import Group
# 检查管理员用户是否存在
admin_user = await User.get(id=1)
if not admin_user:
# 创建初始管理员用户
# 获取管理员
admin_group = await Group.get(id=1)
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
# 生成管理员密码
from pkg.password.pwd import Password
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = User(
email="admin@yxqi.cn",
nick="admin",
status=True, # 正常状态
group_id=admin_group.id,
password=hashed_admin_password,
from .database import get_session
async for session in get_session():
# 检查管理员用户是否存在
admin_user = await User.get(session, User.id == 1)
if not admin_user:
# 创建初始管理员用户
# 获取管理员组
admin_group = await Group.get(id=1)
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
# 生成管理员密码
from pkg.password.pwd import Password
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = User(
email="admin@yxqi.cn",
nick="admin",
status=True, # 正常状态
group_id=admin_group.id,
password=hashed_admin_password,
)
admin_user = await User.create(admin_user)
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')
admin_user = await admin_user.save(session)
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')

View File

@@ -1,4 +1,3 @@
# my_project/models/node.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, func, DateTime

View File

@@ -1,4 +1,3 @@
# my_project/models/order.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, Column, func, DateTime

View File

@@ -1,4 +1,3 @@
# my_project/models/policy.py
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, func, DateTime

View File

@@ -1,4 +1,3 @@
# my_project/models/redeem.py
from typing import Optional
from sqlmodel import Field, text, Column, func, DateTime

View File

@@ -1,4 +1,3 @@
# my_project/models/report.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, Column, func, DateTime

View File

@@ -1,4 +1,3 @@
# my_project/models/setting.py
from typing import Optional, Literal
from sqlmodel import Field, UniqueConstraint, Column, func, DateTime

View File

@@ -1,4 +1,3 @@
# my_project/models/share.py
from typing import Optional, TYPE_CHECKING
from datetime import datetime

View File

@@ -1,4 +1,3 @@
# my_project/models/source_link.py
from typing import TYPE_CHECKING, Optional
from sqlmodel import Field, Relationship, Index

View File

@@ -1,4 +1,3 @@
# my_project/models/storage_pack.py
from typing import Optional, TYPE_CHECKING
from datetime import datetime

View File

@@ -1,4 +1,3 @@
# my_project/models/tag.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime

View File

@@ -1,4 +1,3 @@
# my_project/models/task.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, CheckConstraint
@@ -15,7 +14,6 @@ class Task(TableBase, table=True):
CheckConstraint("progress BETWEEN 0 AND 100", name="ck_task_progress_range"),
)
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误")
type: int = Field(description="任务类型")
progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务进度 (0-100)")

View File

@@ -1,47 +1,58 @@
# my_project/models/user.py
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import TableBase
from .database import get_session
from sqlmodel import select
# TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入
if TYPE_CHECKING:
from .group import Group
from .download import Download
from .file import File
from .folder import Folder
from .order import Order
from .share import Share
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .webdav import WebDAV
from .group import Group
from .download import Download
from .file import File
from .folder import Folder
from .order import Order
from .share import Share
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .webdav import WebDAV
class User(TableBase, table=True):
__tablename__ = 'users'
email: str = Field(max_length=100, unique=True, index=True, description="用户邮箱,唯一")
phone: str = Field(default=None, nullable=True, index=True, description="用户手机号,唯一")
email: str = Field(max_length=100, unique=True, index=True)
"""用户邮箱,唯一"""
nick: Optional[str] = Field(default=None, max_length=50, description="用户昵称")
password: str = Field(max_length=255, description="用户密码(加密后)")
status: Optional[bool] = Field(default=None, sa_column_kwargs={"server_default": "0"}, description="用户状态: True=正常, None=未激活, False=封禁")
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已用存储空间(字节)")
two_factor: Optional[str] = Field(default=None, max_length=255, description="两步验证密钥")
avatar: Optional[str] = Field(default=None, max_length=255, description="头像地址")
options: Optional[str] = Field(default=None, description="用户个人设置 (JSON格式)")
authn: Optional[str] = Field(default=None, description="WebAuthn 凭证")
open_id: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="第三方登录OpenID")
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="用户积分")
group_expires: Optional[datetime] = Field(default=None, description="当前用户组过期时间")
phone: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="手机号")
phone: str | None = Field(default=None, nullable=True, index=True)
"""用户手机号,唯一"""
nick: str | None = Field(default=None, max_length=50)
"""用户昵称"""
password: str = Field(max_length=255)
"""用户密码(加密后)"""
status: bool | None = Field(default=None, sa_column_kwargs={"server_default": "0"})
"""用户状态: True=正常, None=未激活, False=封禁"""
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""已用存储空间(字节)"""
two_factor: str | None = Field(default=None, max_length=255)
"""两步验证密钥"""
avatar: str | None = Field(default=None, max_length=255)
"""头像地址"""
options: str | None = Field(default=None)
"""用户个人设置 (JSON格式)"""
authn: str | None = Field(default=None)
"""WebAuthn 凭证"""
open_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
"""第三方登录OpenID"""
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""用户积分"""
group_expires: datetime | None = Field(default=None)
"""当前用户组过期时间"""
phone: str | None = Field(default=None, max_length=255, unique=True, index=True)
"""手机号"""
# 外键
group_id: int = Field(foreign_key="groups.id", index=True, description="所属用户组ID")
previous_group_id: Optional[int] = Field(default=None, foreign_key="groups.id", description="之前的用户组ID用于过期后恢复")
group_id: int = Field(foreign_key="groups.id", index=True)
"""所属用户组ID"""
previous_group_id: int | None = Field(default=None, foreign_key="groups.id")
"""之前的用户组ID用于过期后恢复"""
# 关系
group: "Group" = Relationship(
@@ -66,152 +77,4 @@ class User(TableBase, table=True):
tags: list["Tag"] = Relationship(back_populates="user")
tasks: list["Task"] = Relationship(back_populates="user")
webdavs: list["WebDAV"] = Relationship(back_populates="user")
@staticmethod
async def create(
user: Optional["User"] = None,
**kwargs
):
"""
向数据库内添加用户。
:param user: User 实例
:type user: User
"""
if not user:
user = User(**kwargs)
from .database import get_session
async for session in get_session():
try:
session.add(user)
await session.commit()
await session.refresh(user)
except Exception as e:
await session.rollback()
raise e
return user
@staticmethod
async def get(
id: int = None,
email: str = None
) -> Optional["User"]:
"""
获取用户信息。
:param id: 用户ID默认为 None
:type id: int
:param email: 用户邮箱,默认为 None
:type email: str
:return: 用户对象或 None
:rtype: Optional[User]
"""
session = get_session()
if id is None and email is None:
return None
async for session in get_session():
query = select(User)
if id is not None:
query = query.where(User.id == id)
if email is not None:
query = query.where(User.email == email)
result = await session.exec(query)
user = result.one_or_none()
return user
@staticmethod
async def update(
id: int,
email: Optional[str] = None,
nick: Optional[str] = None,
password: Optional[str] = None,
status: Optional[int] = None,
storage: Optional[int] = None,
two_factor: Optional[str] = None,
avatar: Optional[str] = None,
options: Optional[str] = None,
authn: Optional[str] = None,
open_id: Optional[str] = None,
score: Optional[int] = None,
group_id: Optional[int] = None
) -> "User":
"""
更新用户信息。
:return: 更新后的用户对象
:rtype: User
"""
async for session in get_session():
try:
statement = select(User).where(User.id == id)
result = await session.exec(statement)
user = result.first()
if user is None:
raise ValueError(f"User with id {id} not found.")
if email is not None:
user.email = email
if nick is not None:
user.nick = nick
if password is not None:
user.password = password
if status is not None:
user.status = status
if storage is not None:
user.storage = storage
if two_factor is not None:
user.two_factor = two_factor
if avatar is not None:
user.avatar = avatar
if options is not None:
user.options = options
if authn is not None:
user.authn = authn
if open_id is not None:
user.open_id = open_id
if score is not None:
user.score = score
if group_id is not None:
user.group_id = group_id
await session.commit()
await session.refresh(user)
return user
except Exception as e:
await session.rollback()
raise e
@staticmethod
async def delete(
id: int
) -> None:
"""
删除用户。
:param id: 用户ID
:type id: int
"""
if id == 1:
raise ValueError("Cannot delete the default admin user with id 1.")
async for session in get_session():
try:
statement = select(User).where(User.id == id)
result = await session.exec(statement)
user = result.one_or_none()
if user is None:
raise ValueError(f"User with id {id} not found.")
await session.delete(user)
await session.commit()
except Exception as e:
await session.rollback()
raise e

View File

@@ -1,4 +1,3 @@
# my_project/models/webdav.py
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime