feat(mixin): add TableBaseMixin and UUIDTableBaseMixin for async CRUD operations

- Implemented TableBaseMixin providing generic CRUD methods and automatic timestamp management.
- Introduced UUIDTableBaseMixin for models using UUID as primary keys.
- Added ListResponse for standardized paginated responses.
- Created TimeFilterRequest and PaginationRequest for filtering and pagination parameters.
- Enhanced get_with_count method to return both item list and total count.
- Included validation for time filter parameters in TimeFilterRequest.
- Improved documentation and usage examples throughout the code.
This commit is contained in:
2025-12-22 18:29:14 +08:00
parent 47a4756227
commit a5efda9c23
44 changed files with 4306 additions and 497 deletions

8
.idea/.gitignore generated vendored
View File

@@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

17
.idea/Server.iml generated
View File

@@ -1,17 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.13 (Server)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="py.test" />
</component>
</module>

View File

@@ -1,5 +0,0 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AgentMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AskMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Ask2AgentMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="EditMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

View File

@@ -1,3 +0,0 @@
<component name="ProjectDictionaryState">
<dictionary name="project" />
</component>

View File

@@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

View File

@@ -1,17 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MaterialThemeProjectNewConfig">
<option name="metadata">
<MTProjectMetadataState>
<option name="migrated" value="true" />
<option name="pristineConfig" value="false" />
<option name="userId" value="298ea09f:198c11a97b9:-7ffe" />
</MTProjectMetadataState>
</option>
<option name="titleBarState">
<MTProjectTitleBarConfigState>
<option name="overrideColor" value="false" />
</MTProjectTitleBarConfigState>
</option>
</component>
</project>

7
.idea/misc.xml generated
View File

@@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.13 (Server)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (Server)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Server.iml" filepath="$PROJECT_DIR$/.idea/Server.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@@ -25,7 +25,7 @@ app = FastAPI(
)
# 挂载路由
app.include_router(router, prefix='/api')
app.include_router(router)
# 防止直接运行 main.py
if __name__ == "__main__":

View File

@@ -228,7 +228,6 @@ models/
| `source_name` | `str?` | 源文件名(仅文件) |
| `size` | `int` | 文件大小(字节),目录为 0 |
| `upload_session_id` | `str?` | 分块上传会话 ID |
| `file_metadata` | `str?` | 文件元数据JSON 格式) |
| `parent_id` | `UUID?` | 父目录外键NULL 表示根目录) |
| `owner_id` | `UUID` | 所有者用户(外键) |
| `policy_id` | `UUID` | 存储策略(外键) |
@@ -237,9 +236,35 @@ models/
- 同一父目录下名称唯一
- 名称不能包含斜杠
**关系**:
- `metadata`: 一对一关联 FileMetadata
---
### 9. SourceLink源链接
### 9. FileMetadata文件元数据
**表名**: `filemetadata`
**基类**: `UUIDTableBase`
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | `UUID` | 主键 |
| `object_id` | `UUID` | 关联的对象(外键,唯一) |
| `width` | `int?` | 图片/视频宽度 |
| `height` | `int?` | 图片/视频高度 |
| `duration` | `float?` | 音视频时长(秒) |
| `mime_type` | `str?` | MIME类型 |
| `bit_rate` | `int?` | 比特率 |
| `sample_rate` | `int?` | 采样率 |
| `channels` | `int?` | 音频通道数 |
| `codec` | `str?` | 编解码器 |
| `frame_rate` | `float?` | 视频帧率 |
| `orientation` | `int?` | 图片方向 |
| `has_thumbnail` | `bool` | 是否有缩略图 |
---
### 10. SourceLink源链接
**表名**: `sourcelink`
**基类**: `TableBase`
@@ -253,7 +278,7 @@ models/
---
### 10. Share分享
### 11. Share分享
**表名**: `share`
**基类**: `TableBase`
@@ -275,7 +300,7 @@ models/
---
### 11. Report举报
### 12. Report举报
**表名**: `report`
**基类**: `TableBase`
@@ -289,7 +314,7 @@ models/
---
### 12. Tag标签
### 13. Tag标签
**表名**: `tag`
**基类**: `TableBase`
@@ -300,7 +325,7 @@ models/
| `name` | `str` | 标签名称 |
| `icon` | `str?` | 标签图标 |
| `color` | `str?` | 标签颜色 |
| `type` | `int` | 标签类型:0=手动1=自动 |
| `type` | `TagType` | 标签类型:manual/automatic |
| `expression` | `str?` | 自动标签的匹配表达式 |
| `user_id` | `UUID` | 所属用户(外键) |
@@ -308,7 +333,7 @@ models/
---
### 13. Task任务
### 14. Task任务
**表名**: `task`
**基类**: `TableBase`
@@ -316,16 +341,35 @@ models/
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | `int` | 主键 |
| `status` | `int` | 任务状态0=排队中1=处理中2=完成3=错误 |
| `type` | `int` | 任务类型 |
| `status` | `TaskStatus` | 任务状态queued/running/completed/error |
| `type` | `int` | 任务类型[TODO] 待定义枚举) |
| `progress` | `int` | 任务进度0-100 |
| `error` | `str?` | 错误信息 |
| `props` | `str?` | 任务属性JSON 格式) |
| `user_id` | `UUID` | 所属用户(外键) |
**索引**: `ix_task_status`, `ix_task_user_status`
**关系**:
- `props`: 一对一关联 TaskProps
- `downloads`: 一对多关联 Download
---
### 14. Download离线下载
### 15. TaskProps任务属性
**表名**: `taskprops`
**基类**: `SQLModelBase`
| 字段 | 类型 | 说明 |
|------|------|------|
| `task_id` | `int` | 关联的任务(外键,主键) |
| `source_path` | `str?` | 源路径 |
| `dest_path` | `str?` | 目标路径 |
| `file_ids` | `str?` | 文件ID列表逗号分隔 |
---
### 16. Download离线下载
**表名**: `download`
**基类**: `UUIDTableBase`
@@ -333,15 +377,14 @@ models/
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | `UUID` | 主键 |
| `status` | `int` | 下载状态0=进行中1=完成2=错误 |
| `type` | `int` | 任务类型 |
| `status` | `DownloadStatus` | 下载状态running/completed/error |
| `type` | `int` | 任务类型[TODO] 待定义枚举) |
| `source` | `str` | 来源 URL 或标识 |
| `total_size` | `int` | 总大小(字节) |
| `downloaded_size` | `int` | 已下载大小(字节) |
| `g_id` | `str?` | Aria2 GID |
| `speed` | `int` | 下载速度bytes/s |
| `parent` | `str?` | 父任务标识 |
| `attrs` | `str?` | 额外属性JSON 格式) |
| `error` | `str?` | 错误信息 |
| `dst` | `str` | 目标存储路径 |
| `user_id` | `UUID` | 所属用户(外键) |
@@ -350,9 +393,52 @@ models/
**约束**: 同一节点下 g_id 唯一
**索引**: `ix_download_status`, `ix_download_user_status`
**关系**:
- `aria2_info`: 一对一关联 DownloadAria2Info
- `aria2_files`: 一对多关联 DownloadAria2File
---
### 15. Node节点
### 17. DownloadAria2InfoAria2下载信息
**表名**: `downloadaria2info`
**基类**: `SQLModelBase`
| 字段 | 类型 | 说明 |
|------|------|------|
| `download_id` | `UUID` | 关联的下载任务(外键,主键) |
| `info_hash` | `str?` | InfoHashBT种子 |
| `piece_length` | `int` | 分片大小 |
| `num_pieces` | `int` | 分片数量 |
| `num_seeders` | `int` | 做种人数 |
| `connections` | `int` | 连接数 |
| `upload_speed` | `int` | 上传速度bytes/s |
| `upload_length` | `int` | 已上传大小(字节) |
| `error_code` | `str?` | 错误代码 |
| `error_message` | `str?` | 错误信息 |
---
### 18. DownloadAria2FileAria2下载文件
**表名**: `downloadaria2file`
**基类**: `TableBase`
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | `int` | 主键 |
| `download_id` | `UUID` | 关联的下载任务(外键) |
| `file_index` | `int` | 文件索引从1开始 |
| `path` | `str` | 文件路径 |
| `length` | `int` | 文件大小(字节) |
| `completed_length` | `int` | 已完成大小(字节) |
| `is_selected` | `bool` | 是否选中下载 |
---
### 19. Node节点
**表名**: `node`
**基类**: `TableBase`
@@ -360,19 +446,41 @@ models/
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | `int` | 主键 |
| `status` | `int` | 节点状态0=正常1=离线 |
| `status` | `NodeStatus` | 节点状态online/offline |
| `name` | `str` | 节点名称,唯一 |
| `type` | `int` | 节点类型 |
| `type` | `int` | 节点类型[TODO] 待定义枚举) |
| `server` | `str` | 节点地址IP 或域名) |
| `slave_key` | `str?` | 从机通讯密钥 |
| `master_key` | `str?` | 主机通讯密钥 |
| `aria2_enabled` | `bool` | 是否启用 Aria2 |
| `aria2_options` | `str?` | Aria2 配置JSON 格式) |
| `rank` | `int` | 节点排序权重 |
**索引**: `ix_node_status`
**关系**:
- `aria2_config`: 一对一关联 Aria2Configuration
- `downloads`: 一对多关联 Download
---
### 16. Order订单
### 20. Aria2ConfigurationAria2配置
**表名**: `aria2configuration`
**基类**: `TableBase`
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | `int` | 主键 |
| `node_id` | `int` | 关联的节点(外键,唯一) |
| `rpc_url` | `str?` | RPC地址 |
| `rpc_secret` | `str?` | RPC密钥 |
| `temp_path` | `str?` | 临时下载路径 |
| `max_concurrent` | `int` | 最大并发数1-50默认5 |
| `timeout` | `int` | 请求超时时间默认300 |
---
### 21. Order订单
**表名**: `order`
**基类**: `TableBase`
@@ -381,18 +489,18 @@ models/
|------|------|------|
| `id` | `int` | 主键 |
| `order_no` | `str` | 订单号,唯一 |
| `type` | `int` | 订单类型 |
| `type` | `int` | 订单类型[TODO] 待定义枚举) |
| `method` | `str?` | 支付方式 |
| `product_id` | `int?` | 商品 ID |
| `num` | `int` | 购买数量 |
| `name` | `str` | 商品名称 |
| `price` | `int` | 订单价格(分) |
| `status` | `int` | 订单状态0=待支付1=已完成2=已取消 |
| `status` | `OrderStatus` | 订单状态pending/completed/cancelled |
| `user_id` | `UUID` | 所属用户(外键) |
---
### 17. Redeem兑换码
### 22. Redeem兑换码
**表名**: `redeem`
**基类**: `TableBase`
@@ -400,7 +508,7 @@ models/
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | `int` | 主键 |
| `type` | `int` | 兑换码类型 |
| `type` | `int` | 兑换码类型[TODO] 待定义枚举) |
| `product_id` | `int?` | 关联的商品/权益 ID |
| `num` | `int` | 可兑换数量/时长等 |
| `code` | `str` | 兑换码,唯一 |
@@ -408,7 +516,7 @@ models/
---
### 18. StoragePack容量包
### 23. StoragePack容量包
**表名**: `storagepack`
**基类**: `TableBase`
@@ -424,7 +532,7 @@ models/
---
### 19. WebDAVWebDAV 账户)
### 24. WebDAVWebDAV 账户)
**表名**: `webdav`
**基类**: `TableBase`
@@ -443,7 +551,7 @@ models/
---
### 20. Setting系统设置
### 25. Setting系统设置
**表名**: `setting`
**基类**: `TableBase`
@@ -467,23 +575,39 @@ models/
### 一对一关系
```
┌─────────────────────────────────────────────────────────┐
┌───────────────────────────────────────────────────────────────────
│ 一对一关系 │
├─────────────────────────────────────────────────────────┤
├───────────────────────────────────────────────────────────────────
│ │
│ Group ◄────────────────────────> GroupOptions │
│ Group ◄────────────────────────> GroupOptions
│ group_id (unique FK) │
│ │
│ Policy ◄───────────────────────> PolicyOptions │
│ Policy ◄───────────────────────> PolicyOptions
│ policy_id (unique FK) │
│ │
└─────────────────────────────────────────────────────────┘
│ Object ◄────────────────────────> FileMetadata │
│ object_id (unique FK) │
│ │
│ Node ◄──────────────────────────> Aria2Configuration │
│ node_id (unique FK) │
│ │
│ Task ◄──────────────────────────> TaskProps │
│ task_id (PK/FK) │
│ │
│ Download ◄──────────────────────> DownloadAria2Info │
│ download_id (PK/FK) │
│ │
└───────────────────────────────────────────────────────────────────┘
```
| 主表 | 从表 | 外键 | 说明 |
|------|------|------|------|
| Group | GroupOptions | `group_id` (unique) | 每个用户组有且仅有一个选项配置 |
| Policy | PolicyOptions | `policy_id` (unique) | 每个存储策略有且仅有一个扩展选项 |
| Object | FileMetadata | `object_id` (unique) | 每个文件对象有且仅有一个元数据 |
| Node | Aria2Configuration | `node_id` (unique) | 每个节点有且仅有一个 Aria2 配置 |
| Task | TaskProps | `task_id` (PK) | 每个任务有且仅有一个属性配置 |
| Download | DownloadAria2Info | `download_id` (PK) | 每个下载任务有且仅有一个 Aria2 信息 |
---
@@ -540,6 +664,7 @@ models/
| **Share** | Report | `share_id` | 分享的举报 |
| **Task** | Download | `task_id` | 任务关联的下载 |
| **Node** | Download | `node_id` | 节点执行的下载任务 |
| **Download** | DownloadAria2File | `download_id` | 下载任务的文件列表 |
---
@@ -666,6 +791,56 @@ class AvatarType(StrEnum):
FILE = "file" # 自定义文件
```
### TagType
```python
class TagType(StrEnum):
MANUAL = "manual" # 手动标签
AUTOMATIC = "automatic" # 自动标签
```
### TaskStatus
```python
class TaskStatus(StrEnum):
QUEUED = "queued" # 排队中
RUNNING = "running" # 处理中
COMPLETED = "completed" # 已完成
ERROR = "error" # 错误
```
### DownloadStatus
```python
class DownloadStatus(StrEnum):
RUNNING = "running" # 进行中
COMPLETED = "completed" # 已完成
ERROR = "error" # 错误
```
### NodeStatus
```python
class NodeStatus(StrEnum):
ONLINE = "online" # 正常
OFFLINE = "offline" # 离线
```
### OrderStatus
```python
class OrderStatus(StrEnum):
PENDING = "pending" # 待支付
COMPLETED = "completed" # 已完成
CANCELLED = "cancelled" # 已取消
```
### 待定义枚举([TODO]
以下枚举已定义框架,具体值待业务需求确定:
- `TaskType` - 任务类型
- `DownloadType` - 下载类型
- `NodeType` - 节点类型
- `OrderType` - 订单类型
- `RedeemType` - 兑换码类型
- `ReportReason` - 举报原因
---
## DTO 模型

View File

@@ -14,12 +14,27 @@ from .user import (
from .user_authn import AuthnResponse, UserAuthn
from .color import ThemeResponse
from .download import Download
from .download import (
Download,
DownloadAria2File,
DownloadAria2Info,
DownloadAria2InfoBase,
DownloadStatus,
DownloadType,
)
from .node import (
Aria2Configuration,
Aria2ConfigurationBase,
Node,
NodeStatus,
NodeType,
)
from .group import Group, GroupBase, GroupOptions, GroupOptionsBase, GroupResponse
from .node import Node
from .object import (
DirectoryCreateRequest,
DirectoryResponse,
FileMetadata,
FileMetadataBase,
Object,
ObjectBase,
ObjectDeleteRequest,
@@ -28,16 +43,16 @@ from .object import (
ObjectType,
PolicyResponse,
)
from .order import Order
from .order import Order, OrderStatus, OrderType
from .policy import Policy, PolicyOptions, PolicyOptionsBase, PolicyType
from .redeem import Redeem
from .report import Report
from .redeem import Redeem, RedeemType
from .report import Report, ReportReason
from .setting import Setting, SettingsType, SiteConfigResponse
from .share import Share
from .source_link import SourceLink
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType
from .webdav import WebDAV
from .database import engine, get_session

657
models/base/README.md Normal file
View File

@@ -0,0 +1,657 @@
# SQLModels Base Module
This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
## Table of Contents
- [Overview](#overview)
- [Migration Notice](#migration-notice)
- [Python 3.14 Compatibility](#python-314-compatibility)
- [Core Component](#core-component)
- [SQLModelBase](#sqlmodelbase)
- [Metaclass Features](#metaclass-features)
- [Automatic sa_type Injection](#automatic-sa_type-injection)
- [Table Configuration](#table-configuration)
- [Polymorphic Support](#polymorphic-support)
- [Custom Types Integration](#custom-types-integration)
- [Best Practices](#best-practices)
- [Troubleshooting](#troubleshooting)
## Overview
The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
- **Flexible configuration** through class parameters and automatic docstring support
- **Type-safe annotations** with automatic validation
All models in this project should directly or indirectly inherit from `SQLModelBase`.
---
## Migration Notice
As of the recent refactoring, the following components have been moved:
| Component | Old Location | New Location |
|-----------|-------------|--------------|
| `TableBase``TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `UUIDTableBase``UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
**Update your imports**:
```python
# ❌ Old (deprecated)
from sqlmodels.base import TableBase, UUIDTableBase
# ✅ New (correct)
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
```
For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Python 3.14 Compatibility
### Overview
This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
**Key Changes in Python 3.14**:
- Annotations are no longer evaluated at class definition time
- Type hints are stored as deferred code objects
- `__annotate__` function generates annotations on demand
- Forward references become `ForwardRef` objects
### Implementation Strategy
We use **`typing.get_type_hints()`** as the universal annotations resolver:
```python
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
# Create temporary proxy class
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
# Use get_type_hints with include_extras=True
evaluated = get_type_hints(
temp_cls,
globalns=module_globals,
localns=localns,
include_extras=True # Preserve Annotated metadata
)
return dict(evaluated), {}, module_globals, localns
```
**Why `get_type_hints()`?**
- ✅ Works across Python 3.10-3.14+
- ✅ Handles PEP 649 automatically
- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
- ✅ Resolves forward references
- ✅ Recommended by Python documentation
### SQLModel Compatibility Patch
**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
```python
if sys.version_info >= (3, 14):
def _patched_get_sqlalchemy_type(field):
annotation = getattr(field, 'annotation', None)
if annotation is not None:
# Priority 1: Check __sqlmodel_sa_type__ attribute
# Handles NumpyVector[dims, dtype] and similar custom types
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# Priority 2: Check Annotated metadata
if get_origin(annotation) is Annotated:
for metadata in get_args(annotation)[1:]:
if hasattr(metadata, '__sqlmodel_sa_type__'):
return metadata.__sqlmodel_sa_type__
# ... handle ForwardRef, ClassVar, etc.
return _original_get_sqlalchemy_type(field)
```
### Supported Patterns
#### Pattern 1: Direct Custom Type Usage
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Voice embedding - sa_type automatically extracted"""
```
#### Pattern 2: Annotated Wrapper
```python
from typing import Annotated
from sqlmodels.mixin import UUIDTableBaseMixin
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: EmbeddingVector
```
#### Pattern 3: Array Type
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import Array
from sqlmodels.mixin import TableBaseMixin
class ServerConfig(TableBaseMixin, table=True):
protocols: Array[ProtocolEnum]
"""Allowed protocols - sa_type from Array handler"""
```
### Migration from Python 3.13
**No code changes required!** The implementation is transparent:
- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
- Custom types already use `__sqlmodel_sa_type__` attribute
- Monkey-patch only activates for Python 3.14+
---
## Core Component
### SQLModelBase
`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
**Key Features**:
- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
- Automatic `validate_by_name` configuration
- Custom metaclass for sa_type injection and polymorphic setup
- Integration with Pydantic v2
- Python 3.14 PEP 649 compatibility
**Usage**:
```python
from sqlmodels.base import SQLModelBase
class UserBase(SQLModelBase):
name: str
"""User's display name"""
email: str
"""User's email address"""
```
**Important Notes**:
- Use **docstrings** for field descriptions, not `Field(description=...)`
- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
- This class should be used for non-table models (DTOs, request/response models)
**For table models**, use mixins from `sqlmodels.mixin`:
- `TableBaseMixin` - Integer primary key with timestamps
- `UUIDTableBaseMixin` - UUID primary key with timestamps
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
---
## Metaclass Features
### Automatic sa_type Injection
The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
**Before** (verbose):
```python
from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: np.ndarray = Field(
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
)
```
**After** (clean):
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Speaker voice embedding"""
```
**How It Works**:
The metaclass uses a three-tier detection strategy:
1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
```python
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
```
2. **Annotated metadata** (Priority 2)
```python
# For Annotated[np.ndarray, NumpyVector[256, np.float32]]
if get_origin(annotation) is typing.Annotated:
for item in metadata_items:
if hasattr(item, '__sqlmodel_sa_type__'):
return item.__sqlmodel_sa_type__
```
3. **Pydantic Core Schema metadata** (Priority 3)
```python
schema = annotation.__get_pydantic_core_schema__(...)
if schema['metadata'].get('sa_type'):
return schema['metadata']['sa_type']
```
After extracting `sa_type`, the metaclass:
- Creates `Field(sa_type=sa_type)` if no Field is defined
- Injects `sa_type` into existing Field if not already set
- Respects explicit `Field(sa_type=...)` (no override)
**Supported Patterns**:
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# Pattern 1: Direct usage (recommended)
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
# Pattern 2: With Field constraints
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32] = Field(nullable=False)
# Pattern 3: Annotated wrapper
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
class Model(UUIDTableBaseMixin, table=True):
embedding: EmbeddingVector
# Pattern 4: Explicit sa_type (override)
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32] = Field(
sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
)
```
### Table Configuration
The metaclass provides smart defaults and flexible configuration:
**Automatic `table=True`**:
```python
# Classes inheriting from TableBaseMixin automatically get table=True
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(UUIDTableBaseMixin): # table=True is automatic
pass
```
**Convenient mapper arguments**:
```python
# Instead of verbose __mapper_args__
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(
UUIDTableBaseMixin,
polymorphic_on='_polymorphic_name',
polymorphic_abstract=True
):
pass
# Equivalent to:
class MyModel(UUIDTableBaseMixin):
__mapper_args__ = {
'polymorphic_on': '_polymorphic_name',
'polymorphic_abstract': True
}
```
**Smart merging**:
```python
# Dictionary and keyword arguments are merged
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(
UUIDTableBaseMixin,
mapper_args={'version_id_col': 'version'},
polymorphic_on='type' # Merged into __mapper_args__
):
pass
```
### Polymorphic Support
The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
**Supported parameters**:
- `polymorphic_on`: Discriminator column name
- `polymorphic_identity`: Identity value for this class
- `polymorphic_abstract`: Whether this is an abstract base
- `table_args`: SQLAlchemy table arguments
- `table_name`: Override table name (becomes `__tablename__`)
**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Custom Types Integration
### Using NumpyVector
The `NumpyVector` type demonstrates automatic sa_type injection:
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
import numpy as np
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Speaker voice embedding - sa_type automatically injected"""
```
**How NumpyVector works**:
```python
# NumpyVector[dims, dtype] returns a class with:
class _NumpyVectorType:
__sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return handler.generate_schema(np.ndarray)
```
This dual approach ensures:
1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
2. Pydantic can validate as `np.ndarray`
### Creating Custom SQLAlchemy Types
To create types that work with automatic injection, provide one of:
**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
```python
from sqlalchemy import TypeDecorator, String
class UpperCaseString(TypeDecorator):
impl = String
def process_bind_param(self, value, dialect):
return value.upper() if value else value
class UpperCaseType:
__sqlmodel_sa_type__ = UpperCaseString()
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return core_schema.str_schema()
# Usage
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(UUIDTableBaseMixin, table=True):
code: UpperCaseType # Automatically uses UpperCaseString()
```
**Option 2: Pydantic metadata with sa_type**:
```python
def __get_pydantic_core_schema__(self, source_type, handler):
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.str_schema(),
metadata={'sa_type': UpperCaseString()}
)
```
**Option 3: Using Annotated**:
```python
from typing import Annotated
from sqlmodels.mixin import UUIDTableBaseMixin
UpperCase = Annotated[str, UpperCaseType()]
class MyModel(UUIDTableBaseMixin, table=True):
code: UpperCase
```
---
## Best Practices
### 1. Inherit from correct base classes
```python
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
# ✅ For non-table models (DTOs, requests, responses)
class UserBase(SQLModelBase):
name: str
# ✅ For table models with UUID primary key
class User(UserBase, UUIDTableBaseMixin, table=True):
email: str
# ✅ For table models with custom primary key
class LegacyUser(TableBaseMixin, table=True):
id: int = Field(primary_key=True)
username: str
```
### 2. Use docstrings for field descriptions
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Recommended
class User(UUIDTableBaseMixin, table=True):
name: str
"""User's display name"""
# ❌ Avoid
class User(UUIDTableBaseMixin, table=True):
name: str = Field(description="User's display name")
```
**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
### 3. Leverage automatic sa_type injection
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Clean and recommended
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Voice embedding"""
# ❌ Verbose and unnecessary
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: np.ndarray = Field(
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
)
```
### 4. Follow polymorphic naming conventions
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
### 5. Separate Base, Parent, and Implementation classes
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
# ✅ Recommended structure
class ASRBase(SQLModelBase):
"""Pure data fields, no table"""
name: str
base_url: str
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract parent with table"""
@abstractmethod
async def transcribe(self, audio: bytes) -> str:
pass
class WhisperASR(ASR, table=True):
"""Concrete implementation"""
model_size: str
async def transcribe(self, audio: bytes) -> str:
# Implementation
pass
```
**Why?**
- Base class can be reused for DTOs
- Parent class defines the polymorphic hierarchy
- Implementation classes are clean and focused
---
## Troubleshooting
### Issue: ValueError: X has no matching SQLAlchemy type
**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
```python
# ✅ Provide __sqlmodel_sa_type__
class MyType:
__sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
```
### Issue: Can't generate DDL for NullType()
**Symptoms**: Error during table creation saying a column has `NullType`.
**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
**Solution**:
1. Ensure your type has `__sqlmodel_sa_type__` class attribute
2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
3. Verify type annotation is correct (not a string forward reference)
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Correct
class Model(UUIDTableBaseMixin, table=True):
data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
# ❌ Wrong (string annotation)
class Model(UUIDTableBaseMixin, table=True):
data: 'NumpyVector[256, np.float32]' # sa_type lost
```
### Issue: Polymorphic identity conflicts
**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
**Solution**:
1. Check that each concrete class has a unique identity
2. Use `AutoPolymorphicIdentityMixin` for automatic naming
3. Manually specify identity if needed:
```python
class MyClass(Parent, polymorphic_identity='unique.name', table=True):
pass
```
### Issue: Python 3.14 annotation errors
**Symptoms**: Errors related to `__annotations__` or type resolution.
**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
1. Check for manual `__annotations__` manipulation (avoid it)
2. Ensure all types are properly imported
3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
### Issue: Polymorphic and CRUD-related errors
For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Implementation Details
For developers modifying this module:
**Core files**:
- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
**Key functions in this module**:
1. **`_resolve_annotations(attrs: dict[str, Any])`**
- Uses `typing.get_type_hints()` for Python 3.14 compatibility
- Returns tuple: `(annotations, annotation_strings, globalns, localns)`
- Preserves `Annotated` metadata with `include_extras=True`
2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
- Extracts SQLAlchemy type from type annotations
- Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
- Called by metaclass during class creation
3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
- Global monkey-patch for SQLModel
- Checks `__sqlmodel_sa_type__` before falling back to original logic
- Handles custom types like `NumpyVector` and `Array`
4. **`__DeclarativeMeta.__new__()`**
- Processes class definition parameters
- Injects `sa_type` into field definitions
- Sets up `__mapper_args__`, `__table_args__`, etc.
- Handles Python 3.14 annotations via `get_type_hints()`
**Metaclass processing order**:
1. Check if class should be a table (`_is_table_mixin`)
2. Collect `__mapper_args__` from kwargs and explicit dict
3. Process `table_args`, `table_name`, `abstract` parameters
4. Resolve annotations using `get_type_hints()`
5. For each field, try to extract `sa_type` and inject into Field
6. Call parent metaclass with cleaned kwargs
For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## See Also
**Project Documentation**:
- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
**External References**:
- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)

