- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
711 lines
28 KiB
Python
711 lines
28 KiB
Python
"""
|
||
联表继承(Joined Table Inheritance)的通用工具
|
||
|
||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||
|
||
Usage Example:
|
||
|
||
from sqlmodels.base import SQLModelBase
|
||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||
from sqlmodels.mixin.polymorphic import (
|
||
PolymorphicBaseMixin,
|
||
create_subclass_id_mixin,
|
||
AutoPolymorphicIdentityMixin
|
||
)
|
||
|
||
# 1. 定义Base类(只有字段,无表)
|
||
class ASRBase(SQLModelBase):
|
||
name: str
|
||
\"\"\"配置名称\"\"\"
|
||
|
||
base_url: str
|
||
\"\"\"服务地址\"\"\"
|
||
|
||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||
class ASR(
|
||
ASRBase,
|
||
UUIDTableBaseMixin,
|
||
PolymorphicBaseMixin,
|
||
ABC
|
||
):
|
||
\"\"\"ASR配置的抽象基类\"\"\"
|
||
# PolymorphicBaseMixin 自动提供:
|
||
# - _polymorphic_name 字段
|
||
# - polymorphic_on='_polymorphic_name'
|
||
# - polymorphic_abstract=True(当有抽象方法时)
|
||
|
||
# 3. 为第二层子类创建ID Mixin
|
||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||
|
||
# 4. 创建第二层抽象类(如果需要)
|
||
class FunASR(
|
||
ASRSubclassIdMixin,
|
||
ASR,
|
||
AutoPolymorphicIdentityMixin,
|
||
polymorphic_abstract=True
|
||
):
|
||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||
pass
|
||
|
||
# 5. 创建具体实现类
|
||
class FunASRLocal(FunASR, table=True):
|
||
\"\"\"FunASR本地部署版本\"\"\"
|
||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||
pass
|
||
|
||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||
concrete_asrs = ASR.get_concrete_subclasses()
|
||
# 返回 [FunASRLocal, ...]
|
||
"""
|
||
import uuid
|
||
from abc import ABC
|
||
from uuid import UUID
|
||
|
||
from loguru import logger as l
|
||
from pydantic.fields import FieldInfo
|
||
from pydantic_core import PydanticUndefined
|
||
from sqlalchemy import Column, String, inspect
|
||
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
|
||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||
from sqlmodel import Field
|
||
from sqlmodel.main import get_column_from_field
|
||
|
||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||
|
||
# 用于延迟注册 STI 子类列的队列
|
||
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
|
||
_sti_subclasses_to_register: list[type] = []
|
||
|
||
|
||
def register_sti_columns_for_all_subclasses() -> None:
|
||
"""
|
||
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
|
||
|
||
此函数应在 configure_mappers() 之前调用。
|
||
将 STI 子类的字段添加到父表的 metadata 中。
|
||
同时修复被 Column 对象污染的 model_fields。
|
||
"""
|
||
for cls in _sti_subclasses_to_register:
|
||
try:
|
||
cls._register_sti_columns()
|
||
except Exception as e:
|
||
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
|
||
|
||
# 修复被 Column 对象污染的 model_fields
|
||
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
|
||
try:
|
||
_fix_polluted_model_fields(cls)
|
||
except Exception as e:
|
||
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
|
||
|
||
|
||
def register_sti_column_properties_for_all_subclasses() -> None:
|
||
"""
|
||
为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
|
||
|
||
此函数应在 configure_mappers() 之后调用。
|
||
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
|
||
"""
|
||
for cls in _sti_subclasses_to_register:
|
||
try:
|
||
cls._register_sti_column_properties()
|
||
except Exception as e:
|
||
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
|
||
|
||
# 清空队列
|
||
_sti_subclasses_to_register.clear()
|
||
|
||
|
||
def _fix_polluted_model_fields(cls: type) -> None:
|
||
"""
|
||
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
|
||
|
||
当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
|
||
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
|
||
时错误地使用这些 SQLAlchemy 对象作为默认值。
|
||
|
||
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
|
||
|
||
:param cls: 要修复的类
|
||
"""
|
||
if not hasattr(cls, 'model_fields'):
|
||
return
|
||
|
||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||
"""从 MRO 中查找字段的原始定义(未被污染的)"""
|
||
for base in cls.__mro__[1:]: # 跳过自己
|
||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||
field_info = base.model_fields[field_name]
|
||
# 跳过被 InstrumentedAttribute 或 Column 污染的
|
||
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
|
||
return field_info
|
||
return None
|
||
|
||
for field_name, current_field in cls.model_fields.items():
|
||
# 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
|
||
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
|
||
# 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
|
||
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
|
||
continue # 未被污染,跳过
|
||
|
||
# 从父类查找原始定义
|
||
original = find_original_field_info(field_name)
|
||
if original is None:
|
||
continue # 找不到原始定义,跳过
|
||
|
||
# 根据原始定义的 default/default_factory 来修复
|
||
if original.default_factory:
|
||
# 有 default_factory(如 uuid.uuid4, now)
|
||
new_field = FieldInfo(
|
||
default_factory=original.default_factory,
|
||
annotation=current_field.annotation,
|
||
json_schema_extra=current_field.json_schema_extra,
|
||
)
|
||
elif original.default is not PydanticUndefined:
|
||
# 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
|
||
# PydanticUndefined 表示字段没有默认值(必填)
|
||
new_field = FieldInfo(
|
||
default=original.default,
|
||
annotation=current_field.annotation,
|
||
json_schema_extra=current_field.json_schema_extra,
|
||
)
|
||
else:
|
||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||
|
||
# 复制 SQLModel 特有的属性
|
||
if hasattr(current_field, 'foreign_key'):
|
||
new_field.foreign_key = current_field.foreign_key
|
||
if hasattr(current_field, 'primary_key'):
|
||
new_field.primary_key = current_field.primary_key
|
||
|
||
cls.model_fields[field_name] = new_field
|
||
|
||
|
||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||
"""
|
||
动态创建SubclassIdMixin类
|
||
|
||
在联表继承中,子类需要一个外键指向父表的主键。
|
||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||
|
||
Args:
|
||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||
|
||
Returns:
|
||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||
|
||
Example:
|
||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||
... pass
|
||
|
||
Note:
|
||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||
"""
|
||
if not parent_table_name:
|
||
raise ValueError("parent_table_name 不能为空")
|
||
|
||
# 转换为PascalCase作为类名
|
||
class_name_parts = parent_table_name.split('_')
|
||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||
|
||
# 使用闭包捕获parent_table_name
|
||
_parent_table_name = parent_table_name
|
||
|
||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||
class SubclassIdMixin(SQLModelBase):
|
||
# 定义id字段
|
||
id: UUID = Field(
|
||
default_factory=uuid.uuid4,
|
||
foreign_key=f'{_parent_table_name}.id',
|
||
primary_key=True,
|
||
)
|
||
|
||
@classmethod
|
||
def __pydantic_init_subclass__(cls, **kwargs):
|
||
"""
|
||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||
|
||
修复联表继承中子类字段的 default_factory 丢失问题。
|
||
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
|
||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||
"""
|
||
super().__pydantic_init_subclass__(**kwargs)
|
||
_fix_polluted_model_fields(cls)
|
||
|
||
# 设置类名和文档
|
||
SubclassIdMixin.__name__ = class_name
|
||
SubclassIdMixin.__qualname__ = class_name
|
||
SubclassIdMixin.__doc__ = f"""
|
||
{parent_table_name}子类的ID Mixin
|
||
|
||
用于{parent_table_name}的子类,提供外键指向父表。
|
||
通过MRO确保此id字段覆盖继承的id字段。
|
||
"""
|
||
|
||
return SubclassIdMixin
|
||
|
||
|
||
class AutoPolymorphicIdentityMixin:
|
||
"""
|
||
自动生成polymorphic_identity的Mixin,并支持STI子类列注册
|
||
|
||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||
|
||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||
|
||
**重要:数据库迁移注意事项**
|
||
|
||
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
|
||
|
||
例如,对于以下继承链::
|
||
|
||
LLM (polymorphic_on='_polymorphic_name')
|
||
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
|
||
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
|
||
|
||
迁移脚本中设置 _polymorphic_name 时::
|
||
|
||
# ❌ 错误:缺少父类前缀
|
||
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
|
||
|
||
# ✅ 正确:包含完整的继承链前缀
|
||
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
|
||
|
||
**STI(单表继承)支持**:
|
||
当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
|
||
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
|
||
注册到父表的问题。
|
||
|
||
Example (JTI):
|
||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||
... __polymorphic_name: str
|
||
...
|
||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||
... pass
|
||
... # polymorphic_identity 会自动设置为 'function'
|
||
...
|
||
>>> class CodeInterpreterFunction(Function, table=True):
|
||
... pass
|
||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||
|
||
Example (STI):
|
||
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
|
||
... user_id: UUID
|
||
...
|
||
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
|
||
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
|
||
... # polymorphic_identity 会自动设置为 'pendingfile'
|
||
|
||
Note:
|
||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||
- STI模式下,新字段会在类定义时自动添加到父表的metadata中
|
||
"""
|
||
|
||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||
"""
|
||
子类化钩子,自动生成polymorphic_identity并处理STI列注册
|
||
|
||
Args:
|
||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||
"""
|
||
super().__init_subclass__(**kwargs)
|
||
|
||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||
if polymorphic_identity is not None:
|
||
identity = polymorphic_identity
|
||
else:
|
||
# 自动生成polymorphic_identity
|
||
class_name = cls.__name__.lower()
|
||
|
||
# 尝试从父类获取polymorphic_identity作为前缀
|
||
parent_identity = None
|
||
for base in cls.__mro__[1:]: # 跳过自己
|
||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||
if parent_identity:
|
||
break
|
||
|
||
# 构建identity
|
||
if parent_identity:
|
||
identity = f'{parent_identity}.{class_name}'
|
||
else:
|
||
identity = class_name
|
||
|
||
# 设置到__mapper_args__
|
||
if '__mapper_args__' not in cls.__dict__:
|
||
cls.__mapper_args__ = {}
|
||
|
||
# 只在尚未设置polymorphic_identity时设置
|
||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||
|
||
# 注册 STI 子类列的延迟执行
|
||
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
|
||
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
|
||
_sti_subclasses_to_register.append(cls)
|
||
|
||
@classmethod
|
||
def __pydantic_init_subclass__(cls, **kwargs):
|
||
"""
|
||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||
|
||
修复 STI 继承中子类字段被 Column 对象污染的问题。
|
||
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
|
||
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
|
||
"""
|
||
super().__pydantic_init_subclass__(**kwargs)
|
||
_fix_polluted_model_fields(cls)
|
||
|
||
@classmethod
|
||
def _register_sti_columns(cls) -> None:
|
||
"""
|
||
将STI子类的新字段注册到父表的列定义中
|
||
|
||
检测当前类是否是STI子类(与父类共用同一张表),
|
||
如果是,则将子类定义的新字段添加到父表的metadata中。
|
||
|
||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||
"""
|
||
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
|
||
parent_table = None
|
||
parent_fields: set[str] = set()
|
||
|
||
for base in cls.__mro__[1:]:
|
||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||
parent_table = base.__table__
|
||
# 收集父类的所有字段名
|
||
if hasattr(base, 'model_fields'):
|
||
parent_fields.update(base.model_fields.keys())
|
||
break
|
||
|
||
if parent_table is None:
|
||
return # 没有找到父表,可能是根类
|
||
|
||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||
# JTI 类有自己独立的表,不需要将列注册到父表
|
||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||
if cls.__table__.name != parent_table.name:
|
||
return # JTI,跳过 STI 列注册
|
||
|
||
# 获取当前类的新字段(不在父类中的字段)
|
||
if not hasattr(cls, 'model_fields'):
|
||
return
|
||
|
||
existing_columns = {col.name for col in parent_table.columns}
|
||
|
||
for field_name, field_info in cls.model_fields.items():
|
||
# 跳过从父类继承的字段
|
||
if field_name in parent_fields:
|
||
continue
|
||
|
||
# 跳过私有字段和ClassVar
|
||
if field_name.startswith('_'):
|
||
continue
|
||
|
||
# 跳过已存在的列
|
||
if field_name in existing_columns:
|
||
continue
|
||
|
||
# 使用 SQLModel 的内置 API 创建列
|
||
try:
|
||
column = get_column_from_field(field_info)
|
||
column.name = field_name
|
||
column.key = field_name
|
||
# STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
|
||
# Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
|
||
column.nullable = True
|
||
|
||
# 将列添加到父表
|
||
parent_table.append_column(column)
|
||
except Exception as e:
|
||
l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
|
||
|
||
@classmethod
|
||
def _register_sti_column_properties(cls) -> None:
|
||
"""
|
||
将 STI 子类的列作为 ColumnProperty 添加到 mapper
|
||
|
||
此方法在 configure_mappers() 之后调用,将已添加到表中的列
|
||
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
|
||
|
||
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
|
||
这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
|
||
避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
|
||
|
||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||
"""
|
||
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
|
||
parent_table = None
|
||
parent_class = None
|
||
for base in cls.__mro__[1:]:
|
||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||
parent_table = base.__table__
|
||
parent_class = base
|
||
break
|
||
|
||
if parent_table is None:
|
||
return # 没有找到父表,可能是根类
|
||
|
||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
|
||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||
if cls.__table__.name != parent_table.name:
|
||
return # JTI,跳过 STI 列属性注册
|
||
|
||
# 获取子类和父类的 mapper
|
||
child_mapper = inspect(cls).mapper
|
||
parent_mapper = inspect(parent_class).mapper
|
||
local_table = child_mapper.local_table
|
||
|
||
# 查找父类的所有字段名
|
||
parent_fields: set[str] = set()
|
||
if hasattr(parent_class, 'model_fields'):
|
||
parent_fields.update(parent_class.model_fields.keys())
|
||
|
||
if not hasattr(cls, 'model_fields'):
|
||
return
|
||
|
||
# 获取两个 mapper 已有的列属性
|
||
child_existing_props = {p.key for p in child_mapper.column_attrs}
|
||
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
|
||
|
||
for field_name in cls.model_fields:
|
||
# 跳过从父类继承的字段
|
||
if field_name in parent_fields:
|
||
continue
|
||
|
||
# 跳过私有字段
|
||
if field_name.startswith('_'):
|
||
continue
|
||
|
||
# 检查表中是否有这个列
|
||
if field_name not in local_table.columns:
|
||
continue
|
||
|
||
column = local_table.columns[field_name]
|
||
|
||
# 添加到子类的 mapper(如果尚不存在)
|
||
if field_name not in child_existing_props:
|
||
try:
|
||
prop = ColumnProperty(column)
|
||
child_mapper.add_property(field_name, prop)
|
||
except Exception as e:
|
||
l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
|
||
|
||
# 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
|
||
if field_name not in parent_existing_props:
|
||
try:
|
||
prop = ColumnProperty(column)
|
||
parent_mapper.add_property(field_name, prop)
|
||
except Exception as e:
|
||
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
|
||
|
||
|
||
class PolymorphicBaseMixin:
|
||
"""
|
||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||
|
||
此 Mixin 自动设置以下内容:
|
||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||
|
||
使用场景:
|
||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||
|
||
用法示例:
|
||
```python
|
||
from abc import ABC
|
||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||
|
||
# 定义基类
|
||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||
__tablename__ = 'mytool'
|
||
|
||
# 不需要手动定义 _polymorphic_name
|
||
# 不需要手动设置 polymorphic_on
|
||
# 不需要手动设置 polymorphic_abstract
|
||
|
||
# 定义子类
|
||
class SpecificTool(MyTool):
|
||
__tablename__ = 'specifictool'
|
||
|
||
# 会自动继承 polymorphic 配置
|
||
```
|
||
|
||
自动行为:
|
||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||
3. 自动检测抽象类:
|
||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||
- 否则设置为 False
|
||
|
||
手动覆盖:
|
||
可以在类定义时手动指定参数来覆盖自动行为:
|
||
```python
|
||
class MyTool(
|
||
UUIDTableBaseMixin,
|
||
PolymorphicBaseMixin,
|
||
ABC,
|
||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||
polymorphic_abstract=False # 强制不设为抽象类
|
||
):
|
||
pass
|
||
```
|
||
|
||
注意事项:
|
||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||
- 适用于联表继承(joined table inheritance)场景
|
||
- 子类会自动继承 _polymorphic_name 字段定义
|
||
- 使用单下划线前缀是因为:
|
||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||
* Pydantic 将其视为私有属性,不参与序列化
|
||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||
"""
|
||
|
||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||
#
|
||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||
#
|
||
# 为什么这样做:
|
||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||
#
|
||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||
"""
|
||
多态鉴别器字段,用于标识具体的子类类型
|
||
|
||
注意:此字段使用单下划线前缀,表示内部使用。
|
||
- ✅ 存储到数据库
|
||
- ✅ 不出现在 API 序列化中
|
||
- ✅ 防止外部直接修改
|
||
"""
|
||
|
||
def __init_subclass__(
|
||
cls,
|
||
polymorphic_on: str | None = None,
|
||
polymorphic_abstract: bool | None = None,
|
||
**kwargs
|
||
):
|
||
"""
|
||
在子类定义时自动配置 polymorphic 设置
|
||
|
||
Args:
|
||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||
polymorphic_abstract: 是否为抽象类。
|
||
- None: 自动检测(默认)
|
||
- True: 强制设为抽象类
|
||
- False: 强制设为非抽象类
|
||
**kwargs: 传递给父类的其他参数
|
||
"""
|
||
super().__init_subclass__(**kwargs)
|
||
|
||
# 初始化 __mapper_args__(如果还没有)
|
||
if '__mapper_args__' not in cls.__dict__:
|
||
cls.__mapper_args__ = {}
|
||
|
||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||
|
||
# 自动检测或设置 polymorphic_abstract
|
||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||
if polymorphic_abstract is None:
|
||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||
has_abc = ABC in cls.__mro__
|
||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||
polymorphic_abstract = has_abc and has_abstract_methods
|
||
|
||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||
|
||
@classmethod
|
||
def _is_joined_table_inheritance(cls) -> bool:
|
||
"""
|
||
检测当前类是否使用联表继承(Joined Table Inheritance)
|
||
|
||
通过检查子类是否有独立的表来判断:
|
||
- JTI: 子类有独立的 local_table(与父类不同)
|
||
- STI: 子类与父类共用同一个 local_table
|
||
|
||
:return: True 表示 JTI,False 表示 STI 或无子类
|
||
"""
|
||
mapper = inspect(cls)
|
||
base_table_name = mapper.local_table.name
|
||
|
||
# 检查所有直接子类
|
||
for subclass in cls.__subclasses__():
|
||
sub_mapper = inspect(subclass)
|
||
# 如果任何子类有不同的表名,说明是 JTI
|
||
if sub_mapper.local_table.name != base_table_name:
|
||
return True
|
||
|
||
return False
|
||
|
||
@classmethod
|
||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||
"""
|
||
递归获取当前类的所有具体(非抽象)子类
|
||
|
||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||
|
||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||
"""
|
||
result: list[type[PolymorphicBaseMixin]] = []
|
||
for subclass in cls.__subclasses__():
|
||
# 使用 inspect() 获取 mapper 的公开属性
|
||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||
mapper = inspect(subclass)
|
||
if not mapper.polymorphic_abstract:
|
||
result.append(subclass)
|
||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||
result.extend(subclass.get_concrete_subclasses())
|
||
return result
|
||
|
||
@classmethod
|
||
def get_polymorphic_discriminator(cls) -> str:
|
||
"""
|
||
获取多态鉴别字段名
|
||
|
||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||
|
||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||
:raises ValueError: 如果类未配置 polymorphic_on
|
||
"""
|
||
polymorphic_on = inspect(cls).polymorphic_on
|
||
if polymorphic_on is None:
|
||
raise ValueError(
|
||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||
f"请确保正确继承 PolymorphicBaseMixin"
|
||
)
|
||
return polymorphic_on.key
|
||
|
||
@classmethod
|
||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||
"""
|
||
获取 polymorphic_identity 到具体子类的映射
|
||
|
||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||
|
||
:return: identity 到子类的映射字典
|
||
"""
|
||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||
for subclass in cls.get_concrete_subclasses():
|
||
identity = inspect(subclass).polymorphic_identity
|
||
if identity:
|
||
result[identity] = subclass
|
||
return result
|