feat(mixin): add TableBaseMixin and UUIDTableBaseMixin for async CRUD operations
- Implemented TableBaseMixin providing generic CRUD methods and automatic timestamp management. - Introduced UUIDTableBaseMixin for models using UUID as primary keys. - Added ListResponse for standardized paginated responses. - Created TimeFilterRequest and PaginationRequest for filtering and pagination parameters. - Enhanced get_with_count method to return both item list and total count. - Included validation for time filter parameters in TimeFilterRequest. - Improved documentation and usage examples throughout the code.
This commit is contained in:
543
models/mixin/README.md
Normal file
543
models/mixin/README.md
Normal file
@@ -0,0 +1,543 @@
|
||||
# SQLModel Mixin Module
|
||||
|
||||
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
|
||||
|
||||
## Module Overview
|
||||
|
||||
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
|
||||
|
||||
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
|
||||
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
|
||||
- **JWT Authentication**: Token generation and validation
|
||||
- **Pagination & Sorting**: Standardized table view parameters
|
||||
- **Response DTOs**: Consistent id/timestamp fields for API responses
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
sqlmodels/mixin/
|
||||
├── __init__.py # Module exports
|
||||
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
|
||||
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
|
||||
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
|
||||
└── jwt/ # JWT authentication
|
||||
├── __init__.py
|
||||
├── key.py # JWTKey database model
|
||||
├── payload.py # JWTPayloadBase
|
||||
├── manager.py # JWTManager singleton
|
||||
├── auth.py # JWTAuthMixin
|
||||
├── exceptions.py # JWT-related exceptions
|
||||
└── responses.py # TokenResponse DTO
|
||||
```
|
||||
|
||||
## Dependency Hierarchy
|
||||
|
||||
The module has a strict import order to avoid circular dependencies:
|
||||
|
||||
1. **polymorphic.py** - Only depends on `SQLModelBase`
|
||||
2. **table.py** - Depends on `polymorphic.py`
|
||||
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
|
||||
4. **info_response.py** - Only depends on `SQLModelBase`
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. TableBaseMixin
|
||||
|
||||
Base mixin for database table models with integer primary keys.
|
||||
|
||||
**Features:**
|
||||
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
|
||||
- Automatic timestamp management (`created_at`, `updated_at`)
|
||||
- Async relationship loading support (via `AsyncAttrs`)
|
||||
- Pagination and sorting via `TableViewRequest`
|
||||
- Polymorphic subclass loading support
|
||||
|
||||
**Fields:**
|
||||
- `id: int | None` - Integer primary key (auto-increment)
|
||||
- `created_at: datetime` - Record creation timestamp
|
||||
- `updated_at: datetime` - Record update timestamp (auto-updated)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class User(SQLModelBase, TableBaseMixin, table=True):
|
||||
name: str
|
||||
email: str
|
||||
"""User email"""
|
||||
|
||||
# CRUD operations
|
||||
async def example(session: AsyncSession):
|
||||
# Add
|
||||
user = User(name="Alice", email="alice@example.com")
|
||||
user = await user.save(session)
|
||||
|
||||
# Get
|
||||
user = await User.get(session, User.id == 1)
|
||||
|
||||
# Update
|
||||
update_data = UserUpdateRequest(name="Alice Smith")
|
||||
user = await user.update(session, update_data)
|
||||
|
||||
# Delete
|
||||
await User.delete(session, user)
|
||||
|
||||
# Count
|
||||
count = await User.count(session, User.is_active == True)
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- `save()` and `update()` return refreshed instances - **always use the return value**:
|
||||
```python
|
||||
# ✅ Correct
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device
|
||||
```
|
||||
|
||||
### 2. UUIDTableBaseMixin
|
||||
|
||||
Extends `TableBaseMixin` with UUID primary keys instead of integers.
|
||||
|
||||
**Differences from TableBaseMixin:**
|
||||
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
|
||||
- `get_exist_one()` accepts `UUID` instead of `int`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
description: str | None = None
|
||||
"""Character description"""
|
||||
```
|
||||
|
||||
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
|
||||
|
||||
### 3. TableViewRequest
|
||||
|
||||
Standardized pagination and sorting parameters for LIST endpoints.
|
||||
|
||||
**Fields:**
|
||||
- `offset: int | None` - Skip first N records (default: 0)
|
||||
- `limit: int | None` - Return max N records (default: 50, max: 100)
|
||||
- `desc: bool | None` - Sort descending (default: True)
|
||||
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
|
||||
|
||||
**Usage with TableBaseMixin.get():**
|
||||
```python
|
||||
from dependencies import TableViewRequestDep
|
||||
|
||||
@router.get("/list")
|
||||
async def list_characters(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
) -> list[Character]:
|
||||
"""List characters with pagination and sorting"""
|
||||
return await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view # Automatically handles pagination and sorting
|
||||
)
|
||||
```
|
||||
|
||||
**Manual usage:**
|
||||
```python
|
||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
||||
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
|
||||
|
||||
### 4. PolymorphicBaseMixin
|
||||
|
||||
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
|
||||
|
||||
**Automatic Configuration:**
|
||||
- Defines `_polymorphic_name: str` field (indexed)
|
||||
- Sets `polymorphic_on='_polymorphic_name'`
|
||||
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
|
||||
|
||||
**Methods:**
|
||||
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
|
||||
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
|
||||
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
|
||||
"""Abstract base class for all tools"""
|
||||
name: str
|
||||
description: str
|
||||
"""Tool description"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, params: dict) -> dict:
|
||||
"""Execute the tool"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Why Single Underscore Prefix?**
|
||||
- SQLAlchemy maps single-underscore fields to database columns
|
||||
- Pydantic treats them as private (excluded from serialization)
|
||||
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
|
||||
|
||||
### 5. create_subclass_id_mixin()
|
||||
|
||||
Factory function to create ID mixins for subclasses in joined table inheritance.
|
||||
|
||||
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
|
||||
"""
|
||||
Args:
|
||||
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
A mixin class containing id field (foreign key + primary key)
|
||||
"""
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import create_subclass_id_mixin
|
||||
|
||||
# Create mixin for ASR subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
|
||||
"""FunASR implementation"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
|
||||
|
||||
### 6. AutoPolymorphicIdentityMixin
|
||||
|
||||
Automatically generates `polymorphic_identity` based on class name.
|
||||
|
||||
**Naming Convention:**
|
||||
- Format: `{parent_identity}.{classname_lowercase}`
|
||||
- If no parent identity exists, uses `{classname_lowercase}`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
|
||||
|
||||
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
"""Base class for function-type tools"""
|
||||
pass
|
||||
# polymorphic_identity = 'function'
|
||||
|
||||
class GetWeatherFunction(Function, table=True):
|
||||
"""Weather query function"""
|
||||
pass
|
||||
# polymorphic_identity = 'function.getweatherfunction'
|
||||
```
|
||||
|
||||
**Manual Override:**
|
||||
```python
|
||||
class CustomTool(
|
||||
Tool,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_identity='custom_name', # Override auto-generated name
|
||||
table=True
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### 7. JWTAuthMixin
|
||||
|
||||
Provides JWT token generation and validation for entity classes (User, Client).
|
||||
|
||||
**Methods:**
|
||||
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
|
||||
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
|
||||
|
||||
**Requirements:**
|
||||
Subclasses must define:
|
||||
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
|
||||
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
|
||||
|
||||
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
|
||||
JWTPayload = UserJWTPayload # Define payload model
|
||||
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
|
||||
|
||||
email: str
|
||||
is_admin: bool = False
|
||||
is_active: bool = True
|
||||
"""User active status"""
|
||||
|
||||
# Generate token
|
||||
async def login(session: AsyncSession, user: User) -> str:
|
||||
token = await user.issue_jwt(session)
|
||||
return token
|
||||
|
||||
# Validate token
|
||||
async def verify(session: AsyncSession, token: str) -> User:
|
||||
user = await User.from_jwt(session, token)
|
||||
return user
|
||||
```
|
||||
|
||||
### 8. Response DTO Mixins
|
||||
|
||||
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
|
||||
|
||||
**Available Mixins:**
|
||||
- `IntIdInfoMixin` - Integer ID field
|
||||
- `UUIDIdInfoMixin` - UUID ID field
|
||||
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
|
||||
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
|
||||
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
|
||||
|
||||
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
|
||||
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
"""Character response DTO with id and timestamps"""
|
||||
pass # Inherits id, created_at, updated_at from mixin
|
||||
```
|
||||
|
||||
## Complete Joined Table Inheritance Example
|
||||
|
||||
Here's a complete example demonstrating polymorphic inheritance:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import (
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. Define Base class (fields only, no table)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
"""Configuration name"""
|
||||
|
||||
base_url: str
|
||||
"""Service URL"""
|
||||
|
||||
# 2. Define abstract parent class (with table)
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract base class for ASR configurations"""
|
||||
# PolymorphicBaseMixin automatically provides:
|
||||
# - _polymorphic_name field
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True (when ABC with abstract methods)
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
"""Transcribe audio to text"""
|
||||
pass
|
||||
|
||||
# 3. Create ID Mixin for second-level subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. Create second-level abstract class (if needed)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
"""FunASR abstract base (may have multiple implementations)"""
|
||||
pass
|
||||
# polymorphic_identity = 'funasr'
|
||||
|
||||
# 5. Create concrete implementation classes
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
"""FunASR local deployment"""
|
||||
# polymorphic_identity = 'funasr.funasrlocal'
|
||||
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
# Implementation...
|
||||
return "transcribed text"
|
||||
|
||||
# 6. Get all concrete subclasses (for selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# Returns: [FunASRLocal, ...]
|
||||
```
|
||||
|
||||
## Import Guidelines
|
||||
|
||||
**Standard Import:**
|
||||
```python
|
||||
from sqlmodels.mixin import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
TableViewRequest,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
JWTAuthMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
Some exports are also available from `sqlmodels.base` for backward compatibility:
|
||||
```python
|
||||
# Legacy import path (still works)
|
||||
from sqlmodels.base import UUIDTableBase, TableViewRequest
|
||||
|
||||
# Recommended new import path
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Mixin Order Matters
|
||||
|
||||
**Correct Order:**
|
||||
```python
|
||||
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
|
||||
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
**Wrong Order:**
|
||||
```python
|
||||
# ❌ ID Mixin not first - won't override parent's id field
|
||||
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. Always Use Return Values from save() and update()
|
||||
|
||||
```python
|
||||
# ✅ Correct - use returned instance
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device # AttributeError when accessing fields
|
||||
```
|
||||
|
||||
### 3. Prefer table_view Over Manual Pagination
|
||||
|
||||
```python
|
||||
# ✅ Recommended - consistent across all endpoints
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view
|
||||
)
|
||||
|
||||
# ⚠️ Works but not recommended - manual parameter management
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
offset=0,
|
||||
limit=20,
|
||||
order_by=[desc(Character.created_at)]
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Polymorphic Loading for Many Subclasses
|
||||
|
||||
```python
|
||||
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
|
||||
)
|
||||
|
||||
# For fewer subclasses, specify the list explicitly
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Response DTOs Should Inherit Base Classes
|
||||
|
||||
```python
|
||||
# ✅ Correct - inherits from CharacterBase
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
pass
|
||||
|
||||
# ❌ Wrong - doesn't inherit from CharacterBase
|
||||
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
|
||||
name: str # Duplicated field definition
|
||||
description: str | None = None
|
||||
```
|
||||
|
||||
**Reason:** Inheriting from Base classes ensures:
|
||||
- Type checking via `isinstance(obj, XxxBase)`
|
||||
- Consistency across related DTOs
|
||||
- Future field additions automatically propagate
|
||||
|
||||
### 6. Use Specific Types, Not Containers
|
||||
|
||||
```python
|
||||
# ✅ Correct - specific DTO for config updates
|
||||
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
|
||||
weather_api_key: str | None = None
|
||||
default_location: str | None = None
|
||||
"""Default location"""
|
||||
|
||||
# ❌ Wrong - lose type safety
|
||||
class ToolUpdateRequest(SQLModelBase):
|
||||
config: dict[str, Any] # No field validation
|
||||
```
|
||||
|
||||
## Type Variables
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import T, M
|
||||
|
||||
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
|
||||
M = TypeVar("M", bound="SQLModel") # For update() method
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import now, now_date
|
||||
|
||||
# Lambda functions for default factories
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
```
|
||||
|
||||
## Related Modules
|
||||
|
||||
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
|
||||
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
|
||||
- **sqlmodels.user** - User model with JWT authentication
|
||||
- **sqlmodels.client** - Client model with JWT authentication
|
||||
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
|
||||
- `CLAUDE.md` - Project coding standards and design philosophy
|
||||
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
46
models/mixin/__init__.py
Normal file
46
models/mixin/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
SQLModel Mixin模块
|
||||
|
||||
提供各种Mixin类供SQLModel实体使用。
|
||||
|
||||
包含:
|
||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
||||
|
||||
导入顺序很重要,避免循环导入:
|
||||
1. polymorphic(只依赖 SQLModelBase)
|
||||
2. table(依赖 polymorphic)
|
||||
|
||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
||||
"""
|
||||
# polymorphic 必须先导入
|
||||
from .polymorphic import (
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
PolymorphicBaseMixin,
|
||||
)
|
||||
# table 依赖 polymorphic
|
||||
from .table import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
TimeFilterRequest,
|
||||
PaginationRequest,
|
||||
TableViewRequest,
|
||||
ListResponse,
|
||||
T,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
||||
from .info_response import (
|
||||
IntIdInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
DatetimeInfoMixin,
|
||||
IntIdDatetimeInfoMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
)
|
||||
46
models/mixin/info_response.py
Normal file
46
models/mixin/info_response.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
InfoResponse DTO Mixin模块
|
||||
|
||||
提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。
|
||||
|
||||
设计说明:
|
||||
- 这些Mixin用于**响应DTO**,不是数据库表
|
||||
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
|
||||
- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成)
|
||||
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
|
||||
"""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from models.base import SQLModelBase
|
||||
|
||||
|
||||
class IntIdInfoMixin(SQLModelBase):
|
||||
"""整数ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: int
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class UUIDIdInfoMixin(SQLModelBase):
|
||||
"""UUID ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: UUID
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class DatetimeInfoMixin(SQLModelBase):
|
||||
"""时间戳响应mixin - 用于InfoResponse DTO"""
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""更新时间"""
|
||||
|
||||
|
||||
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
|
||||
"""整数ID + 时间戳响应mixin"""
|
||||
pass
|
||||
|
||||
|
||||
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
|
||||
"""UUID ID + 时间戳响应mixin"""
|
||||
pass
|
||||
456
models/mixin/polymorphic.py
Normal file
456
models/mixin/polymorphic.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import String, inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的default_factory丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
|
||||
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory,
|
||||
遵循单一真相原则(不硬编码 default_factory)。
|
||||
|
||||
需要修复的字段:
|
||||
- id: 主键(从父类获取 default_factory)
|
||||
- created_at: 创建时间戳(从父类获取 default_factory)
|
||||
- updated_at: 更新时间戳(从父类获取 default_factory)
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 污染的
|
||||
if not isinstance(field_info.default, InstrumentedAttribute):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
# 动态检测所有需要修复的字段
|
||||
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
|
||||
# 1. default 是 InstrumentedAttribute(被 SQLAlchemy 污染)
|
||||
# 2. 原始定义有 default_factory 或明确的 default 值
|
||||
#
|
||||
# 覆盖场景:
|
||||
# - UUID主键(UUIDTableBaseMixin):id 有 default_factory=uuid.uuid4,需要修复
|
||||
# - int主键(TableBaseMixin):id 用 default=None,不需要修复(数据库自增)
|
||||
# - created_at/updated_at:有 default_factory=now,需要修复
|
||||
# - 外键字段(created_by_id等):有 default=None,需要修复
|
||||
# - 普通字段(name, temperature等):无 default_factory,不需要修复
|
||||
#
|
||||
# MRO 查找保证:
|
||||
# - 在多重继承场景下,MRO 顺序是确定性的
|
||||
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute)
|
||||
if not isinstance(current_field.default, InstrumentedAttribute):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制SQLModel特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
Example:
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
927
models/mixin/table.py
Normal file
927
models/mixin/table.py
Normal file
@@ -0,0 +1,927 @@
|
||||
"""
|
||||
表基类 Mixin
|
||||
|
||||
提供 TableBaseMixin、UUIDTableBaseMixin 和 TableViewRequest。
|
||||
这些类实际上是 Mixin,为 SQLModel 模型提供 CRUD 操作和时间戳字段。
|
||||
|
||||
依赖关系:
|
||||
base/sqlmodel_base.py ← 最底层
|
||||
↓
|
||||
mixin/polymorphic.py ← 定义 PolymorphicBaseMixin
|
||||
↓
|
||||
mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin
|
||||
↓
|
||||
base/__init__.py ← 从 mixin 重新导出(保持向后兼容)
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Literal, override, Any, ClassVar, Generic
|
||||
|
||||
# TODO(ListResponse泛型问题): SQLModel泛型类型JSON Schema生成bug
|
||||
# 已知问题: https://github.com/fastapi/sqlmodel/discussions/1002
|
||||
# 修复PR: https://github.com/fastapi/sqlmodel/pull/1275 (尚未合并)
|
||||
# 现象: SQLModel + Generic[T] 的 __pydantic_generic_metadata__ = {origin: None, args: ()}
|
||||
# 导致OpenAPI schema中泛型字段显示为{}而非正确的$ref
|
||||
# 当前方案: ListResponse继承BaseModel而非SQLModel (Discussion #1002推荐的workaround)
|
||||
# 未来: PR #1275合并后可改回继承SQLModelBase
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct
|
||||
from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
|
||||
from sqlmodel import Field, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from .polymorphic import PolymorphicBaseMixin
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# Type variables for generic type hints, improving code completion and analysis.
|
||||
T = TypeVar("T", bound="TableBaseMixin")
|
||||
M = TypeVar("M", bound="SQLModelBase")
|
||||
ItemT = TypeVar("ItemT")
|
||||
|
||||
|
||||
class ListResponse(BaseModel, Generic[ItemT]):
|
||||
"""
|
||||
泛型分页响应
|
||||
|
||||
用于所有LIST端点的标准化响应格式,包含记录总数和项目列表。
|
||||
与 TableBaseMixin.get_with_count() 配合使用。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
@router.get("", response_model=ListResponse[CharacterInfoResponse])
|
||||
async def list_characters(...) -> ListResponse[Character]:
|
||||
return await Character.get_with_count(session, table_view=table_view)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
count: 符合条件的记录总数(用于分页计算)
|
||||
items: 当前页的记录列表
|
||||
|
||||
Note:
|
||||
继承BaseModel而非SQLModelBase,因为SQLModel的metaclass与Generic冲突。
|
||||
详见文件顶部TODO注释。
|
||||
"""
|
||||
# 与SQLModelBase保持一致的配置
|
||||
model_config = ConfigDict(use_attribute_docstrings=True)
|
||||
|
||||
count: int
|
||||
"""符合条件的记录总数"""
|
||||
|
||||
items: list[ItemT]
|
||||
"""当前页的记录列表"""
|
||||
|
||||
|
||||
# Lambda functions to get the current time, used as default factories in model fields.
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
|
||||
|
||||
# ==================== 查询参数请求类 ====================
|
||||
|
||||
class TimeFilterRequest(SQLModelBase):
|
||||
"""
|
||||
时间筛选请求参数
|
||||
|
||||
用于 count() 等只需要时间筛选的场景。
|
||||
纯数据类,只负责参数校验和携带,SQL子句构建由 TableBaseMixin 负责。
|
||||
|
||||
Raises:
|
||||
ValueError: 时间范围无效
|
||||
"""
|
||||
created_after_datetime: datetime | None = None
|
||||
"""创建时间起始筛选(created_at >= datetime),如果为None则不限制"""
|
||||
|
||||
created_before_datetime: datetime | None = None
|
||||
"""创建时间结束筛选(created_at < datetime),如果为None则不限制"""
|
||||
|
||||
updated_after_datetime: datetime | None = None
|
||||
"""更新时间起始筛选(updated_at >= datetime),如果为None则不限制"""
|
||||
|
||||
updated_before_datetime: datetime | None = None
|
||||
"""更新时间结束筛选(updated_at < datetime),如果为None则不限制"""
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
验证时间范围有效性
|
||||
|
||||
验证规则:
|
||||
1. 同类型:after 必须小于 before
|
||||
2. 跨类型:created_after 不能大于 updated_before(记录不可能在创建前被更新)
|
||||
"""
|
||||
# 同类型矛盾验证
|
||||
if self.created_after_datetime and self.created_before_datetime:
|
||||
if self.created_after_datetime >= self.created_before_datetime:
|
||||
raise ValueError("created_after_datetime 必须小于 created_before_datetime")
|
||||
if self.updated_after_datetime and self.updated_before_datetime:
|
||||
if self.updated_after_datetime >= self.updated_before_datetime:
|
||||
raise ValueError("updated_after_datetime 必须小于 updated_before_datetime")
|
||||
|
||||
# 跨类型矛盾验证:created_after >= updated_before 意味着要求创建时间晚于或等于更新时间上界,逻辑矛盾
|
||||
if self.created_after_datetime and self.updated_before_datetime:
|
||||
if self.created_after_datetime >= self.updated_before_datetime:
|
||||
raise ValueError(
|
||||
"created_after_datetime 不能大于或等于 updated_before_datetime"
|
||||
"(记录的更新时间不可能早于或等于创建时间)"
|
||||
)
|
||||
|
||||
|
||||
class PaginationRequest(SQLModelBase):
|
||||
"""
|
||||
分页排序请求参数
|
||||
|
||||
用于需要分页和排序的场景。
|
||||
纯数据类,只负责携带参数,SQL子句构建由 TableBaseMixin 负责。
|
||||
"""
|
||||
offset: int | None = Field(default=0, ge=0)
|
||||
"""偏移量(跳过前N条记录),必须为非负整数"""
|
||||
|
||||
limit: int | None = Field(default=50, le=100)
|
||||
"""每页数量(返回最多N条记录),默认50,最大100"""
|
||||
|
||||
desc: bool | None = True
|
||||
"""是否降序排序(True: 降序, False: 升序)"""
|
||||
|
||||
order: Literal["created_at", "updated_at"] | None = "created_at"
|
||||
"""排序字段(created_at: 创建时间, updated_at: 更新时间)"""
|
||||
|
||||
|
||||
class TableViewRequest(TimeFilterRequest, PaginationRequest):
|
||||
"""
|
||||
表格视图请求参数(分页、排序和时间筛选)
|
||||
|
||||
组合继承 TimeFilterRequest 和 PaginationRequest,用于 get() 等需要完整查询参数的场景。
|
||||
纯数据类,SQL子句构建由 TableBaseMixin 负责。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
# 在端点中使用依赖注入
|
||||
@router.get("/list")
|
||||
async def list_items(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
):
|
||||
items = await Item.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view
|
||||
)
|
||||
return items
|
||||
|
||||
# 直接使用
|
||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
||||
items = await Item.get(session, fetch_mode="all", table_view=table_view)
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ==================== TableBaseMixin ====================
|
||||
|
||||
class TableBaseMixin(AsyncAttrs):
|
||||
"""
|
||||
一个异步 CRUD 操作的基础模型类 Mixin.
|
||||
|
||||
此类必须搭配SQLModelBase使用
|
||||
|
||||
此类为所有继承它的 SQLModel 模型提供了通用的数据库操作方法,
|
||||
例如 add, save, update, delete, 和 get. 它还包括自动管理
|
||||
的 `created_at` 和 `updated_at` 时间戳字段.
|
||||
|
||||
Attributes:
|
||||
id (int | None): 整数主键, 自动递增.
|
||||
created_at (datetime): 记录创建时的时间戳, 自动设置.
|
||||
updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
|
||||
"""
|
||||
_is_table_mixin: ClassVar[bool] = True
|
||||
"""标记此类为表混入类的内部属性"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
接受并传递子类定义时的关键字参数
|
||||
|
||||
这允许元类 __DeclarativeMeta 处理的参数(如 table_args)
|
||||
能够正确传递,而不会在 __init_subclass__ 阶段报错。
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
created_at: datetime = Field(default_factory=now)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
sa_column_kwargs={'default': now, 'onupdate': now},
|
||||
default_factory=now
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
||||
"""
|
||||
向数据库中添加一个新的或多个新的记录.
|
||||
|
||||
这个类方法可以接受单个模型实例或一个实例列表,并将它们
|
||||
一次性提交到数据库中。执行后,可以选择性地刷新这些实例以获取
|
||||
数据库生成的值(例如,自动递增的 ID).
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
|
||||
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
|
||||
|
||||
Returns:
|
||||
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
|
||||
|
||||
Usage:
|
||||
item1 = Item(name="Apple")
|
||||
item2 = Item(name="Banana")
|
||||
|
||||
# 添加多个实例
|
||||
added_items = await Item.add(session, [item1, item2])
|
||||
|
||||
# 添加单个实例
|
||||
item3 = Item(name="Cherry")
|
||||
added_item = await Item.add(session, item3)
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
is_list = True
|
||||
session.add_all(instances)
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
else:
|
||||
await session.refresh(instances)
|
||||
|
||||
return instances
|
||||
|
||||
async def save(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
load: RelationshipInfo | None = None,
|
||||
refresh: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
保存(插入或更新)当前模型实例到数据库.
|
||||
|
||||
这是一个实例方法,它将当前对象添加到会话中并提交更改。
|
||||
可以用于创建新记录或更新现有记录。还可以选择在保存后
|
||||
预加载(eager load)一个关联关系.
|
||||
|
||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||
如果需要继续使用该对象,必须使用返回值:
|
||||
|
||||
```python
|
||||
# ✅ 正确:需要返回值时
|
||||
client = await client.save(session)
|
||||
return client
|
||||
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.save(session, refresh=False)
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.save(session)
|
||||
return client # client 对象已过期
|
||||
```
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
例如 `User.posts`.
|
||||
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
"""
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
# 如果指定了 load, 重新获取实例并加载关联关系
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
async def update(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
other: M,
|
||||
extra_data: dict[str, Any] | None = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude: set[str] | None = None,
|
||||
load: RelationshipInfo | None = None,
|
||||
refresh: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
使用另一个模型实例或字典中的数据来更新当前实例.
|
||||
|
||||
此方法将 `other` 对象中的数据合并到当前实例中。默认情况下,
|
||||
它只会更新 `other` 中被显式设置的字段.
|
||||
|
||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||
如果需要继续使用该对象,必须使用返回值:
|
||||
|
||||
```python
|
||||
# ✅ 正确:需要返回值时
|
||||
client = await client.update(session, update_data)
|
||||
return client
|
||||
|
||||
# ✅ 正确:需要返回值且需要加载关系时
|
||||
user = await user.update(session, update_data, load=User.permission)
|
||||
return user
|
||||
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.update(session, update_data, refresh=False)
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.update(session, update_data)
|
||||
return client # client 对象已过期
|
||||
```
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
other (M): 一个 SQLModel 或 Pydantic 模型实例,其数据将用于更新当前实例.
|
||||
extra_data (dict, optional): 一个额外的字典,用于更新当前实例的特定字段.
|
||||
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
|
||||
的字段将被忽略. 默认为 True.
|
||||
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
|
||||
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
例如 `User.permission`.
|
||||
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
"""
|
||||
self.sqlmodel_update(
|
||||
other.model_dump(exclude_unset=exclude_unset, exclude=exclude),
|
||||
update=extra_data
|
||||
)
|
||||
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None:
|
||||
"""
|
||||
从数据库中删除一个或多个记录.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T]): 要删除的单个模型实例或模型实例列表.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
item_to_delete = await Item.get(session, Item.id == 1)
|
||||
if item_to_delete:
|
||||
await Item.delete(session, item_to_delete)
|
||||
|
||||
items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all")
|
||||
if items_to_delete:
|
||||
await Item.delete(session, items_to_delete)
|
||||
"""
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
else:
|
||||
await session.delete(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
def _build_time_filters(
|
||||
cls: type[T],
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
updated_after_datetime: datetime | None = None,
|
||||
) -> list[BinaryExpression]:
|
||||
"""
|
||||
构建时间筛选条件列表
|
||||
|
||||
Args:
|
||||
created_before_datetime: 筛选 created_at < datetime 的记录
|
||||
created_after_datetime: 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime: 筛选 updated_at < datetime 的记录
|
||||
updated_after_datetime: 筛选 updated_at >= datetime 的记录
|
||||
|
||||
Returns:
|
||||
BinaryExpression 条件列表
|
||||
"""
|
||||
filters: list[BinaryExpression] = []
|
||||
if created_after_datetime is not None:
|
||||
filters.append(cls.created_at >= created_after_datetime)
|
||||
if created_before_datetime is not None:
|
||||
filters.append(cls.created_at < created_before_datetime)
|
||||
if updated_after_datetime is not None:
|
||||
filters.append(cls.updated_at >= updated_after_datetime)
|
||||
if updated_before_datetime is not None:
|
||||
filters.append(cls.updated_at < updated_before_datetime)
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: RelationshipInfo | None = None,
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
with_for_update: bool = False,
|
||||
table_view: TableViewRequest | None = None,
|
||||
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
updated_after_datetime: datetime | None = None,
|
||||
) -> T | list[T] | None:
|
||||
"""
|
||||
根据指定的条件异步地从数据库中获取一个或多个模型实例.
|
||||
|
||||
这是一个功能强大的通用查询方法,支持过滤、排序、分页、连接查询和关联关系预加载.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
condition (BinaryExpression | ClauseElement | None): 主要的查询过滤条件,
|
||||
例如 `User.id == 1`。
|
||||
当为 `None` 时,表示无条件查询(查询所有记录)。
|
||||
offset (int | None): 查询结果的起始偏移量, 用于分页.
|
||||
limit (int | None): 返回记录的最大数量, 用于分页.
|
||||
fetch_mode (Literal["one", "first", "all"]):
|
||||
- "one": 获取唯一的一条记录. 如果找不到或找到多条,会引发异常.
|
||||
- "first": 获取查询结果的第一条记录. 如果找不到,返回 `None`.
|
||||
- "all": 获取所有匹配的记录,返回一个列表.
|
||||
默认为 "first".
|
||||
join (type[T] | tuple[type[T], _OnClauseArgument] | None):
|
||||
要 JOIN 的模型类或一个包含模型类和 ON 子句的元组.
|
||||
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
|
||||
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
|
||||
例如 `[selectinload(User.posts)]`.
|
||||
load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系.
|
||||
例如 `User.profile`.
|
||||
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
|
||||
例如 `[User.name.asc(), User.created_at.desc()]`.
|
||||
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
|
||||
|
||||
with_for_update (bool): 如果为 True, 在查询中使用 `FOR UPDATE` 锁定选定的行. 默认为 False.
|
||||
|
||||
table_view (TableViewRequest | None): TableViewRequest对象,如果提供则自动处理分页、排序和时间筛选。
|
||||
会覆盖offset、limit、order_by及时间筛选参数。
|
||||
这是推荐的分页排序方式,统一了所有LIST端点的参数格式。
|
||||
|
||||
load_polymorphic: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
|
||||
- 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益)
|
||||
- None(默认): 不使用多态加载
|
||||
|
||||
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
|
||||
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
|
||||
updated_after_datetime (datetime | None): 筛选 updated_at >= datetime 的记录
|
||||
|
||||
Returns:
|
||||
T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: 如果提供了无效的 `fetch_mode` 值,或 load_polymorphic 未与 load 配合使用.
|
||||
|
||||
Examples:
|
||||
# 使用table_view参数(推荐)
|
||||
users = await User.get(session, fetch_mode="all", table_view=table_view_args)
|
||||
|
||||
# 传统方式(向后兼容)
|
||||
users = await User.get(session, fetch_mode="all", offset=0, limit=20, order_by=[desc(User.created_at)])
|
||||
|
||||
# 使用多态加载(加载联表继承的子类数据)
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # 只加载实际关联的子类
|
||||
)
|
||||
"""
|
||||
# 参数验证:load_polymorphic 需要与 load 配合使用
|
||||
if load_polymorphic is not None and load is None:
|
||||
raise ValueError(
|
||||
"load_polymorphic 参数需要与 load 参数配合使用,"
|
||||
"请同时指定要加载的关系"
|
||||
)
|
||||
|
||||
# 如果提供table_view,作为默认值使用(单独传入的参数优先级更高)
|
||||
if table_view:
|
||||
# 处理时间筛选(TimeFilterRequest 及其子类)
|
||||
if isinstance(table_view, TimeFilterRequest):
|
||||
if created_after_datetime is None and table_view.created_after_datetime is not None:
|
||||
created_after_datetime = table_view.created_after_datetime
|
||||
if created_before_datetime is None and table_view.created_before_datetime is not None:
|
||||
created_before_datetime = table_view.created_before_datetime
|
||||
if updated_after_datetime is None and table_view.updated_after_datetime is not None:
|
||||
updated_after_datetime = table_view.updated_after_datetime
|
||||
if updated_before_datetime is None and table_view.updated_before_datetime is not None:
|
||||
updated_before_datetime = table_view.updated_before_datetime
|
||||
# 处理分页排序(PaginationRequest 及其子类,包括 TableViewRequest)
|
||||
if isinstance(table_view, PaginationRequest):
|
||||
if offset is None:
|
||||
offset = table_view.offset
|
||||
if limit is None:
|
||||
limit = table_view.limit
|
||||
# 仅在未显式传入order_by时,从table_view构建排序子句
|
||||
if order_by is None:
|
||||
order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at
|
||||
order_by = [desc(order_column) if table_view.desc else asc(order_column)]
|
||||
|
||||
# 对于多态基类,使用 with_polymorphic 预加载所有子类的列
|
||||
# 这避免了在响应序列化时的延迟加载问题(MissingGreenlet 错误)
|
||||
if issubclass(cls, PolymorphicBaseMixin):
|
||||
# '*' 表示加载所有子类
|
||||
polymorphic_cls = with_polymorphic(cls, '*')
|
||||
statement = select(polymorphic_cls)
|
||||
else:
|
||||
statement = select(cls)
|
||||
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
# 应用时间筛选
|
||||
for time_filter in cls._build_time_filters(
|
||||
created_before_datetime, created_after_datetime,
|
||||
updated_before_datetime, updated_after_datetime
|
||||
):
|
||||
statement = statement.where(time_filter)
|
||||
|
||||
if join is not None:
|
||||
# 如果 join 是一个元组,解包它;否则直接使用
|
||||
if isinstance(join, tuple):
|
||||
statement = statement.join(*join)
|
||||
else:
|
||||
statement = statement.join(join)
|
||||
|
||||
|
||||
if options:
|
||||
statement = statement.options(*options)
|
||||
|
||||
if load:
|
||||
# 处理多态加载
|
||||
if load_polymorphic is not None:
|
||||
target_class = load.property.mapper.class_
|
||||
|
||||
# 检查目标类是否继承自 PolymorphicBaseMixin
|
||||
if not issubclass(target_class, PolymorphicBaseMixin):
|
||||
raise ValueError(
|
||||
f"目标类 {target_class.__name__} 不是多态类,"
|
||||
f"请确保其继承自 PolymorphicBaseMixin"
|
||||
)
|
||||
|
||||
if load_polymorphic == 'all':
|
||||
# 两阶段查询:获取实际关联的多态类型
|
||||
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
|
||||
session, condition, load, target_class
|
||||
)
|
||||
else:
|
||||
subclasses_to_load = load_polymorphic
|
||||
|
||||
if subclasses_to_load:
|
||||
# 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项
|
||||
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
|
||||
statement = statement.options(
|
||||
selectinload(load).selectin_polymorphic(subclasses_to_load)
|
||||
)
|
||||
else:
|
||||
statement = statement.options(selectinload(load))
|
||||
else:
|
||||
statement = statement.options(selectinload(load))
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
if filter:
|
||||
statement = statement.filter(filter)
|
||||
|
||||
if with_for_update:
|
||||
statement = statement.with_for_update()
|
||||
|
||||
result = await session.exec(statement)
|
||||
|
||||
if fetch_mode == "one":
|
||||
return result.one()
|
||||
elif fetch_mode == "first":
|
||||
return result.first()
|
||||
elif fetch_mode == "all":
|
||||
return list(result.all())
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@classmethod
|
||||
async def _resolve_polymorphic_subclasses(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None,
|
||||
load: RelationshipInfo,
|
||||
target_class: type[PolymorphicBaseMixin]
|
||||
) -> list[type[PolymorphicBaseMixin]]:
|
||||
"""
|
||||
查询实际关联的多态子类类型
|
||||
|
||||
通过查询多态鉴别字段确定实际存在的子类类型,
|
||||
避免加载所有可能的子类表(对于 > 10 个子类的场景有明显收益)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param condition: 主查询的条件
|
||||
:param load: 关系属性
|
||||
:param target_class: 多态基类
|
||||
:return: 实际关联的子类列表
|
||||
"""
|
||||
# 获取多态鉴别字段(会抛出 ValueError 如果未配置)
|
||||
discriminator = target_class.get_polymorphic_discriminator()
|
||||
poly_name_col = getattr(target_class, discriminator)
|
||||
|
||||
# 获取关系属性
|
||||
relationship_property = load.property
|
||||
|
||||
# 构建查询获取实际的多态类型名称
|
||||
if relationship_property.secondary is not None:
|
||||
# 多对多关系:通过中间表查询
|
||||
secondary = relationship_property.secondary
|
||||
local_cols = list(relationship_property.local_columns)
|
||||
|
||||
type_query = (
|
||||
select(distinct(poly_name_col))
|
||||
.select_from(target_class)
|
||||
.join(secondary)
|
||||
.where(secondary.c[local_cols[0].name].in_(
|
||||
select(cls.id).where(condition) if condition is not None else select(cls.id)
|
||||
))
|
||||
)
|
||||
else:
|
||||
# 一对多关系:通过外键查询
|
||||
foreign_key_col = relationship_property.local_remote_pairs[0][1]
|
||||
type_query = (
|
||||
select(distinct(poly_name_col))
|
||||
.where(foreign_key_col.in_(
|
||||
select(cls.id).where(condition) if condition is not None else select(cls.id)
|
||||
))
|
||||
)
|
||||
|
||||
type_result = await session.exec(type_query)
|
||||
poly_names = list(type_result.all())
|
||||
|
||||
if not poly_names:
|
||||
return []
|
||||
|
||||
# 映射到子类(包含所有层级的具体子类)
|
||||
identity_map = target_class.get_identity_to_class_map()
|
||||
return [identity_map[name] for name in poly_names if name in identity_map]
|
||||
|
||||
@classmethod
|
||||
async def count(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
*,
|
||||
time_filter: TimeFilterRequest | None = None,
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
updated_after_datetime: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
根据条件统计记录数量(支持时间筛选)
|
||||
|
||||
使用数据库层面的 COUNT() 聚合函数,比 get() + len() 更高效。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
condition: 查询条件,例如 `User.is_active == True`
|
||||
time_filter: TimeFilterRequest 对象(优先级更高)
|
||||
created_before_datetime: 筛选 created_at < datetime 的记录
|
||||
created_after_datetime: 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime: 筛选 updated_at < datetime 的记录
|
||||
updated_after_datetime: 筛选 updated_at >= datetime 的记录
|
||||
|
||||
Returns:
|
||||
符合条件的记录数量
|
||||
|
||||
Examples:
|
||||
# 统计所有用户
|
||||
total = await User.count(session)
|
||||
|
||||
# 统计激活的虚拟客户端
|
||||
count = await Client.count(
|
||||
session,
|
||||
(Client.user_id == user_id) & (Client.type != ClientTypeEnum.physical) & (Client.is_active == True)
|
||||
)
|
||||
|
||||
# 使用 TimeFilterRequest 进行时间筛选
|
||||
count = await User.count(session, time_filter=time_filter_request)
|
||||
|
||||
# 使用独立时间参数
|
||||
count = await User.count(
|
||||
session,
|
||||
created_after_datetime=datetime(2025, 1, 1),
|
||||
created_before_datetime=datetime(2025, 2, 1),
|
||||
)
|
||||
"""
|
||||
# time_filter 的时间筛选优先级更高
|
||||
if isinstance(time_filter, TimeFilterRequest):
|
||||
if time_filter.created_after_datetime is not None:
|
||||
created_after_datetime = time_filter.created_after_datetime
|
||||
if time_filter.created_before_datetime is not None:
|
||||
created_before_datetime = time_filter.created_before_datetime
|
||||
if time_filter.updated_after_datetime is not None:
|
||||
updated_after_datetime = time_filter.updated_after_datetime
|
||||
if time_filter.updated_before_datetime is not None:
|
||||
updated_before_datetime = time_filter.updated_before_datetime
|
||||
|
||||
statement = select(func.count()).select_from(cls)
|
||||
|
||||
# 应用查询条件
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
# 应用时间筛选
|
||||
for time_condition in cls._build_time_filters(
|
||||
created_before_datetime, created_after_datetime,
|
||||
updated_before_datetime, updated_after_datetime
|
||||
):
|
||||
statement = statement.where(time_condition)
|
||||
|
||||
result = await session.scalar(statement)
|
||||
return result or 0
|
||||
|
||||
@classmethod
|
||||
async def get_with_count(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
*,
|
||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: RelationshipInfo | None = None,
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
table_view: TableViewRequest | None = None,
|
||||
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
) -> 'ListResponse[T]':
|
||||
"""
|
||||
获取分页列表及总数,直接返回 ListResponse
|
||||
|
||||
同时返回符合条件的记录列表和总数,用于分页场景。
|
||||
与 get() 方法类似,但固定 fetch_mode="all" 并返回 ListResponse。
|
||||
|
||||
注意:如果子类的 get() 方法支持额外参数(如 filter_params),
|
||||
子类应该覆盖此方法以确保 count 和 items 使用相同的过滤条件。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
condition: 查询条件
|
||||
join: JOIN 的模型类或元组
|
||||
options: SQLAlchemy 查询选项
|
||||
load: selectinload 预加载关系
|
||||
order_by: 排序子句
|
||||
filter: 附加过滤条件
|
||||
table_view: 分页排序参数(推荐使用)
|
||||
load_polymorphic: 多态子类加载选项
|
||||
|
||||
Returns:
|
||||
ListResponse[T]: 包含 count 和 items 的分页响应
|
||||
|
||||
Examples:
|
||||
```python
|
||||
@router.get("", response_model=ListResponse[CharacterInfoResponse])
|
||||
async def list_characters(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
) -> ListResponse[Character]:
|
||||
return await Character.get_with_count(session, table_view=table_view)
|
||||
```
|
||||
"""
|
||||
# 提取时间筛选参数(用于 count)
|
||||
time_filter: TimeFilterRequest | None = None
|
||||
if table_view is not None:
|
||||
time_filter = TimeFilterRequest(
|
||||
created_after_datetime=table_view.created_after_datetime,
|
||||
created_before_datetime=table_view.created_before_datetime,
|
||||
updated_after_datetime=table_view.updated_after_datetime,
|
||||
updated_before_datetime=table_view.updated_before_datetime,
|
||||
)
|
||||
|
||||
# 获取总数(不带分页限制)
|
||||
total_count = await cls.count(session, condition, time_filter=time_filter)
|
||||
|
||||
# 获取分页数据
|
||||
items = await cls.get(
|
||||
session,
|
||||
condition,
|
||||
fetch_mode="all",
|
||||
join=join,
|
||||
options=options,
|
||||
load=load,
|
||||
order_by=order_by,
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
load_polymorphic=load_polymorphic,
|
||||
)
|
||||
|
||||
return ListResponse(count=total_count, items=items)
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T:
|
||||
"""
|
||||
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
这个方法是对 `get` 方法的封装,专门用于处理那种"记录必须存在"的业务场景。
|
||||
如果记录未找到,它会直接引发 FastAPI 的 `HTTPException`, 而不是返回 `None`.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (int): 要查找的记录的主键 ID.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果 ID 对应的记录不存在,则抛出状态码为 404 的异常.
|
||||
"""
|
||||
instance = await cls.get(session, cls.id == id, load=load)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return instance
|
||||
|
||||
class UUIDTableBaseMixin(TableBaseMixin):
|
||||
"""
|
||||
一个使用 UUID 作为主键的异步 CRUD 操作基础模型类 Mixin.
|
||||
|
||||
此类继承自 `TableBaseMixin`, 将主键 `id` 的类型覆盖为 `uuid.UUID`,
|
||||
并为新记录自动生成 UUID. 它继承了 `TableBaseMixin` 的所有 CRUD 方法.
|
||||
|
||||
Attributes:
|
||||
id (uuid.UUID): UUID 类型的主键, 在创建时自动生成.
|
||||
"""
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
"""覆盖 `TableBaseMixin` 的 id 字段,使用 UUID 作为主键."""
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
|
||||
"""
|
||||
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
此方法覆盖了父类的同名方法,以确保 `id` 参数的类型注解为 `uuid.UUID`,
|
||||
从而提供更好的类型安全和代码提示.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (uuid.UUID): 要查找的记录的 UUID 主键.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果 UUID 对应的记录不存在,则抛出状态码为 404 的异常.
|
||||
"""
|
||||
# 类型检查器可能会警告这里的 `id` 类型不匹配超类方法,
|
||||
# 但在运行时这是正确的,因为超类方法内部的比较 (cls.id == id)
|
||||
# 会正确处理 UUID 类型。`type: ignore` 用于抑制此警告。
|
||||
return await super().get_exist_one(session, id, load) # type: ignore
|
||||
Reference in New Issue
Block a user