View File

@@ -1,2 +1,12 @@
"""
SQLModel 基础模块
包含:
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
注意:
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 models.mixin
为了避免循环导入,此处不再重新导出它们
请直接从 models.mixin 导入这些类
"""
from .sqlmodel_base import SQLModelBase
from .table_base import TableBase, UUIDTableBase, now, now_date

View File

@@ -1,5 +1,846 @@
from pydantic import ConfigDict
from sqlmodel import SQLModel
import sys
import typing
from typing import Any, Mapping, get_args, get_origin, get_type_hints
from pydantic import ConfigDict
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined as Undefined
from sqlalchemy.orm import Mapped
from sqlmodel import Field, SQLModel
from sqlmodel.main import SQLModelMetaclass
# Python 3.14+ PEP 649支持
if sys.version_info >= (3, 14):
import annotationlib
# 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
import sqlmodel.main
_original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
def _patched_get_sqlalchemy_type(field):
"""
修复SQLModel的get_sqlalchemy_type函数处理Python 3.14的类型问题。
问题:
1. ForwardRef对象来自Relationship字段会导致issubclass错误
2. typing._GenericAlias对象如ClassVar[T])也会导致同样问题
3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
4. Mapped类型在Python 3.14下可能出现在annotation中
5. Annotated类型可能包含sa_type metadata如Array[T]
6. 自定义类型如NumpyVector有__sqlmodel_sa_type__属性
7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
解决:
- 优先检查field.metadata中的__get_pydantic_core_schema__Pydantic已处理的情况
- 检测__sqlmodel_sa_type__属性NumpyVector等
- 检测Relationship/ClassVar等返回None
- 对于Annotated类型尝试提取sa_type metadata
- 其他情况调用原始函数
"""
# 优先检查 field.metadataPydantic已处理Annotated类型的情况
# 当使用 Array[T] 或 Annotated[T, metadata] 时Pydantic会将metadata存储在这里
metadata = getattr(field, 'metadata', None)
if metadata:
# metadata是一个列表包含所有Annotated的元数据项
for metadata_item in metadata:
# 检查metadata_item是否有__get_pydantic_core_schema__方法
if hasattr(metadata_item, '__get_pydantic_core_schema__'):
try:
# 调用获取schema
schema = metadata_item.__get_pydantic_core_schema__(None, None)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Pydantic schema获取可能失败类型不匹配、缺少属性等
# 这是正常情况继续检查下一个metadata项
pass
annotation = getattr(field, 'annotation', None)
if annotation is not None:
# 优先检查 __sqlmodel_sa_type__ 属性
# 这处理 NumpyVector[dims, dtype] 等自定义类型
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# 检查自定义类型如JSON100K的 __get_pydantic_core_schema__ 方法
# 这些类型在schema的metadata中定义sa_type
if hasattr(annotation, '__get_pydantic_core_schema__'):
try:
# 调用获取schema传None作为handler因为我们只需要metadata
schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Schema获取失败继续其他检查
pass
anno_type_name = type(annotation).__name__
# ForwardRef: Relationship字段的annotation
if anno_type_name == 'ForwardRef':
return None
# AnnotatedAlias: 检查是否有sa_type metadata如Array[T]
if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
from typing import get_origin, get_args
import typing
# 尝试提取Annotated的metadata
if hasattr(typing, 'get_args'):
args = get_args(annotation)
# args[0]是实际类型args[1:]是metadata
for metadata in args[1:]:
# 检查metadata是否有__get_pydantic_core_schema__方法
if hasattr(metadata, '__get_pydantic_core_schema__'):
try:
# 调用获取schema
schema = metadata.__get_pydantic_core_schema__(None, None)
# 检查schema中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Annotated metadata的schema获取可能失败
# 这是正常的类型检查过程继续检查下一个metadata
pass
# _GenericAlias或GenericAlias: typing泛型类型
if anno_type_name in ('_GenericAlias', 'GenericAlias'):
from typing import get_origin
import typing
origin = get_origin(annotation)
# ClassVar必须跳过
if origin is typing.ClassVar:
return None
# list/dict/tuple/set等内置泛型如果字段没有明确的Field或Relationship也跳过
# 这通常意味着它是Relationship字段或类变量
if origin in (list, dict, tuple, set):
# 检查field_info是否存在且有意义
# Relationship字段会有特殊的field_info
field_info = getattr(field, 'field_info', None)
if field_info is None:
return None
# Mapped: SQLAlchemy 2.0的Mapped类型SQLModel不应该处理
# 这可能是从父类继承的字段或Python 3.14注解处理的副作用
# 检查类型名称和annotation的字符串表示
if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
return None
# 检查annotation是否是Mapped类或其实例
try:
from sqlalchemy.orm import Mapped as SAMapped
# 检查origin对于Mapped[T]这种泛型)
from typing import get_origin
if get_origin(annotation) is SAMapped:
return None
# 检查类型本身
if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
return None
except (ImportError, TypeError):
# 如果SQLAlchemy没有Mapped或检查失败继续
pass
# 其他情况正常处理
return _original_get_sqlalchemy_type(field)
sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
# 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
# 在Python 3.14 + SQLModel组合下当子类如SMSBaoProvider继承父类如VerificationCodeProvider
# 父类的关系字段如server_config会在子类的model_fields中出现
# 但其default值错误地设置为InstrumentedAttribute对象而不是None
# 这导致实例化时尝试设置InstrumentedAttribute为字段值触发SQLAlchemy内部错误
import sqlmodel._compat as _compat
from sqlalchemy.orm import attributes as _sa_attributes
_original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
def _patched_sqlmodel_table_construct(self_instance, values):
"""
修复sqlmodel_table_construct跳过InstrumentedAttribute默认值
问题:
- 继承自polymorphic基类的表类如FishAudioTTS, SMSBaoProvider
- 其model_fields中的继承字段default值为InstrumentedAttribute
- 原函数尝试将InstrumentedAttribute设置为字段值
- SQLAlchemy无法处理抛出 '_sa_instance_state' 错误
解决:
- 只设置用户提供的值和非InstrumentedAttribute默认值
- InstrumentedAttribute默认值跳过让SQLAlchemy自己处理
"""
cls = type(self_instance)
# 收集要设置的字段值
fields_to_set = {}
for name, field in cls.model_fields.items():
# 如果用户提供了值,直接使用
if name in values:
fields_to_set[name] = values[name]
continue
# 否则检查默认值
# 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
continue
# 使用正常的默认值
if field.default is not Undefined:
fields_to_set[name] = field.default
elif field.default_factory is not None:
fields_to_set[name] = field.get_default(call_default_factory=True)
# 设置属性 - 只设置非InstrumentedAttribute值
for key, value in fields_to_set.items():
if not isinstance(value, _sa_attributes.InstrumentedAttribute):
setattr(self_instance, key, value)
# 设置Pydantic内部属性
object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
if not cls.__pydantic_root_model__:
_extra = None
if cls.model_config.get('extra') == 'allow':
_extra = {}
for k, v in values.items():
if k not in cls.model_fields:
_extra[k] = v
object.__setattr__(self_instance, '__pydantic_extra__', _extra)
if cls.__pydantic_post_init__:
self_instance.model_post_init(None)
elif not cls.__pydantic_root_model__:
object.__setattr__(self_instance, '__pydantic_private__', None)
# 设置关系
for key in self_instance.__sqlmodel_relationships__:
value = values.get(key, Undefined)
if value is not Undefined:
setattr(self_instance, key, value)
return self_instance
_compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
else:
annotationlib = None
def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
"""
从类型注解中提取SQLAlchemy类型。
支持以下形式:
1. NumpyVector[256, np.float32] - 直接使用类型有__sqlmodel_sa_type__属性
2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
Args:
annotation: 字段的类型注解
Returns:
提取到的SQLAlchemy类型如果没有则返回None
"""
# 方法1直接检查类型本身是否有__sqlmodel_sa_type__属性
# 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# 方法2检查是否为Annotated类型
if get_origin(annotation) is typing.Annotated:
# 获取元数据项(跳过第一个实际类型参数)
args = get_args(annotation)
if len(args) >= 2:
metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
# 遍历元数据查找包含sa_type的项
for item in metadata_items:
# 检查元数据项是否有__sqlmodel_sa_type__属性
if hasattr(item, '__sqlmodel_sa_type__'):
return item.__sqlmodel_sa_type__
# 检查是否有__get_pydantic_core_schema__方法
if hasattr(item, '__get_pydantic_core_schema__'):
try:
# 调用该方法获取core schema
schema = item.__get_pydantic_core_schema__(
annotation,
lambda x: None # 虚拟handler
)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError, ValueError):
# Pydantic core schema获取可能失败
# - TypeError: 参数不匹配
# - AttributeError: metadata不存在
# - KeyError: schema结构不符合预期
# - ValueError: 无效的类型定义
# 这是正常的类型探测过程继续检查下一个metadata项
pass
# 方法3检查类型本身是否有__get_pydantic_core_schema__
# 虽然NumpyVector已经在方法1处理但这是通用的fallback
if hasattr(annotation, '__get_pydantic_core_schema__'):
try:
schema = annotation.__get_pydantic_core_schema__(
annotation,
lambda x: None # 虚拟handler
)
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError, ValueError):
# 类型本身的schema获取失败
# 这是正常的fallback机制annotation可能不支持此协议
pass
return None
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
dict[str, Any],
dict[str, str],
Mapping[str, Any],
Mapping[str, Any],
]:
"""
Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
metadata and custom types remain accessible. Forward references that cannot be
evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
whole resolution process.
"""
raw_annotations = attrs.get('__annotations__') or {}
try:
base_annotations = dict(raw_annotations)
except TypeError:
base_annotations = {}
module_name = attrs.get('__module__')
module_globals: dict[str, Any]
if module_name and module_name in sys.modules:
module_globals = dict(sys.modules[module_name].__dict__)
else:
module_globals = {}
module_globals.setdefault('__builtins__', __builtins__)
localns: dict[str, Any] = dict(attrs)
try:
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
temp_cls.__module__ = module_name
extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
evaluated = get_type_hints(
temp_cls,
globalns=module_globals,
localns=localns,
**extras_kw,
)
except (NameError, AttributeError, TypeError, RecursionError):
# get_type_hints可能失败的原因
# - NameError: 前向引用无法解析(类型尚未定义)
# - AttributeError: 模块或类型不存在
# - TypeError: 无效的类型注解
# - RecursionError: 循环依赖的类型定义
# 这是正常情况,回退到原始注解字符串
evaluated = base_annotations
return dict(evaluated), {}, module_globals, localns
def _evaluate_annotation_from_string(
field_name: str,
annotation_strings: dict[str, str],
current_type: Any,
globalns: Mapping[str, Any],
localns: Mapping[str, Any],
) -> Any:
"""
Attempt to re-evaluate the original annotation string for a field.
This is used as a fallback when the resolved annotation lost its metadata
(e.g., Annotated wrappers) and we need to recover custom sa_type data.
"""
if not annotation_strings:
return current_type
expr = annotation_strings.get(field_name)
if not expr or not isinstance(expr, str):
return current_type
try:
return eval(expr, globalns, localns)
except (NameError, SyntaxError, AttributeError, TypeError):
# eval可能失败的原因
# - NameError: 类型名称在namespace中不存在
# - SyntaxError: 注解字符串有语法错误
# - AttributeError: 访问不存在的模块属性
# - TypeError: 无效的类型表达式
# 这是正常的fallback机制返回当前已解析的类型
return current_type
class __DeclarativeMeta(SQLModelMetaclass):
"""
一个智能的混合模式元类,它提供了灵活性和清晰度:
1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
"""
_KNOWN_MAPPER_KEYS = {
"polymorphic_on",
"polymorphic_identity",
"polymorphic_abstract",
"version_id_col",
"concrete",
}
def __new__(cls, name, bases, attrs, **kwargs):
# 1. 约定优于配置:自动设置 table=True
is_intended_as_table = any(getattr(b, '_is_table_mixin', False) for b in bases)
if is_intended_as_table and 'table' not in kwargs:
kwargs['table'] = True
# 2. 智能合并 __mapper_args__
collected_mapper_args = {}
# 首先,处理明确的 mapper_args 字典 (优先级较低)
if 'mapper_args' in kwargs:
collected_mapper_args.update(kwargs.pop('mapper_args'))
# 其次,处理便捷的关键字参数 (优先级更高)
for key in cls._KNOWN_MAPPER_KEYS:
if key in kwargs:
# .pop() 获取值并移除,避免传递给父类
collected_mapper_args[key] = kwargs.pop(key)
# 如果收集到了任何 mapper 参数,则更新到类的属性中
if collected_mapper_args:
existing = attrs.get('__mapper_args__', {}).copy()
existing.update(collected_mapper_args)
attrs['__mapper_args__'] = existing
# 3. 处理其他明确的参数
if 'table_args' in kwargs:
attrs['__table_args__'] = kwargs.pop('table_args')
if 'table_name' in kwargs:
attrs['__tablename__'] = kwargs.pop('table_name')
if 'abstract' in kwargs:
attrs['__abstract__'] = kwargs.pop('abstract')
# 4. 从Annotated元数据中提取sa_type并注入到Field
# 重要必须在调用父类__new__之前处理因为SQLModel会消费annotations
#
# Python 3.14兼容性问题:
# - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃issubclass错误
# - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
# - 虽然PEP 749建议不修改__annotations__但这是修复SQLModel bug的必要措施
#
# 获取annotations的策略
# - Python 3.14+: 优先从__annotate__获取如果存在
# - fallback: 从__annotations__读取如果存在
# - 最终fallback: 空字典
annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
if annotations:
attrs['__annotations__'] = annotations
if annotationlib is not None:
# 在Python 3.14中禁用descriptor转为普通dict
attrs['__annotate__'] = None
for field_name, field_type in annotations.items():
field_type = _evaluate_annotation_from_string(
field_name,
annotation_strings,
field_type,
eval_globals,
eval_locals,
)
# 跳过字符串或ForwardRef类型注解让SQLModel自己处理
if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
continue
# 跳过特殊类型的字段
origin = get_origin(field_type)
# 跳过 ClassVar 字段 - 它们不是数据库字段
if origin is typing.ClassVar:
continue
# 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
if origin is Mapped:
continue
# 尝试从注解中提取sa_type
sa_type = _extract_sa_type_from_annotation(field_type)
if sa_type is not None:
# 检查字段是否已有Field定义
field_value = attrs.get(field_name, Undefined)
if field_value is Undefined:
# 没有Field定义创建一个新的Field并注入sa_type
attrs[field_name] = Field(sa_type=sa_type)
elif isinstance(field_value, FieldInfo):
# 已有Field定义检查是否已设置sa_type
# 注意:只有在未设置时才注入,尊重显式配置
# SQLModel使用Undefined作为"未设置"的标记
if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
field_value.sa_type = sa_type
# 如果field_value是其他类型如默认值不处理
# SQLModel会在后续处理中将其转换为Field
# 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
result = super().__new__(cls, name, bases, attrs, **kwargs)
# 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
# SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
# 这导致子类丢失父类的关系定义,触发错误的 Column 创建
# 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
if kwargs.get('table', False):
for base in bases:
if hasattr(base, '__sqlmodel_relationships__'):
for rel_name, rel_info in base.__sqlmodel_relationships__.items():
# 只继承子类没有重新定义的关系
if rel_name not in result.__sqlmodel_relationships__:
result.__sqlmodel_relationships__[rel_name] = rel_info
# 同时修复被错误创建的 Column - 恢复为父类的 relationship
if hasattr(base, rel_name):
base_attr = getattr(base, rel_name)
setattr(result, rel_name, base_attr)
# 7. 检测:禁止子类重定义父类的 Relationship 字段
# 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
# 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
for base in bases:
parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
for rel_name in parent_relationships:
# 检查当前类是否在 attrs 中重新定义了这个关系字段
if rel_name in attrs:
raise TypeError(
f"{name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'"
f"如需修改关系配置,请在父类中修改。"
)
# 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
# SQLModel 0.0.27 bug子类会错误地继承父类的 Relationship 字段到 model_fields
# 这导致 Pydantic 尝试为 Relationship 字段生成 schema因为类型是
# Mapped[list['Character']] 这种前向引用Pydantic 无法解析,
# 导致 __pydantic_complete__ = False
#
# 修复策略:
# - 检查类的 __sqlmodel_relationships__ 属性
# - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
# - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
relationships = getattr(result, '__sqlmodel_relationships__', {})
if relationships:
model_fields = getattr(result, 'model_fields', {})
pydantic_fields = getattr(result, '__pydantic_fields__', {})
fields_removed = False
for rel_name in relationships:
if rel_name in model_fields:
del model_fields[rel_name]
fields_removed = True
if rel_name in pydantic_fields:
del pydantic_fields[rel_name]
fields_removed = True
# 如果移除了字段,重新构建 Pydantic 模式
# 注意:只在有字段被移除时才 rebuild避免不必要的开销
if fields_removed and hasattr(result, 'model_rebuild'):
result.model_rebuild(force=True)
return result
def __init__(
cls,
classname: str,
bases: tuple[type, ...],
dict_: dict[str, typing.Any],
**kw: typing.Any,
) -> None:
"""
重写 SQLModel 的 __init__ 以支持联表继承Joined Table Inheritance
SQLModel 原始行为:
- 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
- 这阻止了子类创建自己的表
修复逻辑:
- 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
- 强制调用 DeclarativeMeta.__init__ 来创建子表
"""
from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
# 检查是否是表模型
if not is_table_model_class(cls):
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
return
# 检查是否有基类是表模型
base_is_table = any(is_table_model_class(base) for base in bases)
if not base_is_table:
# 没有基类是表模型,走正常的 SQLModel 流程
# 处理关系字段
cls._setup_relationships()
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
return
# 关键:检测联表继承场景
# 条件:
# 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
# 2. 当前类有字段带有 foreign_key 指向父表
current_tablename = getattr(cls, '__tablename__', None)
# 查找父表信息
parent_table = None
parent_tablename = None
for base in bases:
if is_table_model_class(base) and hasattr(base, '__tablename__'):
parent_tablename = base.__tablename__
break
# 检查是否有不同的 tablename
has_different_tablename = (
current_tablename is not None
and parent_tablename is not None
and current_tablename != parent_tablename
)
# 检查是否有外键字段指向父表的主键
# 注意:由于字段合并,我们需要检查直接基类的 model_fields
# 而不是当前类的合并后的 model_fields
has_fk_to_parent = False
def _normalize_tablename(name: str) -> str:
"""标准化表名以进行比较(移除下划线,转小写)"""
return name.replace('_', '').lower()
def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
"""检查 FK 字符串是否指向父表"""
if not fk_str or not parent_table:
return False
# FK 格式: "tablename.column" 或 "schema.tablename.column"
parts = fk_str.split('.')
if len(parts) >= 2:
fk_table = parts[-2] # 取倒数第二个作为表名
# 标准化比较(处理下划线差异)
return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
return False
if has_different_tablename and parent_tablename:
# 首先检查当前类的 model_fields
for field_name, field_info in cls.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
has_fk_to_parent = True
break
# 如果没找到,检查直接基类的 model_fields解决 mixin 字段被覆盖的问题)
if not has_fk_to_parent:
for base in bases:
if hasattr(base, 'model_fields'):
for field_name, field_info in base.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
has_fk_to_parent = True
break
if has_fk_to_parent:
break
is_joined_inheritance = has_different_tablename and has_fk_to_parent
if is_joined_inheritance:
# 联表继承:需要创建子表
# 修复外键字段:由于字段合并,外键信息可能丢失
# 需要从基类的 mixin 中找回外键信息,并重建列
from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
from sqlalchemy.dialects.postgresql import UUID as SA_UUID
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm.attributes import InstrumentedAttribute
# 联表继承:子表只应该有 idFK 到父表)+ 子类特有的字段
# 所有继承自祖先表的列都不应该在子表中重复创建
# 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
# 需要遍历整个 MRO因为可能是多级继承如 Tool -> Function -> GetWeatherFunction
ancestor_column_names: set[str] = set()
for ancestor in cls.__mro__:
if ancestor is cls:
continue # 跳过当前类
if is_table_model_class(ancestor):
try:
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
mapper = sa_inspect(ancestor)
for col in mapper.local_table.columns:
# 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
if col.name.startswith('_polymorphic'):
continue
ancestor_column_names.add(col.name)
except NoInspectionAvailable:
continue
# 找到子类自己定义的字段(不在父类中的)
child_own_fields: set[str] = set()
for field_name in cls.model_fields:
# 检查这个字段是否是在当前类直接定义的(不是继承的)
# 通过检查父类是否有这个字段来判断
is_inherited = False
for base in bases:
if hasattr(base, 'model_fields') and field_name in base.model_fields:
is_inherited = True
break
if not is_inherited:
child_own_fields.add(field_name)
# 从子类类属性中移除父表已有的列定义
# 这样 SQLAlchemy 就不会在子表中创建这些列
fk_field_name = None
for base in bases:
if hasattr(base, 'model_fields'):
for field_name, field_info in base.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
pk = getattr(field_info, 'primary_key', False)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
fk_field_name = field_name
# 找到了外键字段,重建它
# 创建一个新的 Column 对象包含外键约束
new_col = Column(
field_name,
SA_UUID(as_uuid=True),
ForeignKey(fk),
primary_key=pk if pk else False
)
setattr(cls, field_name, new_col)
break
else:
continue
break
# 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
# 这防止 SQLAlchemy 在子表中创建重复列
# 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
for col_name in ancestor_column_names:
if col_name == fk_field_name:
continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
if col_name == 'id':
continue # id 会被 FK 字段覆盖
if col_name in child_own_fields:
continue # 保留子类自己定义的字段
# 检查类属性是否是 Column 或 InstrumentedAttribute
if col_name in cls.__dict__:
attr = cls.__dict__[col_name]
# Column 对象或 InstrumentedAttribute 都需要删除
if isinstance(attr, (Column, InstrumentedAttribute)):
try:
delattr(cls, col_name)
except AttributeError:
pass
# 找到子类自己定义的关系(不在父类中的)
# 继承的关系会从父类自动获取,只需要设置子类新增的关系
child_own_relationships: set[str] = set()
for rel_name in cls.__sqlmodel_relationships__:
is_inherited = False
for base in bases:
if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
is_inherited = True
break
if not is_inherited:
child_own_relationships.add(rel_name)
# 只为子类自己定义的新关系调用关系设置
if child_own_relationships:
cls._setup_relationships(only_these=child_own_relationships)
# 强制调用 DeclarativeMeta.__init__
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
else:
# 非联表继承:单表继承或正常 Pydantic 模型
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
"""
设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
Args:
only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
如果为 None设置所有关系默认行为
"""
from sqlalchemy.orm import relationship, Mapped
from sqlalchemy import inspect
from sqlmodel.main import get_relationship_to
from typing import get_origin
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
# 如果指定了 only_these只设置这些关系
if only_these is not None and rel_name not in only_these:
continue
if rel_info.sa_relationship:
setattr(cls, rel_name, rel_info.sa_relationship)
continue
raw_ann = cls.__annotations__[rel_name]
origin: typing.Any = get_origin(raw_ann)
if origin is Mapped:
ann = raw_ann.__args__[0]
else:
ann = raw_ann
cls.__annotations__[rel_name] = Mapped[ann]
relationship_to = get_relationship_to(
name=rel_name, rel_info=rel_info, annotation=ann
)
rel_kwargs: dict[str, typing.Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.cascade_delete:
rel_kwargs["cascade"] = "all, delete-orphan"
if rel_info.passive_deletes:
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table")
if local_table is None:
raise RuntimeError(
f"Couldn't find secondary table for {rel_info.link_model}"
)
rel_kwargs["secondary"] = local_table
rel_args: list[typing.Any] = []
if rel_info.sa_relationship_args:
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
setattr(cls, rel_name, rel_value)
class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
"""此类必须和TableBase系列类搭配使用"""
class SQLModelBase(SQLModel):
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)

