feat: add models for physical files, policies, and user management

- Implement PhysicalFile model to manage physical file references and reference counting.
- Create Policy model with associated options and group links for storage policies.
- Introduce Redeem and Report models for handling redeem codes and reports.
- Add Settings model for site configuration and user settings management.
- Develop Share model for sharing objects with unique codes and associated metadata.
- Implement SourceLink model for managing download links associated with objects.
- Create StoragePack model for managing user storage packages.
- Add Tag model for user-defined tags with manual and automatic types.
- Implement Task model for managing background tasks with status tracking.
- Develop User model with comprehensive user management features including authentication.
- Introduce UserAuthn model for managing WebAuthn credentials.
- Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
2026-02-10 16:25:49 +08:00
parent 62c671e07b
commit 209cb24ab4
92 changed files with 3640 additions and 1444 deletions

543
sqlmodels/mixin/README.md Normal file
View File

@@ -0,0 +1,543 @@
# SQLModel Mixin Module
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
## Module Overview
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
- **JWT Authentication**: Token generation and validation
- **Pagination & Sorting**: Standardized table view parameters
- **Response DTOs**: Consistent id/timestamp fields for API responses
## Module Structure
```
sqlmodels/mixin/
├── __init__.py # Module exports
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
└── jwt/ # JWT authentication
├── __init__.py
├── key.py # JWTKey database model
├── payload.py # JWTPayloadBase
├── manager.py # JWTManager singleton
├── auth.py # JWTAuthMixin
├── exceptions.py # JWT-related exceptions
└── responses.py # TokenResponse DTO
```
## Dependency Hierarchy
The module has a strict import order to avoid circular dependencies:
1. **polymorphic.py** - Only depends on `SQLModelBase`
2. **table.py** - Depends on `polymorphic.py`
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
4. **info_response.py** - Only depends on `SQLModelBase`
## Core Components
### 1. TableBaseMixin
Base mixin for database table models with integer primary keys.
**Features:**
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
- Automatic timestamp management (`created_at`, `updated_at`)
- Async relationship loading support (via `AsyncAttrs`)
- Pagination and sorting via `TableViewRequest`
- Polymorphic subclass loading support
**Fields:**
- `id: int | None` - Integer primary key (auto-increment)
- `created_at: datetime` - Record creation timestamp
- `updated_at: datetime` - Record update timestamp (auto-updated)
**Usage:**
```python
from sqlmodels.mixin import TableBaseMixin
from sqlmodels.base import SQLModelBase
class User(SQLModelBase, TableBaseMixin, table=True):
name: str
email: str
"""User email"""
# CRUD operations
async def example(session: AsyncSession):
# Add
user = User(name="Alice", email="alice@example.com")
user = await user.save(session)
# Get
user = await User.get(session, User.id == 1)
# Update
update_data = UserUpdateRequest(name="Alice Smith")
user = await user.update(session, update_data)
# Delete
await User.delete(session, user)
# Count
count = await User.count(session, User.is_active == True)
```
**Important Notes:**
- `save()` and `update()` return refreshed instances - **always use the return value**:
```python
# ✅ Correct
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device
```
### 2. UUIDTableBaseMixin
Extends `TableBaseMixin` with UUID primary keys instead of integers.
**Differences from TableBaseMixin:**
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
- `get_exist_one()` accepts `UUID` instead of `int`
**Usage:**
```python
from sqlmodels.mixin import UUIDTableBaseMixin
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
name: str
description: str | None = None
"""Character description"""
```
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
### 3. TableViewRequest
Standardized pagination and sorting parameters for LIST endpoints.
**Fields:**
- `offset: int | None` - Skip first N records (default: 0)
- `limit: int | None` - Return max N records (default: 50, max: 100)
- `desc: bool | None` - Sort descending (default: True)
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
**Usage with TableBaseMixin.get():**
```python
from dependencies import TableViewRequestDep
@router.get("/list")
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> list[Character]:
"""List characters with pagination and sorting"""
return await Character.get(
session,
fetch_mode="all",
table_view=table_view # Automatically handles pagination and sorting
)
```
**Manual usage:**
```python
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
```
**Backward Compatibility:**
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
### 4. PolymorphicBaseMixin
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
**Automatic Configuration:**
- Defines `_polymorphic_name: str` field (indexed)
- Sets `polymorphic_on='_polymorphic_name'`
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
**Methods:**
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
**Usage:**
```python
from abc import ABC, abstractmethod
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
"""Abstract base class for all tools"""
name: str
description: str
"""Tool description"""
@abstractmethod
async def execute(self, params: dict) -> dict:
"""Execute the tool"""
pass
```
**Why Single Underscore Prefix?**
- SQLAlchemy maps single-underscore fields to database columns
- Pydantic treats them as private (excluded from serialization)
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
### 5. create_subclass_id_mixin()
Factory function to create ID mixins for subclasses in joined table inheritance.
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
**Signature:**
```python
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
"""
Args:
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
Returns:
A mixin class containing id field (foreign key + primary key)
"""
```
**Usage:**
```python
from sqlmodels.mixin import create_subclass_id_mixin
# Create mixin for ASR subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
"""FunASR implementation"""
pass
```
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
### 6. AutoPolymorphicIdentityMixin
Automatically generates `polymorphic_identity` based on class name.
**Naming Convention:**
- Format: `{parent_identity}.{classname_lowercase}`
- If no parent identity exists, uses `{classname_lowercase}`
**Usage:**
```python
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
"""Base class for function-type tools"""
pass
# polymorphic_identity = 'function'
class GetWeatherFunction(Function, table=True):
"""Weather query function"""
pass
# polymorphic_identity = 'function.getweatherfunction'
```
**Manual Override:**
```python
class CustomTool(
Tool,
AutoPolymorphicIdentityMixin,
polymorphic_identity='custom_name', # Override auto-generated name
table=True
):
pass
```
### 7. JWTAuthMixin
Provides JWT token generation and validation for entity classes (User, Client).
**Methods:**
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
**Requirements:**
Subclasses must define:
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
**Usage:**
```python
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
JWTPayload = UserJWTPayload # Define payload model
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
email: str
is_admin: bool = False
is_active: bool = True
"""User active status"""
# Generate token
async def login(session: AsyncSession, user: User) -> str:
token = await user.issue_jwt(session)
return token
# Validate token
async def verify(session: AsyncSession, token: str) -> User:
user = await User.from_jwt(session, token)
return user
```
### 8. Response DTO Mixins
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
**Available Mixins:**
- `IntIdInfoMixin` - Integer ID field
- `UUIDIdInfoMixin` - UUID ID field
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
**Usage:**
```python
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
"""Character response DTO with id and timestamps"""
pass # Inherits id, created_at, updated_at from mixin
```
## Complete Joined Table Inheritance Example
Here's a complete example demonstrating polymorphic inheritance:
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import (
UUIDTableBaseMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
)
# 1. Define Base class (fields only, no table)
class ASRBase(SQLModelBase):
name: str
"""Configuration name"""
base_url: str
"""Service URL"""
# 2. Define abstract parent class (with table)
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract base class for ASR configurations"""
# PolymorphicBaseMixin automatically provides:
# - _polymorphic_name field
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True (when ABC with abstract methods)
@abstractmethod
async def transcribe(self, pcm_data: bytes) -> str:
"""Transcribe audio to text"""
pass
# 3. Create ID Mixin for second-level subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. Create second-level abstract class (if needed)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
"""FunASR abstract base (may have multiple implementations)"""
pass
# polymorphic_identity = 'funasr'
# 5. Create concrete implementation classes
class FunASRLocal(FunASR, table=True):
"""FunASR local deployment"""
# polymorphic_identity = 'funasr.funasrlocal'
async def transcribe(self, pcm_data: bytes) -> str:
# Implementation...
return "transcribed text"
# 6. Get all concrete subclasses (for selectin_polymorphic)
concrete_asrs = ASR.get_concrete_subclasses()
# Returns: [FunASRLocal, ...]
```
## Import Guidelines
**Standard Import:**
```python
from sqlmodels.mixin import (
TableBaseMixin,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
TableViewRequest,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
JWTAuthMixin,
UUIDIdDatetimeInfoMixin,
now,
now_date,
)
```
**Backward Compatibility:**
Some exports are also available from `sqlmodels.base` for backward compatibility:
```python
# Legacy import path (still works)
from sqlmodels.base import UUIDTableBase, TableViewRequest
# Recommended new import path
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
```
## Best Practices
### 1. Mixin Order Matters
**Correct Order:**
```python
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
pass
```
**Wrong Order:**
```python
# ❌ ID Mixin not first - won't override parent's id field
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
pass
```
### 2. Always Use Return Values from save() and update()
```python
# ✅ Correct - use returned instance
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device # AttributeError when accessing fields
```
### 3. Prefer table_view Over Manual Pagination
```python
# ✅ Recommended - consistent across all endpoints
characters = await Character.get(
session,
fetch_mode="all",
table_view=table_view
)
# ⚠️ Works but not recommended - manual parameter management
characters = await Character.get(
session,
fetch_mode="all",
offset=0,
limit=20,
order_by=[desc(Character.created_at)]
)
```
### 4. Polymorphic Loading for Many Subclasses
```python
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
)
# For fewer subclasses, specify the list explicitly
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
)
```
### 5. Response DTOs Should Inherit Base Classes
```python
# ✅ Correct - inherits from CharacterBase
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
pass
# ❌ Wrong - doesn't inherit from CharacterBase
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
name: str # Duplicated field definition
description: str | None = None
```
**Reason:** Inheriting from Base classes ensures:
- Type checking via `isinstance(obj, XxxBase)`
- Consistency across related DTOs
- Future field additions automatically propagate
### 6. Use Specific Types, Not Containers
```python
# ✅ Correct - specific DTO for config updates
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
weather_api_key: str | None = None
default_location: str | None = None
"""Default location"""
# ❌ Wrong - lose type safety
class ToolUpdateRequest(SQLModelBase):
config: dict[str, Any] # No field validation
```
## Type Variables
```python
from sqlmodels.mixin import T, M
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
M = TypeVar("M", bound="SQLModel") # For update() method
```
## Utility Functions
```python
from sqlmodels.mixin import now, now_date
# Lambda functions for default factories
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
```
## Related Modules
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
- **sqlmodels.user** - User model with JWT authentication
- **sqlmodels.client** - Client model with JWT authentication
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
## Additional Resources
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
- `CLAUDE.md` - Project coding standards and design philosophy
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)

View File

@@ -0,0 +1,62 @@
"""
SQLModel Mixin模块
提供各种Mixin类供SQLModel实体使用。
包含:
- polymorphic: 联表继承工具create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin
- optimistic_lock: 乐观锁OptimisticLockMixin, OptimisticLockError
- table: 表基类TableBaseMixin, UUIDTableBaseMixin
- table: 查询参数类TimeFilterRequest, PaginationRequest, TableViewRequest
- relation_preload: 关系预加载RelationPreloadMixin, requires_relations
- jwt/: JWT认证相关JWTAuthMixin, JWTManager, JWTKey等- 需要时直接从 .jwt 导入
- info_response: InfoResponse DTO的id/时间戳Mixin
导入顺序很重要,避免循环导入:
1. polymorphic只依赖 SQLModelBase
2. optimistic_lock只依赖 SQLAlchemy
3. table依赖 polymorphic 和 optimistic_lock
4. relation_preload只依赖 SQLModelBase
注意jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
"""
# polymorphic 必须先导入
from .polymorphic import (
AutoPolymorphicIdentityMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
register_sti_column_properties_for_all_subclasses,
register_sti_columns_for_all_subclasses,
)
# optimistic_lock 只依赖 SQLAlchemy必须在 table 之前
from .optimistic_lock import (
OptimisticLockError,
OptimisticLockMixin,
)
# table 依赖 polymorphic 和 optimistic_lock
from .table import (
ListResponse,
PaginationRequest,
T,
TableBaseMixin,
TableViewRequest,
TimeFilterRequest,
UUIDTableBaseMixin,
now,
now_date,
)
# relation_preload 只依赖 SQLModelBase
from .relation_preload import (
RelationPreloadMixin,
requires_relations,
)
# jwt 不在此处导入避免循环jwt/manager.py → ServerConfig → mixin → jwt
# 需要时直接从 sqlmodels.mixin.jwt 导入
from .info_response import (
DatetimeInfoMixin,
IntIdDatetimeInfoMixin,
IntIdInfoMixin,
UUIDIdDatetimeInfoMixin,
UUIDIdInfoMixin,
)

View File

@@ -0,0 +1,46 @@
"""
InfoResponse DTO Mixin模块
提供用于InfoResponse类型DTO的Mixin统一定义id/created_at/updated_at字段。
设计说明:
- 这些Mixin用于**响应DTO**,不是数据库表
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
- TableBase中的id=None和default_factory=now是正确的入库前为None数据库生成
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
"""
from datetime import datetime
from uuid import UUID
from sqlmodels.base import SQLModelBase
class IntIdInfoMixin(SQLModelBase):
"""整数ID响应mixin - 用于InfoResponse DTO"""
id: int
"""记录ID"""
class UUIDIdInfoMixin(SQLModelBase):
"""UUID ID响应mixin - 用于InfoResponse DTO"""
id: UUID
"""记录ID"""
class DatetimeInfoMixin(SQLModelBase):
"""时间戳响应mixin - 用于InfoResponse DTO"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""更新时间"""
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
"""整数ID + 时间戳响应mixin"""
pass
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
"""UUID ID + 时间戳响应mixin"""
pass

