feat: add models for physical files, policies, and user management
- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
543
sqlmodels/mixin/README.md
Normal file
543
sqlmodels/mixin/README.md
Normal file
@@ -0,0 +1,543 @@
|
||||
# SQLModel Mixin Module
|
||||
|
||||
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
|
||||
|
||||
## Module Overview
|
||||
|
||||
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
|
||||
|
||||
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
|
||||
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
|
||||
- **JWT Authentication**: Token generation and validation
|
||||
- **Pagination & Sorting**: Standardized table view parameters
|
||||
- **Response DTOs**: Consistent id/timestamp fields for API responses
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
sqlmodels/mixin/
|
||||
├── __init__.py # Module exports
|
||||
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
|
||||
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
|
||||
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
|
||||
└── jwt/ # JWT authentication
|
||||
├── __init__.py
|
||||
├── key.py # JWTKey database model
|
||||
├── payload.py # JWTPayloadBase
|
||||
├── manager.py # JWTManager singleton
|
||||
├── auth.py # JWTAuthMixin
|
||||
├── exceptions.py # JWT-related exceptions
|
||||
└── responses.py # TokenResponse DTO
|
||||
```
|
||||
|
||||
## Dependency Hierarchy
|
||||
|
||||
The module has a strict import order to avoid circular dependencies:
|
||||
|
||||
1. **polymorphic.py** - Only depends on `SQLModelBase`
|
||||
2. **table.py** - Depends on `polymorphic.py`
|
||||
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
|
||||
4. **info_response.py** - Only depends on `SQLModelBase`
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. TableBaseMixin
|
||||
|
||||
Base mixin for database table models with integer primary keys.
|
||||
|
||||
**Features:**
|
||||
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
|
||||
- Automatic timestamp management (`created_at`, `updated_at`)
|
||||
- Async relationship loading support (via `AsyncAttrs`)
|
||||
- Pagination and sorting via `TableViewRequest`
|
||||
- Polymorphic subclass loading support
|
||||
|
||||
**Fields:**
|
||||
- `id: int | None` - Integer primary key (auto-increment)
|
||||
- `created_at: datetime` - Record creation timestamp
|
||||
- `updated_at: datetime` - Record update timestamp (auto-updated)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class User(SQLModelBase, TableBaseMixin, table=True):
|
||||
name: str
|
||||
email: str
|
||||
"""User email"""
|
||||
|
||||
# CRUD operations
|
||||
async def example(session: AsyncSession):
|
||||
# Add
|
||||
user = User(name="Alice", email="alice@example.com")
|
||||
user = await user.save(session)
|
||||
|
||||
# Get
|
||||
user = await User.get(session, User.id == 1)
|
||||
|
||||
# Update
|
||||
update_data = UserUpdateRequest(name="Alice Smith")
|
||||
user = await user.update(session, update_data)
|
||||
|
||||
# Delete
|
||||
await User.delete(session, user)
|
||||
|
||||
# Count
|
||||
count = await User.count(session, User.is_active == True)
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- `save()` and `update()` return refreshed instances - **always use the return value**:
|
||||
```python
|
||||
# ✅ Correct
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device
|
||||
```
|
||||
|
||||
### 2. UUIDTableBaseMixin
|
||||
|
||||
Extends `TableBaseMixin` with UUID primary keys instead of integers.
|
||||
|
||||
**Differences from TableBaseMixin:**
|
||||
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
|
||||
- `get_exist_one()` accepts `UUID` instead of `int`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
description: str | None = None
|
||||
"""Character description"""
|
||||
```
|
||||
|
||||
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
|
||||
|
||||
### 3. TableViewRequest
|
||||
|
||||
Standardized pagination and sorting parameters for LIST endpoints.
|
||||
|
||||
**Fields:**
|
||||
- `offset: int | None` - Skip first N records (default: 0)
|
||||
- `limit: int | None` - Return max N records (default: 50, max: 100)
|
||||
- `desc: bool | None` - Sort descending (default: True)
|
||||
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
|
||||
|
||||
**Usage with TableBaseMixin.get():**
|
||||
```python
|
||||
from dependencies import TableViewRequestDep
|
||||
|
||||
@router.get("/list")
|
||||
async def list_characters(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
) -> list[Character]:
|
||||
"""List characters with pagination and sorting"""
|
||||
return await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view # Automatically handles pagination and sorting
|
||||
)
|
||||
```
|
||||
|
||||
**Manual usage:**
|
||||
```python
|
||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
||||
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
|
||||
|
||||
### 4. PolymorphicBaseMixin
|
||||
|
||||
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
|
||||
|
||||
**Automatic Configuration:**
|
||||
- Defines `_polymorphic_name: str` field (indexed)
|
||||
- Sets `polymorphic_on='_polymorphic_name'`
|
||||
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
|
||||
|
||||
**Methods:**
|
||||
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
|
||||
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
|
||||
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
|
||||
"""Abstract base class for all tools"""
|
||||
name: str
|
||||
description: str
|
||||
"""Tool description"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, params: dict) -> dict:
|
||||
"""Execute the tool"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Why Single Underscore Prefix?**
|
||||
- SQLAlchemy maps single-underscore fields to database columns
|
||||
- Pydantic treats them as private (excluded from serialization)
|
||||
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
|
||||
|
||||
### 5. create_subclass_id_mixin()
|
||||
|
||||
Factory function to create ID mixins for subclasses in joined table inheritance.
|
||||
|
||||
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
|
||||
"""
|
||||
Args:
|
||||
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
A mixin class containing id field (foreign key + primary key)
|
||||
"""
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import create_subclass_id_mixin
|
||||
|
||||
# Create mixin for ASR subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
|
||||
"""FunASR implementation"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
|
||||
|
||||
### 6. AutoPolymorphicIdentityMixin
|
||||
|
||||
Automatically generates `polymorphic_identity` based on class name.
|
||||
|
||||
**Naming Convention:**
|
||||
- Format: `{parent_identity}.{classname_lowercase}`
|
||||
- If no parent identity exists, uses `{classname_lowercase}`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
|
||||
|
||||
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
"""Base class for function-type tools"""
|
||||
pass
|
||||
# polymorphic_identity = 'function'
|
||||
|
||||
class GetWeatherFunction(Function, table=True):
|
||||
"""Weather query function"""
|
||||
pass
|
||||
# polymorphic_identity = 'function.getweatherfunction'
|
||||
```
|
||||
|
||||
**Manual Override:**
|
||||
```python
|
||||
class CustomTool(
|
||||
Tool,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_identity='custom_name', # Override auto-generated name
|
||||
table=True
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### 7. JWTAuthMixin
|
||||
|
||||
Provides JWT token generation and validation for entity classes (User, Client).
|
||||
|
||||
**Methods:**
|
||||
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
|
||||
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
|
||||
|
||||
**Requirements:**
|
||||
Subclasses must define:
|
||||
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
|
||||
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
|
||||
|
||||
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
|
||||
JWTPayload = UserJWTPayload # Define payload model
|
||||
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
|
||||
|
||||
email: str
|
||||
is_admin: bool = False
|
||||
is_active: bool = True
|
||||
"""User active status"""
|
||||
|
||||
# Generate token
|
||||
async def login(session: AsyncSession, user: User) -> str:
|
||||
token = await user.issue_jwt(session)
|
||||
return token
|
||||
|
||||
# Validate token
|
||||
async def verify(session: AsyncSession, token: str) -> User:
|
||||
user = await User.from_jwt(session, token)
|
||||
return user
|
||||
```
|
||||
|
||||
### 8. Response DTO Mixins
|
||||
|
||||
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
|
||||
|
||||
**Available Mixins:**
|
||||
- `IntIdInfoMixin` - Integer ID field
|
||||
- `UUIDIdInfoMixin` - UUID ID field
|
||||
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
|
||||
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
|
||||
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
|
||||
|
||||
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
|
||||
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
"""Character response DTO with id and timestamps"""
|
||||
pass # Inherits id, created_at, updated_at from mixin
|
||||
```
|
||||
|
||||
## Complete Joined Table Inheritance Example
|
||||
|
||||
Here's a complete example demonstrating polymorphic inheritance:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import (
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. Define Base class (fields only, no table)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
"""Configuration name"""
|
||||
|
||||
base_url: str
|
||||
"""Service URL"""
|
||||
|
||||
# 2. Define abstract parent class (with table)
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract base class for ASR configurations"""
|
||||
# PolymorphicBaseMixin automatically provides:
|
||||
# - _polymorphic_name field
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True (when ABC with abstract methods)
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
"""Transcribe audio to text"""
|
||||
pass
|
||||
|
||||
# 3. Create ID Mixin for second-level subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. Create second-level abstract class (if needed)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
"""FunASR abstract base (may have multiple implementations)"""
|
||||
pass
|
||||
# polymorphic_identity = 'funasr'
|
||||
|
||||
# 5. Create concrete implementation classes
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
"""FunASR local deployment"""
|
||||
# polymorphic_identity = 'funasr.funasrlocal'
|
||||
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
# Implementation...
|
||||
return "transcribed text"
|
||||
|
||||
# 6. Get all concrete subclasses (for selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# Returns: [FunASRLocal, ...]
|
||||
```
|
||||
|
||||
## Import Guidelines
|
||||
|
||||
**Standard Import:**
|
||||
```python
|
||||
from sqlmodels.mixin import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
TableViewRequest,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
JWTAuthMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
Some exports are also available from `sqlmodels.base` for backward compatibility:
|
||||
```python
|
||||
# Legacy import path (still works)
|
||||
from sqlmodels.base import UUIDTableBase, TableViewRequest
|
||||
|
||||
# Recommended new import path
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Mixin Order Matters
|
||||
|
||||
**Correct Order:**
|
||||
```python
|
||||
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
|
||||
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
**Wrong Order:**
|
||||
```python
|
||||
# ❌ ID Mixin not first - won't override parent's id field
|
||||
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. Always Use Return Values from save() and update()
|
||||
|
||||
```python
|
||||
# ✅ Correct - use returned instance
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device # AttributeError when accessing fields
|
||||
```
|
||||
|
||||
### 3. Prefer table_view Over Manual Pagination
|
||||
|
||||
```python
|
||||
# ✅ Recommended - consistent across all endpoints
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view
|
||||
)
|
||||
|
||||
# ⚠️ Works but not recommended - manual parameter management
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
offset=0,
|
||||
limit=20,
|
||||
order_by=[desc(Character.created_at)]
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Polymorphic Loading for Many Subclasses
|
||||
|
||||
```python
|
||||
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
|
||||
)
|
||||
|
||||
# For fewer subclasses, specify the list explicitly
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Response DTOs Should Inherit Base Classes
|
||||
|
||||
```python
|
||||
# ✅ Correct - inherits from CharacterBase
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
pass
|
||||
|
||||
# ❌ Wrong - doesn't inherit from CharacterBase
|
||||
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
|
||||
name: str # Duplicated field definition
|
||||
description: str | None = None
|
||||
```
|
||||
|
||||
**Reason:** Inheriting from Base classes ensures:
|
||||
- Type checking via `isinstance(obj, XxxBase)`
|
||||
- Consistency across related DTOs
|
||||
- Future field additions automatically propagate
|
||||
|
||||
### 6. Use Specific Types, Not Containers
|
||||
|
||||
```python
|
||||
# ✅ Correct - specific DTO for config updates
|
||||
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
|
||||
weather_api_key: str | None = None
|
||||
default_location: str | None = None
|
||||
"""Default location"""
|
||||
|
||||
# ❌ Wrong - lose type safety
|
||||
class ToolUpdateRequest(SQLModelBase):
|
||||
config: dict[str, Any] # No field validation
|
||||
```
|
||||
|
||||
## Type Variables
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import T, M
|
||||
|
||||
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
|
||||
M = TypeVar("M", bound="SQLModel") # For update() method
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import now, now_date
|
||||
|
||||
# Lambda functions for default factories
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
```
|
||||
|
||||
## Related Modules
|
||||
|
||||
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
|
||||
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
|
||||
- **sqlmodels.user** - User model with JWT authentication
|
||||
- **sqlmodels.client** - Client model with JWT authentication
|
||||
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
|
||||
- `CLAUDE.md` - Project coding standards and design philosophy
|
||||
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
62
sqlmodels/mixin/__init__.py
Normal file
62
sqlmodels/mixin/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
SQLModel Mixin模块
|
||||
|
||||
提供各种Mixin类供SQLModel实体使用。
|
||||
|
||||
包含:
|
||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
||||
- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError)
|
||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
||||
- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations)
|
||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
||||
|
||||
导入顺序很重要,避免循环导入:
|
||||
1. polymorphic(只依赖 SQLModelBase)
|
||||
2. optimistic_lock(只依赖 SQLAlchemy)
|
||||
3. table(依赖 polymorphic 和 optimistic_lock)
|
||||
4. relation_preload(只依赖 SQLModelBase)
|
||||
|
||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
||||
"""
|
||||
# polymorphic 必须先导入
|
||||
from .polymorphic import (
|
||||
AutoPolymorphicIdentityMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
register_sti_column_properties_for_all_subclasses,
|
||||
register_sti_columns_for_all_subclasses,
|
||||
)
|
||||
# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前
|
||||
from .optimistic_lock import (
|
||||
OptimisticLockError,
|
||||
OptimisticLockMixin,
|
||||
)
|
||||
# table 依赖 polymorphic 和 optimistic_lock
|
||||
from .table import (
|
||||
ListResponse,
|
||||
PaginationRequest,
|
||||
T,
|
||||
TableBaseMixin,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
UUIDTableBaseMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
# relation_preload 只依赖 SQLModelBase
|
||||
from .relation_preload import (
|
||||
RelationPreloadMixin,
|
||||
requires_relations,
|
||||
)
|
||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
||||
from .info_response import (
|
||||
DatetimeInfoMixin,
|
||||
IntIdDatetimeInfoMixin,
|
||||
IntIdInfoMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
)
|
||||
46
sqlmodels/mixin/info_response.py
Normal file
46
sqlmodels/mixin/info_response.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
InfoResponse DTO Mixin模块
|
||||
|
||||
提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。
|
||||
|
||||
设计说明:
|
||||
- 这些Mixin用于**响应DTO**,不是数据库表
|
||||
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
|
||||
- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成)
|
||||
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
|
||||
"""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
|
||||
class IntIdInfoMixin(SQLModelBase):
|
||||
"""整数ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: int
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class UUIDIdInfoMixin(SQLModelBase):
|
||||
"""UUID ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: UUID
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class DatetimeInfoMixin(SQLModelBase):
|
||||
"""时间戳响应mixin - 用于InfoResponse DTO"""
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""更新时间"""
|
||||
|
||||
|
||||
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
|
||||
"""整数ID + 时间戳响应mixin"""
|
||||
pass
|
||||
|
||||
|
||||
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
|
||||
"""UUID ID + 时间戳响应mixin"""
|
||||
pass
|
||||
90
sqlmodels/mixin/optimistic_lock.py
Normal file
90
sqlmodels/mixin/optimistic_lock.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
|
||||
|
||||
乐观锁适用场景:
|
||||
- 涉及"状态转换"的表(如:待支付 -> 已支付)
|
||||
- 涉及"数值变动"的表(如:余额、库存)
|
||||
|
||||
不适用场景:
|
||||
- 日志表、纯插入表、低价值统计表
|
||||
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
|
||||
|
||||
使用示例:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
status: OrderStatusEnum
|
||||
amount: Decimal
|
||||
|
||||
# save/update 时自动检查版本号
|
||||
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
|
||||
try:
|
||||
order = await order.save(session)
|
||||
except OptimisticLockError as e:
|
||||
# 处理冲突:重新查询并重试,或报错给用户
|
||||
l.warning(f"乐观锁冲突: {e}")
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
|
||||
|
||||
class OptimisticLockError(Exception):
|
||||
"""
|
||||
乐观锁冲突异常
|
||||
|
||||
当 save/update 操作检测到版本号不匹配时抛出。
|
||||
这意味着在读取和写入之间,其他事务已经修改了该记录。
|
||||
|
||||
Attributes:
|
||||
model_class: 发生冲突的模型类名
|
||||
record_id: 记录 ID(如果可用)
|
||||
expected_version: 期望的版本号(如果可用)
|
||||
original_error: 原始的 StaleDataError
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
model_class: str | None = None,
|
||||
record_id: str | None = None,
|
||||
expected_version: int | None = None,
|
||||
original_error: StaleDataError | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.model_class = model_class
|
||||
self.record_id = record_id
|
||||
self.expected_version = expected_version
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class OptimisticLockMixin:
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
|
||||
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
|
||||
session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。
|
||||
|
||||
原理:
|
||||
1. 每条记录有一个 version 字段,初始值为 0
|
||||
2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似:
|
||||
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
|
||||
3. 如果 WHERE 条件不匹配(version 已被其他事务修改),
|
||||
UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError
|
||||
|
||||
继承顺序:
|
||||
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
...
|
||||
|
||||
配套重试:
|
||||
如果加了乐观锁,业务层需要处理 OptimisticLockError:
|
||||
- 报错给用户:"数据已被修改,请刷新后重试"
|
||||
- 自动重试:重新查询最新数据并再次尝试
|
||||
"""
|
||||
_has_optimistic_lock: ClassVar[bool] = True
|
||||
"""标记此类启用了乐观锁"""
|
||||
|
||||
version: int = 0
|
||||
"""乐观锁版本号,每次更新自动递增"""
|
||||
710
sqlmodels/mixin/polymorphic.py
Normal file
710
sqlmodels/mixin/polymorphic.py
Normal file
@@ -0,0 +1,710 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import Column, String, inspect
|
||||
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
from sqlmodel.main import get_column_from_field
|
||||
|
||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# 用于延迟注册 STI 子类列的队列
|
||||
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
|
||||
_sti_subclasses_to_register: list[type] = []
|
||||
|
||||
|
||||
def register_sti_columns_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
|
||||
|
||||
此函数应在 configure_mappers() 之前调用。
|
||||
将 STI 子类的字段添加到父表的 metadata 中。
|
||||
同时修复被 Column 对象污染的 model_fields。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_columns()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
|
||||
|
||||
# 修复被 Column 对象污染的 model_fields
|
||||
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
|
||||
try:
|
||||
_fix_polluted_model_fields(cls)
|
||||
except Exception as e:
|
||||
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
|
||||
|
||||
|
||||
def register_sti_column_properties_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
|
||||
|
||||
此函数应在 configure_mappers() 之后调用。
|
||||
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_column_properties()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
|
||||
|
||||
# 清空队列
|
||||
_sti_subclasses_to_register.clear()
|
||||
|
||||
|
||||
def _fix_polluted_model_fields(cls: type) -> None:
|
||||
"""
|
||||
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
|
||||
|
||||
当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
|
||||
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
|
||||
时错误地使用这些 SQLAlchemy 对象作为默认值。
|
||||
|
||||
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
|
||||
|
||||
:param cls: 要修复的类
|
||||
"""
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 或 Column 污染的
|
||||
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
|
||||
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
|
||||
# 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
|
||||
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制 SQLModel 特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的 default_factory 丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin,并支持STI子类列注册
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
**重要:数据库迁移注意事项**
|
||||
|
||||
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
|
||||
|
||||
例如,对于以下继承链::
|
||||
|
||||
LLM (polymorphic_on='_polymorphic_name')
|
||||
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
|
||||
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
|
||||
|
||||
迁移脚本中设置 _polymorphic_name 时::
|
||||
|
||||
# ❌ 错误:缺少父类前缀
|
||||
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
|
||||
|
||||
# ✅ 正确:包含完整的继承链前缀
|
||||
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
|
||||
|
||||
**STI(单表继承)支持**:
|
||||
当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
|
||||
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
|
||||
注册到父表的问题。
|
||||
|
||||
Example (JTI):
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Example (STI):
|
||||
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
|
||||
... user_id: UUID
|
||||
...
|
||||
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
|
||||
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
|
||||
... # polymorphic_identity 会自动设置为 'pendingfile'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
- STI模式下,新字段会在类定义时自动添加到父表的metadata中
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity并处理STI列注册
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
# 注册 STI 子类列的延迟执行
|
||||
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
|
||||
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
|
||||
_sti_subclasses_to_register.append(cls)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复 STI 继承中子类字段被 Column 对象污染的问题。
|
||||
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
|
||||
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
@classmethod
|
||||
def _register_sti_columns(cls) -> None:
|
||||
"""
|
||||
将STI子类的新字段注册到父表的列定义中
|
||||
|
||||
检测当前类是否是STI子类(与父类共用同一张表),
|
||||
如果是,则将子类定义的新字段添加到父表的metadata中。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_fields: set[str] = set()
|
||||
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
# 收集父类的所有字段名
|
||||
if hasattr(base, 'model_fields'):
|
||||
parent_fields.update(base.model_fields.keys())
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列注册到父表
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列注册
|
||||
|
||||
# 获取当前类的新字段(不在父类中的字段)
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
existing_columns = {col.name for col in parent_table.columns}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段和ClassVar
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 跳过已存在的列
|
||||
if field_name in existing_columns:
|
||||
continue
|
||||
|
||||
# 使用 SQLModel 的内置 API 创建列
|
||||
try:
|
||||
column = get_column_from_field(field_info)
|
||||
column.name = field_name
|
||||
column.key = field_name
|
||||
# STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
|
||||
# Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
|
||||
column.nullable = True
|
||||
|
||||
# 将列添加到父表
|
||||
parent_table.append_column(column)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def _register_sti_column_properties(cls) -> None:
|
||||
"""
|
||||
将 STI 子类的列作为 ColumnProperty 添加到 mapper
|
||||
|
||||
此方法在 configure_mappers() 之后调用,将已添加到表中的列
|
||||
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
|
||||
|
||||
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
|
||||
这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
|
||||
避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_class = None
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
parent_class = base
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列属性注册
|
||||
|
||||
# 获取子类和父类的 mapper
|
||||
child_mapper = inspect(cls).mapper
|
||||
parent_mapper = inspect(parent_class).mapper
|
||||
local_table = child_mapper.local_table
|
||||
|
||||
# 查找父类的所有字段名
|
||||
parent_fields: set[str] = set()
|
||||
if hasattr(parent_class, 'model_fields'):
|
||||
parent_fields.update(parent_class.model_fields.keys())
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
# 获取两个 mapper 已有的列属性
|
||||
child_existing_props = {p.key for p in child_mapper.column_attrs}
|
||||
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
|
||||
|
||||
for field_name in cls.model_fields:
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 检查表中是否有这个列
|
||||
if field_name not in local_table.columns:
|
||||
continue
|
||||
|
||||
column = local_table.columns[field_name]
|
||||
|
||||
# 添加到子类的 mapper(如果尚不存在)
|
||||
if field_name not in child_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
child_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
|
||||
|
||||
# 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
|
||||
if field_name not in parent_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
parent_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def _is_joined_table_inheritance(cls) -> bool:
|
||||
"""
|
||||
检测当前类是否使用联表继承(Joined Table Inheritance)
|
||||
|
||||
通过检查子类是否有独立的表来判断:
|
||||
- JTI: 子类有独立的 local_table(与父类不同)
|
||||
- STI: 子类与父类共用同一个 local_table
|
||||
|
||||
:return: True 表示 JTI,False 表示 STI 或无子类
|
||||
"""
|
||||
mapper = inspect(cls)
|
||||
base_table_name = mapper.local_table.name
|
||||
|
||||
# 检查所有直接子类
|
||||
for subclass in cls.__subclasses__():
|
||||
sub_mapper = inspect(subclass)
|
||||
# 如果任何子类有不同的表名,说明是 JTI
|
||||
if sub_mapper.local_table.name != base_table_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
470
sqlmodels/mixin/relation_preload.py
Normal file
470
sqlmodels/mixin/relation_preload.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
||||
|
||||
设计原则:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
||||
- 零侵入:调用方无需任何改动
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
||||
|
||||
使用方式:
|
||||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
||||
|
||||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
||||
kling_video_generator: KlingO1Generator = Relationship(...)
|
||||
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# 自动加载,可以安全访问
|
||||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
||||
...
|
||||
|
||||
# 调用方 - 无需任何改动
|
||||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
||||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
||||
|
||||
支持 AsyncGenerator:
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
||||
"""
|
||||
import inspect as python_inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
|
||||
|
||||
def _extract_session(
|
||||
func: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> AsyncSession | None:
|
||||
"""
|
||||
从方法参数中提取 AsyncSession
|
||||
|
||||
按以下顺序查找:
|
||||
1. kwargs 中名为 'session' 的参数
|
||||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
||||
3. kwargs 中类型为 AsyncSession 的参数
|
||||
"""
|
||||
# 1. 优先从 kwargs 查找
|
||||
if 'session' in kwargs:
|
||||
return kwargs['session']
|
||||
|
||||
# 2. 从函数签名定位位置参数
|
||||
try:
|
||||
sig = python_inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
if 'session' in param_names:
|
||||
# 计算位置(减去 self)
|
||||
idx = param_names.index('session') - 1
|
||||
if 0 <= idx < len(args):
|
||||
return args[idx]
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, AsyncSession):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
||||
"""
|
||||
检查对象的关系是否已加载(独立函数版本)
|
||||
|
||||
Args:
|
||||
obj: 要检查的对象
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(obj)
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
||||
"""
|
||||
在类中查找指向目标类的关系属性名
|
||||
|
||||
Args:
|
||||
from_class: 源类
|
||||
to_class: 目标类
|
||||
|
||||
Returns:
|
||||
关系属性名,如果找不到则返回 None
|
||||
|
||||
Example:
|
||||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
||||
# 返回 'kling_video_generator'
|
||||
"""
|
||||
for attr_name in dir(from_class):
|
||||
try:
|
||||
attr = getattr(from_class, attr_name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
||||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
||||
target_class = attr.property.mapper.class_
|
||||
if target_class == to_class:
|
||||
return attr_name
|
||||
except AttributeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""
|
||||
装饰器:声明方法需要的关系,自动按需增量加载
|
||||
|
||||
参数格式:
|
||||
- 字符串:本类属性名,如 'kling_video_generator'
|
||||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
||||
|
||||
行为:
|
||||
- 方法调用时自动检查关系是否已加载
|
||||
- 未加载的关系会被增量加载(单次查询)
|
||||
- 已加载的关系直接跳过
|
||||
|
||||
支持:
|
||||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
||||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
||||
|
||||
Example:
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# self.kling_video_generator.kling_o1 已自动加载
|
||||
...
|
||||
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
||||
|
||||
验证:
|
||||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
||||
- 拼写错误会在导入时抛出 AttributeError
|
||||
"""
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
# 检测是否是 async generator 函数
|
||||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
||||
|
||||
if is_async_gen:
|
||||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
# 委托给原始 async generator,逐个 yield 值
|
||||
async for item in func(self, *args, **kwargs):
|
||||
yield item # type: ignore
|
||||
else:
|
||||
# 普通 async 函数:await 并返回结果
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
# 保存关系声明供验证和内省使用
|
||||
wrapper._required_relations = relations # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RelationPreloadMixin:
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
||||
|
||||
特性:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 原地更新:直接修改 self,无需替换实例
|
||||
- 导入时验证:字符串关系名在类创建时验证
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""类创建时验证所有 @requires_relations 声明"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 收集类及其父类的所有注解(包含普通字段)
|
||||
all_annotations: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__annotations__'):
|
||||
all_annotations.update(klass.__annotations__.keys())
|
||||
|
||||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
||||
sqlmodel_relationships: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__sqlmodel_relationships__'):
|
||||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
||||
|
||||
# 合并所有可用的属性名
|
||||
all_available_names = all_annotations | sqlmodel_relationships
|
||||
|
||||
for method_name in dir(cls):
|
||||
if method_name.startswith('__'):
|
||||
continue
|
||||
|
||||
try:
|
||||
method = getattr(cls, method_name, None)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
continue
|
||||
|
||||
# 验证字符串格式的关系名
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
# 检查注解、Relationship 或已有属性
|
||||
if spec not in all_available_names and not hasattr(cls, spec):
|
||||
raise AttributeError(
|
||||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
||||
f"但 {cls.__name__} 没有此属性"
|
||||
)
|
||||
|
||||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
||||
"""
|
||||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
||||
|
||||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
||||
自动处理 commit 导致的 expire 问题。
|
||||
|
||||
Args:
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(self)
|
||||
# unloaded 包含未加载的关系属性名
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
# 对象可能未被 SQLAlchemy 管理
|
||||
return False
|
||||
|
||||
async def _ensure_relations_loaded(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
relations: tuple[str | RelationshipInfo, ...],
|
||||
) -> None:
|
||||
"""
|
||||
确保指定关系已加载,只加载未加载的部分
|
||||
|
||||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
||||
- 首次访问的关系
|
||||
- commit 后 expire 的关系
|
||||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
relations: 需要的关系规格
|
||||
"""
|
||||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
||||
to_load: list[str | RelationshipInfo] = []
|
||||
# 区分直接关系和嵌套关系的 key
|
||||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
||||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
||||
|
||||
for rel in relations:
|
||||
if isinstance(rel, str):
|
||||
# 直接关系:检查本类的关系是否已加载
|
||||
if not self._is_relation_loaded(rel):
|
||||
to_load.append(rel)
|
||||
direct_keys.add(rel)
|
||||
else:
|
||||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
||||
# 1. 查找指向父类的关系属性
|
||||
parent_class = rel.parent.class_
|
||||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
||||
|
||||
if parent_attr is None:
|
||||
# 找不到路径,可能是配置错误,但仍尝试加载
|
||||
l.warning(
|
||||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
||||
f"无法检查 {rel.key} 是否已加载"
|
||||
)
|
||||
to_load.append(rel)
|
||||
continue
|
||||
|
||||
# 2. 检查父对象是否已加载
|
||||
if not self._is_relation_loaded(parent_attr):
|
||||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
else:
|
||||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
||||
parent_obj = getattr(self, parent_attr)
|
||||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
||||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
||||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
|
||||
if not to_load:
|
||||
return # 全部已加载,跳过
|
||||
|
||||
# 构建 load 参数
|
||||
load_options = self._specs_to_load_options(to_load)
|
||||
if not load_options:
|
||||
return
|
||||
|
||||
# 安全地获取主键值(避免触发懒加载)
|
||||
state = sa_inspect(self)
|
||||
pk_tuple = state.key[1] if state.key else None
|
||||
if pk_tuple is None:
|
||||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
||||
return
|
||||
# 主键是元组,取第一个值(假设单列主键)
|
||||
pk_value = pk_tuple[0]
|
||||
|
||||
# 单次查询加载缺失的关系
|
||||
fresh = await self.__class__.get(
|
||||
session,
|
||||
self.__class__.id == pk_value,
|
||||
load=load_options,
|
||||
)
|
||||
|
||||
if fresh is None:
|
||||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
||||
return
|
||||
|
||||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
||||
all_direct_keys = direct_keys | nested_parent_keys
|
||||
for key in all_direct_keys:
|
||||
value = getattr(fresh, key, None)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _specs_to_load_options(
|
||||
self,
|
||||
specs: list[str | RelationshipInfo],
|
||||
) -> list[RelationshipInfo]:
|
||||
"""
|
||||
将关系规格转换为 load 参数
|
||||
|
||||
- 字符串 → cls.{name}
|
||||
- RelationshipInfo → 直接使用
|
||||
"""
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for spec in specs:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(self.__class__, spec, None)
|
||||
if rel is not None:
|
||||
result.append(rel)
|
||||
else:
|
||||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 可选的手动预加载 API ====================
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取指定方法声明的关系(用于外部预加载场景)
|
||||
|
||||
Args:
|
||||
method_name: 方法名
|
||||
|
||||
Returns:
|
||||
RelationshipInfo 列表
|
||||
"""
|
||||
method = getattr(cls, method_name, None)
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
return []
|
||||
|
||||
result: list[RelationshipInfo] = []
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(cls, spec, None)
|
||||
if rel:
|
||||
result.append(rel)
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取多个方法的关系并去重(用于批量预加载场景)
|
||||
|
||||
Args:
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
去重后的 RelationshipInfo 列表
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
for rel in cls.get_relations_for_method(method_name):
|
||||
key = rel.key
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
result.append(rel)
|
||||
|
||||
return result
|
||||
|
||||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
||||
"""
|
||||
手动预加载指定方法的关系(可选优化 API)
|
||||
|
||||
当需要确保在调用方法前完成所有加载时使用。
|
||||
通常情况下不需要调用此方法,装饰器会自动处理。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Example:
|
||||
# 可选:显式预加载(通常不需要)
|
||||
tool = await tool.preload_for(session, 'cost', '_call')
|
||||
"""
|
||||
all_relations: list[str | RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
method = getattr(self.__class__, method_name, None)
|
||||
if method and hasattr(method, '_required_relations'):
|
||||
all_relations.extend(method._required_relations)
|
||||
|
||||
if all_relations:
|
||||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
||||
|
||||
return self
|
||||
1247
sqlmodels/mixin/table.py
Normal file
1247
sqlmodels/mixin/table.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user