Files
disknext/sqlmodels/mixin/table.py
于小丘 209cb24ab4 feat: add models for physical files, policies, and user management
- Implement PhysicalFile model to manage physical file references and reference counting.
- Create Policy model with associated options and group links for storage policies.
- Introduce Redeem and Report models for handling redeem codes and reports.
- Add Settings model for site configuration and user settings management.
- Develop Share model for sharing objects with unique codes and associated metadata.
- Implement SourceLink model for managing download links associated with objects.
- Create StoragePack model for managing user storage packages.
- Add Tag model for user-defined tags with manual and automatic types.
- Implement Task model for managing background tasks with status tracking.
- Develop User model with comprehensive user management features including authentication.
- Introduce UserAuthn model for managing WebAuthn credentials.
- Create WebDAV model for managing WebDAV accounts associated with users.
2026-02-10 19:07:48 +08:00

1248 lines
54 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
表基类 Mixin
提供 TableBaseMixin、UUIDTableBaseMixin 和 TableViewRequest。
这些类实际上是 Mixin为 SQLModel 模型提供 CRUD 操作和时间戳字段。
依赖关系:
base/sqlmodel_base.py ← 最底层
mixin/polymorphic.py ← 定义 PolymorphicBaseMixin
mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin
base/__init__.py ← 从 mixin 重新导出(保持向后兼容)
维护须知:
增删功能时必须更新 __version__ 字段(遵循语义化版本)
版本历史:
0.1.0 - delete() 方法支持条件删除condition 参数)
"""
__version__ = "0.1.0"
import uuid
from datetime import datetime
from typing import TypeVar, Literal, override, Any, ClassVar, Generic
# TODO(ListResponse泛型问题): SQLModel泛型类型JSON Schema生成bug
# 已知问题: https://github.com/fastapi/sqlmodel/discussions/1002
# 修复PR: https://github.com/fastapi/sqlmodel/pull/1275 (尚未合并)
# 现象: SQLModel + Generic[T] 的 __pydantic_generic_metadata__ = {origin: None, args: ()}
# 导致OpenAPI schema中泛型字段显示为{}而非正确的$ref
# 当前方案: ListResponse继承BaseModel而非SQLModel (Discussion #1002推荐的workaround)
# 未来: PR #1275合并后可改回继承SQLModelBase
from pydantic import BaseModel, ConfigDict
from fastapi import HTTPException
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct, delete as sql_delete, inspect
from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
from sqlalchemy.orm.exc import StaleDataError
from sqlmodel import Field, select
from .optimistic_lock import OptimisticLockError
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.sql._typing import _OnClauseArgument
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel.main import RelationshipInfo
from .polymorphic import PolymorphicBaseMixin
from sqlmodels.base.sqlmodel_base import SQLModelBase
# Type variables for generic type hints, improving code completion and analysis.
T = TypeVar("T", bound="TableBaseMixin")
M = TypeVar("M", bound="SQLModelBase")
ItemT = TypeVar("ItemT")
class ListResponse(BaseModel, Generic[ItemT]):
"""
泛型分页响应
用于所有LIST端点的标准化响应格式包含记录总数和项目列表。
与 TableBaseMixin.get_with_count() 配合使用。
使用示例:
```python
@router.get("", response_model=ListResponse[CharacterInfoResponse])
async def list_characters(...) -> ListResponse[Character]:
return await Character.get_with_count(session, table_view=table_view)
```
Attributes:
count: 符合条件的记录总数(用于分页计算)
items: 当前页的记录列表
Note:
继承BaseModel而非SQLModelBase因为SQLModel的metaclass与Generic冲突。
详见文件顶部TODO注释。
"""
# 与SQLModelBase保持一致的配置
model_config = ConfigDict(use_attribute_docstrings=True)
count: int
"""符合条件的记录总数"""
items: list[ItemT]
"""当前页的记录列表"""
# Lambda functions to get the current time, used as default factories in model fields.
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
# ==================== 查询参数请求类 ====================
class TimeFilterRequest(SQLModelBase):
"""
时间筛选请求参数
用于 count() 等只需要时间筛选的场景。
纯数据类只负责参数校验和携带SQL子句构建由 TableBaseMixin 负责。
Raises:
ValueError: 时间范围无效
"""
created_after_datetime: datetime | None = None
"""创建时间起始筛选created_at >= datetime如果为None则不限制"""
created_before_datetime: datetime | None = None
"""创建时间结束筛选created_at < datetime如果为None则不限制"""
updated_after_datetime: datetime | None = None
"""更新时间起始筛选updated_at >= datetime如果为None则不限制"""
updated_before_datetime: datetime | None = None
"""更新时间结束筛选updated_at < datetime如果为None则不限制"""
def model_post_init(self, __context: Any) -> None:
"""
验证时间范围有效性
验证规则:
1. 同类型after 必须小于 before
2. 跨类型created_after 不能大于 updated_before记录不可能在创建前被更新
"""
# 同类型矛盾验证
if self.created_after_datetime and self.created_before_datetime:
if self.created_after_datetime >= self.created_before_datetime:
raise ValueError("created_after_datetime 必须小于 created_before_datetime")
if self.updated_after_datetime and self.updated_before_datetime:
if self.updated_after_datetime >= self.updated_before_datetime:
raise ValueError("updated_after_datetime 必须小于 updated_before_datetime")
# 跨类型矛盾验证created_after >= updated_before 意味着要求创建时间晚于或等于更新时间上界,逻辑矛盾
if self.created_after_datetime and self.updated_before_datetime:
if self.created_after_datetime >= self.updated_before_datetime:
raise ValueError(
"created_after_datetime 不能大于或等于 updated_before_datetime"
"(记录的更新时间不可能早于或等于创建时间)"
)
class PaginationRequest(SQLModelBase):
"""
分页排序请求参数
用于需要分页和排序的场景。
纯数据类只负责携带参数SQL子句构建由 TableBaseMixin 负责。
"""
offset: int | None = Field(default=0, ge=0)
"""偏移量跳过前N条记录必须为非负整数"""
limit: int | None = Field(default=50, le=100)
"""每页数量返回最多N条记录默认50最大100"""
desc: bool | None = True
"""是否降序排序True: 降序, False: 升序)"""
order: Literal["created_at", "updated_at"] | None = "created_at"
"""排序字段created_at: 创建时间, updated_at: 更新时间)"""
class TableViewRequest(TimeFilterRequest, PaginationRequest):
"""
表格视图请求参数(分页、排序和时间筛选)
组合继承 TimeFilterRequest 和 PaginationRequest用于 get() 等需要完整查询参数的场景。
纯数据类SQL子句构建由 TableBaseMixin 负责。
使用示例:
```python
# 在端点中使用依赖注入
@router.get("/list")
async def list_items(
session: SessionDep,
table_view: TableViewRequestDep
):
items = await Item.get(
session,
fetch_mode="all",
table_view=table_view
)
return items
# 直接使用
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
items = await Item.get(session, fetch_mode="all", table_view=table_view)
```
"""
pass
# ==================== TableBaseMixin ====================
class TableBaseMixin(AsyncAttrs):
"""
一个异步 CRUD 操作的基础模型类 Mixin.
此类必须搭配SQLModelBase使用
此类为所有继承它的 SQLModel 模型提供了通用的数据库操作方法,
例如 add, save, update, delete, 和 get. 它还包括自动管理
的 `created_at` 和 `updated_at` 时间戳字段.
Attributes:
id (int | None): 整数主键, 自动递增.
created_at (datetime): 记录创建时的时间戳, 自动设置.
updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
"""
_has_table_mixin: ClassVar[bool] = True
"""标记此类继承了表混入类的内部属性"""
def __init_subclass__(cls, **kwargs):
"""
接受并传递子类定义时的关键字参数
这允许元类 __DeclarativeMeta 处理的参数(如 table_args
能够正确传递,而不会在 __init_subclass__ 阶段报错。
"""
super().__init_subclass__(**kwargs)
id: int | None = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=now)
updated_at: datetime = Field(
sa_type=DateTime,
sa_column_kwargs={'default': now, 'onupdate': now},
default_factory=now
)
@classmethod
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
"""
向数据库中添加一个新的或多个新的记录.
这个类方法可以接受单个模型实例或一个实例列表,并将它们
一次性提交到数据库中。执行后,可以选择性地刷新这些实例以获取
数据库生成的值(例如,自动递增的 ID.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
Returns:
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
Usage:
item1 = Item(name="Apple")
item2 = Item(name="Banana")
# 添加多个实例
added_items = await Item.add(session, [item1, item2])
# 添加单个实例
item3 = Item(name="Cherry")
added_item = await Item.add(session, item3)
"""
is_list = False
if isinstance(instances, list):
is_list = True
session.add_all(instances)
else:
session.add(instances)
await session.commit()
if refresh:
if is_list:
for instance in instances:
await session.refresh(instance)
else:
await session.refresh(instances)
return instances
async def save(
self: T,
session: AsyncSession,
load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True,
commit: bool = True,
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
optimistic_retry_count: int = 0,
) -> T:
"""
保存(插入或更新)当前模型实例到数据库.
这是一个实例方法,它将当前对象添加到会话中并提交更改。
可以用于创建新记录或更新现有记录。还可以选择在保存后
预加载eager load一个关联关系.
**重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值:
```python
# ✅ 正确:需要返回值时
client = await client.save(session)
return client
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.save(session, refresh=False)
# ✅ 正确:批量操作时延迟提交
for item in items:
item = await item.save(session, commit=False)
await session.commit()
# ✅ 正确:保存后需要访问多态关系时
tool_set = await tool_set.save(session, load=ToolSet.tools, jti_subclasses='all')
return tool_set # tools 关系已正确加载子类数据
# ✅ 正确:启用乐观锁自动重试
order = await order.save(session, optimistic_retry_count=3)
# ❌ 错误:需要返回值但未使用
await client.save(session)
return client # client 对象已过期
```
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
例如 `User.posts`.
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
commit (bool): 是否在保存后提交事务。如果为 False只会 flush 获取 ID
但不提交,适用于批量操作场景。默认为 True.
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询,只加载实际关联的子类
- None默认: 不使用多态加载
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0不重试
重试时会重新查询最新数据,将当前修改合并后再次保存。
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
Raises:
OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
"""
cls = type(self)
instance = self
retries_remaining = optimistic_retry_count
current_data: dict[str, Any] | None = None # 延迟计算,仅在需要重试时
while True:
session.add(instance)
try:
if commit:
await session.commit()
else:
await session.flush()
break # 成功,退出循环
except StaleDataError as e:
await session.rollback()
if retries_remaining <= 0:
raise OptimisticLockError(
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
model_class=cls.__name__,
record_id=str(getattr(instance, 'id', None)),
expected_version=getattr(instance, 'version', None),
original_error=e,
) from e
# 失败后重试:重新查询最新数据并合并修改
retries_remaining -= 1
if current_data is None:
current_data = self.model_dump(exclude={'id', 'version', 'created_at', 'updated_at'})
fresh = await cls.get(session, cls.id == self.id)
if fresh is None:
raise OptimisticLockError(
message=f"{cls.__name__} 重试失败:记录已被删除",
model_class=cls.__name__,
record_id=str(getattr(self, 'id', None)),
original_error=e,
) from e
for key, value in current_data.items():
if hasattr(fresh, key):
setattr(fresh, key, value)
instance = fresh
if not refresh:
return instance
if load is not None:
await session.refresh(instance)
return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
else:
await session.refresh(instance)
return instance
async def update(
self: T,
session: AsyncSession,
other: M,
extra_data: dict[str, Any] | None = None,
exclude_unset: bool = True,
exclude: set[str] | None = None,
load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True,
commit: bool = True,
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
optimistic_retry_count: int = 0,
) -> T:
"""
使用另一个模型实例或字典中的数据来更新当前实例.
此方法将 `other` 对象中的数据合并到当前实例中。默认情况下,
它只会更新 `other` 中被显式设置的字段.
**重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值:
```python
# ✅ 正确:需要返回值时
client = await client.update(session, update_data)
return client
# ✅ 正确:需要返回值且需要加载关系时
user = await user.update(session, update_data, load=User.permission)
return user
# ✅ 正确:更新后需要访问多态关系时
tool_set = await tool_set.update(session, data, load=ToolSet.tools, jti_subclasses='all')
return tool_set # tools 关系已正确加载子类数据
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.update(session, update_data, refresh=False)
# ✅ 正确:批量操作时延迟提交
for item in items:
item = await item.update(session, data, commit=False)
await session.commit()
# ✅ 正确:启用乐观锁自动重试
order = await order.update(session, update_data, optimistic_retry_count=3)
# ❌ 错误:需要返回值但未使用
await client.update(session, update_data)
return client # client 对象已过期
```
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
other (M): 一个 SQLModel 或 Pydantic 模型实例,其数据将用于更新当前实例.
extra_data (dict, optional): 一个额外的字典,用于更新当前实例的特定字段.
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
的字段将被忽略. 默认为 True.
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
例如 `User.permission`.
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
commit (bool): 是否在更新后提交事务。如果为 False只会 flush
但不提交,适用于批量操作场景。默认为 True.
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询,只加载实际关联的子类
- None默认: 不使用多态加载
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0不重试
重试时会重新查询最新数据,将 other 的更新重新应用后再次保存。
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
Raises:
OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
"""
cls = type(self)
update_data = other.model_dump(exclude_unset=exclude_unset, exclude=exclude)
instance = self
retries_remaining = optimistic_retry_count
while True:
instance.sqlmodel_update(update_data, update=extra_data)
session.add(instance)
try:
if commit:
await session.commit()
else:
await session.flush()
break # 成功,退出循环
except StaleDataError as e:
await session.rollback()
if retries_remaining <= 0:
raise OptimisticLockError(
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
model_class=cls.__name__,
record_id=str(getattr(instance, 'id', None)),
expected_version=getattr(instance, 'version', None),
original_error=e,
) from e
# 失败后重试:重新查询最新数据并重新应用更新
retries_remaining -= 1
fresh = await cls.get(session, cls.id == self.id)
if fresh is None:
raise OptimisticLockError(
message=f"{cls.__name__} 重试失败:记录已被删除",
model_class=cls.__name__,
record_id=str(getattr(self, 'id', None)),
original_error=e,
) from e
instance = fresh
if not refresh:
return instance
if load is not None:
await session.refresh(instance)
return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
else:
await session.refresh(instance)
return instance
@classmethod
async def delete(
cls: type[T],
session: AsyncSession,
instances: T | list[T] | None = None,
*,
condition: BinaryExpression | ClauseElement | None = None,
commit: bool = True,
) -> int:
"""
从数据库中删除记录,支持实例删除和条件删除两种模式。
Args:
session: 用于数据库操作的异步会话对象
instances: 要删除的单个模型实例或模型实例列表(实例删除模式)
condition: WHERE 条件表达式(条件删除模式,直接执行 SQL DELETE
commit: 是否在删除后提交事务。默认为 True
Returns:
删除的记录数(条件删除模式返回实际删除数,实例删除模式返回实例数)
Raises:
ValueError: 同时提供 instances 和 condition或两者都未提供
Usage:
# 实例删除模式
item = await Item.get(session, Item.id == 1)
if item:
await Item.delete(session, item)
items = await Item.get(session, Item.name.in_(["A", "B"]), fetch_mode="all")
if items:
await Item.delete(session, items)
# 条件删除模式(高效批量删除,不加载实例到内存)
deleted_count = await Item.delete(
session,
condition=(Item.user_id == user_id) & (Item.status == "expired"),
)
"""
if instances is not None and condition is not None:
raise ValueError("不能同时提供 instances 和 condition 参数")
if instances is None and condition is None:
raise ValueError("必须提供 instances 或 condition 参数之一")
deleted_count = 0
if condition is not None:
# 条件删除模式:直接执行 SQL DELETE
stmt = sql_delete(cls).where(condition)
result = await session.execute(stmt)
deleted_count = result.rowcount
else:
# 实例删除模式
if isinstance(instances, list):
for instance in instances:
await session.delete(instance)
deleted_count = len(instances)
else:
await session.delete(instances)
deleted_count = 1
if commit:
await session.commit()
return deleted_count
@classmethod
def _build_time_filters(
cls: type[T],
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> list[BinaryExpression]:
"""
构建时间筛选条件列表
Args:
created_before_datetime: 筛选 created_at < datetime 的记录
created_after_datetime: 筛选 created_at >= datetime 的记录
updated_before_datetime: 筛选 updated_at < datetime 的记录
updated_after_datetime: 筛选 updated_at >= datetime 的记录
Returns:
BinaryExpression 条件列表
"""
filters: list[BinaryExpression] = []
if created_after_datetime is not None:
filters.append(cls.created_at >= created_after_datetime)
if created_before_datetime is not None:
filters.append(cls.created_at < created_before_datetime)
if updated_after_datetime is not None:
filters.append(cls.updated_at >= updated_after_datetime)
if updated_before_datetime is not None:
filters.append(cls.updated_at < updated_before_datetime)
return filters
@classmethod
async def get(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
offset: int | None = None,
limit: int | None = None,
fetch_mode: Literal["one", "first", "all"] = "first",
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | list[RelationshipInfo] | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
with_for_update: bool = False,
table_view: TableViewRequest | None = None,
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
populate_existing: bool = False,
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> T | list[T] | None:
"""
根据指定的条件异步地从数据库中获取一个或多个模型实例.
这是一个功能强大的通用查询方法,支持过滤、排序、分页、连接查询和关联关系预加载.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
condition (BinaryExpression | ClauseElement | None): 主要的查询过滤条件,
例如 `User.id == 1`。
当为 `None` 时,表示无条件查询(查询所有记录)。
offset (int | None): 查询结果的起始偏移量, 用于分页.
limit (int | None): 返回记录的最大数量, 用于分页.
fetch_mode (Literal["one", "first", "all"]):
- "one": 获取唯一的一条记录. 如果找不到或找到多条,会引发异常.
- "first": 获取查询结果的第一条记录. 如果找不到,返回 `None`.
- "all": 获取所有匹配的记录,返回一个列表.
默认为 "first".
join (type[T] | tuple[type[T], _OnClauseArgument] | None):
要 JOIN 的模型类或一个包含模型类和 ON 子句的元组.
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
例如 `[selectinload(User.posts)]`.
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
可以是单个关系或关系列表。支持嵌套关系预加载:
当传入多个关系时,会自动检测依赖关系并构建链式 selectinload。
例如 `[NodeGroupNode.element_links, NodeGroupElementLink.node]`
会自动构建 `selectinload(element_links).selectinload(node)`。
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
例如 `[User.name.asc(), User.created_at.desc()]`.
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
with_for_update (bool): 如果为 True, 在查询中使用 `FOR UPDATE` 锁定选定的行. 默认为 False.
table_view (TableViewRequest | None): TableViewRequest对象如果提供则自动处理分页、排序和时间筛选。
会覆盖offset、limit、order_by及时间筛选参数。
这是推荐的分页排序方式统一了所有LIST端点的参数格式。
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益)
- None默认: 不使用多态加载
populate_existing (bool): 如果为 True强制用数据库数据覆盖 session 中已存在的对象identity map
用于批量刷新对象,避免循环调用 session.refresh() 导致的 N 次查询。
注意只刷新标量字段不影响运行时属性_开头的属性
对于 STI单表继承对象推荐按子类分组查询以包含子类字段。默认为 False。
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
updated_after_datetime (datetime | None): 筛选 updated_at >= datetime 的记录
Returns:
T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
Raises:
ValueError: 如果提供了无效的 `fetch_mode` 值,或 jti_subclasses 未与 load 配合使用.
Examples:
# 使用table_view参数推荐
users = await User.get(session, fetch_mode="all", table_view=table_view_args)
# 传统方式(向后兼容)
users = await User.get(session, fetch_mode="all", offset=0, limit=20, order_by=[desc(User.created_at)])
# 使用多态加载(加载联表继承的子类数据)
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
jti_subclasses='all' # 只加载实际关联的子类
)
"""
# 参数验证jti_subclasses 需要与 load 配合使用
if jti_subclasses is not None and load is None:
raise ValueError(
"jti_subclasses 参数需要与 load 参数配合使用,"
"请同时指定要加载的关系"
)
# 如果提供table_view作为默认值使用单独传入的参数优先级更高
if table_view:
# 处理时间筛选TimeFilterRequest 及其子类)
if isinstance(table_view, TimeFilterRequest):
if created_after_datetime is None and table_view.created_after_datetime is not None:
created_after_datetime = table_view.created_after_datetime
if created_before_datetime is None and table_view.created_before_datetime is not None:
created_before_datetime = table_view.created_before_datetime
if updated_after_datetime is None and table_view.updated_after_datetime is not None:
updated_after_datetime = table_view.updated_after_datetime
if updated_before_datetime is None and table_view.updated_before_datetime is not None:
updated_before_datetime = table_view.updated_before_datetime
# 处理分页排序PaginationRequest 及其子类,包括 TableViewRequest
if isinstance(table_view, PaginationRequest):
if offset is None:
offset = table_view.offset
if limit is None:
limit = table_view.limit
# 仅在未显式传入order_by时从table_view构建排序子句
if order_by is None:
order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at
order_by = [desc(order_column) if table_view.desc else asc(order_column)]
# 对于多态基类,使用 with_polymorphic 预加载所有子类的列
# 这避免了在响应序列化时的延迟加载问题MissingGreenlet 错误)
polymorphic_cls = None # 保存多态实体,用于子类关系预加载
is_polymorphic = issubclass(cls, PolymorphicBaseMixin)
is_jti = is_polymorphic and cls._is_joined_table_inheritance()
is_sti = is_polymorphic and not cls._is_joined_table_inheritance()
# JTI 模式:总是使用 with_polymorphic避免 N+1 查询)
# STI 模式:不使用 with_polymorphic批量刷新时请按子类分组查询
if is_jti:
# '*' 表示加载所有子类
polymorphic_cls = with_polymorphic(cls, '*')
statement = select(polymorphic_cls)
else:
statement = select(cls)
# 对于 STI单表继承子类自动添加多态过滤条件
# SQLAlchemy/SQLModel 在 STI 模式下不会自动添加 WHERE discriminator = 'identity' 过滤
# 这是已知行为,参考:
# - https://github.com/sqlalchemy/sqlalchemy/issues/5018 (bulk operations 不自动添加多态过滤)
# - https://github.com/fastapi/sqlmodel/issues/488 (SQLModel STI 支持不完整)
# 社区最佳实践是显式添加多态过滤条件
if issubclass(cls, PolymorphicBaseMixin) and not cls._is_joined_table_inheritance():
mapper = inspect(cls)
# 检查是否有 polymorphic_identity 且不是抽象类
if mapper.polymorphic_identity is not None and not mapper.polymorphic_abstract:
poly_on = mapper.polymorphic_on
if poly_on is not None:
statement = statement.where(poly_on == mapper.polymorphic_identity)
if condition is not None:
statement = statement.where(condition)
# 应用时间筛选
for time_filter in cls._build_time_filters(
created_before_datetime, created_after_datetime,
updated_before_datetime, updated_after_datetime
):
statement = statement.where(time_filter)
if join is not None:
# 如果 join 是一个元组,解包它;否则直接使用
if isinstance(join, tuple):
statement = statement.join(*join)
else:
statement = statement.join(join)
if options:
statement = statement.options(*options)
if load:
# 标准化为列表
load_list = load if isinstance(load, list) else [load]
# 构建链式 selectinload支持嵌套关系预加载
# 例如load=[NodeGroupNode.element_links, NodeGroupElementLink.node]
# 会构建selectinload(element_links).selectinload(node)
load_chains = cls._build_load_chains(load_list)
# 处理多态加载(仅支持单链且只有一个关系)
if jti_subclasses is not None:
if len(load_chains) > 1 or len(load_chains[0]) > 1:
raise ValueError(
"jti_subclasses 仅支持单个关系(无嵌套链),请不要传入多个关系"
)
single_load = load_chains[0][0]
target_class = single_load.property.mapper.class_
# 检查目标类是否继承自 PolymorphicBaseMixin
if not issubclass(target_class, PolymorphicBaseMixin):
raise ValueError(
f"目标类 {target_class.__name__} 不是多态类,"
f"请确保其继承自 PolymorphicBaseMixin"
)
if jti_subclasses == 'all':
# 两阶段查询:获取实际关联的多态类型
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
session, condition, single_load, target_class
)
else:
subclasses_to_load = jti_subclasses
if subclasses_to_load:
# 关键selectin_polymorphic 必须作为 selectinload 的链式子选项
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
statement = statement.options(
selectinload(single_load).selectin_polymorphic(subclasses_to_load)
)
else:
statement = statement.options(selectinload(single_load))
else:
# 为每条链构建链式 selectinload
for chain in load_chains:
# 获取第一个关系并检查是否需要通过多态实体访问
first_rel = chain[0]
first_rel_parent = first_rel.property.parent.class_
# 如果关系的 parent_class 是当前类的子类(不是 cls 本身),
# 且当前是多态查询,则需要通过 polymorphic_cls.SubclassName 访问
if (
polymorphic_cls is not None
and first_rel_parent is not cls
and issubclass(first_rel_parent, cls)
):
# 通过多态实体访问子类的关系属性
# 例如polymorphic_cls.NodeGroupNode.element_links
subclass_alias = getattr(polymorphic_cls, first_rel_parent.__name__)
rel_name = first_rel.key
first_rel_via_poly = getattr(subclass_alias, rel_name)
loader = selectinload(first_rel_via_poly)
else:
loader = selectinload(first_rel)
for rel in chain[1:]:
loader = loader.selectinload(rel)
statement = statement.options(loader)
if order_by is not None:
statement = statement.order_by(*order_by)
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
if filter:
statement = statement.filter(filter)
if with_for_update:
# 对于联表继承的多态模型,使用 FOR UPDATE OF <主表> 来避免 PostgreSQL 的限制
# PostgreSQL 不支持在 LEFT OUTER JOIN 的可空侧使用 FOR UPDATE
if issubclass(cls, PolymorphicBaseMixin):
statement = statement.with_for_update(of=cls)
else:
statement = statement.with_for_update()
if populate_existing:
# 强制用数据库数据覆盖 identity map 中的对象
# 用于批量刷新,避免循环 refresh() 的 N 次查询
statement = statement.execution_options(populate_existing=True)
result = await session.exec(statement)
if fetch_mode == "one":
return result.one()
elif fetch_mode == "first":
return result.first()
elif fetch_mode == "all":
return list(result.all())
else:
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
@staticmethod
def _build_load_chains(load_list: list[RelationshipInfo]) -> list[list[RelationshipInfo]]:
"""
将关系列表构建为链式加载结构
自动检测关系之间的依赖关系,构建嵌套预加载链。
例如:[NodeGroupNode.element_links, NodeGroupElementLink.node]
会构建:[[element_links, node]](一条链)
算法:
1. 获取每个关系的 parent class 和 target class
2. 如果关系 B 的 parent class 等于关系 A 的 target class则 B 链在 A 后面
3. 独立的关系各自成为一条链
Args:
load_list: 关系属性列表
Returns:
链式关系列表,每条链是一个关系列表
"""
if not load_list:
return []
# 构建关系信息:{关系: (parent_class, target_class)}
rel_info: dict[RelationshipInfo, tuple[type, type]] = {}
for rel in load_list:
parent_class = rel.property.parent.class_
target_class = rel.property.mapper.class_
rel_info[rel] = (parent_class, target_class)
# 构建依赖图:{关系: 其前置关系}
predecessors: dict[RelationshipInfo, RelationshipInfo | None] = {rel: None for rel in load_list}
for rel_b in load_list:
parent_b, _ = rel_info[rel_b]
for rel_a in load_list:
if rel_a is rel_b:
continue
_, target_a = rel_info[rel_a]
# 如果 B 的 parent 精确等于 A 的 target则 B 链在 A 后面
# 使用精确匹配避免继承关系导致的误判(如 NodeGroupNode 是 CanvasNode 子类)
if parent_b is target_a:
predecessors[rel_b] = rel_a
break
# 找出所有链的起点(没有前置关系的)
roots = [rel for rel, pred in predecessors.items() if pred is None]
# 构建链
chains: list[list[RelationshipInfo]] = []
used: set[RelationshipInfo] = set()
for root in roots:
chain = [root]
used.add(root)
# 找后续节点
current = root
while True:
# 找以 current 的 target 为 parent 的关系
_, current_target = rel_info[current]
next_rel = None
for rel, (parent, _) in rel_info.items():
if rel not in used and parent is current_target:
next_rel = rel
break
if next_rel is None:
break
chain.append(next_rel)
used.add(next_rel)
current = next_rel
chains.append(chain)
return chains
@classmethod
async def _resolve_polymorphic_subclasses(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None,
load: RelationshipInfo,
target_class: type[PolymorphicBaseMixin]
) -> list[type[PolymorphicBaseMixin]]:
"""
查询实际关联的多态子类类型
通过查询多态鉴别字段确定实际存在的子类类型,
避免加载所有可能的子类表(对于 > 10 个子类的场景有明显收益)。
:param session: 数据库会话
:param condition: 主查询的条件
:param load: 关系属性
:param target_class: 多态基类
:return: 实际关联的子类列表
"""
# 获取多态鉴别字段(会抛出 ValueError 如果未配置)
discriminator = target_class.get_polymorphic_discriminator()
poly_name_col = getattr(target_class, discriminator)
# 获取关系属性
relationship_property = load.property
# 构建查询获取实际的多态类型名称
if relationship_property.secondary is not None:
# 多对多关系:通过中间表查询
secondary = relationship_property.secondary
local_cols = list(relationship_property.local_columns)
type_query = (
select(distinct(poly_name_col))
.select_from(target_class)
.join(secondary)
.where(secondary.c[local_cols[0].name].in_(
select(cls.id).where(condition) if condition is not None else select(cls.id)
))
)
else:
# 多对一/一对多关系:通过外键查询
# local_remote_pairs[0] = (local_fk_col, remote_pk_col)
# 对于多对一local 是当前类的外键remote 是目标类的主键
local_fk_col = relationship_property.local_remote_pairs[0][0]
remote_pk_col = relationship_property.local_remote_pairs[0][1]
type_query = (
select(distinct(poly_name_col))
.where(remote_pk_col.in_(
select(local_fk_col).where(condition) if condition is not None else select(local_fk_col)
))
)
type_result = await session.exec(type_query)
poly_names = list(type_result.all())
if not poly_names:
return []
# 映射到子类(包含所有层级的具体子类)
identity_map = target_class.get_identity_to_class_map()
return [identity_map[name] for name in poly_names if name in identity_map]
@classmethod
async def count(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
time_filter: TimeFilterRequest | None = None,
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> int:
"""
根据条件统计记录数量(支持时间筛选)
使用数据库层面的 COUNT() 聚合函数,比 get() + len() 更高效。
Args:
session: 数据库会话
condition: 查询条件,例如 `User.is_active == True`
time_filter: TimeFilterRequest 对象(优先级更高)
created_before_datetime: 筛选 created_at < datetime 的记录
created_after_datetime: 筛选 created_at >= datetime 的记录
updated_before_datetime: 筛选 updated_at < datetime 的记录
updated_after_datetime: 筛选 updated_at >= datetime 的记录
Returns:
符合条件的记录数量
Examples:
# 统计所有用户
total = await User.count(session)
# 统计激活的虚拟客户端
count = await Client.count(
session,
(Client.user_id == user_id) & (Client.type != ClientTypeEnum.physical) & (Client.is_active == True)
)
# 使用 TimeFilterRequest 进行时间筛选
count = await User.count(session, time_filter=time_filter_request)
# 使用独立时间参数
count = await User.count(
session,
created_after_datetime=datetime(2025, 1, 1),
created_before_datetime=datetime(2025, 2, 1),
)
"""
# time_filter 的时间筛选优先级更高
if isinstance(time_filter, TimeFilterRequest):
if time_filter.created_after_datetime is not None:
created_after_datetime = time_filter.created_after_datetime
if time_filter.created_before_datetime is not None:
created_before_datetime = time_filter.created_before_datetime
if time_filter.updated_after_datetime is not None:
updated_after_datetime = time_filter.updated_after_datetime
if time_filter.updated_before_datetime is not None:
updated_before_datetime = time_filter.updated_before_datetime
statement = select(func.count()).select_from(cls)
# 应用查询条件
if condition is not None:
statement = statement.where(condition)
# 应用时间筛选
for time_condition in cls._build_time_filters(
created_before_datetime, created_after_datetime,
updated_before_datetime, updated_after_datetime
):
statement = statement.where(time_condition)
result = await session.scalar(statement)
return result or 0
@classmethod
async def get_with_count(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | list[RelationshipInfo] | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
table_view: TableViewRequest | None = None,
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
) -> 'ListResponse[T]':
"""
获取分页列表及总数,直接返回 ListResponse
同时返回符合条件的记录列表和总数,用于分页场景。
与 get() 方法类似,但固定 fetch_mode="all" 并返回 ListResponse。
注意:如果子类的 get() 方法支持额外参数(如 filter_params
子类应该覆盖此方法以确保 count 和 items 使用相同的过滤条件。
Args:
session: 数据库会话
condition: 查询条件
join: JOIN 的模型类或元组
options: SQLAlchemy 查询选项
load: selectinload 预加载关系
order_by: 排序子句
filter: 附加过滤条件
table_view: 分页排序参数(推荐使用)
jti_subclasses: 多态子类加载选项
Returns:
ListResponse[T]: 包含 count 和 items 的分页响应
Examples:
```python
@router.get("", response_model=ListResponse[CharacterInfoResponse])
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> ListResponse[Character]:
return await Character.get_with_count(session, table_view=table_view)
```
"""
# 提取时间筛选参数(用于 count
time_filter: TimeFilterRequest | None = None
if table_view is not None:
time_filter = TimeFilterRequest(
created_after_datetime=table_view.created_after_datetime,
created_before_datetime=table_view.created_before_datetime,
updated_after_datetime=table_view.updated_after_datetime,
updated_before_datetime=table_view.updated_before_datetime,
)
# 获取总数(不带分页限制)
total_count = await cls.count(session, condition, time_filter=time_filter)
# 获取分页数据
items = await cls.get(
session,
condition,
fetch_mode="all",
join=join,
options=options,
load=load,
order_by=order_by,
filter=filter,
table_view=table_view,
jti_subclasses=jti_subclasses,
)
return ListResponse(count=total_count, items=items)
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
"""
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
这个方法是对 `get` 方法的封装,专门用于处理那种"记录必须存在"的业务场景。
如果记录未找到,它会直接引发 FastAPI 的 `HTTPException`, 而不是返回 `None`.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (int): 要查找的记录的主键 ID.
load (Relationship | None): 可选的,用于预加载的关联属性.
Returns:
T: 找到的模型实例.
Raises:
HTTPException: 如果 ID 对应的记录不存在,则抛出状态码为 404 的异常.
"""
instance = await cls.get(session, cls.id == id, load=load)
if not instance:
raise HTTPException(status_code=404, detail="Not found")
return instance
class UUIDTableBaseMixin(TableBaseMixin):
"""
一个使用 UUID 作为主键的异步 CRUD 操作基础模型类 Mixin.
此类继承自 `TableBaseMixin`, 将主键 `id` 的类型覆盖为 `uuid.UUID`
并为新记录自动生成 UUID. 它继承了 `TableBaseMixin` 的所有 CRUD 方法.
Attributes:
id (uuid.UUID): UUID 类型的主键, 在创建时自动生成.
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
"""覆盖 `TableBaseMixin` 的 id 字段,使用 UUID 作为主键."""
@override
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
"""
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
此方法覆盖了父类的同名方法,以确保 `id` 参数的类型注解为 `uuid.UUID`,
从而提供更好的类型安全和代码提示.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (uuid.UUID): 要查找的记录的 UUID 主键.
load (Relationship | None): 可选的,用于预加载的关联属性.
Returns:
T: 找到的模型实例.
Raises:
HTTPException: 如果 UUID 对应的记录不存在,则抛出状态码为 404 的异常.
"""
# 类型检查器可能会警告这里的 `id` 类型不匹配超类方法,
# 但在运行时这是正确的,因为超类方法内部的比较 (cls.id == id)
# 会正确处理 UUID 类型。`type: ignore` 用于抑制此警告。
return await super().get_exist_one(session, id, load) # type: ignore