View File

@@ -1,203 +0,0 @@
import uuid
from datetime import datetime
from typing import Union, List, TypeVar, Type, Literal, override, Optional
from fastapi import HTTPException
from sqlalchemy import DateTime, BinaryExpression, ClauseElement
from sqlalchemy.orm import selectinload
from sqlmodel import Field, select, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.sql._typing import _OnClauseArgument
from sqlalchemy.ext.asyncio import AsyncAttrs
from .sqlmodel_base import SQLModelBase
T = TypeVar("T", bound="TableBase")
M = TypeVar("M", bound="SQLModel")
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
class TableBase(SQLModelBase, AsyncAttrs):
id: int | None = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=now)
updated_at: datetime = Field(
sa_type=DateTime,
sa_column_kwargs={"default": now, "onupdate": now},
default_factory=now
)
@classmethod
async def add(cls: Type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | List[T]:
"""
新增一条记录
:param session: 数据库会话
:param instances:
:param refresh:
:return: 新增的实例对象
usage:
item1 = Item(...)
item2 = Item(...)
Item.add(session, [item1, item2])
item1_id = item1.id
"""
is_list = False
if isinstance(instances, list):
is_list = True
session.add_all(instances)
else:
session.add(instances)
await session.commit()
if refresh:
if is_list:
for instance in instances:
await session.refresh(instance)
else:
await session.refresh(instances)
return instances
async def save(self: T, session: AsyncSession, load: Optional[Relationship] = None) -> T:
session.add(self)
await session.commit()
if load is not None:
cls = type(self)
return await cls.get(session, cls.id == self.id, load=load)
else:
await session.refresh(self)
return self
async def update(
self: T,
session: AsyncSession,
other: M,
extra_data: dict = None,
exclude_unset: bool = True
) -> T:
"""
更新记录
:param session: 数据库会话
:param other:
:param extra_data:
:param exclude_unset:
:return:
"""
self.sqlmodel_update(other.model_dump(exclude_unset=exclude_unset), update=extra_data)
session.add(self)
await session.commit()
await session.refresh(self)
return self
@classmethod
async def delete(cls: Type[T], session: AsyncSession, instances: T | list[T]) -> None:
"""
删除一些记录
:param session: 数据库会话
:param instances:
:return: None
usage:
item1 = Item.get(...)
item2 = Item.get(...)
Item.delete(session, [item1, item2])
"""
if isinstance(instances, list):
for instance in instances:
await session.delete(instance)
else:
await session.delete(instances)
await session.commit()
@classmethod
async def get(
cls: Type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None,
*,
offset: int | None = None,
limit: int | None = None,
fetch_mode: Literal["one", "first", "all"] = "first",
join: Type[T] | tuple[Type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: Union[Relationship, None] = None,
order_by: list[ClauseElement] | None = None
) -> T | List[T] | None:
"""
异步获取模型实例
参数:
session: 异步数据库会话
condition: SQLAlchemy查询条件如Model.id == 1
offset: 结果偏移量
limit: 结果数量限制
options: 查询选项如selectinload(Model.relation),异步访问关系属性必备,不然会报错
fetch_mode: 获取模式 - "one"/"all"/"first"
join: 要联接的模型类
返回:
根据fetch_mode返回相应的查询结果
"""
statement = select(cls)
if condition is not None:
statement = statement.where(condition)
if join is not None:
statement = statement.join(*join)
if options:
statement = statement.options(*options)
if load:
statement = statement.options(selectinload(load))
if order_by is not None:
statement = statement.order_by(*order_by)
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await session.exec(statement)
if fetch_mode == "one":
return result.one()
elif fetch_mode == "first":
return result.first()
elif fetch_mode == "all":
return list(result.all())
else:
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
@classmethod
async def get_exist_one(cls: Type[T], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> T:
"""此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None
而是会抛出fastapi 404 异常"""
instance = await cls.get(session, cls.id == id, load=load)
if not instance:
raise HTTPException(status_code=404, detail="Not found")
return instance
class UUIDTableBase(TableBase):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
"""override"""
@classmethod
@override
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Union[Relationship, None] = None) -> T:
return await super().get_exist_one(session, id, load) # type: ignore

View File

@@ -1,46 +1,176 @@
from typing import Optional, TYPE_CHECKING
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel import Field, Relationship, UniqueConstraint, Index
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin, TableBaseMixin
class DownloadStatus(StrEnum):
"""下载状态枚举"""
RUNNING = "running"
"""进行中"""
COMPLETED = "completed"
"""已完成"""
ERROR = "error"
"""错误"""
class DownloadType(StrEnum):
"""下载类型枚举"""
# [TODO] 补充具体下载类型
pass
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .user import User
from .task import Task
from .node import Node
# ==================== Aria2 信息模型 ====================
class DownloadAria2InfoBase(SQLModelBase):
"""Aria2下载信息基础模型"""
info_hash: str | None = Field(default=None, max_length=40)
"""InfoHashBT种子"""
piece_length: int = 0
"""分片大小"""
num_pieces: int = 0
"""分片数量"""
num_seeders: int = 0
"""做种人数"""
connections: int = 0
"""连接数"""
upload_speed: int = 0
"""上传速度bytes/s"""
upload_length: int = 0
"""已上传大小(字节)"""
error_code: str | None = None
"""错误代码"""
error_message: str | None = None
"""错误信息"""
class DownloadAria2Info(DownloadAria2InfoBase, SQLModelBase, table=True):
"""Aria2下载信息模型与Download一对一关联"""
download_id: UUID = Field(foreign_key="download.id", primary_key=True)
"""关联的下载任务UUID"""
# 反向关系
download: "Download" = Relationship(back_populates="aria2_info")
"""关联的下载任务"""
class DownloadAria2File(SQLModelBase, TableBaseMixin):
"""Aria2下载文件列表与Download一对多关联"""
download_id: UUID = Field(foreign_key="download.id", index=True)
"""关联的下载任务UUID"""
file_index: int = Field(ge=1)
"""文件索引从1开始"""
path: str
"""文件路径"""
length: int = 0
"""文件大小(字节)"""
completed_length: int = 0
"""已完成大小(字节)"""
is_selected: bool = True
"""是否选中下载"""
# 反向关系
download: "Download" = Relationship(back_populates="aria2_files")
"""关联的下载任务"""
# ==================== 主模型 ====================
class DownloadBase(SQLModelBase):
pass
class Download(DownloadBase, UUIDTableBase, table=True):
class Download(DownloadBase, UUIDTableBaseMixin):
"""离线下载任务模型"""
__table_args__ = (
UniqueConstraint("node_id", "g_id", name="uq_download_node_gid"),
Index("ix_download_status", "status"),
Index("ix_download_user_status", "user_id", "status"),
)
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载状态: 0=进行中, 1=完成, 2=错误")
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务类型")
source: str = Field(description="来源URL或标识")
total_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="总大小(字节)")
downloaded_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已下载大小(字节)")
g_id: str | None = Field(default=None, index=True, description="Aria2 GID")
speed: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载速度 (bytes/s)")
parent: str | None = Field(default=None, description="父任务标识")
attrs: str | None = Field(default=None, description="额外属性 (JSON格式)")
# attrs 示例: {"gid":"65c5faf38374cc63","status":"removed","totalLength":"0","completedLength":"0","uploadLength":"0","bitfield":"","downloadSpeed":"0","uploadSpeed":"0","infoHash":"ca159db2b1e78f6e95fd972be72251f967f639d4","numSeeders":"0","seeder":"","pieceLength":"16384","numPieces":"0","connections":"0","errorCode":"31","errorMessage":"","followedBy":null,"belongsTo":"","dir":"/data/ccaaDown/aria2/7a208304-9126-46d2-ba47-a6959f236a07","files":[{"index":"1","path":"[METADATA]zh-cn_windows_11_consumer_editions_version_21h2_updated_aug_2022_x64_dvd_a29983d5.iso","length":"0","completedLength":"0","selected":"true","uris":[]}],"bittorrent":{"announceList":[["udp://tracker.opentrackr.org:1337/announce"],["udp://9.rarbg.com:2810/announce"],["udp://tracker.openbittorrent.com:6969/announce"],["https://opentracker.i2p.rocks:443/announce"],["http://tracker.openbittorrent.com:80/announce"],["udp://open.stealth.si:80/announce"],["udp://tracker.torrent.eu.org:451/announce"],["udp://exodus.desync.com:6969/announce"],["udp://tracker.tiny-vps.com:6969/announce"],["udp://tracker.pomf.se:80/announce"],["udp://tracker.moeking.me:6969/announce"],["udp://tracker.dler.org:6969/announce"],["udp://open.demonii.com:1337/announce"],["udp://explodie.org:6969/announce"],["udp://chouchou.top:8080/announce"],["udp://bt.oiyo.tk:6969/announce"],["https://tracker.nanoha.org:443/announce"],["https://tracker.lilithraws.org:443/announce"],["http://tracker3.ctix.cn:8080/announce"],["http://tracker.nucozer-tracker.ml:2710/announce"]],"comment":"","creationDate":0,"mode":"","info":{"name":""}}}
error: str | None = Field(default=None, description="错误信息")
dst: str = Field(description="目标存储路径")
status: DownloadStatus = Field(default=DownloadStatus.RUNNING, sa_column_kwargs={"server_default": "'running'"})
"""下载状态"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""任务类型 [TODO] 待定义枚举"""
source: str
"""来源URL或标识"""
total_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""总大小(字节)"""
downloaded_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""已下载大小(字节)"""
g_id: str | None = Field(default=None, index=True)
"""Aria2 GID"""
speed: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""下载速度bytes/s"""
parent: str | None = Field(default=None, max_length=255)
"""父任务标识"""
error: str | None = Field(default=None)
"""错误信息"""
dst: str
"""目标存储路径"""
# 外键
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
task_id: int | None = Field(default=None, foreign_key="task.id", index=True, description="关联的任务ID")
node_id: int = Field(foreign_key="node.id", index=True, description="执行下载的节点ID")
user_id: UUID = Field(foreign_key="user.id", index=True)
"""所属用户UUID"""
task_id: int | None = Field(default=None, foreign_key="task.id", index=True)
"""关联的任务ID"""
node_id: int = Field(foreign_key="node.id", index=True)
"""执行下载的节点ID"""
# 关系
aria2_info: DownloadAria2Info | None = Relationship(
back_populates="download",
sa_relationship_kwargs={"uselist": False},
)
"""Aria2下载信息"""
aria2_files: list[DownloadAria2File] = Relationship(back_populates="download")
"""Aria2文件列表"""
user: "User" = Relationship(back_populates="downloads")
task: Optional["Task"] = Relationship(back_populates="downloads")
"""所属用户"""
task: "Task" = Relationship(back_populates="downloads")
"""关联的任务"""
node: "Node" = Relationship(back_populates="downloads")
"""执行下载的节点"""

View File

@@ -4,7 +4,8 @@ from uuid import UUID
from sqlmodel import Field, Relationship, text
from .base import TableBase, SQLModelBase, UUIDTableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -75,7 +76,7 @@ class GroupResponse(GroupBase, GroupOptionsBase):
from .policy import GroupPolicyLink
class GroupOptions(GroupOptionsBase, TableBase, table=True):
class GroupOptions(GroupOptionsBase, TableBaseMixin):
"""用户组选项模型"""
group_id: UUID = Field(foreign_key="group.id", unique=True)
@@ -100,7 +101,7 @@ class GroupOptions(GroupOptionsBase, TableBase, table=True):
group: "Group" = Relationship(back_populates="options")
class Group(GroupBase, UUIDTableBase, table=True):
class Group(GroupBase, UUIDTableBaseMixin):
"""用户组模型"""
name: str = Field(max_length=255, unique=True)
@@ -134,14 +135,17 @@ class Group(GroupBase, UUIDTableBase, table=True):
)
# 关系:一个组可以有多个用户
user: list["User"] = Relationship(
users: list["User"] = Relationship(
back_populates="group",
sa_relationship_kwargs={"foreign_keys": "User.group_id"}
)
previous_user: list["User"] = Relationship(
"""当前属于该组的用户列表"""
previous_users: list["User"] = Relationship(
back_populates="previous_group",
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
)
"""之前属于该组的用户列表(用于过期后恢复)"""
def to_response(self) -> "GroupResponse":
"""转换为响应 DTO"""

543
models/mixin/README.md Normal file
View File

@@ -0,0 +1,543 @@
# SQLModel Mixin Module
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
## Module Overview
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
- **JWT Authentication**: Token generation and validation
- **Pagination & Sorting**: Standardized table view parameters
- **Response DTOs**: Consistent id/timestamp fields for API responses
## Module Structure
```
sqlmodels/mixin/
├── __init__.py # Module exports
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
└── jwt/ # JWT authentication
├── __init__.py
├── key.py # JWTKey database model
├── payload.py # JWTPayloadBase
├── manager.py # JWTManager singleton
├── auth.py # JWTAuthMixin
├── exceptions.py # JWT-related exceptions
└── responses.py # TokenResponse DTO
```
## Dependency Hierarchy
The module has a strict import order to avoid circular dependencies:
1. **polymorphic.py** - Only depends on `SQLModelBase`
2. **table.py** - Depends on `polymorphic.py`
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
4. **info_response.py** - Only depends on `SQLModelBase`
## Core Components
### 1. TableBaseMixin
Base mixin for database table models with integer primary keys.
**Features:**
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
- Automatic timestamp management (`created_at`, `updated_at`)
- Async relationship loading support (via `AsyncAttrs`)
- Pagination and sorting via `TableViewRequest`
- Polymorphic subclass loading support
**Fields:**
- `id: int | None` - Integer primary key (auto-increment)
- `created_at: datetime` - Record creation timestamp
- `updated_at: datetime` - Record update timestamp (auto-updated)
**Usage:**
```python
from sqlmodels.mixin import TableBaseMixin
from sqlmodels.base import SQLModelBase
class User(SQLModelBase, TableBaseMixin, table=True):
name: str
email: str
"""User email"""
# CRUD operations
async def example(session: AsyncSession):
# Add
user = User(name="Alice", email="alice@example.com")
user = await user.save(session)
# Get
user = await User.get(session, User.id == 1)
# Update
update_data = UserUpdateRequest(name="Alice Smith")
user = await user.update(session, update_data)
# Delete
await User.delete(session, user)
# Count
count = await User.count(session, User.is_active == True)
```
**Important Notes:**
- `save()` and `update()` return refreshed instances - **always use the return value**:
```python
# ✅ Correct
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device
```
### 2. UUIDTableBaseMixin
Extends `TableBaseMixin` with UUID primary keys instead of integers.
**Differences from TableBaseMixin:**
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
- `get_exist_one()` accepts `UUID` instead of `int`
**Usage:**
```python
from sqlmodels.mixin import UUIDTableBaseMixin
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
name: str
description: str | None = None
"""Character description"""
```
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
### 3. TableViewRequest
Standardized pagination and sorting parameters for LIST endpoints.
**Fields:**
- `offset: int | None` - Skip first N records (default: 0)
- `limit: int | None` - Return max N records (default: 50, max: 100)
- `desc: bool | None` - Sort descending (default: True)
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
**Usage with TableBaseMixin.get():**
```python
from dependencies import TableViewRequestDep
@router.get("/list")
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> list[Character]:
"""List characters with pagination and sorting"""
return await Character.get(
session,
fetch_mode="all",
table_view=table_view # Automatically handles pagination and sorting
)
```
**Manual usage:**
```python
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
```
**Backward Compatibility:**
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
### 4. PolymorphicBaseMixin
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
**Automatic Configuration:**
- Defines `_polymorphic_name: str` field (indexed)
- Sets `polymorphic_on='_polymorphic_name'`
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
**Methods:**
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
**Usage:**
```python
from abc import ABC, abstractmethod
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
"""Abstract base class for all tools"""
name: str
description: str
"""Tool description"""
@abstractmethod
async def execute(self, params: dict) -> dict:
"""Execute the tool"""
pass
```
**Why Single Underscore Prefix?**
- SQLAlchemy maps single-underscore fields to database columns
- Pydantic treats them as private (excluded from serialization)
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
### 5. create_subclass_id_mixin()
Factory function to create ID mixins for subclasses in joined table inheritance.
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
**Signature:**
```python
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
"""
Args:
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
Returns:
A mixin class containing id field (foreign key + primary key)
"""
```
**Usage:**
```python
from sqlmodels.mixin import create_subclass_id_mixin
# Create mixin for ASR subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
"""FunASR implementation"""
pass
```
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
### 6. AutoPolymorphicIdentityMixin
Automatically generates `polymorphic_identity` based on class name.
**Naming Convention:**
- Format: `{parent_identity}.{classname_lowercase}`
- If no parent identity exists, uses `{classname_lowercase}`
**Usage:**
```python
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
"""Base class for function-type tools"""
pass
# polymorphic_identity = 'function'
class GetWeatherFunction(Function, table=True):
"""Weather query function"""
pass
# polymorphic_identity = 'function.getweatherfunction'
```
**Manual Override:**
```python
class CustomTool(
Tool,
AutoPolymorphicIdentityMixin,
polymorphic_identity='custom_name', # Override auto-generated name
table=True
):
pass
```
### 7. JWTAuthMixin
Provides JWT token generation and validation for entity classes (User, Client).
**Methods:**
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
**Requirements:**
Subclasses must define:
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
**Usage:**
```python
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
JWTPayload = UserJWTPayload # Define payload model
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
email: str
is_admin: bool = False
is_active: bool = True
"""User active status"""
# Generate token
async def login(session: AsyncSession, user: User) -> str:
token = await user.issue_jwt(session)
return token
# Validate token
async def verify(session: AsyncSession, token: str) -> User:
user = await User.from_jwt(session, token)
return user
```
### 8. Response DTO Mixins
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
**Available Mixins:**
- `IntIdInfoMixin` - Integer ID field
- `UUIDIdInfoMixin` - UUID ID field
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
**Usage:**
```python
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
"""Character response DTO with id and timestamps"""
pass # Inherits id, created_at, updated_at from mixin
```
## Complete Joined Table Inheritance Example
Here's a complete example demonstrating polymorphic inheritance:
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import (
UUIDTableBaseMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
)
# 1. Define Base class (fields only, no table)
class ASRBase(SQLModelBase):
name: str
"""Configuration name"""
base_url: str
"""Service URL"""
# 2. Define abstract parent class (with table)
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract base class for ASR configurations"""
# PolymorphicBaseMixin automatically provides:
# - _polymorphic_name field
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True (when ABC with abstract methods)
@abstractmethod
async def transcribe(self, pcm_data: bytes) -> str:
"""Transcribe audio to text"""
pass
# 3. Create ID Mixin for second-level subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. Create second-level abstract class (if needed)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
"""FunASR abstract base (may have multiple implementations)"""
pass
# polymorphic_identity = 'funasr'
# 5. Create concrete implementation classes
class FunASRLocal(FunASR, table=True):
"""FunASR local deployment"""
# polymorphic_identity = 'funasr.funasrlocal'
async def transcribe(self, pcm_data: bytes) -> str:
# Implementation...
return "transcribed text"
# 6. Get all concrete subclasses (for selectin_polymorphic)
concrete_asrs = ASR.get_concrete_subclasses()
# Returns: [FunASRLocal, ...]
```
## Import Guidelines
**Standard Import:**
```python
from sqlmodels.mixin import (
TableBaseMixin,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
TableViewRequest,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
JWTAuthMixin,
UUIDIdDatetimeInfoMixin,
now,
now_date,
)
```
**Backward Compatibility:**
Some exports are also available from `sqlmodels.base` for backward compatibility:
```python
# Legacy import path (still works)
from sqlmodels.base import UUIDTableBase, TableViewRequest
# Recommended new import path
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
```
## Best Practices
### 1. Mixin Order Matters
**Correct Order:**
```python
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
pass
```
**Wrong Order:**
```python
# ❌ ID Mixin not first - won't override parent's id field
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
pass
```
### 2. Always Use Return Values from save() and update()
```python
# ✅ Correct - use returned instance
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device # AttributeError when accessing fields
```
### 3. Prefer table_view Over Manual Pagination
```python
# ✅ Recommended - consistent across all endpoints
characters = await Character.get(
session,
fetch_mode="all",
table_view=table_view
)
# ⚠️ Works but not recommended - manual parameter management
characters = await Character.get(
session,
fetch_mode="all",
offset=0,
limit=20,
order_by=[desc(Character.created_at)]
)
```
### 4. Polymorphic Loading for Many Subclasses
```python
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
)
# For fewer subclasses, specify the list explicitly
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
)
```
### 5. Response DTOs Should Inherit Base Classes
```python
# ✅ Correct - inherits from CharacterBase
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
pass
# ❌ Wrong - doesn't inherit from CharacterBase
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
name: str # Duplicated field definition
description: str | None = None
```
**Reason:** Inheriting from Base classes ensures:
- Type checking via `isinstance(obj, XxxBase)`
- Consistency across related DTOs
- Future field additions automatically propagate
### 6. Use Specific Types, Not Containers
```python
# ✅ Correct - specific DTO for config updates
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
weather_api_key: str | None = None
default_location: str | None = None
"""Default location"""
# ❌ Wrong - lose type safety
class ToolUpdateRequest(SQLModelBase):
config: dict[str, Any] # No field validation
```
## Type Variables
```python
from sqlmodels.mixin import T, M
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
M = TypeVar("M", bound="SQLModel") # For update() method
```
## Utility Functions
```python
from sqlmodels.mixin import now, now_date
# Lambda functions for default factories
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
```
## Related Modules
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
- **sqlmodels.user** - User model with JWT authentication
- **sqlmodels.client** - Client model with JWT authentication
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
## Additional Resources
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
- `CLAUDE.md` - Project coding standards and design philosophy
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)

46
models/mixin/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
"""
SQLModel Mixin模块
提供各种Mixin类供SQLModel实体使用。
包含:
- polymorphic: 联表继承工具create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin
- table: 表基类TableBaseMixin, UUIDTableBaseMixin
- table: 查询参数类TimeFilterRequest, PaginationRequest, TableViewRequest
- jwt/: JWT认证相关JWTAuthMixin, JWTManager, JWTKey等- 需要时直接从 .jwt 导入
- info_response: InfoResponse DTO的id/时间戳Mixin
导入顺序很重要,避免循环导入:
1. polymorphic只依赖 SQLModelBase
2. table依赖 polymorphic
注意jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
"""
# polymorphic 必须先导入
from .polymorphic import (
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
PolymorphicBaseMixin,
)
# table 依赖 polymorphic
from .table import (
TableBaseMixin,
UUIDTableBaseMixin,
TimeFilterRequest,
PaginationRequest,
TableViewRequest,
ListResponse,
T,
now,
now_date,
)
# jwt 不在此处导入避免循环jwt/manager.py → ServerConfig → mixin → jwt
# 需要时直接从 sqlmodels.mixin.jwt 导入
from .info_response import (
IntIdInfoMixin,
UUIDIdInfoMixin,
DatetimeInfoMixin,
IntIdDatetimeInfoMixin,
UUIDIdDatetimeInfoMixin,
)

View File

@@ -0,0 +1,46 @@
"""
InfoResponse DTO Mixin模块
提供用于InfoResponse类型DTO的Mixin统一定义id/created_at/updated_at字段。
设计说明:
- 这些Mixin用于**响应DTO**,不是数据库表
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
- TableBase中的id=None和default_factory=now是正确的入库前为None数据库生成
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
"""
from datetime import datetime
from uuid import UUID
from models.base import SQLModelBase
class IntIdInfoMixin(SQLModelBase):
"""整数ID响应mixin - 用于InfoResponse DTO"""
id: int
"""记录ID"""
class UUIDIdInfoMixin(SQLModelBase):
"""UUID ID响应mixin - 用于InfoResponse DTO"""
id: UUID
"""记录ID"""
class DatetimeInfoMixin(SQLModelBase):
"""时间戳响应mixin - 用于InfoResponse DTO"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""更新时间"""
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
"""整数ID + 时间戳响应mixin"""
pass
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
"""UUID ID + 时间戳响应mixin"""
pass

456
models/mixin/polymorphic.py Normal file
View File

@@ -0,0 +1,456 @@
"""
联表继承Joined Table Inheritance的通用工具
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
Usage Example:
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import (
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin
)
# 1. 定义Base类只有字段无表
class ASRBase(SQLModelBase):
name: str
\"\"\"配置名称\"\"\"
base_url: str
\"\"\"服务地址\"\"\"
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
class ASR(
ASRBase,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC
):
\"\"\"ASR配置的抽象基类\"\"\"
# PolymorphicBaseMixin 自动提供:
# - _polymorphic_name 字段
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True当有抽象方法时
# 3. 为第二层子类创建ID Mixin
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. 创建第二层抽象类(如果需要)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
\"\"\"FunASR的抽象基类可能有多个实现\"\"\"
pass
# 5. 创建具体实现类
class FunASRLocal(FunASR, table=True):
\"\"\"FunASR本地部署版本\"\"\"
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
pass
# 6. 获取所有具体子类(用于 selectin_polymorphic
concrete_asrs = ASR.get_concrete_subclasses()
# 返回 [FunASRLocal, ...]
"""
import uuid
from abc import ABC
from uuid import UUID
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import String, inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import Field
from models.base.sqlmodel_base import SQLModelBase
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
"""
动态创建SubclassIdMixin类
在联表继承中,子类需要一个外键指向父表的主键。
此函数生成一个Mixin类提供这个外键字段并自动生成UUID。
Args:
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function'
Returns:
一个Mixin类包含id字段外键 + 主键 + default_factory=uuid.uuid4
Example:
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
... pass
Note:
- 生成的Mixin应该放在继承列表的第一位确保通过MRO覆盖UUIDTableBaseMixin的id
- 生成的类名为 {ParentTableName}SubclassIdMixinPascalCase
- 本项目所有联表继承均使用UUID主键UUIDTableBaseMixin
"""
if not parent_table_name:
raise ValueError("parent_table_name 不能为空")
# 转换为PascalCase作为类名
class_name_parts = parent_table_name.split('_')
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
# 使用闭包捕获parent_table_name
_parent_table_name = parent_table_name
# 创建带有__init_subclass__的mixin类用于在子类定义后修复model_fields
class SubclassIdMixin(SQLModelBase):
# 定义id字段
id: UUID = Field(
default_factory=uuid.uuid4,
foreign_key=f'{_parent_table_name}.id',
primary_key=True,
)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复联表继承中子类字段的default_factory丢失问题。
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory
遵循单一真相原则(不硬编码 default_factory
需要修复的字段:
- id: 主键(从父类获取 default_factory
- created_at: 创建时间戳(从父类获取 default_factory
- updated_at: 更新时间戳(从父类获取 default_factory
"""
super().__pydantic_init_subclass__(**kwargs)
if not hasattr(cls, 'model_fields'):
return
def find_original_field_info(field_name: str) -> FieldInfo | None:
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, 'model_fields') and field_name in base.model_fields:
field_info = base.model_fields[field_name]
# 跳过被 InstrumentedAttribute 污染的
if not isinstance(field_info.default, InstrumentedAttribute):
return field_info
return None
# 动态检测所有需要修复的字段
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
# 1. default 是 InstrumentedAttribute被 SQLAlchemy 污染)
# 2. 原始定义有 default_factory 或明确的 default 值
#
# 覆盖场景:
# - UUID主键UUIDTableBaseMixinid 有 default_factory=uuid.uuid4需要修复
# - int主键TableBaseMixinid 用 default=None不需要修复数据库自增
# - created_at/updated_at有 default_factory=now需要修复
# - 外键字段created_by_id等有 default=None需要修复
# - 普通字段name, temperature等无 default_factory不需要修复
#
# MRO 查找保证:
# - 在多重继承场景下MRO 顺序是确定性的
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
for field_name, current_field in cls.model_fields.items():
# 检查是否被污染default 是 InstrumentedAttribute
if not isinstance(current_field.default, InstrumentedAttribute):
continue # 未被污染,跳过
# 从父类查找原始定义
original = find_original_field_info(field_name)
if original is None:
continue # 找不到原始定义,跳过
# 根据原始定义的 default/default_factory 来修复
if original.default_factory:
# 有 default_factory如 uuid.uuid4, now
new_field = FieldInfo(
default_factory=original.default_factory,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
elif original.default is not PydanticUndefined:
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
# PydanticUndefined 表示字段没有默认值(必填)
new_field = FieldInfo(
default=original.default,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
else:
continue # 既没有 default_factory 也没有有效的 default跳过
# 复制SQLModel特有的属性
if hasattr(current_field, 'foreign_key'):
new_field.foreign_key = current_field.foreign_key
if hasattr(current_field, 'primary_key'):
new_field.primary_key = current_field.primary_key
cls.model_fields[field_name] = new_field
# 设置类名和文档
SubclassIdMixin.__name__ = class_name
SubclassIdMixin.__qualname__ = class_name
SubclassIdMixin.__doc__ = f"""
{parent_table_name}子类的ID Mixin
用于{parent_table_name}的子类,提供外键指向父表。
通过MRO确保此id字段覆盖继承的id字段。
"""
return SubclassIdMixin
class AutoPolymorphicIdentityMixin:
"""
自动生成polymorphic_identity的Mixin
使用此Mixin的类会自动根据类名生成polymorphic_identity。
格式:{parent_polymorphic_identity}.{classname_lowercase}
如果没有父类的polymorphic_identity则直接使用类名小写。
Example:
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
... __polymorphic_name: str
...
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
... pass
... # polymorphic_identity 会自动设置为 'function'
...
>>> class CodeInterpreterFunction(Function, table=True):
... pass
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
Note:
- 如果手动在__mapper_args__中指定了polymorphic_identity会被保留
- 此Mixin应该在继承列表中靠后的位置在表基类之前
"""
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
"""
子类化钩子自动生成polymorphic_identity
Args:
polymorphic_identity: 如果手动指定,则使用指定的值
**kwargs: 其他SQLModel参数如table=True, polymorphic_abstract=True
"""
super().__init_subclass__(**kwargs)
# 如果手动指定了polymorphic_identity使用指定的值
if polymorphic_identity is not None:
identity = polymorphic_identity
else:
# 自动生成polymorphic_identity
class_name = cls.__name__.lower()
# 尝试从父类获取polymorphic_identity作为前缀
parent_identity = None
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
parent_identity = base.__mapper_args__.get('polymorphic_identity')
if parent_identity:
break
# 构建identity
if parent_identity:
identity = f'{parent_identity}.{class_name}'
else:
identity = class_name
# 设置到__mapper_args__
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 只在尚未设置polymorphic_identity时设置
if 'polymorphic_identity' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_identity'] = identity
class PolymorphicBaseMixin:
"""
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
此 Mixin 自动设置以下内容:
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
使用场景:
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
用法示例:
```python
from abc import ABC
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
# 定义基类
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
__tablename__ = 'mytool'
# 不需要手动定义 _polymorphic_name
# 不需要手动设置 polymorphic_on
# 不需要手动设置 polymorphic_abstract
# 定义子类
class SpecificTool(MyTool):
__tablename__ = 'specifictool'
# 会自动继承 polymorphic 配置
```
自动行为:
1. 定义 `_polymorphic_name: str` 字段(带索引)
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
3. 自动检测抽象类:
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- 否则设置为 False
手动覆盖:
可以在类定义时手动指定参数来覆盖自动行为:
```python
class MyTool(
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC,
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
polymorphic_abstract=False # 强制不设为抽象类
):
pass
```
注意事项:
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- 适用于联表继承joined table inheritance场景
- 子类会自动继承 _polymorphic_name 字段定义
- 使用单下划线前缀是因为:
* SQLAlchemy 会映射单下划线字段为数据库列
* Pydantic 将其视为私有属性,不参与序列化
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
"""
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
#
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
#
# 为什么这样做:
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
# 4. 字段不出现在 Pydantic 序列化中model_dump() 和 JSON schema
# 5. 内部代码仍然可以正常访问和修改此字段
#
# 详细说明请参考sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
"""
多态鉴别器字段,用于标识具体的子类类型
注意:此字段使用单下划线前缀,表示内部使用。
- ✅ 存储到数据库
- ✅ 不出现在 API 序列化中
- ✅ 防止外部直接修改
"""
def __init_subclass__(
cls,
polymorphic_on: str | None = None,
polymorphic_abstract: bool | None = None,
**kwargs
):
"""
在子类定义时自动配置 polymorphic 设置
Args:
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'
设置为其他值可以使用不同的字段作为多态鉴别器。
polymorphic_abstract: 是否为抽象类。
- None: 自动检测(默认)
- True: 强制设为抽象类
- False: 强制设为非抽象类
**kwargs: 传递给父类的其他参数
"""
super().__init_subclass__(**kwargs)
# 初始化 __mapper_args__如果还没有
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 设置 polymorphic_on默认为 _polymorphic_name
if 'polymorphic_on' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
# 自动检测或设置 polymorphic_abstract
if 'polymorphic_abstract' not in cls.__mapper_args__:
if polymorphic_abstract is None:
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
has_abc = ABC in cls.__mro__
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
polymorphic_abstract = has_abc and has_abstract_methods
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
@classmethod
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
"""
递归获取当前类的所有具体(非抽象)子类
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
可在任意多态基类上调用,返回该类的所有非抽象子类。
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
"""
result: list[type[PolymorphicBaseMixin]] = []
for subclass in cls.__subclasses__():
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
mapper = inspect(subclass)
if not mapper.polymorphic_abstract:
result.append(subclass)
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
if hasattr(subclass, 'get_concrete_subclasses'):
result.extend(subclass.get_concrete_subclasses())
return result
@classmethod
def get_polymorphic_discriminator(cls) -> str:
"""
获取多态鉴别字段名
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
:return: 多态鉴别字段名(如 '_polymorphic_name'
:raises ValueError: 如果类未配置 polymorphic_on
"""
polymorphic_on = inspect(cls).polymorphic_on
if polymorphic_on is None:
raise ValueError(
f"{cls.__name__} 未配置 polymorphic_on"
f"请确保正确继承 PolymorphicBaseMixin"
)
return polymorphic_on.key
@classmethod
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
"""
获取 polymorphic_identity 到具体子类的映射
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
:return: identity 到子类的映射字典
"""
result: dict[str, type[PolymorphicBaseMixin]] = {}
for subclass in cls.get_concrete_subclasses():
identity = inspect(subclass).polymorphic_identity
if identity:
result[identity] = subclass
return result

927
models/mixin/table.py Normal file
View File

@@ -0,0 +1,927 @@
"""
表基类 Mixin
提供 TableBaseMixin、UUIDTableBaseMixin 和 TableViewRequest。
这些类实际上是 Mixin为 SQLModel 模型提供 CRUD 操作和时间戳字段。
依赖关系:
base/sqlmodel_base.py ← 最底层
mixin/polymorphic.py ← 定义 PolymorphicBaseMixin
mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin
base/__init__.py ← 从 mixin 重新导出(保持向后兼容)
"""
import uuid
from datetime import datetime
from typing import TypeVar, Literal, override, Any, ClassVar, Generic
# TODO(ListResponse泛型问题): SQLModel泛型类型JSON Schema生成bug
# 已知问题: https://github.com/fastapi/sqlmodel/discussions/1002
# 修复PR: https://github.com/fastapi/sqlmodel/pull/1275 (尚未合并)
# 现象: SQLModel + Generic[T] 的 __pydantic_generic_metadata__ = {origin: None, args: ()}
# 导致OpenAPI schema中泛型字段显示为{}而非正确的$ref
# 当前方案: ListResponse继承BaseModel而非SQLModel (Discussion #1002推荐的workaround)
# 未来: PR #1275合并后可改回继承SQLModelBase
from pydantic import BaseModel, ConfigDict
from fastapi import HTTPException
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct
from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
from sqlmodel import Field, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.sql._typing import _OnClauseArgument
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel.main import RelationshipInfo
from .polymorphic import PolymorphicBaseMixin
from models.base.sqlmodel_base import SQLModelBase
# Type variables for generic type hints, improving code completion and analysis.
T = TypeVar("T", bound="TableBaseMixin")
M = TypeVar("M", bound="SQLModelBase")
ItemT = TypeVar("ItemT")
class ListResponse(BaseModel, Generic[ItemT]):
"""
泛型分页响应
用于所有LIST端点的标准化响应格式包含记录总数和项目列表。
与 TableBaseMixin.get_with_count() 配合使用。
使用示例:
```python
@router.get("", response_model=ListResponse[CharacterInfoResponse])
async def list_characters(...) -> ListResponse[Character]:
return await Character.get_with_count(session, table_view=table_view)
```
Attributes:
count: 符合条件的记录总数(用于分页计算)
items: 当前页的记录列表
Note:
继承BaseModel而非SQLModelBase因为SQLModel的metaclass与Generic冲突。
详见文件顶部TODO注释。
"""
# 与SQLModelBase保持一致的配置
model_config = ConfigDict(use_attribute_docstrings=True)
count: int
"""符合条件的记录总数"""
items: list[ItemT]
"""当前页的记录列表"""
# Lambda functions to get the current time, used as default factories in model fields.
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
# ==================== 查询参数请求类 ====================
class TimeFilterRequest(SQLModelBase):
"""
时间筛选请求参数
用于 count() 等只需要时间筛选的场景。
纯数据类只负责参数校验和携带SQL子句构建由 TableBaseMixin 负责。
Raises:
ValueError: 时间范围无效
"""
created_after_datetime: datetime | None = None
"""创建时间起始筛选created_at >= datetime如果为None则不限制"""
created_before_datetime: datetime | None = None
"""创建时间结束筛选created_at < datetime如果为None则不限制"""
updated_after_datetime: datetime | None = None
"""更新时间起始筛选updated_at >= datetime如果为None则不限制"""
updated_before_datetime: datetime | None = None
"""更新时间结束筛选updated_at < datetime如果为None则不限制"""
def model_post_init(self, __context: Any) -> None:
"""
验证时间范围有效性
验证规则:
1. 同类型after 必须小于 before
2. 跨类型created_after 不能大于 updated_before记录不可能在创建前被更新
"""
# 同类型矛盾验证
if self.created_after_datetime and self.created_before_datetime:
if self.created_after_datetime >= self.created_before_datetime:
raise ValueError("created_after_datetime 必须小于 created_before_datetime")
if self.updated_after_datetime and self.updated_before_datetime:
if self.updated_after_datetime >= self.updated_before_datetime:
raise ValueError("updated_after_datetime 必须小于 updated_before_datetime")
# 跨类型矛盾验证created_after >= updated_before 意味着要求创建时间晚于或等于更新时间上界,逻辑矛盾
if self.created_after_datetime and self.updated_before_datetime:
if self.created_after_datetime >= self.updated_before_datetime:
raise ValueError(
"created_after_datetime 不能大于或等于 updated_before_datetime"
"(记录的更新时间不可能早于或等于创建时间)"
)
class PaginationRequest(SQLModelBase):
"""
分页排序请求参数
用于需要分页和排序的场景。
纯数据类只负责携带参数SQL子句构建由 TableBaseMixin 负责。
"""
offset: int | None = Field(default=0, ge=0)
"""偏移量跳过前N条记录必须为非负整数"""
limit: int | None = Field(default=50, le=100)
"""每页数量返回最多N条记录默认50最大100"""
desc: bool | None = True
"""是否降序排序True: 降序, False: 升序)"""
order: Literal["created_at", "updated_at"] | None = "created_at"
"""排序字段created_at: 创建时间, updated_at: 更新时间)"""
class TableViewRequest(TimeFilterRequest, PaginationRequest):
"""
表格视图请求参数(分页、排序和时间筛选)
组合继承 TimeFilterRequest 和 PaginationRequest用于 get() 等需要完整查询参数的场景。
纯数据类SQL子句构建由 TableBaseMixin 负责。
使用示例:
```python
# 在端点中使用依赖注入
@router.get("/list")
async def list_items(
session: SessionDep,
table_view: TableViewRequestDep
):
items = await Item.get(
session,
fetch_mode="all",
table_view=table_view
)
return items
# 直接使用
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
items = await Item.get(session, fetch_mode="all", table_view=table_view)
```
"""
pass
# ==================== TableBaseMixin ====================
class TableBaseMixin(AsyncAttrs):
"""
一个异步 CRUD 操作的基础模型类 Mixin.
此类必须搭配SQLModelBase使用
此类为所有继承它的 SQLModel 模型提供了通用的数据库操作方法,
例如 add, save, update, delete, 和 get. 它还包括自动管理
的 `created_at` 和 `updated_at` 时间戳字段.
Attributes:
id (int | None): 整数主键, 自动递增.
created_at (datetime): 记录创建时的时间戳, 自动设置.
updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
"""
_is_table_mixin: ClassVar[bool] = True
"""标记此类为表混入类的内部属性"""
def __init_subclass__(cls, **kwargs):
"""
接受并传递子类定义时的关键字参数
这允许元类 __DeclarativeMeta 处理的参数(如 table_args
能够正确传递,而不会在 __init_subclass__ 阶段报错。
"""
super().__init_subclass__(**kwargs)
id: int | None = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=now)
updated_at: datetime = Field(
sa_type=DateTime,
sa_column_kwargs={'default': now, 'onupdate': now},
default_factory=now
)
@classmethod
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
"""
向数据库中添加一个新的或多个新的记录.
这个类方法可以接受单个模型实例或一个实例列表,并将它们
一次性提交到数据库中。执行后,可以选择性地刷新这些实例以获取
数据库生成的值(例如,自动递增的 ID.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
Returns:
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
Usage:
item1 = Item(name="Apple")
item2 = Item(name="Banana")
# 添加多个实例
added_items = await Item.add(session, [item1, item2])
# 添加单个实例
item3 = Item(name="Cherry")
added_item = await Item.add(session, item3)
"""
is_list = False
if isinstance(instances, list):
is_list = True
session.add_all(instances)
else:
session.add(instances)
await session.commit()
if refresh:
if is_list:
for instance in instances:
await session.refresh(instance)
else:
await session.refresh(instances)
return instances
async def save(
self: T,
session: AsyncSession,
load: RelationshipInfo | None = None,
refresh: bool = True
) -> T:
"""
保存(插入或更新)当前模型实例到数据库.
这是一个实例方法,它将当前对象添加到会话中并提交更改。
可以用于创建新记录或更新现有记录。还可以选择在保存后
预加载eager load一个关联关系.
**重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值:
```python
# ✅ 正确:需要返回值时
client = await client.save(session)
return client
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.save(session, refresh=False)
# ❌ 错误:需要返回值但未使用
await client.save(session)
return client # client 对象已过期
```
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
例如 `User.posts`.
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
"""
session.add(self)
await session.commit()
if not refresh:
return self
if load is not None:
cls = type(self)
await session.refresh(self)
# 如果指定了 load, 重新获取实例并加载关联关系
return await cls.get(session, cls.id == self.id, load=load)
else:
await session.refresh(self)
return self
async def update(
self: T,
session: AsyncSession,
other: M,
extra_data: dict[str, Any] | None = None,
exclude_unset: bool = True,
exclude: set[str] | None = None,
load: RelationshipInfo | None = None,
refresh: bool = True
) -> T:
"""
使用另一个模型实例或字典中的数据来更新当前实例.
此方法将 `other` 对象中的数据合并到当前实例中。默认情况下,
它只会更新 `other` 中被显式设置的字段.
**重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值:
```python
# ✅ 正确:需要返回值时
client = await client.update(session, update_data)
return client
# ✅ 正确:需要返回值且需要加载关系时
user = await user.update(session, update_data, load=User.permission)
return user
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.update(session, update_data, refresh=False)
# ❌ 错误:需要返回值但未使用
await client.update(session, update_data)
return client # client 对象已过期
```
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
other (M): 一个 SQLModel 或 Pydantic 模型实例,其数据将用于更新当前实例.
extra_data (dict, optional): 一个额外的字典,用于更新当前实例的特定字段.
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
的字段将被忽略. 默认为 True.
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
例如 `User.permission`.
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
"""
self.sqlmodel_update(
other.model_dump(exclude_unset=exclude_unset, exclude=exclude),
update=extra_data
)
session.add(self)
await session.commit()
if not refresh:
return self
if load is not None:
cls = type(self)
await session.refresh(self)
return await cls.get(session, cls.id == self.id, load=load)
else:
await session.refresh(self)
return self
@classmethod
async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None:
"""
从数据库中删除一个或多个记录.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要删除的单个模型实例或模型实例列表.
Returns:
None
Usage:
item_to_delete = await Item.get(session, Item.id == 1)
if item_to_delete:
await Item.delete(session, item_to_delete)
items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all")
if items_to_delete:
await Item.delete(session, items_to_delete)
"""
if isinstance(instances, list):
for instance in instances:
await session.delete(instance)
else:
await session.delete(instances)
await session.commit()
@classmethod
def _build_time_filters(
cls: type[T],
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> list[BinaryExpression]:
"""
构建时间筛选条件列表
Args:
created_before_datetime: 筛选 created_at < datetime 的记录
created_after_datetime: 筛选 created_at >= datetime 的记录
updated_before_datetime: 筛选 updated_at < datetime 的记录
updated_after_datetime: 筛选 updated_at >= datetime 的记录
Returns:
BinaryExpression 条件列表
"""
filters: list[BinaryExpression] = []
if created_after_datetime is not None:
filters.append(cls.created_at >= created_after_datetime)
if created_before_datetime is not None:
filters.append(cls.created_at < created_before_datetime)
if updated_after_datetime is not None:
filters.append(cls.updated_at >= updated_after_datetime)
if updated_before_datetime is not None:
filters.append(cls.updated_at < updated_before_datetime)
return filters
@classmethod
async def get(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
offset: int | None = None,
limit: int | None = None,
fetch_mode: Literal["one", "first", "all"] = "first",
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
with_for_update: bool = False,
table_view: TableViewRequest | None = None,
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> T | list[T] | None:
"""
根据指定的条件异步地从数据库中获取一个或多个模型实例.
这是一个功能强大的通用查询方法,支持过滤、排序、分页、连接查询和关联关系预加载.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
condition (BinaryExpression | ClauseElement | None): 主要的查询过滤条件,
例如 `User.id == 1`。
当为 `None` 时,表示无条件查询(查询所有记录)。
offset (int | None): 查询结果的起始偏移量, 用于分页.
limit (int | None): 返回记录的最大数量, 用于分页.
fetch_mode (Literal["one", "first", "all"]):
- "one": 获取唯一的一条记录. 如果找不到或找到多条,会引发异常.
- "first": 获取查询结果的第一条记录. 如果找不到,返回 `None`.
- "all": 获取所有匹配的记录,返回一个列表.
默认为 "first".
join (type[T] | tuple[type[T], _OnClauseArgument] | None):
要 JOIN 的模型类或一个包含模型类和 ON 子句的元组.
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
例如 `[selectinload(User.posts)]`.
load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系.
例如 `User.profile`.
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
例如 `[User.name.asc(), User.created_at.desc()]`.
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
with_for_update (bool): 如果为 True, 在查询中使用 `FOR UPDATE` 锁定选定的行. 默认为 False.
table_view (TableViewRequest | None): TableViewRequest对象如果提供则自动处理分页、排序和时间筛选。
会覆盖offset、limit、order_by及时间筛选参数。
这是推荐的分页排序方式统一了所有LIST端点的参数格式。
load_polymorphic: 多态子类加载选项,需要与 load 参数配合使用。
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益)
- None默认: 不使用多态加载
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
updated_after_datetime (datetime | None): 筛选 updated_at >= datetime 的记录
Returns:
T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
Raises:
ValueError: 如果提供了无效的 `fetch_mode` 值,或 load_polymorphic 未与 load 配合使用.
Examples:
# 使用table_view参数推荐
users = await User.get(session, fetch_mode="all", table_view=table_view_args)
# 传统方式(向后兼容)
users = await User.get(session, fetch_mode="all", offset=0, limit=20, order_by=[desc(User.created_at)])
# 使用多态加载(加载联表继承的子类数据)
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic='all' # 只加载实际关联的子类
)
"""
# 参数验证load_polymorphic 需要与 load 配合使用
if load_polymorphic is not None and load is None:
raise ValueError(
"load_polymorphic 参数需要与 load 参数配合使用,"
"请同时指定要加载的关系"
)
# 如果提供table_view作为默认值使用单独传入的参数优先级更高
if table_view:
# 处理时间筛选TimeFilterRequest 及其子类)
if isinstance(table_view, TimeFilterRequest):
if created_after_datetime is None and table_view.created_after_datetime is not None:
created_after_datetime = table_view.created_after_datetime
if created_before_datetime is None and table_view.created_before_datetime is not None:
created_before_datetime = table_view.created_before_datetime
if updated_after_datetime is None and table_view.updated_after_datetime is not None:
updated_after_datetime = table_view.updated_after_datetime
if updated_before_datetime is None and table_view.updated_before_datetime is not None:
updated_before_datetime = table_view.updated_before_datetime
# 处理分页排序PaginationRequest 及其子类,包括 TableViewRequest
if isinstance(table_view, PaginationRequest):
if offset is None:
offset = table_view.offset
if limit is None:
limit = table_view.limit
# 仅在未显式传入order_by时从table_view构建排序子句
if order_by is None:
order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at
order_by = [desc(order_column) if table_view.desc else asc(order_column)]
# 对于多态基类,使用 with_polymorphic 预加载所有子类的列
# 这避免了在响应序列化时的延迟加载问题MissingGreenlet 错误)
if issubclass(cls, PolymorphicBaseMixin):
# '*' 表示加载所有子类
polymorphic_cls = with_polymorphic(cls, '*')
statement = select(polymorphic_cls)
else:
statement = select(cls)
if condition is not None:
statement = statement.where(condition)
# 应用时间筛选
for time_filter in cls._build_time_filters(
created_before_datetime, created_after_datetime,
updated_before_datetime, updated_after_datetime
):
statement = statement.where(time_filter)
if join is not None:
# 如果 join 是一个元组,解包它;否则直接使用
if isinstance(join, tuple):
statement = statement.join(*join)
else:
statement = statement.join(join)
if options:
statement = statement.options(*options)
if load:
# 处理多态加载
if load_polymorphic is not None:
target_class = load.property.mapper.class_
# 检查目标类是否继承自 PolymorphicBaseMixin
if not issubclass(target_class, PolymorphicBaseMixin):
raise ValueError(
f"目标类 {target_class.__name__} 不是多态类,"
f"请确保其继承自 PolymorphicBaseMixin"
)
if load_polymorphic == 'all':
# 两阶段查询:获取实际关联的多态类型
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
session, condition, load, target_class
)
else:
subclasses_to_load = load_polymorphic
if subclasses_to_load:
# 关键selectin_polymorphic 必须作为 selectinload 的链式子选项
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
statement = statement.options(
selectinload(load).selectin_polymorphic(subclasses_to_load)
)
else:
statement = statement.options(selectinload(load))
else:
statement = statement.options(selectinload(load))
if order_by is not None:
statement = statement.order_by(*order_by)
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
if filter:
statement = statement.filter(filter)
if with_for_update:
statement = statement.with_for_update()
result = await session.exec(statement)
if fetch_mode == "one":
return result.one()
elif fetch_mode == "first":
return result.first()
elif fetch_mode == "all":
return list(result.all())
else:
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
@classmethod
async def _resolve_polymorphic_subclasses(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None,
load: RelationshipInfo,
target_class: type[PolymorphicBaseMixin]
) -> list[type[PolymorphicBaseMixin]]:
"""
查询实际关联的多态子类类型
通过查询多态鉴别字段确定实际存在的子类类型,
避免加载所有可能的子类表(对于 > 10 个子类的场景有明显收益)。
:param session: 数据库会话
:param condition: 主查询的条件
:param load: 关系属性
:param target_class: 多态基类
:return: 实际关联的子类列表
"""
# 获取多态鉴别字段(会抛出 ValueError 如果未配置)
discriminator = target_class.get_polymorphic_discriminator()
poly_name_col = getattr(target_class, discriminator)
# 获取关系属性
relationship_property = load.property
# 构建查询获取实际的多态类型名称
if relationship_property.secondary is not None:
# 多对多关系:通过中间表查询
secondary = relationship_property.secondary
local_cols = list(relationship_property.local_columns)
type_query = (
select(distinct(poly_name_col))
.select_from(target_class)
.join(secondary)
.where(secondary.c[local_cols[0].name].in_(
select(cls.id).where(condition) if condition is not None else select(cls.id)
))
)
else:
# 一对多关系:通过外键查询
foreign_key_col = relationship_property.local_remote_pairs[0][1]
type_query = (
select(distinct(poly_name_col))
.where(foreign_key_col.in_(
select(cls.id).where(condition) if condition is not None else select(cls.id)
))
)
type_result = await session.exec(type_query)
poly_names = list(type_result.all())
if not poly_names:
return []
# 映射到子类(包含所有层级的具体子类)
identity_map = target_class.get_identity_to_class_map()
return [identity_map[name] for name in poly_names if name in identity_map]
@classmethod
async def count(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
time_filter: TimeFilterRequest | None = None,
created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None,
updated_after_datetime: datetime | None = None,
) -> int:
"""
根据条件统计记录数量(支持时间筛选)
使用数据库层面的 COUNT() 聚合函数,比 get() + len() 更高效。
Args:
session: 数据库会话
condition: 查询条件,例如 `User.is_active == True`
time_filter: TimeFilterRequest 对象(优先级更高)
created_before_datetime: 筛选 created_at < datetime 的记录
created_after_datetime: 筛选 created_at >= datetime 的记录
updated_before_datetime: 筛选 updated_at < datetime 的记录
updated_after_datetime: 筛选 updated_at >= datetime 的记录
Returns:
符合条件的记录数量
Examples:
# 统计所有用户
total = await User.count(session)
# 统计激活的虚拟客户端
count = await Client.count(
session,
(Client.user_id == user_id) & (Client.type != ClientTypeEnum.physical) & (Client.is_active == True)
)
# 使用 TimeFilterRequest 进行时间筛选
count = await User.count(session, time_filter=time_filter_request)
# 使用独立时间参数
count = await User.count(
session,
created_after_datetime=datetime(2025, 1, 1),
created_before_datetime=datetime(2025, 2, 1),
)
"""
# time_filter 的时间筛选优先级更高
if isinstance(time_filter, TimeFilterRequest):
if time_filter.created_after_datetime is not None:
created_after_datetime = time_filter.created_after_datetime
if time_filter.created_before_datetime is not None:
created_before_datetime = time_filter.created_before_datetime
if time_filter.updated_after_datetime is not None:
updated_after_datetime = time_filter.updated_after_datetime
if time_filter.updated_before_datetime is not None:
updated_before_datetime = time_filter.updated_before_datetime
statement = select(func.count()).select_from(cls)
# 应用查询条件
if condition is not None:
statement = statement.where(condition)
# 应用时间筛选
for time_condition in cls._build_time_filters(
created_before_datetime, created_after_datetime,
updated_before_datetime, updated_after_datetime
):
statement = statement.where(time_condition)
result = await session.scalar(statement)
return result or 0
@classmethod
async def get_with_count(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
table_view: TableViewRequest | None = None,
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
) -> 'ListResponse[T]':
"""
获取分页列表及总数,直接返回 ListResponse
同时返回符合条件的记录列表和总数,用于分页场景。
与 get() 方法类似,但固定 fetch_mode="all" 并返回 ListResponse。
注意:如果子类的 get() 方法支持额外参数(如 filter_params
子类应该覆盖此方法以确保 count 和 items 使用相同的过滤条件。
Args:
session: 数据库会话
condition: 查询条件
join: JOIN 的模型类或元组
options: SQLAlchemy 查询选项
load: selectinload 预加载关系
order_by: 排序子句
filter: 附加过滤条件
table_view: 分页排序参数(推荐使用)
load_polymorphic: 多态子类加载选项
Returns:
ListResponse[T]: 包含 count 和 items 的分页响应
Examples:
```python
@router.get("", response_model=ListResponse[CharacterInfoResponse])
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> ListResponse[Character]:
return await Character.get_with_count(session, table_view=table_view)
```
"""
# 提取时间筛选参数(用于 count
time_filter: TimeFilterRequest | None = None
if table_view is not None:
time_filter = TimeFilterRequest(
created_after_datetime=table_view.created_after_datetime,
created_before_datetime=table_view.created_before_datetime,
updated_after_datetime=table_view.updated_after_datetime,
updated_before_datetime=table_view.updated_before_datetime,
)
# 获取总数(不带分页限制)
total_count = await cls.count(session, condition, time_filter=time_filter)
# 获取分页数据
items = await cls.get(
session,
condition,
fetch_mode="all",
join=join,
options=options,
load=load,
order_by=order_by,
filter=filter,
table_view=table_view,
load_polymorphic=load_polymorphic,
)
return ListResponse(count=total_count, items=items)
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T:
"""
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
这个方法是对 `get` 方法的封装,专门用于处理那种"记录必须存在"的业务场景。
如果记录未找到,它会直接引发 FastAPI 的 `HTTPException`, 而不是返回 `None`.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (int): 要查找的记录的主键 ID.
load (Relationship | None): 可选的,用于预加载的关联属性.
Returns:
T: 找到的模型实例.
Raises:
HTTPException: 如果 ID 对应的记录不存在,则抛出状态码为 404 的异常.
"""
instance = await cls.get(session, cls.id == id, load=load)
if not instance:
raise HTTPException(status_code=404, detail="Not found")
return instance
class UUIDTableBaseMixin(TableBaseMixin):
"""
一个使用 UUID 作为主键的异步 CRUD 操作基础模型类 Mixin.
此类继承自 `TableBaseMixin`, 将主键 `id` 的类型覆盖为 `uuid.UUID`
并为新记录自动生成 UUID. 它继承了 `TableBaseMixin` 的所有 CRUD 方法.
Attributes:
id (uuid.UUID): UUID 类型的主键, 在创建时自动生成.
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
"""覆盖 `TableBaseMixin` 的 id 字段,使用 UUID 作为主键."""
@override
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
"""
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
此方法覆盖了父类的同名方法,以确保 `id` 参数的类型注解为 `uuid.UUID`,
从而提供更好的类型安全和代码提示.
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (uuid.UUID): 要查找的记录的 UUID 主键.
load (Relationship | None): 可选的,用于预加载的关联属性.
Returns:
T: 找到的模型实例.
Raises:
HTTPException: 如果 UUID 对应的记录不存在,则抛出状态码为 404 的异常.
"""
# 类型检查器可能会警告这里的 `id` 类型不匹配超类方法,
# 但在运行时这是正确的,因为超类方法内部的比较 (cls.id == id)
# 会正确处理 UUID 类型。`type: ignore` 用于抑制此警告。
return await super().get_exist_one(session, id, load) # type: ignore

View File

@@ -1,23 +1,98 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, text
from .base import TableBase
from sqlmodel import Field, Relationship, text, Index
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .download import Download
class Node(TableBase, table=True):
class NodeStatus(StrEnum):
"""节点状态枚举"""
ONLINE = "online"
"""正常"""
OFFLINE = "offline"
"""离线"""
class NodeType(StrEnum):
"""节点类型枚举"""
MASTER = "master"
"""主节点"""
SLAVE = "slave"
"""从节点"""
class Aria2ConfigurationBase(SQLModelBase):
"""Aria2配置基础模型"""
rpc_url: str | None = Field(default=None, max_length=255)
"""RPC地址"""
rpc_secret: str | None = None
"""RPC密钥"""
temp_path: str | None = Field(default=None, max_length=255)
"""临时下载路径"""
max_concurrent: int = Field(default=5, ge=1, le=50)
"""最大并发数"""
timeout: int = Field(default=300, ge=1)
"""请求超时时间(秒)"""
class Aria2Configuration(Aria2ConfigurationBase, TableBaseMixin):
"""Aria2配置模型与Node一对一关联"""
node_id: int = Field(foreign_key="node.id", unique=True, index=True)
"""关联的节点ID"""
# 反向关系
node: "Node" = Relationship(back_populates="aria2_config")
"""关联的节点"""
class Node(SQLModelBase, TableBaseMixin):
"""节点模型"""
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点状态: 0=正常, 1=离线")
name: str = Field(max_length=255, unique=True, description="节点名称")
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点类型")
server: str = Field(max_length=255, description="节点地址IP或域名")
slave_key: str | None = Field(default=None, description="从机通讯密钥")
master_key: str | None = Field(default=None, description="主机通讯密钥")
aria2_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否启用Aria2")
aria2_options: str | None = Field(default=None, description="Aria2配置 (JSON格式)")
rank: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点排序权重")
__table_args__ = (
Index("ix_node_status", "status"),
)
status: NodeStatus = Field(default=NodeStatus.ONLINE, sa_column_kwargs={"server_default": "'online'"})
"""节点状态"""
name: str = Field(max_length=255, unique=True)
"""节点名称"""
type: NodeType
"""节点类型"""
server: str = Field(max_length=255)
"""节点地址IP或域名"""
slave_key: str | None = Field(default=None, max_length=255)
"""从机通讯密钥"""
master_key: str | None = Field(default=None, max_length=255)
"""主机通讯密钥"""
aria2_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否启用Aria2"""
rank: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""节点排序权重"""
# 关系
aria2_config: Aria2Configuration | None = Relationship(
back_populates="node",
sa_relationship_kwargs={"uselist": False},
)
"""Aria2配置"""
downloads: list["Download"] = Relationship(back_populates="node")
"""该节点的下载任务"""

View File

@@ -1,12 +1,13 @@
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Literal
from uuid import UUID
from enum import StrEnum
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index
from .base import SQLModelBase, UUIDTableBase
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -20,6 +21,43 @@ class ObjectType(StrEnum):
FILE = "file"
FOLDER = "folder"
class StorageType(StrEnum):
"""存储类型枚举"""
LOCAL = "local"
QINIU = "qiniu"
TENCENT = "tencent"
ALIYUN = "aliyun"
ONEDRIVE = "onedrive"
GOOGLE_DRIVE = "google_drive"
DROPBOX = "dropbox"
WEBDAV = "webdav"
REMOTE = "remote"
class FileMetadataBase(SQLModelBase):
"""文件元数据基础模型"""
width: int | None = Field(default=None)
"""图片宽度(像素)"""
height: int | None = Field(default=None)
"""图片高度(像素)"""
duration: float | None = Field(default=None)
"""音视频时长(秒)"""
bitrate: int | None = Field(default=None)
"""比特率kbps"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型"""
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和"""
checksum_sha256: str | None = Field(default=None, max_length=64)
"""SHA256校验和"""
# ==================== Base 模型 ====================
@@ -99,10 +137,10 @@ class PolicyResponse(SQLModelBase):
name: str
"""策略名称"""
type: Literal["local", "qiniu", "tencent", "aliyun", "onedrive", "google_drive", "dropbox", "webdav", "remote"]
type: StorageType
"""存储类型"""
max_size: int = 0
max_size: int = Field(ge=0, default=0)
"""单文件最大限制单位字节0表示不限制"""
file_type: list[str] | None = None
@@ -127,7 +165,18 @@ class DirectoryResponse(SQLModelBase):
# ==================== 数据库模型 ====================
class Object(ObjectBase, UUIDTableBase, table=True):
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
"""文件元数据模型与Object一对一关联"""
object_id: UUID = Field(foreign_key="object.id", unique=True, index=True)
"""关联的对象UUID"""
# 反向关系
object: "Object" = Relationship(back_populates="file_metadata")
"""关联的对象"""
class Object(ObjectBase, UUIDTableBaseMixin):
"""
统一对象模型
@@ -143,7 +192,7 @@ class Object(ObjectBase, UUIDTableBase, table=True):
__table_args__ = (
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况)
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"),
# 名称不能包含斜杠
# 名称不能包含斜杠 ([TODO] 还有特殊字符)
CheckConstraint(
"name NOT LIKE '%/%' AND name NOT LIKE '%\\%'",
name="ck_object_name_no_slash",
@@ -168,7 +217,7 @@ class Object(ObjectBase, UUIDTableBase, table=True):
# ==================== 文件专属字段 ====================
source_name: str | None = None
source_name: str | None = Field(default=None, max_length=255)
"""源文件名(仅文件有效)"""
size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
@@ -177,10 +226,6 @@ class Object(ObjectBase, UUIDTableBase, table=True):
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
"""分块上传会话ID仅文件有效"""
# [TODO] 拆分
file_metadata: str | None = None
"""文件元数据 (JSON格式),仅文件有效"""
# ==================== 外键 ====================
parent_id: UUID | None = Field(default=None, foreign_key="object.id", index=True)
@@ -201,7 +246,7 @@ class Object(ObjectBase, UUIDTableBase, table=True):
"""存储策略"""
# 自引用关系
parent: Optional["Object"] = Relationship(
parent: "Object" = Relationship(
back_populates="children",
sa_relationship_kwargs={"remote_side": "Object.id"},
)
@@ -211,6 +256,12 @@ class Object(ObjectBase, UUIDTableBase, table=True):
"""子对象(文件和子目录)"""
# 仅文件有效的关系
file_metadata: FileMetadata | None = Relationship(
back_populates="object",
sa_relationship_kwargs={"uselist": False},
)
"""文件元数据(仅文件有效)"""
source_links: list["SourceLink"] = Relationship(back_populates="object")
"""源链接列表(仅文件有效)"""

View File

@@ -1,25 +1,58 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class Order(TableBase, table=True):
class OrderStatus(StrEnum):
"""订单状态枚举"""
PENDING = "pending"
"""待支付"""
COMPLETED = "completed"
"""已完成"""
CANCELLED = "cancelled"
"""已取消"""
class OrderType(StrEnum):
"""订单类型枚举"""
# [TODO] 补充具体订单类型
pass
class Order(SQLModelBase, TableBaseMixin):
"""订单模型"""
order_no: str = Field(max_length=255, unique=True, index=True, description="订单号,唯一")
type: int = Field(description="订单类型")
method: str | None = Field(default=None, max_length=255, description="支付方式")
product_id: int | None = Field(default=None, description="商品ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="购买数量")
name: str = Field(max_length=255, description="商品名称")
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单价格(分)")
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单状态: 0=待支付, 1=已完成, 2=已取消")
order_no: str = Field(max_length=255, unique=True, index=True)
"""订单号,唯一"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""订单类型 [TODO] 待定义枚举"""
method: str | None = Field(default=None, max_length=255)
"""支付方式"""
product_id: int | None = Field(default=None)
"""商品ID"""
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""购买数量"""
name: str = Field(max_length=255)
"""商品名称"""
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""订单价格(分)"""
status: OrderStatus = Field(default=OrderStatus.PENDING, sa_column_kwargs={"server_default": "'pending'"})
"""订单状态"""
# 外键
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")

View File

@@ -4,7 +4,8 @@ from uuid import UUID
from enum import StrEnum
from sqlmodel import Field, Relationship, text
from .base import SQLModelBase, UUIDTableBase
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
@@ -29,16 +30,16 @@ class PolicyType(StrEnum):
class PolicyOptionsBase(SQLModelBase):
"""存储策略选项的基础模型"""
token: str | None = None
token: str | None = Field(default=None)
"""访问令牌"""
file_type: str | None = None
file_type: str | None = Field(default=None)
"""允许的文件类型"""
mimetype: str | None = None
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: str | None = None
od_redirect: str | None = Field(default=None, max_length=255)
"""OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
@@ -48,7 +49,7 @@ class PolicyOptionsBase(SQLModelBase):
"""是否使用S3路径风格"""
class PolicyOptions(PolicyOptionsBase, UUIDTableBase, table=True):
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
"""存储策略选项模型与Policy一对一关联"""
policy_id: UUID = Field(foreign_key="policy.id", unique=True)
@@ -59,7 +60,7 @@ class PolicyOptions(PolicyOptionsBase, UUIDTableBase, table=True):
"""关联的策略"""
class Policy(UUIDTableBase, table=True):
class Policy(SQLModelBase, UUIDTableBaseMixin):
"""存储策略模型"""
name: str = Field(max_length=255, unique=True)

View File

@@ -1,10 +1,22 @@
from sqlmodel import Field, text
from .base import TableBase
from enum import StrEnum
class Redeem(TableBase, table=True):
from sqlmodel import Field, text
from .base import SQLModelBase
from .mixin import TableBaseMixin
class RedeemType(StrEnum):
"""兑换码类型枚举"""
# [TODO] 补充具体兑换码类型
pass
class Redeem(SQLModelBase, TableBaseMixin):
"""兑换码模型"""
type: int = Field(description="兑换码类型")
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""兑换码类型 [TODO] 待定义枚举"""
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
code: str = Field(unique=True, index=True, description="兑换码,唯一")

View File

@@ -1,15 +1,26 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .share import Share
class Report(TableBase, table=True):
class ReportReason(StrEnum):
"""举报原因枚举"""
# [TODO] 补充具体举报原因
pass
class Report(SQLModelBase, TableBaseMixin):
"""举报模型"""
reason: int = Field(description="举报原因代码")
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""举报原因 [TODO] 待定义枚举"""
description: str | None = Field(default=None, max_length=255, description="补充描述")
# 外键

View File

@@ -1,15 +1,12 @@
"""
通用响应模型定义
"""
from typing import Any
import uuid
from sqlmodel import Field
from .base import SQLModelBase
# [TODO] 未来把这拆了,直接按需返回状态码
class ResponseBase(SQLModelBase):
"""通用响应模型"""

View File

@@ -1,9 +1,10 @@
from typing import Literal
from enum import StrEnum
from sqlmodel import Field, UniqueConstraint
from .base import TableBase, SQLModelBase
from enum import StrEnum
from .base import SQLModelBase
from .mixin import TableBaseMixin
# ==================== DTO 模型 ====================
@@ -72,7 +73,7 @@ class SettingsType(StrEnum):
WOPI = "wopi"
# 数据库模型
class Setting(TableBase, table=True):
class Setting(SQLModelBase, TableBaseMixin):
"""设置模型"""
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)

View File

@@ -5,7 +5,8 @@ from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -13,7 +14,7 @@ if TYPE_CHECKING:
from .object import Object
class Share(TableBase, table=True):
class Share(SQLModelBase, TableBaseMixin):
"""分享模型"""
__table_args__ = (
@@ -38,10 +39,10 @@ class Share(TableBase, table=True):
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""下载次数"""
remain_downloads: int | None = None
remain_downloads: int | None = Field(default=None)
"""剩余下载次数 (NULL为不限制)"""
expires: datetime | None = None
expires: datetime | None = Field(default=None)
"""过期时间 (NULL为永不过期)"""
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})

View File

@@ -4,13 +4,14 @@ from uuid import UUID
from sqlmodel import Field, Relationship, Index
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .object import Object
class SourceLink(TableBase, table=True):
class SourceLink(SQLModelBase, TableBaseMixin):
"""链接模型"""
__table_args__ = (

View File

@@ -1,16 +1,17 @@
from typing import Optional, TYPE_CHECKING
from typing import TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class StoragePack(TableBase, table=True):
class StoragePack(SQLModelBase, TableBaseMixin):
"""容量包模型"""
name: str = Field(max_length=255, description="容量包名称")

View File

@@ -1,24 +1,41 @@
from typing import Optional, TYPE_CHECKING
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class Tag(TableBase, table=True):
class TagType(StrEnum):
"""标签类型枚举"""
MANUAL = "manual"
"""手动标签"""
AUTOMATIC = "automatic"
"""自动标签"""
class Tag(SQLModelBase, TableBaseMixin):
"""标签模型"""
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
name: str = Field(max_length=255, description="标签名称")
icon: str | None = Field(default=None, max_length=255, description="标签图标")
color: str | None = Field(default=None, max_length=255, description="标签颜色")
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="标签类型: 0=手动, 1=自动")
name: str = Field(max_length=255)
"""标签名称"""
icon: str | None = Field(default=None, max_length=255)
"""标签图标"""
color: str | None = Field(default=None, max_length=255)
"""标签颜色"""
type: TagType = Field(default=TagType.MANUAL, sa_column_kwargs={"server_default": "'manual'"})
"""标签类型"""
expression: str | None = Field(default=None, description="自动标签的匹配表达式")
# 外键

View File

@@ -1,32 +1,96 @@
from typing import Optional, TYPE_CHECKING
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, CheckConstraint
from sqlmodel import Field, Relationship, CheckConstraint, Index
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
from .download import Download
class Task(TableBase, table=True):
class TaskStatus(StrEnum):
"""任务状态枚举"""
QUEUED = "queued"
"""排队中"""
RUNNING = "running"
"""处理中"""
COMPLETED = "completed"
"""已完成"""
ERROR = "error"
"""错误"""
class TaskType(StrEnum):
"""任务类型枚举"""
# [TODO] 补充具体任务类型
pass
class TaskPropsBase(SQLModelBase):
"""任务属性基础模型"""
source_path: str | None = None
"""源路径"""
dest_path: str | None = None
"""目标路径"""
file_ids: str | None = None
"""文件ID列表逗号分隔"""
# [TODO] 根据业务需求补充更多字段
class TaskProps(TaskPropsBase, TableBaseMixin):
"""任务属性模型与Task一对一关联"""
task_id: int = Field(foreign_key="task.id", primary_key=True)
"""关联的任务ID"""
# 反向关系
task: "Task" = Relationship(back_populates="props")
"""关联的任务"""
class Task(SQLModelBase, TableBaseMixin):
"""任务模型"""
__table_args__ = (
CheckConstraint("progress BETWEEN 0 AND 100", name="ck_task_progress_range"),
Index("ix_task_status", "status"),
Index("ix_task_user_status", "user_id", "status"),
)
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误")
type: int = Field(description="任务类型")
progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务进度 (0-100)")
error: str | None = Field(default=None, description="错误信息")
props: str | None = Field(default=None, description="任务属性 (JSON格式)")
status: TaskStatus = Field(default=TaskStatus.QUEUED, sa_column_kwargs={"server_default": "'queued'"})
"""任务状态"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""任务类型 [TODO] 待定义枚举"""
progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0, le=100)
"""任务进度0-100"""
error: str | None = Field(default=None)
"""错误信息"""
# 外键
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
user_id: UUID = Field(foreign_key="user.id", index=True)
"""所属用户UUID"""
# 关系
props: TaskProps | None = Relationship(
back_populates="task",
sa_relationship_kwargs={"uselist": False},
)
"""任务属性"""
user: "User" = Relationship(back_populates="tasks")
"""所属用户"""
downloads: list["Download"] = Relationship(back_populates="task")
"""关联的下载任务"""

View File

@@ -1,11 +1,12 @@
from datetime import datetime
from enum import StrEnum
from typing import Literal, Optional, TYPE_CHECKING
from typing import Literal, TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship
from .base import SQLModelBase, UUIDTableBase
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .group import Group
@@ -19,15 +20,6 @@ if TYPE_CHECKING:
from .user_authn import UserAuthn
from .webdav import WebDAV
"""
Option 需求
- 主题 跟随系统/浅色/深色
- 颜色方案 参考 ThemeResponse
- 语言
- 时区
- 切换到不同存储策略是否提醒
"""
class AvatarType(StrEnum):
"""头像类型枚举"""
@@ -42,6 +34,13 @@ class ThemeType(StrEnum):
DARK = "dark"
SYSTEM = "system"
class UserStatus(StrEnum):
"""用户状态枚举"""
ACTIVE = "active"
ADMIN_BANNED = "admin_banned"
SYSTEM_BANNED = "system_banned"
# ==================== Base 模型 ====================
@@ -51,8 +50,8 @@ class UserBase(SQLModelBase):
username: str
"""用户名"""
status: bool = True
"""用户状态: True=正常, False=封禁"""
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
score: int = 0
"""用户积分"""
@@ -72,7 +71,7 @@ class LoginRequest(SQLModelBase):
captcha: str | None = None
"""验证码"""
two_fa_code: str | None = None
two_fa_code: int | None = Field(min_length=6, max_length=6)
"""两步验证代码"""
@@ -192,9 +191,6 @@ class UserSettingResponse(SQLModelBase):
prefer_theme: str = "#5898d4"
"""用户首选主题"""
qq: str | None = None
"""QQ号"""
themes: dict[str, str] = {}
"""用户主题配置"""
@@ -216,7 +212,7 @@ UserSettingResponse.model_rebuild()
# ==================== 数据库模型 ====================
class User(UserBase, UUIDTableBase, table=True):
class User(UserBase, UUIDTableBaseMixin):
"""用户模型"""
username: str = Field(max_length=50, unique=True, index=True)
@@ -243,7 +239,7 @@ class User(UserBase, UUIDTableBase, table=True):
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
"""用户积分"""
group_expires: datetime | None = None
group_expires: datetime | None = Field(default=None)
"""当前用户组过期时间"""
# Option 相关字段
@@ -266,13 +262,13 @@ class User(UserBase, UUIDTableBase, table=True):
# 关系
group: "Group" = Relationship(
back_populates="user",
back_populates="users",
sa_relationship_kwargs={
"foreign_keys": "User.group_id"
}
)
previous_group: Optional["Group"] = Relationship(
back_populates="previous_user",
previous_group: "Group" = Relationship(
back_populates="previous_users",
sa_relationship_kwargs={
"foreign_keys": "User.previous_group_id"
}

View File

@@ -4,7 +4,8 @@ from uuid import UUID
from sqlalchemy import Column, Text
from sqlmodel import Field, Relationship
from .base import TableBase, SQLModelBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
@@ -24,7 +25,7 @@ class AuthnResponse(SQLModelBase):
# ==================== 数据库模型 ====================
class UserAuthn(TableBase, table=True):
class UserAuthn(SQLModelBase, TableBaseMixin):
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
credential_id: str = Field(max_length=255, unique=True, index=True)

View File

@@ -4,12 +4,13 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
from .base import TableBase
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class WebDAV(TableBase, table=True):
class WebDAV(SQLModelBase, TableBaseMixin):
"""WebDAV账户模型"""
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)

View File

@@ -6,7 +6,7 @@ from fastapi.security import OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌',
description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
tokenUrl="/api/user/session",
tokenUrl="/api/v1/user/session",
)
SECRET_KEY = ''