View File

@@ -0,0 +1,90 @@
"""
乐观锁 Mixin
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
乐观锁适用场景:
- 涉及"状态转换"的表(如:待支付 -> 已支付)
- 涉及"数值变动"的表(如:余额、库存)
不适用场景:
- 日志表、纯插入表、低价值统计表
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
使用示例:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
status: OrderStatusEnum
amount: Decimal
# save/update 时自动检查版本号
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
try:
order = await order.save(session)
except OptimisticLockError as e:
# 处理冲突:重新查询并重试,或报错给用户
l.warning(f"乐观锁冲突: {e}")
"""
from typing import ClassVar
from sqlalchemy.orm.exc import StaleDataError
class OptimisticLockError(Exception):
"""
乐观锁冲突异常
当 save/update 操作检测到版本号不匹配时抛出。
这意味着在读取和写入之间,其他事务已经修改了该记录。
Attributes:
model_class: 发生冲突的模型类名
record_id: 记录 ID如果可用
expected_version: 期望的版本号(如果可用)
original_error: 原始的 StaleDataError
"""
def __init__(
self,
message: str,
model_class: str | None = None,
record_id: str | None = None,
expected_version: int | None = None,
original_error: StaleDataError | None = None,
):
super().__init__(message)
self.model_class = model_class
self.record_id = record_id
self.expected_version = expected_version
self.original_error = original_error
class OptimisticLockMixin:
"""
乐观锁 Mixin
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
session.commit() 会抛出 StaleDataError被 save/update 方法捕获并转换为 OptimisticLockError。
原理:
1. 每条记录有一个 version 字段,初始值为 0
2. 每次 UPDATE 时SQLAlchemy 生成的 SQL 类似:
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
3. 如果 WHERE 条件不匹配version 已被其他事务修改),
UPDATE 影响 0 行SQLAlchemy 抛出 StaleDataError
继承顺序:
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
...
配套重试:
如果加了乐观锁,业务层需要处理 OptimisticLockError
- 报错给用户:"数据已被修改,请刷新后重试"
- 自动重试:重新查询最新数据并再次尝试
"""
_has_optimistic_lock: ClassVar[bool] = True
"""标记此类启用了乐观锁"""
version: int = 0
"""乐观锁版本号,每次更新自动递增"""

