feat(mixin): add TableBaseMixin and UUIDTableBaseMixin for async CRUD operations
- Implemented TableBaseMixin providing generic CRUD methods and automatic timestamp management. - Introduced UUIDTableBaseMixin for models using UUID as primary keys. - Added ListResponse for standardized paginated responses. - Created TimeFilterRequest and PaginationRequest for filtering and pagination parameters. - Enhanced get_with_count method to return both item list and total count. - Included validation for time filter parameters in TimeFilterRequest. - Improved documentation and usage examples throughout the code.
This commit is contained in:
456
models/mixin/polymorphic.py
Normal file
456
models/mixin/polymorphic.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import String, inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的default_factory丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
|
||||
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory,
|
||||
遵循单一真相原则(不硬编码 default_factory)。
|
||||
|
||||
需要修复的字段:
|
||||
- id: 主键(从父类获取 default_factory)
|
||||
- created_at: 创建时间戳(从父类获取 default_factory)
|
||||
- updated_at: 更新时间戳(从父类获取 default_factory)
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 污染的
|
||||
if not isinstance(field_info.default, InstrumentedAttribute):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
# 动态检测所有需要修复的字段
|
||||
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
|
||||
# 1. default 是 InstrumentedAttribute(被 SQLAlchemy 污染)
|
||||
# 2. 原始定义有 default_factory 或明确的 default 值
|
||||
#
|
||||
# 覆盖场景:
|
||||
# - UUID主键(UUIDTableBaseMixin):id 有 default_factory=uuid.uuid4,需要修复
|
||||
# - int主键(TableBaseMixin):id 用 default=None,不需要修复(数据库自增)
|
||||
# - created_at/updated_at:有 default_factory=now,需要修复
|
||||
# - 外键字段(created_by_id等):有 default=None,需要修复
|
||||
# - 普通字段(name, temperature等):无 default_factory,不需要修复
|
||||
#
|
||||
# MRO 查找保证:
|
||||
# - 在多重继承场景下,MRO 顺序是确定性的
|
||||
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute)
|
||||
if not isinstance(current_field.default, InstrumentedAttribute):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制SQLModel特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
Example:
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
Reference in New Issue
Block a user