feat(mixin): add TableBaseMixin and UUIDTableBaseMixin for async CRUD operations

- Implemented TableBaseMixin providing generic CRUD methods and automatic timestamp management.
- Introduced UUIDTableBaseMixin for models using UUID as primary keys.
- Added ListResponse for standardized paginated responses.
- Created TimeFilterRequest and PaginationRequest for filtering and pagination parameters.
- Enhanced get_with_count method to return both item list and total count.
- Included validation for time filter parameters in TimeFilterRequest.
- Improved documentation and usage examples throughout the code.
This commit is contained in:
2025-12-22 18:29:14 +08:00
parent 47a4756227
commit a5efda9c23
44 changed files with 4306 additions and 497 deletions

543
models/mixin/README.md Normal file
View File

@@ -0,0 +1,543 @@
# SQLModel Mixin Module
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
## Module Overview
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
- **JWT Authentication**: Token generation and validation
- **Pagination & Sorting**: Standardized table view parameters
- **Response DTOs**: Consistent id/timestamp fields for API responses
## Module Structure
```
sqlmodels/mixin/
├── __init__.py # Module exports
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
└── jwt/ # JWT authentication
├── __init__.py
├── key.py # JWTKey database model
├── payload.py # JWTPayloadBase
├── manager.py # JWTManager singleton
├── auth.py # JWTAuthMixin
├── exceptions.py # JWT-related exceptions
└── responses.py # TokenResponse DTO
```
## Dependency Hierarchy
The module has a strict import order to avoid circular dependencies:
1. **polymorphic.py** - Only depends on `SQLModelBase`
2. **table.py** - Depends on `polymorphic.py`
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
4. **info_response.py** - Only depends on `SQLModelBase`
## Core Components
### 1. TableBaseMixin
Base mixin for database table models with integer primary keys.
**Features:**
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
- Automatic timestamp management (`created_at`, `updated_at`)
- Async relationship loading support (via `AsyncAttrs`)
- Pagination and sorting via `TableViewRequest`
- Polymorphic subclass loading support
**Fields:**
- `id: int | None` - Integer primary key (auto-increment)
- `created_at: datetime` - Record creation timestamp
- `updated_at: datetime` - Record update timestamp (auto-updated)
**Usage:**
```python
from sqlmodels.mixin import TableBaseMixin
from sqlmodels.base import SQLModelBase
class User(SQLModelBase, TableBaseMixin, table=True):
name: str
email: str
"""User email"""
# CRUD operations
async def example(session: AsyncSession):
# Add
user = User(name="Alice", email="alice@example.com")
user = await user.save(session)
# Get
user = await User.get(session, User.id == 1)
# Update
update_data = UserUpdateRequest(name="Alice Smith")
user = await user.update(session, update_data)
# Delete
await User.delete(session, user)
# Count
count = await User.count(session, User.is_active == True)
```
**Important Notes:**
- `save()` and `update()` return refreshed instances - **always use the return value**:
```python
# ✅ Correct
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device
```
### 2. UUIDTableBaseMixin
Extends `TableBaseMixin` with UUID primary keys instead of integers.
**Differences from TableBaseMixin:**
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
- `get_exist_one()` accepts `UUID` instead of `int`
**Usage:**
```python
from sqlmodels.mixin import UUIDTableBaseMixin
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
name: str
description: str | None = None
"""Character description"""
```
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
### 3. TableViewRequest
Standardized pagination and sorting parameters for LIST endpoints.
**Fields:**
- `offset: int | None` - Skip first N records (default: 0)
- `limit: int | None` - Return max N records (default: 50, max: 100)
- `desc: bool | None` - Sort descending (default: True)
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
**Usage with TableBaseMixin.get():**
```python
from dependencies import TableViewRequestDep
@router.get("/list")
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> list[Character]:
"""List characters with pagination and sorting"""
return await Character.get(
session,
fetch_mode="all",
table_view=table_view # Automatically handles pagination and sorting
)
```
**Manual usage:**
```python
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
```
**Backward Compatibility:**
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
### 4. PolymorphicBaseMixin
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
**Automatic Configuration:**
- Defines `_polymorphic_name: str` field (indexed)
- Sets `polymorphic_on='_polymorphic_name'`
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
**Methods:**
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
**Usage:**
```python
from abc import ABC, abstractmethod
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
"""Abstract base class for all tools"""
name: str
description: str
"""Tool description"""
@abstractmethod
async def execute(self, params: dict) -> dict:
"""Execute the tool"""
pass
```
**Why Single Underscore Prefix?**
- SQLAlchemy maps single-underscore fields to database columns
- Pydantic treats them as private (excluded from serialization)
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
### 5. create_subclass_id_mixin()
Factory function to create ID mixins for subclasses in joined table inheritance.
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
**Signature:**
```python
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
"""
Args:
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
Returns:
A mixin class containing id field (foreign key + primary key)
"""
```
**Usage:**
```python
from sqlmodels.mixin import create_subclass_id_mixin
# Create mixin for ASR subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
"""FunASR implementation"""
pass
```
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
### 6. AutoPolymorphicIdentityMixin
Automatically generates `polymorphic_identity` based on class name.
**Naming Convention:**
- Format: `{parent_identity}.{classname_lowercase}`
- If no parent identity exists, uses `{classname_lowercase}`
**Usage:**
```python
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
"""Base class for function-type tools"""
pass
# polymorphic_identity = 'function'
class GetWeatherFunction(Function, table=True):
"""Weather query function"""
pass
# polymorphic_identity = 'function.getweatherfunction'
```
**Manual Override:**
```python
class CustomTool(
Tool,
AutoPolymorphicIdentityMixin,
polymorphic_identity='custom_name', # Override auto-generated name
table=True
):
pass
```
### 7. JWTAuthMixin
Provides JWT token generation and validation for entity classes (User, Client).
**Methods:**
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
**Requirements:**
Subclasses must define:
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
**Usage:**
```python
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
JWTPayload = UserJWTPayload # Define payload model
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
email: str
is_admin: bool = False
is_active: bool = True
"""User active status"""
# Generate token
async def login(session: AsyncSession, user: User) -> str:
token = await user.issue_jwt(session)
return token
# Validate token
async def verify(session: AsyncSession, token: str) -> User:
user = await User.from_jwt(session, token)
return user
```
### 8. Response DTO Mixins
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
**Available Mixins:**
- `IntIdInfoMixin` - Integer ID field
- `UUIDIdInfoMixin` - UUID ID field
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
**Usage:**
```python
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
"""Character response DTO with id and timestamps"""
pass # Inherits id, created_at, updated_at from mixin
```
## Complete Joined Table Inheritance Example
Here's a complete example demonstrating polymorphic inheritance:
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import (
UUIDTableBaseMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
)
# 1. Define Base class (fields only, no table)
class ASRBase(SQLModelBase):
name: str
"""Configuration name"""
base_url: str
"""Service URL"""
# 2. Define abstract parent class (with table)
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract base class for ASR configurations"""
# PolymorphicBaseMixin automatically provides:
# - _polymorphic_name field
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True (when ABC with abstract methods)
@abstractmethod
async def transcribe(self, pcm_data: bytes) -> str:
"""Transcribe audio to text"""
pass
# 3. Create ID Mixin for second-level subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. Create second-level abstract class (if needed)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
"""FunASR abstract base (may have multiple implementations)"""
pass
# polymorphic_identity = 'funasr'
# 5. Create concrete implementation classes
class FunASRLocal(FunASR, table=True):
"""FunASR local deployment"""
# polymorphic_identity = 'funasr.funasrlocal'
async def transcribe(self, pcm_data: bytes) -> str:
# Implementation...
return "transcribed text"
# 6. Get all concrete subclasses (for selectin_polymorphic)
concrete_asrs = ASR.get_concrete_subclasses()
# Returns: [FunASRLocal, ...]
```
## Import Guidelines
**Standard Import:**
```python
from sqlmodels.mixin import (
TableBaseMixin,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
TableViewRequest,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
JWTAuthMixin,
UUIDIdDatetimeInfoMixin,
now,
now_date,
)
```
**Backward Compatibility:**
Some exports are also available from `sqlmodels.base` for backward compatibility:
```python
# Legacy import path (still works)
from sqlmodels.base import UUIDTableBase, TableViewRequest
# Recommended new import path
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
```
## Best Practices
### 1. Mixin Order Matters
**Correct Order:**
```python
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
pass
```
**Wrong Order:**
```python
# ❌ ID Mixin not first - won't override parent's id field
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
pass
```
### 2. Always Use Return Values from save() and update()
```python
# ✅ Correct - use returned instance
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device # AttributeError when accessing fields
```
### 3. Prefer table_view Over Manual Pagination
```python
# ✅ Recommended - consistent across all endpoints
characters = await Character.get(
session,
fetch_mode="all",
table_view=table_view
)
# ⚠️ Works but not recommended - manual parameter management
characters = await Character.get(
session,
fetch_mode="all",
offset=0,
limit=20,
order_by=[desc(Character.created_at)]
)
```
### 4. Polymorphic Loading for Many Subclasses
```python
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
)
# For fewer subclasses, specify the list explicitly
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
)
```
### 5. Response DTOs Should Inherit Base Classes
```python
# ✅ Correct - inherits from CharacterBase
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
pass
# ❌ Wrong - doesn't inherit from CharacterBase
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
name: str # Duplicated field definition
description: str | None = None
```
**Reason:** Inheriting from Base classes ensures:
- Type checking via `isinstance(obj, XxxBase)`
- Consistency across related DTOs
- Future field additions automatically propagate
### 6. Use Specific Types, Not Containers
```python
# ✅ Correct - specific DTO for config updates
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
weather_api_key: str | None = None
default_location: str | None = None
"""Default location"""
# ❌ Wrong - lose type safety
class ToolUpdateRequest(SQLModelBase):
config: dict[str, Any] # No field validation
```
## Type Variables
```python
from sqlmodels.mixin import T, M
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
M = TypeVar("M", bound="SQLModel") # For update() method
```
## Utility Functions
```python
from sqlmodels.mixin import now, now_date
# Lambda functions for default factories
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
```
## Related Modules
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
- **sqlmodels.user** - User model with JWT authentication
- **sqlmodels.client** - Client model with JWT authentication
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
## Additional Resources
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
- `CLAUDE.md` - Project coding standards and design philosophy
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)

46
models/mixin/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
"""
SQLModel Mixin模块
提供各种Mixin类供SQLModel实体使用。
包含:
- polymorphic: 联表继承工具create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin
- table: 表基类TableBaseMixin, UUIDTableBaseMixin
- table: 查询参数类TimeFilterRequest, PaginationRequest, TableViewRequest
- jwt/: JWT认证相关JWTAuthMixin, JWTManager, JWTKey等- 需要时直接从 .jwt 导入
- info_response: InfoResponse DTO的id/时间戳Mixin
导入顺序很重要,避免循环导入:
1. polymorphic只依赖 SQLModelBase
2. table依赖 polymorphic
注意jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
"""
# polymorphic 必须先导入
from .polymorphic import (
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
PolymorphicBaseMixin,
)
# table 依赖 polymorphic
from .table import (
TableBaseMixin,
UUIDTableBaseMixin,
TimeFilterRequest,
PaginationRequest,
TableViewRequest,
ListResponse,
T,
now,
now_date,
)
# jwt 不在此处导入避免循环jwt/manager.py → ServerConfig → mixin → jwt
# 需要时直接从 sqlmodels.mixin.jwt 导入
from .info_response import (
IntIdInfoMixin,
UUIDIdInfoMixin,
DatetimeInfoMixin,
IntIdDatetimeInfoMixin,
UUIDIdDatetimeInfoMixin,
)

View File

@@ -0,0 +1,46 @@
"""
InfoResponse DTO Mixin模块
提供用于InfoResponse类型DTO的Mixin统一定义id/created_at/updated_at字段。
设计说明:
- 这些Mixin用于**响应DTO**,不是数据库表
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
- TableBase中的id=None和default_factory=now是正确的入库前为None数据库生成
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
"""
from datetime import datetime
from uuid import UUID
from models.base import SQLModelBase
class IntIdInfoMixin(SQLModelBase):
"""整数ID响应mixin - 用于InfoResponse DTO"""
id: int
"""记录ID"""
class UUIDIdInfoMixin(SQLModelBase):
"""UUID ID响应mixin - 用于InfoResponse DTO"""
id: UUID
"""记录ID"""
class DatetimeInfoMixin(SQLModelBase):
"""时间戳响应mixin - 用于InfoResponse DTO"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""更新时间"""
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
"""整数ID + 时间戳响应mixin"""
pass
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
"""UUID ID + 时间戳响应mixin"""
pass