View File

@@ -0,0 +1,710 @@
"""
联表继承Joined Table Inheritance的通用工具
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
Usage Example:
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import (
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin
)
# 1. 定义Base类只有字段无表
class ASRBase(SQLModelBase):
name: str
\"\"\"配置名称\"\"\"
base_url: str
\"\"\"服务地址\"\"\"
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
class ASR(
ASRBase,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC
):
\"\"\"ASR配置的抽象基类\"\"\"
# PolymorphicBaseMixin 自动提供:
# - _polymorphic_name 字段
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True当有抽象方法时
# 3. 为第二层子类创建ID Mixin
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. 创建第二层抽象类(如果需要)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
\"\"\"FunASR的抽象基类可能有多个实现\"\"\"
pass
# 5. 创建具体实现类
class FunASRLocal(FunASR, table=True):
\"\"\"FunASR本地部署版本\"\"\"
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
pass
# 6. 获取所有具体子类(用于 selectin_polymorphic
concrete_asrs = ASR.get_concrete_subclasses()
# 返回 [FunASRLocal, ...]
"""
import uuid
from abc import ABC
from uuid import UUID
from loguru import logger as l
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import Column, String, inspect
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import Field
from sqlmodel.main import get_column_from_field
from sqlmodels.base.sqlmodel_base import SQLModelBase
# 用于延迟注册 STI 子类列的队列
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
_sti_subclasses_to_register: list[type] = []
def register_sti_columns_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
此函数应在 configure_mappers() 之前调用。
将 STI 子类的字段添加到父表的 metadata 中。
同时修复被 Column 对象污染的 model_fields。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_columns()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
# 修复被 Column 对象污染的 model_fields
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
try:
_fix_polluted_model_fields(cls)
except Exception as e:
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
def register_sti_column_properties_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类添加列属性到 mapper第二阶段
此函数应在 configure_mappers() 之后调用。
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_column_properties()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
# 清空队列
_sti_subclasses_to_register.clear()
def _fix_polluted_model_fields(cls: type) -> None:
"""
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
当 SQLModel 类继承有表的父类时SQLAlchemy 会在类上创建 InstrumentedAttribute
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
时错误地使用这些 SQLAlchemy 对象作为默认值。
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
:param cls: 要修复的类
"""
if not hasattr(cls, 'model_fields'):
return
def find_original_field_info(field_name: str) -> FieldInfo | None:
"""从 MRO 中查找字段的原始定义(未被污染的)"""
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, 'model_fields') and field_name in base.model_fields:
field_info = base.model_fields[field_name]
# 跳过被 InstrumentedAttribute 或 Column 污染的
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
return field_info
return None
for field_name, current_field in cls.model_fields.items():
# 检查是否被污染default 是 InstrumentedAttribute 或 Column
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
# 被继承到有表的子类时SQLModel 会创建一个 Column 对象替换原始默认值
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
continue # 未被污染,跳过
# 从父类查找原始定义
original = find_original_field_info(field_name)
if original is None:
continue # 找不到原始定义,跳过
# 根据原始定义的 default/default_factory 来修复
if original.default_factory:
# 有 default_factory如 uuid.uuid4, now
new_field = FieldInfo(
default_factory=original.default_factory,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
elif original.default is not PydanticUndefined:
# 有明确的 default 值(如 None, 0, True且不是 PydanticUndefined
# PydanticUndefined 表示字段没有默认值(必填)
new_field = FieldInfo(
default=original.default,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
else:
continue # 既没有 default_factory 也没有有效的 default跳过
# 复制 SQLModel 特有的属性
if hasattr(current_field, 'foreign_key'):
new_field.foreign_key = current_field.foreign_key
if hasattr(current_field, 'primary_key'):
new_field.primary_key = current_field.primary_key
cls.model_fields[field_name] = new_field
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
"""
动态创建SubclassIdMixin类
在联表继承中,子类需要一个外键指向父表的主键。
此函数生成一个Mixin类提供这个外键字段并自动生成UUID。
Args:
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function'
Returns:
一个Mixin类包含id字段外键 + 主键 + default_factory=uuid.uuid4
Example:
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
... pass
Note:
- 生成的Mixin应该放在继承列表的第一位确保通过MRO覆盖UUIDTableBaseMixin的id
- 生成的类名为 {ParentTableName}SubclassIdMixinPascalCase
- 本项目所有联表继承均使用UUID主键UUIDTableBaseMixin
"""
if not parent_table_name:
raise ValueError("parent_table_name 不能为空")
# 转换为PascalCase作为类名
class_name_parts = parent_table_name.split('_')
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
# 使用闭包捕获parent_table_name
_parent_table_name = parent_table_name
# 创建带有__init_subclass__的mixin类用于在子类定义后修复model_fields
class SubclassIdMixin(SQLModelBase):
# 定义id字段
id: UUID = Field(
default_factory=uuid.uuid4,
foreign_key=f'{_parent_table_name}.id',
primary_key=True,
)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复联表继承中子类字段的 default_factory 丢失问题。
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
# 设置类名和文档
SubclassIdMixin.__name__ = class_name
SubclassIdMixin.__qualname__ = class_name
SubclassIdMixin.__doc__ = f"""
{parent_table_name}子类的ID Mixin
用于{parent_table_name}的子类,提供外键指向父表。
通过MRO确保此id字段覆盖继承的id字段。
"""
return SubclassIdMixin
class AutoPolymorphicIdentityMixin:
"""
自动生成polymorphic_identity的Mixin并支持STI子类列注册
使用此Mixin的类会自动根据类名生成polymorphic_identity。
格式:{parent_polymorphic_identity}.{classname_lowercase}
如果没有父类的polymorphic_identity则直接使用类名小写。
**重要:数据库迁移注意事项**
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
例如,对于以下继承链::
LLM (polymorphic_on='_polymorphic_name')
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
迁移脚本中设置 _polymorphic_name 时::
# ❌ 错误:缺少父类前缀
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
# ✅ 正确:包含完整的继承链前缀
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
**STI单表继承支持**
当子类与父类共用同一张表STI模式此Mixin会自动将子类的新字段
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
注册到父表的问题。
Example (JTI):
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
... __polymorphic_name: str
...
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
... pass
... # polymorphic_identity 会自动设置为 'function'
...
>>> class CodeInterpreterFunction(Function, table=True):
... pass
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
Example (STI):
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
... user_id: UUID
...
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
... # polymorphic_identity 会自动设置为 'pendingfile'
Note:
- 如果手动在__mapper_args__中指定了polymorphic_identity会被保留
- 此Mixin应该在继承列表中靠后的位置在表基类之前
- STI模式下新字段会在类定义时自动添加到父表的metadata中
"""
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
"""
子类化钩子自动生成polymorphic_identity并处理STI列注册
Args:
polymorphic_identity: 如果手动指定,则使用指定的值
**kwargs: 其他SQLModel参数如table=True, polymorphic_abstract=True
"""
super().__init_subclass__(**kwargs)
# 如果手动指定了polymorphic_identity使用指定的值
if polymorphic_identity is not None:
identity = polymorphic_identity
else:
# 自动生成polymorphic_identity
class_name = cls.__name__.lower()
# 尝试从父类获取polymorphic_identity作为前缀
parent_identity = None
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
parent_identity = base.__mapper_args__.get('polymorphic_identity')
if parent_identity:
break
# 构建identity
if parent_identity:
identity = f'{parent_identity}.{class_name}'
else:
identity = class_name
# 设置到__mapper_args__
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 只在尚未设置polymorphic_identity时设置
if 'polymorphic_identity' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_identity'] = identity
# 注册 STI 子类列的延迟执行
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
_sti_subclasses_to_register.append(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复 STI 继承中子类字段被 Column 对象污染的问题。
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
@classmethod
def _register_sti_columns(cls) -> None:
"""
将STI子类的新字段注册到父表的列定义中
检测当前类是否是STI子类与父类共用同一张表
如果是则将子类定义的新字段添加到父表的metadata中。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_fields: set[str] = set()
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
# 收集父类的所有字段名
if hasattr(base, 'model_fields'):
parent_fields.update(base.model_fields.keys())
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列注册到父表
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列注册
# 获取当前类的新字段(不在父类中的字段)
if not hasattr(cls, 'model_fields'):
return
existing_columns = {col.name for col in parent_table.columns}
for field_name, field_info in cls.model_fields.items():
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段和ClassVar
if field_name.startswith('_'):
continue
# 跳过已存在的列
if field_name in existing_columns:
continue
# 使用 SQLModel 的内置 API 创建列
try:
column = get_column_from_field(field_info)
column.name = field_name
column.key = field_name
# STI子类字段在数据库层面必须可空因为其他子类的行不会有这些字段的值
# Pydantic层面的约束仍然有效创建特定子类时会验证必填字段
column.nullable = True
# 将列添加到父表
parent_table.append_column(column)
except Exception as e:
l.warning(f"{cls.__name__} 创建列 {field_name} 失败: {e}")
@classmethod
def _register_sti_column_properties(cls) -> None:
"""
将 STI 子类的列作为 ColumnProperty 添加到 mapper
此方法在 configure_mappers() 之后调用,将已添加到表中的列
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
这确保了查询父类时SELECT 语句包含所有 STI 子类的列,
避免在响应序列化时触发懒加载MissingGreenlet 错误)。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_class = None
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
parent_class = base
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列属性注册
# 获取子类和父类的 mapper
child_mapper = inspect(cls).mapper
parent_mapper = inspect(parent_class).mapper
local_table = child_mapper.local_table
# 查找父类的所有字段名
parent_fields: set[str] = set()
if hasattr(parent_class, 'model_fields'):
parent_fields.update(parent_class.model_fields.keys())
if not hasattr(cls, 'model_fields'):
return
# 获取两个 mapper 已有的列属性
child_existing_props = {p.key for p in child_mapper.column_attrs}
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
for field_name in cls.model_fields:
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段
if field_name.startswith('_'):
continue
# 检查表中是否有这个列
if field_name not in local_table.columns:
continue
column = local_table.columns[field_name]
# 添加到子类的 mapper如果尚不存在
if field_name not in child_existing_props:
try:
prop = ColumnProperty(column)
child_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"{cls.__name__} 添加列属性 {field_name} 失败: {e}")
# 同时添加到父类的 mapper确保查询父类时 SELECT 包含所有 STI 子类的列)
if field_name not in parent_existing_props:
try:
prop = ColumnProperty(column)
parent_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
class PolymorphicBaseMixin:
"""
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
此 Mixin 自动设置以下内容:
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
使用场景:
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
用法示例:
```python
from abc import ABC
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
# 定义基类
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
__tablename__ = 'mytool'
# 不需要手动定义 _polymorphic_name
# 不需要手动设置 polymorphic_on
# 不需要手动设置 polymorphic_abstract
# 定义子类
class SpecificTool(MyTool):
__tablename__ = 'specifictool'
# 会自动继承 polymorphic 配置
```
自动行为:
1. 定义 `_polymorphic_name: str` 字段(带索引)
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
3. 自动检测抽象类:
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- 否则设置为 False
手动覆盖:
可以在类定义时手动指定参数来覆盖自动行为:
```python
class MyTool(
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC,
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
polymorphic_abstract=False # 强制不设为抽象类
):
pass
```
注意事项:
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- 适用于联表继承joined table inheritance场景
- 子类会自动继承 _polymorphic_name 字段定义
- 使用单下划线前缀是因为:
* SQLAlchemy 会映射单下划线字段为数据库列
* Pydantic 将其视为私有属性,不参与序列化
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
"""
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
#
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
#
# 为什么这样做:
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
# 4. 字段不出现在 Pydantic 序列化中model_dump() 和 JSON schema
# 5. 内部代码仍然可以正常访问和修改此字段
#
# 详细说明请参考sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
"""
多态鉴别器字段,用于标识具体的子类类型
注意:此字段使用单下划线前缀,表示内部使用。
- ✅ 存储到数据库
- ✅ 不出现在 API 序列化中
- ✅ 防止外部直接修改
"""
def __init_subclass__(
cls,
polymorphic_on: str | None = None,
polymorphic_abstract: bool | None = None,
**kwargs
):
"""
在子类定义时自动配置 polymorphic 设置
Args:
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'
设置为其他值可以使用不同的字段作为多态鉴别器。
polymorphic_abstract: 是否为抽象类。
- None: 自动检测(默认)
- True: 强制设为抽象类
- False: 强制设为非抽象类
**kwargs: 传递给父类的其他参数
"""
super().__init_subclass__(**kwargs)
# 初始化 __mapper_args__如果还没有
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 设置 polymorphic_on默认为 _polymorphic_name
if 'polymorphic_on' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
# 自动检测或设置 polymorphic_abstract
if 'polymorphic_abstract' not in cls.__mapper_args__:
if polymorphic_abstract is None:
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
has_abc = ABC in cls.__mro__
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
polymorphic_abstract = has_abc and has_abstract_methods
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
@classmethod
def _is_joined_table_inheritance(cls) -> bool:
"""
检测当前类是否使用联表继承Joined Table Inheritance
通过检查子类是否有独立的表来判断:
- JTI: 子类有独立的 local_table与父类不同
- STI: 子类与父类共用同一个 local_table
:return: True 表示 JTIFalse 表示 STI 或无子类
"""
mapper = inspect(cls)
base_table_name = mapper.local_table.name
# 检查所有直接子类
for subclass in cls.__subclasses__():
sub_mapper = inspect(subclass)
# 如果任何子类有不同的表名,说明是 JTI
if sub_mapper.local_table.name != base_table_name:
return True
return False
@classmethod
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
"""
递归获取当前类的所有具体(非抽象)子类
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
可在任意多态基类上调用,返回该类的所有非抽象子类。
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
"""
result: list[type[PolymorphicBaseMixin]] = []
for subclass in cls.__subclasses__():
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
mapper = inspect(subclass)
if not mapper.polymorphic_abstract:
result.append(subclass)
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
if hasattr(subclass, 'get_concrete_subclasses'):
result.extend(subclass.get_concrete_subclasses())
return result
@classmethod
def get_polymorphic_discriminator(cls) -> str:
"""
获取多态鉴别字段名
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
:return: 多态鉴别字段名(如 '_polymorphic_name'
:raises ValueError: 如果类未配置 polymorphic_on
"""
polymorphic_on = inspect(cls).polymorphic_on
if polymorphic_on is None:
raise ValueError(
f"{cls.__name__} 未配置 polymorphic_on"
f"请确保正确继承 PolymorphicBaseMixin"
)
return polymorphic_on.key
@classmethod
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
"""
获取 polymorphic_identity 到具体子类的映射
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
:return: identity 到子类的映射字典
"""
result: dict[str, type[PolymorphicBaseMixin]] = {}
for subclass in cls.get_concrete_subclasses():
identity = inspect(subclass).polymorphic_identity
if identity:
result[identity] = subclass
return result

View File

@@ -0,0 +1,470 @@
"""
关系预加载 Mixin
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
设计原则:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 查询最优:相同关系只查询一次,不同关系增量查询
- 零侵入:调用方无需任何改动
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
使用方式:
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
kling_video_generator: KlingO1Generator = Relationship(...)
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# 自动加载,可以安全访问
price = self.kling_video_generator.kling_o1.pro_price_per_second
...
# 调用方 - 无需任何改动
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
await tool._call(...) # 关系相同则跳过,否则增量加载
支持 AsyncGenerator
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # 装饰器正确处理 async generator
"""
import inspect as python_inspect
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any
from loguru import logger as l
from sqlalchemy import inspect as sa_inspect
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
P = ParamSpec('P')
R = TypeVar('R')
def _extract_session(
func: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> AsyncSession | None:
"""
从方法参数中提取 AsyncSession
按以下顺序查找:
1. kwargs 中名为 'session' 的参数
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
3. kwargs 中类型为 AsyncSession 的参数
"""
# 1. 优先从 kwargs 查找
if 'session' in kwargs:
return kwargs['session']
# 2. 从函数签名定位位置参数
try:
sig = python_inspect.signature(func)
param_names = list(sig.parameters.keys())
if 'session' in param_names:
# 计算位置(减去 self
idx = param_names.index('session') - 1
if 0 <= idx < len(args):
return args[idx]
except (ValueError, TypeError):
pass
# 3. 遍历 kwargs 找 AsyncSession 类型
for value in kwargs.values():
if isinstance(value, AsyncSession):
return value
return None
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
"""
检查对象的关系是否已加载(独立函数版本)
Args:
obj: 要检查的对象
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(obj)
return rel_name not in state.unloaded
except Exception:
return False
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
"""
在类中查找指向目标类的关系属性名
Args:
from_class: 源类
to_class: 目标类
Returns:
关系属性名,如果找不到则返回 None
Example:
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
# 返回 'kling_video_generator'
"""
for attr_name in dir(from_class):
try:
attr = getattr(from_class, attr_name, None)
if attr is None:
continue
# 检查是否是 SQLAlchemy InstrumentedAttribute关系属性
# parent.class_ 是关系所在的类property.mapper.class_ 是关系指向的目标类
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
target_class = attr.property.mapper.class_
if target_class == to_class:
return attr_name
except AttributeError:
continue
return None
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
装饰器:声明方法需要的关系,自动按需增量加载
参数格式:
- 字符串:本类属性名,如 'kling_video_generator'
- RelationshipInfo外部类属性如 KlingO1Generator.kling_o1
行为:
- 方法调用时自动检查关系是否已加载
- 未加载的关系会被增量加载(单次查询)
- 已加载的关系直接跳过
支持:
- 普通 async 方法:`async def cost(...) -> ToolCost`
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
Example:
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# self.kling_video_generator.kling_o1 已自动加载
...
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # AsyncGenerator 正确处理
验证:
- 字符串格式的关系名在类创建时__init_subclass__验证
- 拼写错误会在导入时抛出 AttributeError
"""
def decorator(func: Callable[P, R]) -> Callable[P, R]:
# 检测是否是 async generator 函数
is_async_gen = python_inspect.isasyncgenfunction(func)
if is_async_gen:
# AsyncGenerator 需要特殊处理wrapper 也必须是 async generator
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
# 委托给原始 async generator逐个 yield 值
async for item in func(self, *args, **kwargs):
yield item # type: ignore
else:
# 普通 async 函数await 并返回结果
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
return await func(self, *args, **kwargs)
# 保存关系声明供验证和内省使用
wrapper._required_relations = relations # type: ignore
return wrapper
return decorator
class RelationPreloadMixin:
"""
关系预加载 Mixin
提供按需增量加载能力,确保 SQL 查询数理论最优。
特性:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 原地更新:直接修改 self无需替换实例
- 导入时验证:字符串关系名在类创建时验证
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
"""
def __init_subclass__(cls, **kwargs) -> None:
"""类创建时验证所有 @requires_relations 声明"""
super().__init_subclass__(**kwargs)
# 收集类及其父类的所有注解(包含普通字段)
all_annotations: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__annotations__'):
all_annotations.update(klass.__annotations__.keys())
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__
sqlmodel_relationships: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__sqlmodel_relationships__'):
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
# 合并所有可用的属性名
all_available_names = all_annotations | sqlmodel_relationships
for method_name in dir(cls):
if method_name.startswith('__'):
continue
try:
method = getattr(cls, method_name, None)
except AttributeError:
continue
if method is None or not hasattr(method, '_required_relations'):
continue
# 验证字符串格式的关系名
for spec in method._required_relations:
if isinstance(spec, str):
# 检查注解、Relationship 或已有属性
if spec not in all_available_names and not hasattr(cls, spec):
raise AttributeError(
f"{cls.__name__}.{method_name} 声明了关系 '{spec}'"
f"{cls.__name__} 没有此属性"
)
def _is_relation_loaded(self, rel_name: str) -> bool:
"""
检查关系是否真正已加载(基于 SQLAlchemy inspect
使用 SQLAlchemy 的 inspect 检测真实加载状态,
自动处理 commit 导致的 expire 问题。
Args:
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(self)
# unloaded 包含未加载的关系属性名
return rel_name not in state.unloaded
except Exception:
# 对象可能未被 SQLAlchemy 管理
return False
async def _ensure_relations_loaded(
self,
session: AsyncSession,
relations: tuple[str | RelationshipInfo, ...],
) -> None:
"""
确保指定关系已加载,只加载未加载的部分
基于 SQLAlchemy inspect 检测真实状态,自动处理:
- 首次访问的关系
- commit 后 expire 的关系
- 嵌套关系(如 KlingO1Generator.kling_o1
Args:
session: 数据库会话
relations: 需要的关系规格
"""
# 找出真正未加载的关系(基于 SQLAlchemy inspect
to_load: list[str | RelationshipInfo] = []
# 区分直接关系和嵌套关系的 key
direct_keys: set[str] = set() # 本类的直接关系属性名
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
for rel in relations:
if isinstance(rel, str):
# 直接关系:检查本类的关系是否已加载
if not self._is_relation_loaded(rel):
to_load.append(rel)
direct_keys.add(rel)
else:
# 嵌套关系InstrumentedAttribute如 KlingO1Generator.kling_o1
# 1. 查找指向父类的关系属性
parent_class = rel.parent.class_
parent_attr = _find_relation_to_class(self.__class__, parent_class)
if parent_attr is None:
# 找不到路径,可能是配置错误,但仍尝试加载
l.warning(
f"无法找到从 {self.__class__.__name__}{parent_class.__name__} 的关系路径,"
f"无法检查 {rel.key} 是否已加载"
)
to_load.append(rel)
continue
# 2. 检查父对象是否已加载
if not self._is_relation_loaded(parent_attr):
# 父对象未加载,需要同时加载父对象和嵌套关系
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
else:
# 3. 父对象已加载,检查嵌套关系是否已加载
parent_obj = getattr(self, parent_attr)
if not _is_obj_relation_loaded(parent_obj, rel.key):
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
# 因为 _build_load_chains 需要完整的链来构建 selectinload
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
if not to_load:
return # 全部已加载,跳过
# 构建 load 参数
load_options = self._specs_to_load_options(to_load)
if not load_options:
return
# 安全地获取主键值(避免触发懒加载)
state = sa_inspect(self)
pk_tuple = state.key[1] if state.key else None
if pk_tuple is None:
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
return
# 主键是元组,取第一个值(假设单列主键)
pk_value = pk_tuple[0]
# 单次查询加载缺失的关系
fresh = await self.__class__.get(
session,
self.__class__.id == pk_value,
load=load_options,
)
if fresh is None:
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
return
# 原地复制到 self只复制直接关系嵌套关系通过父关系自动可访问
all_direct_keys = direct_keys | nested_parent_keys
for key in all_direct_keys:
value = getattr(fresh, key, None)
object.__setattr__(self, key, value)
def _specs_to_load_options(
self,
specs: list[str | RelationshipInfo],
) -> list[RelationshipInfo]:
"""
将关系规格转换为 load 参数
- 字符串 → cls.{name}
- RelationshipInfo → 直接使用
"""
result: list[RelationshipInfo] = []
for spec in specs:
if isinstance(spec, str):
rel = getattr(self.__class__, spec, None)
if rel is not None:
result.append(rel)
else:
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
else:
result.append(spec)
return result
# ==================== 可选的手动预加载 API ====================
@classmethod
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
"""
获取指定方法声明的关系(用于外部预加载场景)
Args:
method_name: 方法名
Returns:
RelationshipInfo 列表
"""
method = getattr(cls, method_name, None)
if method is None or not hasattr(method, '_required_relations'):
return []
result: list[RelationshipInfo] = []
for spec in method._required_relations:
if isinstance(spec, str):
rel = getattr(cls, spec, None)
if rel:
result.append(rel)
else:
result.append(spec)
return result
@classmethod
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
"""
获取多个方法的关系并去重(用于批量预加载场景)
Args:
method_names: 方法名列表
Returns:
去重后的 RelationshipInfo 列表
"""
seen: set[str] = set()
result: list[RelationshipInfo] = []
for method_name in method_names:
for rel in cls.get_relations_for_method(method_name):
key = rel.key
if key not in seen:
seen.add(key)
result.append(rel)
return result
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
"""
手动预加载指定方法的关系(可选优化 API
当需要确保在调用方法前完成所有加载时使用。
通常情况下不需要调用此方法,装饰器会自动处理。
Args:
session: 数据库会话
method_names: 方法名列表
Returns:
self支持链式调用
Example:
# 可选:显式预加载(通常不需要)
tool = await tool.preload_for(session, 'cost', '_call')
"""
all_relations: list[str | RelationshipInfo] = []
for method_name in method_names:
method = getattr(self.__class__, method_name, None)
if method and hasattr(method, '_required_relations'):
all_relations.extend(method._required_relations)
if all_relations:
await self._ensure_relations_loaded(session, tuple(all_relations))
return self

1247
sqlmodels/mixin/table.py Normal file

File diff suppressed because it is too large Load Diff