456
models/mixin/polymorphic.py Normal file
View File

@@ -0,0 +1,456 @@
"""
联表继承Joined Table Inheritance的通用工具
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
Usage Example:
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import (
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin
)
# 1. 定义Base类只有字段无表
class ASRBase(SQLModelBase):
name: str
\"\"\"配置名称\"\"\"
base_url: str
\"\"\"服务地址\"\"\"
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
class ASR(
ASRBase,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC
):
\"\"\"ASR配置的抽象基类\"\"\"
# PolymorphicBaseMixin 自动提供:
# - _polymorphic_name 字段
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True当有抽象方法时
# 3. 为第二层子类创建ID Mixin
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. 创建第二层抽象类(如果需要)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
\"\"\"FunASR的抽象基类可能有多个实现\"\"\"
pass
# 5. 创建具体实现类
class FunASRLocal(FunASR, table=True):
\"\"\"FunASR本地部署版本\"\"\"
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
pass
# 6. 获取所有具体子类(用于 selectin_polymorphic
concrete_asrs = ASR.get_concrete_subclasses()
# 返回 [FunASRLocal, ...]
"""
import uuid
from abc import ABC
from uuid import UUID
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import String, inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import Field
from models.base.sqlmodel_base import SQLModelBase
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
"""
动态创建SubclassIdMixin类
在联表继承中,子类需要一个外键指向父表的主键。
此函数生成一个Mixin类提供这个外键字段并自动生成UUID。
Args:
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function'
Returns:
一个Mixin类包含id字段外键 + 主键 + default_factory=uuid.uuid4
Example:
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
... pass
Note:
- 生成的Mixin应该放在继承列表的第一位确保通过MRO覆盖UUIDTableBaseMixin的id
- 生成的类名为 {ParentTableName}SubclassIdMixinPascalCase
- 本项目所有联表继承均使用UUID主键UUIDTableBaseMixin
"""
if not parent_table_name:
raise ValueError("parent_table_name 不能为空")
# 转换为PascalCase作为类名
class_name_parts = parent_table_name.split('_')
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
# 使用闭包捕获parent_table_name
_parent_table_name = parent_table_name
# 创建带有__init_subclass__的mixin类用于在子类定义后修复model_fields
class SubclassIdMixin(SQLModelBase):
# 定义id字段
id: UUID = Field(
default_factory=uuid.uuid4,
foreign_key=f'{_parent_table_name}.id',
primary_key=True,
)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复联表继承中子类字段的default_factory丢失问题。
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory
遵循单一真相原则(不硬编码 default_factory
需要修复的字段:
- id: 主键(从父类获取 default_factory
- created_at: 创建时间戳(从父类获取 default_factory
- updated_at: 更新时间戳(从父类获取 default_factory
"""
super().__pydantic_init_subclass__(**kwargs)
if not hasattr(cls, 'model_fields'):
return
def find_original_field_info(field_name: str) -> FieldInfo | None:
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, 'model_fields') and field_name in base.model_fields:
field_info = base.model_fields[field_name]
# 跳过被 InstrumentedAttribute 污染的
if not isinstance(field_info.default, InstrumentedAttribute):
return field_info
return None
# 动态检测所有需要修复的字段
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
# 1. default 是 InstrumentedAttribute被 SQLAlchemy 污染)
# 2. 原始定义有 default_factory 或明确的 default 值
#
# 覆盖场景:
# - UUID主键UUIDTableBaseMixinid 有 default_factory=uuid.uuid4需要修复
# - int主键TableBaseMixinid 用 default=None不需要修复数据库自增
# - created_at/updated_at有 default_factory=now需要修复
# - 外键字段created_by_id等有 default=None需要修复
# - 普通字段name, temperature等无 default_factory不需要修复
#
# MRO 查找保证:
# - 在多重继承场景下MRO 顺序是确定性的
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
for field_name, current_field in cls.model_fields.items():
# 检查是否被污染default 是 InstrumentedAttribute
if not isinstance(current_field.default, InstrumentedAttribute):
continue # 未被污染,跳过
# 从父类查找原始定义
original = find_original_field_info(field_name)
if original is None:
continue # 找不到原始定义,跳过
# 根据原始定义的 default/default_factory 来修复
if original.default_factory:
# 有 default_factory如 uuid.uuid4, now
new_field = FieldInfo(
default_factory=original.default_factory,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
elif original.default is not PydanticUndefined:
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
# PydanticUndefined 表示字段没有默认值(必填)
new_field = FieldInfo(
default=original.default,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
else:
continue # 既没有 default_factory 也没有有效的 default跳过
# 复制SQLModel特有的属性
if hasattr(current_field, 'foreign_key'):
new_field.foreign_key = current_field.foreign_key
if hasattr(current_field, 'primary_key'):
new_field.primary_key = current_field.primary_key
cls.model_fields[field_name] = new_field
# 设置类名和文档
SubclassIdMixin.__name__ = class_name
SubclassIdMixin.__qualname__ = class_name
SubclassIdMixin.__doc__ = f"""
{parent_table_name}子类的ID Mixin
用于{parent_table_name}的子类,提供外键指向父表。
通过MRO确保此id字段覆盖继承的id字段。
"""
return SubclassIdMixin
class AutoPolymorphicIdentityMixin:
"""
自动生成polymorphic_identity的Mixin
使用此Mixin的类会自动根据类名生成polymorphic_identity。
格式:{parent_polymorphic_identity}.{classname_lowercase}
如果没有父类的polymorphic_identity则直接使用类名小写。
Example:
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
... __polymorphic_name: str
...
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
... pass
... # polymorphic_identity 会自动设置为 'function'
...
>>> class CodeInterpreterFunction(Function, table=True):
... pass
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
Note:
- 如果手动在__mapper_args__中指定了polymorphic_identity会被保留
- 此Mixin应该在继承列表中靠后的位置在表基类之前
"""
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
"""
子类化钩子自动生成polymorphic_identity
Args:
polymorphic_identity: 如果手动指定,则使用指定的值
**kwargs: 其他SQLModel参数如table=True, polymorphic_abstract=True
"""
super().__init_subclass__(**kwargs)
# 如果手动指定了polymorphic_identity使用指定的值
if polymorphic_identity is not None:
identity = polymorphic_identity
else:
# 自动生成polymorphic_identity
class_name = cls.__name__.lower()
# 尝试从父类获取polymorphic_identity作为前缀
parent_identity = None
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
parent_identity = base.__mapper_args__.get('polymorphic_identity')
if parent_identity:
break
# 构建identity
if parent_identity:
identity = f'{parent_identity}.{class_name}'
else:
identity = class_name
# 设置到__mapper_args__
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 只在尚未设置polymorphic_identity时设置
if 'polymorphic_identity' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_identity'] = identity
class PolymorphicBaseMixin:
"""
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
此 Mixin 自动设置以下内容:
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
使用场景:
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
用法示例:
```python
from abc import ABC
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
# 定义基类
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
__tablename__ = 'mytool'
# 不需要手动定义 _polymorphic_name
# 不需要手动设置 polymorphic_on
# 不需要手动设置 polymorphic_abstract
# 定义子类
class SpecificTool(MyTool):
__tablename__ = 'specifictool'
# 会自动继承 polymorphic 配置
```
自动行为:
1. 定义 `_polymorphic_name: str` 字段(带索引)
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
3. 自动检测抽象类:
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- 否则设置为 False
手动覆盖:
可以在类定义时手动指定参数来覆盖自动行为:
```python
class MyTool(
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC,
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
polymorphic_abstract=False # 强制不设为抽象类
):
pass
```
注意事项:
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- 适用于联表继承joined table inheritance场景
- 子类会自动继承 _polymorphic_name 字段定义
- 使用单下划线前缀是因为:
* SQLAlchemy 会映射单下划线字段为数据库列
* Pydantic 将其视为私有属性,不参与序列化
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
"""
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
#
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
#
# 为什么这样做:
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
# 4. 字段不出现在 Pydantic 序列化中model_dump() 和 JSON schema
# 5. 内部代码仍然可以正常访问和修改此字段
#
# 详细说明请参考sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
"""
多态鉴别器字段,用于标识具体的子类类型
注意:此字段使用单下划线前缀,表示内部使用。
- ✅ 存储到数据库
- ✅ 不出现在 API 序列化中
- ✅ 防止外部直接修改
"""
def __init_subclass__(
cls,
polymorphic_on: str | None = None,
polymorphic_abstract: bool | None = None,
**kwargs
):
"""
在子类定义时自动配置 polymorphic 设置
Args:
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'
设置为其他值可以使用不同的字段作为多态鉴别器。
polymorphic_abstract: 是否为抽象类。
- None: 自动检测(默认)
- True: 强制设为抽象类
- False: 强制设为非抽象类
**kwargs: 传递给父类的其他参数
"""
super().__init_subclass__(**kwargs)
# 初始化 __mapper_args__如果还没有
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 设置 polymorphic_on默认为 _polymorphic_name
if 'polymorphic_on' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
# 自动检测或设置 polymorphic_abstract
if 'polymorphic_abstract' not in cls.__mapper_args__:
if polymorphic_abstract is None:
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
has_abc = ABC in cls.__mro__
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
polymorphic_abstract = has_abc and has_abstract_methods
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
@classmethod
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
"""
递归获取当前类的所有具体(非抽象)子类
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
可在任意多态基类上调用,返回该类的所有非抽象子类。
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
"""
result: list[type[PolymorphicBaseMixin]] = []
for subclass in cls.__subclasses__():
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
mapper = inspect(subclass)
if not mapper.polymorphic_abstract:
result.append(subclass)
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
if hasattr(subclass, 'get_concrete_subclasses'):
result.extend(subclass.get_concrete_subclasses())
return result
@classmethod
def get_polymorphic_discriminator(cls) -> str:
"""
获取多态鉴别字段名
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
:return: 多态鉴别字段名(如 '_polymorphic_name'
:raises ValueError: 如果类未配置 polymorphic_on
"""
polymorphic_on = inspect(cls).polymorphic_on
if polymorphic_on is None:
raise ValueError(
f"{cls.__name__} 未配置 polymorphic_on"
f"请确保正确继承 PolymorphicBaseMixin"
)
return polymorphic_on.key
@classmethod
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
"""
获取 polymorphic_identity 到具体子类的映射
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
:return: identity 到子类的映射字典
"""
result: dict[str, type[PolymorphicBaseMixin]] = {}
for subclass in cls.get_concrete_subclasses():
identity = inspect(subclass).polymorphic_identity
if identity:
result[identity] = subclass
return result

927
models/mixin/table.py Normal file
View File

@@ -0,0 +1,927 @@
"""
表基类 Mixin
提供 TableBaseMixin、UUIDTableBaseMixin 和 TableViewRequest。
这些类实际上是 Mixin为 SQLModel 模型提供 CRUD 操作和时间戳字段。
依赖关系:
base/sqlmodel_base.py ← 最底层
mixin/polymorphic.py ← 定义 PolymorphicBaseMixin
mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin
base/__init__.py ← 从 mixin 重新导出(保持向后兼容)
"""
import uuid
from datetime import datetime
from typing import TypeVar, Literal, override, Any, ClassVar, Generic
# TODO(ListResponse泛型问题): SQLModel泛型类型JSON Schema生成bug
# 已知问题: https://github.com/fastapi/sqlmodel/discussions/1002
# 修复PR: https://github.com/fastapi/sqlmodel/pull/1275 (尚未合并)
# 现象: SQLModel + Generic[T] 的 __pydantic_generic_metadata__ = {origin: None, args: ()}
# 导致OpenAPI schema中泛型字段显示为{}而非正确的$ref
# 当前方案: ListResponse继承BaseModel而非SQLModel (Discussion #1002推荐的workaround)
# 未来: PR #1275合并后可改回继承SQLModelBase
from pydantic import BaseModel, ConfigDict
from fastapi import HTTPException
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct
from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
from sqlmodel import Field, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.sql._typing import _OnClauseArgument
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel.main import RelationshipInfo
from .polymorphic import PolymorphicBaseMixin
from models.base.sqlmodel_base import SQLModelBase
# Type variables for generic type hints, improving code completion and analysis.
T = TypeVar("T", bound="TableBaseMixin")
M = TypeVar("M", bound="SQLModelBase")
ItemT = TypeVar("ItemT")
class ListResponse(BaseModel, Generic[ItemT]):
"""
泛型分页响应
用于所有LIST端点的标准化响应格式包含记录总数和项目列表。
与 TableBaseMixin.get_with_count() 配合使用。
使用示例:
```python
@router.get("", response_model=ListResponse[CharacterInfoResponse])
async def list_characters(...) -> ListResponse[Character]:
return await Character.get_with_count(session, table_view=table_view)
```
Attributes:
count: 符合条件的记录总数(用于分页计算)
items: 当前页的记录列表
Note:
继承BaseModel而非SQLModelBase因为SQLModel的metaclass与Generic冲突。
详见文件顶部TODO注释。
"""
# 与SQLModelBase保持一致的配置
model_config = ConfigDict(use_attribute_docstrings=True)
count: int
"""符合条件的记录总数"""
items: list[ItemT]
"""当前页的记录列表"""
# Lambda functions to get the current time, used as default factories in model fields.
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
# ==================== 查询参数请求类 ====================
class TimeFilterRequest(SQLModelBase):
"""
时间筛选请求参数
用于 count() 等只需要时间筛选的场景。
纯数据类只负责参数校验和携带SQL子句构建由 TableBaseMixin 负责。
Raises:
ValueError: 时间范围无效
"""
created_after_datetime: datetime | None = None
"""创建时间起始筛选created_at >= datetime如果为None则不限制"""
created_before_datetime: datetime | None = None
"""创建时间结束筛选created_at < datetime如果为None则不限制"""
updated_after_datetime: datetime | None = None
"""更新时间起始筛选updated_at >= datetime如果为None则不限制"""
updated_before_datetime: datetime | None = None
"""更新时间结束筛选updated_at < datetime如果为None则不限制"""
def model_post_init(self, __context: Any) -> None:
"""
验证时间范围有效性
验证规则:
1. 同类型after 必须小于 before
2. 跨类型created_after 不能大于 updated_before记录不可能在创建前被更新
"""
# 同类型矛盾验证
if self.created_after_datetime and self.created_before_datetime:
if self.created_after_datetime >= self.created_before_datetime:
raise ValueError("created_after_datetime 必须小于 created_before_datetime")
if self.updated_after_datetime and self.updated_before_datetime:
if self.updated_after_datetime >= self.updated_before_datetime:
raise ValueError("updated_after_datetime 必须小于 updated_before_datetime")
# 跨类型矛盾验证created_after >= updated_before 意味着要求创建时间晚于或等于更新时间上界,逻辑矛盾
if self.created_after_datetime and self.updated_before_datetime:
if self.created_after_datetime >= self.updated_before_datetime:
raise ValueError(
"created_after_datetime 不能大于或等于 updated_before_datetime"
"(记录的更新时间不可能早于或等于创建时间)"
)
class PaginationRequest(SQLModelBase):
"""
分页排序请求参数
用于需要分页和排序的场景。
纯数据类只负责携带参数SQL子句构建由 TableBaseMixin 负责。
"""
offset: int | None = Field(default=0, ge=0)
"""偏移量跳过前N条记录必须为非负整数"""
limit: int | None = Field(default=50, le=100)
"""每页数量返回最多N条记录默认50最大100"""
desc: bool | None = True
"""是否降序排序True: 降序, False: 升序)"""
order: Literal["created_at", "updated_at"] | None = "created_at"
"""排序字段created_at: 创建时间, updated_at: 更新时间)"""
class TableViewRequest(TimeFilterRequest, PaginationRequest):
"""
表格视图请求参数(分页、排序和时间筛选)
组合继承 TimeFilterRequest 和 PaginationRequest用于 get() 等需要完整查询参数的场景。
纯数据类SQL子句构建由 TableBaseMixin 负责。
使用示例:
```python
# 在端点中使用依赖注入
@router.get("/list")
async def list_items(
session: SessionDep,
table_view: TableViewRequestDep
):
items = await Item.get(
session,
fetch_mode="all",
table_view=table_view
)
return items
# 直接使用
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
items = await Item.get(session, fetch_mode="all", table_view=table_view)
```
"""
pass
# ==================== TableBaseMixin ====================
class TableBaseMixin(AsyncAttrs):
"""
一个异步 CRUD 操作的基础模型类 Mixin.
此类必须搭配SQLModelBase使用
此类为所有继承它的 SQLModel 模型提供了通用的数据库操作方法,
例如 add, save, update, delete, 和 get. 它还包括自动管理
的 `created_at` 和 `updated_at` 时间戳字段.
Attributes:
id (int | None): 整数主键, 自动递增.
created_at (datetime): 记录创建时的时间戳, 自动设置.
updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
"""
_is_table_mixin: ClassVar[bool] = True
"""标记此类为表混入类的内部属性"""
def __init_subclass__(cls, **kwargs):
"""
接受并传递子类定义时的关键字参数
这允许元类 __DeclarativeMeta 处理的参数(如 table_args
能够正确传递,而不会在 __init_subclass__ 阶段报错。
"""
super().__init_subclass__(**kwargs)
id: int | None = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=now)
updated_at: datetime = Field(
sa_type=DateTime,
sa_column_kwargs={'default': now, 'onupdate': now},
default_factory=now
)
@classmethod
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
"""
向数据库中添加一个新的或多个新的记录.
这个类方法可以接受单个模型实例或一个实例列表,并将它们
一次性提交到数据库中。执行后,可以选择性地刷新这些实例以获取
数据库生成的值(例如,自动递增的 ID.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
Returns:
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
Usage:
item1 = Item(name="Apple")
item2 = Item(name="Banana")
# 添加多个实例
added_items = await Item.add(session, [item1, item2])
# 添加单个实例
item3 = Item(name="Cherry")
added_item = await Item.add(session, item3)
"""
is_list = False
if isinstance(instances, list):
is_list = True
session.add_all(instances)
else:
session.add(instances)
await session.commit()
if refresh:
if is_list:
for instance in instances:
await session.refresh(instance)
else:
await session.refresh(instances)
return instances
async def save(
self: T,
session: AsyncSession,
load: RelationshipInfo | None = None,
refresh: bool = True
) -> T:
"""
保存(插入或更新)当前模型实例到数据库.
这是一个实例方法,它将当前对象添加到会话中并提交更改。
可以用于创建新记录或更新现有记录。还可以选择在保存后
预加载eager load一个关联关系.
**重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值:
```python
# ✅ 正确:需要返回值时
client = await client.save(session)
return client
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.save(session, refresh=False)
# ❌ 错误:需要返回值但未使用
await client.save(session)
return client # client 对象已过期
```
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
例如 `User.posts`.
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
"""
session.add(self)
await session.commit()
if not refresh:
return self
if load is not None:
cls = type(self)
await session.refresh(self)
# 如果指定了 load, 重新获取实例并加载关联关系
return await cls.get(session, cls.id == self.id, load=load)
else:
await session.refresh(self)
return self
async def update(
self: T,
session: AsyncSession,
other: M,
extra_data: dict[str, Any] | None = None,
exclude_unset: bool = True,
exclude: set[str] | None = None,
load: RelationshipInfo | None = None,
refresh: bool = True
) -> T:
"""
使用另一个模型实例或字典中的数据来更新当前实例.
此方法将 `other` 对象中的数据合并到当前实例中。默认情况下,
它只会更新 `other` 中被显式设置的字段.
**重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值:
```python
# ✅ 正确:需要返回值时
client = await client.update(session, update_data)
return client
# ✅ 正确:需要返回值且需要加载关系时
user = await user.update(session, update_data, load=User.permission)
return user
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.update(session, update_data, refresh=False)
# ❌ 错误:需要返回值但未使用
await client.update(session, update_data)
return client # client 对象已过期
```
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
other (M): 一个 SQLModel 或 Pydantic 模型实例,其数据将用于更新当前实例.
extra_data (dict, optional): 一个额外的字典,用于更新当前实例的特定字段.
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
的字段将被忽略. 默认为 True.
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
例如 `User.permission`.
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
"""
self.sqlmodel_update(
other.model_dump(exclude_unset=exclude_unset, exclude=exclude),
update=extra_data
)
session.add(self)
await session.commit()
if not refresh:
return self
if load is not None:
cls = type(self)
await session.refresh(self)
return await cls.get(session, cls.id == self.id, load=load)
else:
await session.refresh(self)
return self
@classmethod
async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None:
"""
从数据库中删除一个或多个记录.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要删除的单个模型实例或模型实例列表.
Returns:
None
Usage:
item_to_delete = await Item.get(session, Item.id == 1)
if item_to_delete:
await Item.delete(session, item_to_delete)
items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all")
if items_to_delete:
await Item.delete(session, items_to_delete)
"""
if isinstance(instances, list):
for instance in instances:
await session.delete(instance)
else:
await session.delete(instances)
await session.commit()
@classmethod
def _build_time_filters(
cls: type[T],
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> list[BinaryExpression]:
"""
构建时间筛选条件列表
Args:
created_before_datetime: 筛选 created_at < datetime 的记录
created_after_datetime: 筛选 created_at >= datetime 的记录
updated_before_datetime: 筛选 updated_at < datetime 的记录
updated_after_datetime: 筛选 updated_at >= datetime 的记录
Returns:
BinaryExpression 条件列表
"""
filters: list[BinaryExpression] = []
if created_after_datetime is not None:
filters.append(cls.created_at >= created_after_datetime)
if created_before_datetime is not None:
filters.append(cls.created_at < created_before_datetime)
if updated_after_datetime is not None:
filters.append(cls.updated_at >= updated_after_datetime)
if updated_before_datetime is not None:
filters.append(cls.updated_at < updated_before_datetime)
return filters
@classmethod
async def get(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
offset: int | None = None,
limit: int | None = None,
fetch_mode: Literal["one", "first", "all"] = "first",
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
with_for_update: bool = False,
table_view: TableViewRequest | None = None,
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> T | list[T] | None:
"""
根据指定的条件异步地从数据库中获取一个或多个模型实例.
这是一个功能强大的通用查询方法,支持过滤、排序、分页、连接查询和关联关系预加载.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
condition (BinaryExpression | ClauseElement | None): 主要的查询过滤条件,
例如 `User.id == 1`。
当为 `None` 时,表示无条件查询(查询所有记录)。
offset (int | None): 查询结果的起始偏移量, 用于分页.
limit (int | None): 返回记录的最大数量, 用于分页.
fetch_mode (Literal["one", "first", "all"]):
- "one": 获取唯一的一条记录. 如果找不到或找到多条,会引发异常.
- "first": 获取查询结果的第一条记录. 如果找不到,返回 `None`.
- "all": 获取所有匹配的记录,返回一个列表.
默认为 "first".
join (type[T] | tuple[type[T], _OnClauseArgument] | None):
要 JOIN 的模型类或一个包含模型类和 ON 子句的元组.
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
例如 `[selectinload(User.posts)]`.
load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系.
例如 `User.profile`.
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
例如 `[User.name.asc(), User.created_at.desc()]`.
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
with_for_update (bool): 如果为 True, 在查询中使用 `FOR UPDATE` 锁定选定的行. 默认为 False.
table_view (TableViewRequest | None): TableViewRequest对象如果提供则自动处理分页、排序和时间筛选。
会覆盖offset、limit、order_by及时间筛选参数。
这是推荐的分页排序方式统一了所有LIST端点的参数格式。
load_polymorphic: 多态子类加载选项,需要与 load 参数配合使用。
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益)
- None默认: 不使用多态加载
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
updated_after_datetime (datetime | None): 筛选 updated_at >= datetime 的记录
Returns:
T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
Raises:
ValueError: 如果提供了无效的 `fetch_mode` 值,或 load_polymorphic 未与 load 配合使用.
Examples:
# 使用table_view参数推荐
users = await User.get(session, fetch_mode="all", table_view=table_view_args)
# 传统方式(向后兼容)
users = await User.get(session, fetch_mode="all", offset=0, limit=20, order_by=[desc(User.created_at)])
# 使用多态加载(加载联表继承的子类数据)
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic='all' # 只加载实际关联的子类
)
"""
# 参数验证load_polymorphic 需要与 load 配合使用
if load_polymorphic is not None and load is None:
raise ValueError(
"load_polymorphic 参数需要与 load 参数配合使用,"
"请同时指定要加载的关系"
)
# 如果提供table_view作为默认值使用单独传入的参数优先级更高
if table_view:
# 处理时间筛选TimeFilterRequest 及其子类)
if isinstance(table_view, TimeFilterRequest):
if created_after_datetime is None and table_view.created_after_datetime is not None:
created_after_datetime = table_view.created_after_datetime
if created_before_datetime is None and table_view.created_before_datetime is not None:
created_before_datetime = table_view.created_before_datetime
if updated_after_datetime is None and table_view.updated_after_datetime is not None:
updated_after_datetime = table_view.updated_after_datetime
if updated_before_datetime is None and table_view.updated_before_datetime is not None:
updated_before_datetime = table_view.updated_before_datetime
# 处理分页排序PaginationRequest 及其子类,包括 TableViewRequest
if isinstance(table_view, PaginationRequest):
if offset is None:
offset = table_view.offset
if limit is None:
limit = table_view.limit
# 仅在未显式传入order_by时从table_view构建排序子句
if order_by is None:
order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at
order_by = [desc(order_column) if table_view.desc else asc(order_column)]
# 对于多态基类,使用 with_polymorphic 预加载所有子类的列
# 这避免了在响应序列化时的延迟加载问题MissingGreenlet 错误)
if issubclass(cls, PolymorphicBaseMixin):
# '*' 表示加载所有子类
polymorphic_cls = with_polymorphic(cls, '*')
statement = select(polymorphic_cls)
else:
statement = select(cls)
if condition is not None:
statement = statement.where(condition)
# 应用时间筛选
for time_filter in cls._build_time_filters(
created_before_datetime, created_after_datetime,
updated_before_datetime, updated_after_datetime
):
statement = statement.where(time_filter)
if join is not None:
# 如果 join 是一个元组,解包它;否则直接使用
if isinstance(join, tuple):
statement = statement.join(*join)
else:
statement = statement.join(join)
if options:
statement = statement.options(*options)
if load:
# 处理多态加载
if load_polymorphic is not None:
target_class = load.property.mapper.class_
# 检查目标类是否继承自 PolymorphicBaseMixin
if not issubclass(target_class, PolymorphicBaseMixin):
raise ValueError(
f"目标类 {target_class.__name__} 不是多态类,"
f"请确保其继承自 PolymorphicBaseMixin"
)
if load_polymorphic == 'all':
# 两阶段查询:获取实际关联的多态类型
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
session, condition, load, target_class
)
else:
subclasses_to_load = load_polymorphic
if subclasses_to_load:
# 关键selectin_polymorphic 必须作为 selectinload 的链式子选项
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
statement = statement.options(
selectinload(load).selectin_polymorphic(subclasses_to_load)
)
else:
statement = statement.options(selectinload(load))
else:
statement = statement.options(selectinload(load))
if order_by is not None:
statement = statement.order_by(*order_by)
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
if filter:
statement = statement.filter(filter)
if with_for_update:
statement = statement.with_for_update()
result = await session.exec(statement)
if fetch_mode == "one":
return result.one()
elif fetch_mode == "first":
return result.first()
elif fetch_mode == "all":
return list(result.all())
else:
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
@classmethod
async def _resolve_polymorphic_subclasses(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None,
load: RelationshipInfo,
target_class: type[PolymorphicBaseMixin]
) -> list[type[PolymorphicBaseMixin]]:
"""
查询实际关联的多态子类类型
通过查询多态鉴别字段确定实际存在的子类类型,
避免加载所有可能的子类表(对于 > 10 个子类的场景有明显收益)。
:param session: 数据库会话
:param condition: 主查询的条件
:param load: 关系属性
:param target_class: 多态基类
:return: 实际关联的子类列表
"""
# 获取多态鉴别字段(会抛出 ValueError 如果未配置)
discriminator = target_class.get_polymorphic_discriminator()
poly_name_col = getattr(target_class, discriminator)
# 获取关系属性
relationship_property = load.property
# 构建查询获取实际的多态类型名称
if relationship_property.secondary is not None:
# 多对多关系:通过中间表查询
secondary = relationship_property.secondary
local_cols = list(relationship_property.local_columns)
type_query = (
select(distinct(poly_name_col))
.select_from(target_class)
.join(secondary)
.where(secondary.c[local_cols[0].name].in_(
select(cls.id).where(condition) if condition is not None else select(cls.id)
))
)
else:
# 一对多关系:通过外键查询
foreign_key_col = relationship_property.local_remote_pairs[0][1]
type_query = (
select(distinct(poly_name_col))
.where(foreign_key_col.in_(
select(cls.id).where(condition) if condition is not None else select(cls.id)
))
)
type_result = await session.exec(type_query)
poly_names = list(type_result.all())
if not poly_names:
return []
# 映射到子类(包含所有层级的具体子类)
identity_map = target_class.get_identity_to_class_map()
return [identity_map[name] for name in poly_names if name in identity_map]
@classmethod
async def count(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
time_filter: TimeFilterRequest | None = None,
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> int:
"""
根据条件统计记录数量(支持时间筛选)
使用数据库层面的 COUNT() 聚合函数,比 get() + len() 更高效。
Args:
session: 数据库会话
condition: 查询条件,例如 `User.is_active == True`
time_filter: TimeFilterRequest 对象(优先级更高)
created_before_datetime: 筛选 created_at < datetime 的记录
created_after_datetime: 筛选 created_at >= datetime 的记录
updated_before_datetime: 筛选 updated_at < datetime 的记录
updated_after_datetime: 筛选 updated_at >= datetime 的记录
Returns:
符合条件的记录数量
Examples:
# 统计所有用户
total = await User.count(session)
# 统计激活的虚拟客户端
count = await Client.count(
session,
(Client.user_id == user_id) & (Client.type != ClientTypeEnum.physical) & (Client.is_active == True)
)
# 使用 TimeFilterRequest 进行时间筛选
count = await User.count(session, time_filter=time_filter_request)
# 使用独立时间参数
count = await User.count(
session,
created_after_datetime=datetime(2025, 1, 1),
created_before_datetime=datetime(2025, 2, 1),
)
"""
# time_filter 的时间筛选优先级更高
if isinstance(time_filter, TimeFilterRequest):
if time_filter.created_after_datetime is not None:
created_after_datetime = time_filter.created_after_datetime
if time_filter.created_before_datetime is not None:
created_before_datetime = time_filter.created_before_datetime
if time_filter.updated_after_datetime is not None:
updated_after_datetime = time_filter.updated_after_datetime
if time_filter.updated_before_datetime is not None:
updated_before_datetime = time_filter.updated_before_datetime
statement = select(func.count()).select_from(cls)
# 应用查询条件
if condition is not None:
statement = statement.where(condition)
# 应用时间筛选
for time_condition in cls._build_time_filters(
created_before_datetime, created_after_datetime,
updated_before_datetime, updated_after_datetime
):
statement = statement.where(time_condition)
result = await session.scalar(statement)
return result or 0
@classmethod
async def get_with_count(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
table_view: TableViewRequest | None = None,
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
) -> 'ListResponse[T]':
"""
获取分页列表及总数,直接返回 ListResponse
同时返回符合条件的记录列表和总数,用于分页场景。
与 get() 方法类似,但固定 fetch_mode="all" 并返回 ListResponse。
注意:如果子类的 get() 方法支持额外参数(如 filter_params
子类应该覆盖此方法以确保 count 和 items 使用相同的过滤条件。
Args:
session: 数据库会话
condition: 查询条件
join: JOIN 的模型类或元组
options: SQLAlchemy 查询选项
load: selectinload 预加载关系
order_by: 排序子句
filter: 附加过滤条件
table_view: 分页排序参数(推荐使用)
load_polymorphic: 多态子类加载选项
Returns:
ListResponse[T]: 包含 count 和 items 的分页响应
Examples:
```python
@router.get("", response_model=ListResponse[CharacterInfoResponse])
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> ListResponse[Character]:
return await Character.get_with_count(session, table_view=table_view)
```
"""
# 提取时间筛选参数(用于 count
time_filter: TimeFilterRequest | None = None
if table_view is not None:
time_filter = TimeFilterRequest(
created_after_datetime=table_view.created_after_datetime,
created_before_datetime=table_view.created_before_datetime,
updated_after_datetime=table_view.updated_after_datetime,
updated_before_datetime=table_view.updated_before_datetime,
)
# 获取总数(不带分页限制)
total_count = await cls.count(session, condition, time_filter=time_filter)
# 获取分页数据
items = await cls.get(
session,
condition,
fetch_mode="all",
join=join,
options=options,
load=load,
order_by=order_by,
filter=filter,
table_view=table_view,
load_polymorphic=load_polymorphic,
)
return ListResponse(count=total_count, items=items)
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T:
"""
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
这个方法是对 `get` 方法的封装,专门用于处理那种"记录必须存在"的业务场景。
如果记录未找到,它会直接引发 FastAPI 的 `HTTPException`, 而不是返回 `None`.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (int): 要查找的记录的主键 ID.
load (Relationship | None): 可选的,用于预加载的关联属性.
Returns:
T: 找到的模型实例.
Raises:
HTTPException: 如果 ID 对应的记录不存在,则抛出状态码为 404 的异常.
"""
instance = await cls.get(session, cls.id == id, load=load)
if not instance:
raise HTTPException(status_code=404, detail="Not found")
return instance
class UUIDTableBaseMixin(TableBaseMixin):
"""
一个使用 UUID 作为主键的异步 CRUD 操作基础模型类 Mixin.
此类继承自 `TableBaseMixin`, 将主键 `id` 的类型覆盖为 `uuid.UUID`
并为新记录自动生成 UUID. 它继承了 `TableBaseMixin` 的所有 CRUD 方法.
Attributes:
id (uuid.UUID): UUID 类型的主键, 在创建时自动生成.
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
"""覆盖 `TableBaseMixin` 的 id 字段,使用 UUID 作为主键."""
@override
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
"""
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
此方法覆盖了父类的同名方法,以确保 `id` 参数的类型注解为 `uuid.UUID`,
从而提供更好的类型安全和代码提示.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (uuid.UUID): 要查找的记录的 UUID 主键.
load (Relationship | None): 可选的,用于预加载的关联属性.
Returns:
T: 找到的模型实例.
Raises:
HTTPException: 如果 UUID 对应的记录不存在,则抛出状态码为 404 的异常.
"""
# 类型检查器可能会警告这里的 `id` 类型不匹配超类方法,
# 但在运行时这是正确的,因为超类方法内部的比较 (cls.id == id)
# 会正确处理 UUID 类型。`type: ignore` 用于抑制此警告。
return await super().get_exist_one(session, id, load) # type: ignore