diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 35410ca..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# 默认忽略的文件 -/shelf/ -/workspace.xml -# 基于编辑器的 HTTP 客户端请求 -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/.idea/Server.iml b/.idea/Server.iml deleted file mode 100644 index c6352eb..0000000 --- a/.idea/Server.iml +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml deleted file mode 100644 index a55e7a1..0000000 --- a/.idea/codeStyles/codeStyleConfig.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/copilot.data.migration.agent.xml b/.idea/copilot.data.migration.agent.xml deleted file mode 100644 index 4ea72a9..0000000 --- a/.idea/copilot.data.migration.agent.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/copilot.data.migration.ask.xml b/.idea/copilot.data.migration.ask.xml deleted file mode 100644 index 7ef04e2..0000000 --- a/.idea/copilot.data.migration.ask.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/copilot.data.migration.ask2agent.xml b/.idea/copilot.data.migration.ask2agent.xml deleted file mode 100644 index 1f2ea11..0000000 --- a/.idea/copilot.data.migration.ask2agent.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/copilot.data.migration.edit.xml b/.idea/copilot.data.migration.edit.xml deleted file mode 100644 index 8648f94..0000000 --- a/.idea/copilot.data.migration.edit.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/dictionaries/project.xml b/.idea/dictionaries/project.xml deleted file mode 100644 index 4787784..0000000 --- a/.idea/dictionaries/project.xml +++ /dev/null @@ -1,3 +0,0 @@ - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml deleted file mode 100644 index dbb3a47..0000000 --- a/.idea/material_theme_project_new.xml +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index caad0e7..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 8f3a104..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/main.py b/main.py index f53c2a4..72c4886 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ app = FastAPI( ) # 挂载路由 -app.include_router(router, prefix='/api') +app.include_router(router) # 防止直接运行 main.py if __name__ == "__main__": diff --git a/models/README.md b/models/README.md index 2f192f1..2732592 100644 --- a/models/README.md +++ b/models/README.md @@ -228,7 +228,6 @@ models/ | `source_name` | `str?` | 源文件名(仅文件) | | `size` | `int` | 文件大小(字节),目录为 0 | | `upload_session_id` | `str?` | 分块上传会话 ID | -| `file_metadata` | `str?` | 文件元数据(JSON 格式) | | `parent_id` | `UUID?` | 父目录(外键,NULL 表示根目录) | | `owner_id` | `UUID` | 所有者用户(外键) | | `policy_id` | `UUID` | 存储策略(外键) | @@ -237,9 +236,35 @@ models/ - 同一父目录下名称唯一 - 名称不能包含斜杠 +**关系**: +- `metadata`: 一对一关联 FileMetadata + --- -### 9. SourceLink(源链接) +### 9. FileMetadata(文件元数据) + +**表名**: `filemetadata` +**基类**: `UUIDTableBase` + +| 字段 | 类型 | 说明 | +|------|------|------| +| `id` | `UUID` | 主键 | +| `object_id` | `UUID` | 关联的对象(外键,唯一) | +| `width` | `int?` | 图片/视频宽度 | +| `height` | `int?` | 图片/视频高度 | +| `duration` | `float?` | 音视频时长(秒) | +| `mime_type` | `str?` | MIME类型 | +| `bit_rate` | `int?` | 比特率 | +| `sample_rate` | `int?` | 采样率 | +| `channels` | `int?` | 音频通道数 | +| `codec` | `str?` | 编解码器 | +| `frame_rate` | `float?` | 视频帧率 | +| `orientation` | `int?` | 图片方向 | +| `has_thumbnail` | `bool` | 是否有缩略图 | + +--- + +### 10. SourceLink(源链接) **表名**: `sourcelink` **基类**: `TableBase` @@ -253,7 +278,7 @@ models/ --- -### 10. Share(分享) +### 11. Share(分享) **表名**: `share` **基类**: `TableBase` @@ -275,7 +300,7 @@ models/ --- -### 11. Report(举报) +### 12. Report(举报) **表名**: `report` **基类**: `TableBase` @@ -289,7 +314,7 @@ models/ --- -### 12. Tag(标签) +### 13. Tag(标签) **表名**: `tag` **基类**: `TableBase` @@ -300,7 +325,7 @@ models/ | `name` | `str` | 标签名称 | | `icon` | `str?` | 标签图标 | | `color` | `str?` | 标签颜色 | -| `type` | `int` | 标签类型:0=手动,1=自动 | +| `type` | `TagType` | 标签类型:manual/automatic | | `expression` | `str?` | 自动标签的匹配表达式 | | `user_id` | `UUID` | 所属用户(外键) | @@ -308,7 +333,7 @@ models/ --- -### 13. Task(任务) +### 14. Task(任务) **表名**: `task` **基类**: `TableBase` @@ -316,16 +341,35 @@ models/ | 字段 | 类型 | 说明 | |------|------|------| | `id` | `int` | 主键 | -| `status` | `int` | 任务状态:0=排队中,1=处理中,2=完成,3=错误 | -| `type` | `int` | 任务类型 | +| `status` | `TaskStatus` | 任务状态:queued/running/completed/error | +| `type` | `int` | 任务类型([TODO] 待定义枚举) | | `progress` | `int` | 任务进度(0-100) | | `error` | `str?` | 错误信息 | -| `props` | `str?` | 任务属性(JSON 格式) | | `user_id` | `UUID` | 所属用户(外键) | +**索引**: `ix_task_status`, `ix_task_user_status` + +**关系**: +- `props`: 一对一关联 TaskProps +- `downloads`: 一对多关联 Download + --- -### 14. Download(离线下载) +### 15. TaskProps(任务属性) + +**表名**: `taskprops` +**基类**: `SQLModelBase` + +| 字段 | 类型 | 说明 | +|------|------|------| +| `task_id` | `int` | 关联的任务(外键,主键) | +| `source_path` | `str?` | 源路径 | +| `dest_path` | `str?` | 目标路径 | +| `file_ids` | `str?` | 文件ID列表(逗号分隔) | + +--- + +### 16. Download(离线下载) **表名**: `download` **基类**: `UUIDTableBase` @@ -333,15 +377,14 @@ models/ | 字段 | 类型 | 说明 | |------|------|------| | `id` | `UUID` | 主键 | -| `status` | `int` | 下载状态:0=进行中,1=完成,2=错误 | -| `type` | `int` | 任务类型 | +| `status` | `DownloadStatus` | 下载状态:running/completed/error | +| `type` | `int` | 任务类型([TODO] 待定义枚举) | | `source` | `str` | 来源 URL 或标识 | | `total_size` | `int` | 总大小(字节) | | `downloaded_size` | `int` | 已下载大小(字节) | | `g_id` | `str?` | Aria2 GID | | `speed` | `int` | 下载速度(bytes/s) | | `parent` | `str?` | 父任务标识 | -| `attrs` | `str?` | 额外属性(JSON 格式) | | `error` | `str?` | 错误信息 | | `dst` | `str` | 目标存储路径 | | `user_id` | `UUID` | 所属用户(外键) | @@ -350,9 +393,52 @@ models/ **约束**: 同一节点下 g_id 唯一 +**索引**: `ix_download_status`, `ix_download_user_status` + +**关系**: +- `aria2_info`: 一对一关联 DownloadAria2Info +- `aria2_files`: 一对多关联 DownloadAria2File + --- -### 15. Node(节点) +### 17. DownloadAria2Info(Aria2下载信息) + +**表名**: `downloadaria2info` +**基类**: `SQLModelBase` + +| 字段 | 类型 | 说明 | +|------|------|------| +| `download_id` | `UUID` | 关联的下载任务(外键,主键) | +| `info_hash` | `str?` | InfoHash(BT种子) | +| `piece_length` | `int` | 分片大小 | +| `num_pieces` | `int` | 分片数量 | +| `num_seeders` | `int` | 做种人数 | +| `connections` | `int` | 连接数 | +| `upload_speed` | `int` | 上传速度(bytes/s) | +| `upload_length` | `int` | 已上传大小(字节) | +| `error_code` | `str?` | 错误代码 | +| `error_message` | `str?` | 错误信息 | + +--- + +### 18. DownloadAria2File(Aria2下载文件) + +**表名**: `downloadaria2file` +**基类**: `TableBase` + +| 字段 | 类型 | 说明 | +|------|------|------| +| `id` | `int` | 主键 | +| `download_id` | `UUID` | 关联的下载任务(外键) | +| `file_index` | `int` | 文件索引(从1开始) | +| `path` | `str` | 文件路径 | +| `length` | `int` | 文件大小(字节) | +| `completed_length` | `int` | 已完成大小(字节) | +| `is_selected` | `bool` | 是否选中下载 | + +--- + +### 19. Node(节点) **表名**: `node` **基类**: `TableBase` @@ -360,19 +446,41 @@ models/ | 字段 | 类型 | 说明 | |------|------|------| | `id` | `int` | 主键 | -| `status` | `int` | 节点状态:0=正常,1=离线 | +| `status` | `NodeStatus` | 节点状态:online/offline | | `name` | `str` | 节点名称,唯一 | -| `type` | `int` | 节点类型 | +| `type` | `int` | 节点类型([TODO] 待定义枚举) | | `server` | `str` | 节点地址(IP 或域名) | | `slave_key` | `str?` | 从机通讯密钥 | | `master_key` | `str?` | 主机通讯密钥 | | `aria2_enabled` | `bool` | 是否启用 Aria2 | -| `aria2_options` | `str?` | Aria2 配置(JSON 格式) | | `rank` | `int` | 节点排序权重 | +**索引**: `ix_node_status` + +**关系**: +- `aria2_config`: 一对一关联 Aria2Configuration +- `downloads`: 一对多关联 Download + --- -### 16. Order(订单) +### 20. Aria2Configuration(Aria2配置) + +**表名**: `aria2configuration` +**基类**: `TableBase` + +| 字段 | 类型 | 说明 | +|------|------|------| +| `id` | `int` | 主键 | +| `node_id` | `int` | 关联的节点(外键,唯一) | +| `rpc_url` | `str?` | RPC地址 | +| `rpc_secret` | `str?` | RPC密钥 | +| `temp_path` | `str?` | 临时下载路径 | +| `max_concurrent` | `int` | 最大并发数(1-50,默认5) | +| `timeout` | `int` | 请求超时时间(秒,默认300) | + +--- + +### 21. Order(订单) **表名**: `order` **基类**: `TableBase` @@ -381,18 +489,18 @@ models/ |------|------|------| | `id` | `int` | 主键 | | `order_no` | `str` | 订单号,唯一 | -| `type` | `int` | 订单类型 | +| `type` | `int` | 订单类型([TODO] 待定义枚举) | | `method` | `str?` | 支付方式 | | `product_id` | `int?` | 商品 ID | | `num` | `int` | 购买数量 | | `name` | `str` | 商品名称 | | `price` | `int` | 订单价格(分) | -| `status` | `int` | 订单状态:0=待支付,1=已完成,2=已取消 | +| `status` | `OrderStatus` | 订单状态:pending/completed/cancelled | | `user_id` | `UUID` | 所属用户(外键) | --- -### 17. Redeem(兑换码) +### 22. Redeem(兑换码) **表名**: `redeem` **基类**: `TableBase` @@ -400,7 +508,7 @@ models/ | 字段 | 类型 | 说明 | |------|------|------| | `id` | `int` | 主键 | -| `type` | `int` | 兑换码类型 | +| `type` | `int` | 兑换码类型([TODO] 待定义枚举) | | `product_id` | `int?` | 关联的商品/权益 ID | | `num` | `int` | 可兑换数量/时长等 | | `code` | `str` | 兑换码,唯一 | @@ -408,7 +516,7 @@ models/ --- -### 18. StoragePack(容量包) +### 23. StoragePack(容量包) **表名**: `storagepack` **基类**: `TableBase` @@ -424,7 +532,7 @@ models/ --- -### 19. WebDAV(WebDAV 账户) +### 24. WebDAV(WebDAV 账户) **表名**: `webdav` **基类**: `TableBase` @@ -443,7 +551,7 @@ models/ --- -### 20. Setting(系统设置) +### 25. Setting(系统设置) **表名**: `setting` **基类**: `TableBase` @@ -467,23 +575,39 @@ models/ ### 一对一关系 ``` -┌─────────────────────────────────────────────────────────┐ -│ 一对一关系 │ -├─────────────────────────────────────────────────────────┤ -│ │ -│ Group ◄────────────────────────> GroupOptions │ -│ group_id (unique FK) │ -│ │ -│ Policy ◄───────────────────────> PolicyOptions │ -│ policy_id (unique FK) │ -│ │ -└─────────────────────────────────────────────────────────┘ +┌───────────────────────────────────────────────────────────────────┐ +│ 一对一关系 │ +├───────────────────────────────────────────────────────────────────┤ +│ │ +│ Group ◄─────────────────────────> GroupOptions │ +│ group_id (unique FK) │ +│ │ +│ Policy ◄────────────────────────> PolicyOptions │ +│ policy_id (unique FK) │ +│ │ +│ Object ◄────────────────────────> FileMetadata │ +│ object_id (unique FK) │ +│ │ +│ Node ◄──────────────────────────> Aria2Configuration │ +│ node_id (unique FK) │ +│ │ +│ Task ◄──────────────────────────> TaskProps │ +│ task_id (PK/FK) │ +│ │ +│ Download ◄──────────────────────> DownloadAria2Info │ +│ download_id (PK/FK) │ +│ │ +└───────────────────────────────────────────────────────────────────┘ ``` | 主表 | 从表 | 外键 | 说明 | |------|------|------|------| | Group | GroupOptions | `group_id` (unique) | 每个用户组有且仅有一个选项配置 | | Policy | PolicyOptions | `policy_id` (unique) | 每个存储策略有且仅有一个扩展选项 | +| Object | FileMetadata | `object_id` (unique) | 每个文件对象有且仅有一个元数据 | +| Node | Aria2Configuration | `node_id` (unique) | 每个节点有且仅有一个 Aria2 配置 | +| Task | TaskProps | `task_id` (PK) | 每个任务有且仅有一个属性配置 | +| Download | DownloadAria2Info | `download_id` (PK) | 每个下载任务有且仅有一个 Aria2 信息 | --- @@ -540,6 +664,7 @@ models/ | **Share** | Report | `share_id` | 分享的举报 | | **Task** | Download | `task_id` | 任务关联的下载 | | **Node** | Download | `node_id` | 节点执行的下载任务 | +| **Download** | DownloadAria2File | `download_id` | 下载任务的文件列表 | --- @@ -666,6 +791,56 @@ class AvatarType(StrEnum): FILE = "file" # 自定义文件 ``` +### TagType +```python +class TagType(StrEnum): + MANUAL = "manual" # 手动标签 + AUTOMATIC = "automatic" # 自动标签 +``` + +### TaskStatus +```python +class TaskStatus(StrEnum): + QUEUED = "queued" # 排队中 + RUNNING = "running" # 处理中 + COMPLETED = "completed" # 已完成 + ERROR = "error" # 错误 +``` + +### DownloadStatus +```python +class DownloadStatus(StrEnum): + RUNNING = "running" # 进行中 + COMPLETED = "completed" # 已完成 + ERROR = "error" # 错误 +``` + +### NodeStatus +```python +class NodeStatus(StrEnum): + ONLINE = "online" # 正常 + OFFLINE = "offline" # 离线 +``` + +### OrderStatus +```python +class OrderStatus(StrEnum): + PENDING = "pending" # 待支付 + COMPLETED = "completed" # 已完成 + CANCELLED = "cancelled" # 已取消 +``` + +### 待定义枚举([TODO]) + +以下枚举已定义框架,具体值待业务需求确定: + +- `TaskType` - 任务类型 +- `DownloadType` - 下载类型 +- `NodeType` - 节点类型 +- `OrderType` - 订单类型 +- `RedeemType` - 兑换码类型 +- `ReportReason` - 举报原因 + --- ## DTO 模型 diff --git a/models/__init__.py b/models/__init__.py index d4f1d78..3579b57 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -14,12 +14,27 @@ from .user import ( from .user_authn import AuthnResponse, UserAuthn from .color import ThemeResponse -from .download import Download +from .download import ( + Download, + DownloadAria2File, + DownloadAria2Info, + DownloadAria2InfoBase, + DownloadStatus, + DownloadType, +) +from .node import ( + Aria2Configuration, + Aria2ConfigurationBase, + Node, + NodeStatus, + NodeType, +) from .group import Group, GroupBase, GroupOptions, GroupOptionsBase, GroupResponse -from .node import Node from .object import ( DirectoryCreateRequest, DirectoryResponse, + FileMetadata, + FileMetadataBase, Object, ObjectBase, ObjectDeleteRequest, @@ -28,16 +43,16 @@ from .object import ( ObjectType, PolicyResponse, ) -from .order import Order +from .order import Order, OrderStatus, OrderType from .policy import Policy, PolicyOptions, PolicyOptionsBase, PolicyType -from .redeem import Redeem -from .report import Report +from .redeem import Redeem, RedeemType +from .report import Report, ReportReason from .setting import Setting, SettingsType, SiteConfigResponse from .share import Share from .source_link import SourceLink from .storage_pack import StoragePack -from .tag import Tag -from .task import Task +from .tag import Tag, TagType +from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType from .webdav import WebDAV from .database import engine, get_session diff --git a/models/base/README.md b/models/base/README.md new file mode 100644 index 0000000..346771f --- /dev/null +++ b/models/base/README.md @@ -0,0 +1,657 @@ +# SQLModels Base Module + +This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility. + +**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities. + +## Table of Contents + +- [Overview](#overview) +- [Migration Notice](#migration-notice) +- [Python 3.14 Compatibility](#python-314-compatibility) +- [Core Component](#core-component) + - [SQLModelBase](#sqlmodelbase) +- [Metaclass Features](#metaclass-features) + - [Automatic sa_type Injection](#automatic-sa_type-injection) + - [Table Configuration](#table-configuration) + - [Polymorphic Support](#polymorphic-support) +- [Custom Types Integration](#custom-types-integration) +- [Best Practices](#best-practices) +- [Troubleshooting](#troubleshooting) + +## Overview + +The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features: + +- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations +- **Python 3.14 compatibility** through comprehensive PEP 649/749 support +- **Flexible configuration** through class parameters and automatic docstring support +- **Type-safe annotations** with automatic validation + +All models in this project should directly or indirectly inherit from `SQLModelBase`. + +--- + +## Migration Notice + +As of the recent refactoring, the following components have been moved: + +| Component | Old Location | New Location | +|-----------|-------------|--------------| +| `TableBase` → `TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` | +| `UUIDTableBase` → `UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` | +| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` | +| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` | +| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` | +| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` | +| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` | + +**Update your imports**: + +```python +# ❌ Old (deprecated) +from sqlmodels.base import TableBase, UUIDTableBase + +# ✅ New (correct) +from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin +``` + +For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md). + +--- + +## Python 3.14 Compatibility + +### Overview + +This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default). + +**Key Changes in Python 3.14**: +- Annotations are no longer evaluated at class definition time +- Type hints are stored as deferred code objects +- `__annotate__` function generates annotations on demand +- Forward references become `ForwardRef` objects + +### Implementation Strategy + +We use **`typing.get_type_hints()`** as the universal annotations resolver: + +```python +def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]: + # Create temporary proxy class + temp_cls = type('AnnotationProxy', (object,), dict(attrs)) + + # Use get_type_hints with include_extras=True + evaluated = get_type_hints( + temp_cls, + globalns=module_globals, + localns=localns, + include_extras=True # Preserve Annotated metadata + ) + + return dict(evaluated), {}, module_globals, localns +``` + +**Why `get_type_hints()`?** +- ✅ Works across Python 3.10-3.14+ +- ✅ Handles PEP 649 automatically +- ✅ Preserves `Annotated` metadata (with `include_extras=True`) +- ✅ Resolves forward references +- ✅ Recommended by Python documentation + +### SQLModel Compatibility Patch + +**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute. + +**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic: + +```python +if sys.version_info >= (3, 14): + def _patched_get_sqlalchemy_type(field): + annotation = getattr(field, 'annotation', None) + if annotation is not None: + # Priority 1: Check __sqlmodel_sa_type__ attribute + # Handles NumpyVector[dims, dtype] and similar custom types + if hasattr(annotation, '__sqlmodel_sa_type__'): + return annotation.__sqlmodel_sa_type__ + + # Priority 2: Check Annotated metadata + if get_origin(annotation) is Annotated: + for metadata in get_args(annotation)[1:]: + if hasattr(metadata, '__sqlmodel_sa_type__'): + return metadata.__sqlmodel_sa_type__ + + # ... handle ForwardRef, ClassVar, etc. + + return _original_get_sqlalchemy_type(field) +``` + +### Supported Patterns + +#### Pattern 1: Direct Custom Type Usage +```python +from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector +from sqlmodels.mixin import UUIDTableBaseMixin + +class SpeakerInfo(UUIDTableBaseMixin, table=True): + embedding: NumpyVector[256, np.float32] + """Voice embedding - sa_type automatically extracted""" +``` + +#### Pattern 2: Annotated Wrapper +```python +from typing import Annotated +from sqlmodels.mixin import UUIDTableBaseMixin + +EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]] + +class SpeakerInfo(UUIDTableBaseMixin, table=True): + embedding: EmbeddingVector +``` + +#### Pattern 3: Array Type +```python +from sqlmodels.sqlmodel_types.dialects.postgresql import Array +from sqlmodels.mixin import TableBaseMixin + +class ServerConfig(TableBaseMixin, table=True): + protocols: Array[ProtocolEnum] + """Allowed protocols - sa_type from Array handler""" +``` + +### Migration from Python 3.13 + +**No code changes required!** The implementation is transparent: + +- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14 +- Custom types already use `__sqlmodel_sa_type__` attribute +- Monkey-patch only activates for Python 3.14+ + +--- + +## Core Component + +### SQLModelBase + +`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities. + +**Key Features**: +- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`) +- Automatic `validate_by_name` configuration +- Custom metaclass for sa_type injection and polymorphic setup +- Integration with Pydantic v2 +- Python 3.14 PEP 649 compatibility + +**Usage**: + +```python +from sqlmodels.base import SQLModelBase + +class UserBase(SQLModelBase): + name: str + """User's display name""" + + email: str + """User's email address""" +``` + +**Important Notes**: +- Use **docstrings** for field descriptions, not `Field(description=...)` +- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase) +- This class should be used for non-table models (DTOs, request/response models) + +**For table models**, use mixins from `sqlmodels.mixin`: +- `TableBaseMixin` - Integer primary key with timestamps +- `UUIDTableBaseMixin` - UUID primary key with timestamps + +See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation. + +--- + +## Metaclass Features + +### Automatic sa_type Injection + +The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types. + +**Before** (verbose): +```python +from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType +from sqlmodels.mixin import UUIDTableBaseMixin + +class SpeakerInfo(UUIDTableBaseMixin, table=True): + embedding: np.ndarray = Field( + sa_type=_NumpyVectorSQLAlchemyType(256, np.float32) + ) +``` + +**After** (clean): +```python +from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector +from sqlmodels.mixin import UUIDTableBaseMixin + +class SpeakerInfo(UUIDTableBaseMixin, table=True): + embedding: NumpyVector[256, np.float32] + """Speaker voice embedding""" +``` + +**How It Works**: + +The metaclass uses a three-tier detection strategy: + +1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1) + ```python + if hasattr(annotation, '__sqlmodel_sa_type__'): + return annotation.__sqlmodel_sa_type__ + ``` + +2. **Annotated metadata** (Priority 2) + ```python + # For Annotated[np.ndarray, NumpyVector[256, np.float32]] + if get_origin(annotation) is typing.Annotated: + for item in metadata_items: + if hasattr(item, '__sqlmodel_sa_type__'): + return item.__sqlmodel_sa_type__ + ``` + +3. **Pydantic Core Schema metadata** (Priority 3) + ```python + schema = annotation.__get_pydantic_core_schema__(...) + if schema['metadata'].get('sa_type'): + return schema['metadata']['sa_type'] + ``` + +After extracting `sa_type`, the metaclass: +- Creates `Field(sa_type=sa_type)` if no Field is defined +- Injects `sa_type` into existing Field if not already set +- Respects explicit `Field(sa_type=...)` (no override) + +**Supported Patterns**: + +```python +from sqlmodels.mixin import UUIDTableBaseMixin + +# Pattern 1: Direct usage (recommended) +class Model(UUIDTableBaseMixin, table=True): + embedding: NumpyVector[256, np.float32] + +# Pattern 2: With Field constraints +class Model(UUIDTableBaseMixin, table=True): + embedding: NumpyVector[256, np.float32] = Field(nullable=False) + +# Pattern 3: Annotated wrapper +EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]] + +class Model(UUIDTableBaseMixin, table=True): + embedding: EmbeddingVector + +# Pattern 4: Explicit sa_type (override) +class Model(UUIDTableBaseMixin, table=True): + embedding: NumpyVector[256, np.float32] = Field( + sa_type=_NumpyVectorSQLAlchemyType(128, np.float16) + ) +``` + +### Table Configuration + +The metaclass provides smart defaults and flexible configuration: + +**Automatic `table=True`**: +```python +# Classes inheriting from TableBaseMixin automatically get table=True +from sqlmodels.mixin import UUIDTableBaseMixin + +class MyModel(UUIDTableBaseMixin): # table=True is automatic + pass +``` + +**Convenient mapper arguments**: +```python +# Instead of verbose __mapper_args__ +from sqlmodels.mixin import UUIDTableBaseMixin + +class MyModel( + UUIDTableBaseMixin, + polymorphic_on='_polymorphic_name', + polymorphic_abstract=True +): + pass + +# Equivalent to: +class MyModel(UUIDTableBaseMixin): + __mapper_args__ = { + 'polymorphic_on': '_polymorphic_name', + 'polymorphic_abstract': True + } +``` + +**Smart merging**: +```python +# Dictionary and keyword arguments are merged +from sqlmodels.mixin import UUIDTableBaseMixin + +class MyModel( + UUIDTableBaseMixin, + mapper_args={'version_id_col': 'version'}, + polymorphic_on='type' # Merged into __mapper_args__ +): + pass +``` + +### Polymorphic Support + +The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters: + +**Supported parameters**: +- `polymorphic_on`: Discriminator column name +- `polymorphic_identity`: Identity value for this class +- `polymorphic_abstract`: Whether this is an abstract base +- `table_args`: SQLAlchemy table arguments +- `table_name`: Override table name (becomes `__tablename__`) + +**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md). + +--- + +## Custom Types Integration + +### Using NumpyVector + +The `NumpyVector` type demonstrates automatic sa_type injection: + +```python +from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector +from sqlmodels.mixin import UUIDTableBaseMixin +import numpy as np + +class SpeakerInfo(UUIDTableBaseMixin, table=True): + embedding: NumpyVector[256, np.float32] + """Speaker voice embedding - sa_type automatically injected""" +``` + +**How NumpyVector works**: + +```python +# NumpyVector[dims, dtype] returns a class with: +class _NumpyVectorType: + __sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype) + + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + return handler.generate_schema(np.ndarray) +``` + +This dual approach ensures: +1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__` +2. Pydantic can validate as `np.ndarray` + +### Creating Custom SQLAlchemy Types + +To create types that work with automatic injection, provide one of: + +**Option 1: `__sqlmodel_sa_type__` attribute** (preferred): + +```python +from sqlalchemy import TypeDecorator, String + +class UpperCaseString(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + return value.upper() if value else value + +class UpperCaseType: + __sqlmodel_sa_type__ = UpperCaseString() + + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + return core_schema.str_schema() + +# Usage +from sqlmodels.mixin import UUIDTableBaseMixin + +class MyModel(UUIDTableBaseMixin, table=True): + code: UpperCaseType # Automatically uses UpperCaseString() +``` + +**Option 2: Pydantic metadata with sa_type**: + +```python +def __get_pydantic_core_schema__(self, source_type, handler): + return core_schema.json_or_python_schema( + json_schema=core_schema.str_schema(), + python_schema=core_schema.str_schema(), + metadata={'sa_type': UpperCaseString()} + ) +``` + +**Option 3: Using Annotated**: + +```python +from typing import Annotated +from sqlmodels.mixin import UUIDTableBaseMixin + +UpperCase = Annotated[str, UpperCaseType()] + +class MyModel(UUIDTableBaseMixin, table=True): + code: UpperCase +``` + +--- + +## Best Practices + +### 1. Inherit from correct base classes + +```python +from sqlmodels.base import SQLModelBase +from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin + +# ✅ For non-table models (DTOs, requests, responses) +class UserBase(SQLModelBase): + name: str + +# ✅ For table models with UUID primary key +class User(UserBase, UUIDTableBaseMixin, table=True): + email: str + +# ✅ For table models with custom primary key +class LegacyUser(TableBaseMixin, table=True): + id: int = Field(primary_key=True) + username: str +``` + +### 2. Use docstrings for field descriptions + +```python +from sqlmodels.mixin import UUIDTableBaseMixin + +# ✅ Recommended +class User(UUIDTableBaseMixin, table=True): + name: str + """User's display name""" + +# ❌ Avoid +class User(UUIDTableBaseMixin, table=True): + name: str = Field(description="User's display name") +``` + +**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs. + +### 3. Leverage automatic sa_type injection + +```python +from sqlmodels.mixin import UUIDTableBaseMixin + +# ✅ Clean and recommended +class SpeakerInfo(UUIDTableBaseMixin, table=True): + embedding: NumpyVector[256, np.float32] + """Voice embedding""" + +# ❌ Verbose and unnecessary +class SpeakerInfo(UUIDTableBaseMixin, table=True): + embedding: np.ndarray = Field( + sa_type=_NumpyVectorSQLAlchemyType(256, np.float32) + ) +``` + +### 4. Follow polymorphic naming conventions + +See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`. + +### 5. Separate Base, Parent, and Implementation classes + +```python +from abc import ABC, abstractmethod +from sqlmodels.base import SQLModelBase +from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin + +# ✅ Recommended structure +class ASRBase(SQLModelBase): + """Pure data fields, no table""" + name: str + base_url: str + +class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC): + """Abstract parent with table""" + @abstractmethod + async def transcribe(self, audio: bytes) -> str: + pass + +class WhisperASR(ASR, table=True): + """Concrete implementation""" + model_size: str + + async def transcribe(self, audio: bytes) -> str: + # Implementation + pass +``` + +**Why?** +- Base class can be reused for DTOs +- Parent class defines the polymorphic hierarchy +- Implementation classes are clean and focused + +--- + +## Troubleshooting + +### Issue: ValueError: X has no matching SQLAlchemy type + +**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`. + +```python +# ✅ Provide __sqlmodel_sa_type__ +class MyType: + __sqlmodel_sa_type__ = MyCustomSQLAlchemyType() +``` + +### Issue: Can't generate DDL for NullType() + +**Symptoms**: Error during table creation saying a column has `NullType`. + +**Root Cause**: Custom type's `sa_type` not detected by SQLModel. + +**Solution**: +1. Ensure your type has `__sqlmodel_sa_type__` class attribute +2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`) +3. Verify type annotation is correct (not a string forward reference) + +```python +from sqlmodels.mixin import UUIDTableBaseMixin + +# ✅ Correct +class Model(UUIDTableBaseMixin, table=True): + data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected + +# ❌ Wrong (string annotation) +class Model(UUIDTableBaseMixin, table=True): + data: 'NumpyVector[256, np.float32]' # sa_type lost +``` + +### Issue: Polymorphic identity conflicts + +**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities. + +**Solution**: +1. Check that each concrete class has a unique identity +2. Use `AutoPolymorphicIdentityMixin` for automatic naming +3. Manually specify identity if needed: + ```python + class MyClass(Parent, polymorphic_identity='unique.name', table=True): + pass + ``` + +### Issue: Python 3.14 annotation errors + +**Symptoms**: Errors related to `__annotations__` or type resolution. + +**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist: +1. Check for manual `__annotations__` manipulation (avoid it) +2. Ensure all types are properly imported +3. Avoid `from __future__ import annotations` (can cause SQLModel issues) + +### Issue: Polymorphic and CRUD-related errors + +For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md). + +--- + +## Implementation Details + +For developers modifying this module: + +**Core files**: +- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase` +- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin` +- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin` + +**Key functions in this module**: + +1. **`_resolve_annotations(attrs: dict[str, Any])`** + - Uses `typing.get_type_hints()` for Python 3.14 compatibility + - Returns tuple: `(annotations, annotation_strings, globalns, localns)` + - Preserves `Annotated` metadata with `include_extras=True` + +2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`** + - Extracts SQLAlchemy type from type annotations + - Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema + - Called by metaclass during class creation + +3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+) + - Global monkey-patch for SQLModel + - Checks `__sqlmodel_sa_type__` before falling back to original logic + - Handles custom types like `NumpyVector` and `Array` + +4. **`__DeclarativeMeta.__new__()`** + - Processes class definition parameters + - Injects `sa_type` into field definitions + - Sets up `__mapper_args__`, `__table_args__`, etc. + - Handles Python 3.14 annotations via `get_type_hints()` + +**Metaclass processing order**: +1. Check if class should be a table (`_is_table_mixin`) +2. Collect `__mapper_args__` from kwargs and explicit dict +3. Process `table_args`, `table_name`, `abstract` parameters +4. Resolve annotations using `get_type_hints()` +5. For each field, try to extract `sa_type` and inject into Field +6. Call parent metaclass with cleaned kwargs + +For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md). + +--- + +## See Also + +**Project Documentation**: +- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns +- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md) +- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md) + +**External References**: +- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance) +- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/) +- [SQLModel Documentation](https://sqlmodel.tiangolo.com/) +- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/) +- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/) +- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html) diff --git a/models/base/__init__.py b/models/base/__init__.py index d5463ec..4744778 100644 --- a/models/base/__init__.py +++ b/models/base/__init__.py @@ -1,2 +1,12 @@ +""" +SQLModel 基础模块 + +包含: +- SQLModelBase: 所有 SQLModel 类的基类(真正的基类) + +注意: + TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 models.mixin + 为了避免循环导入,此处不再重新导出它们 + 请直接从 models.mixin 导入这些类 +""" from .sqlmodel_base import SQLModelBase -from .table_base import TableBase, UUIDTableBase, now, now_date diff --git a/models/base/sqlmodel_base.py b/models/base/sqlmodel_base.py index 96b7f09..22bee25 100644 --- a/models/base/sqlmodel_base.py +++ b/models/base/sqlmodel_base.py @@ -1,5 +1,846 @@ -from pydantic import ConfigDict -from sqlmodel import SQLModel +import sys +import typing +from typing import Any, Mapping, get_args, get_origin, get_type_hints + +from pydantic import ConfigDict +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined as Undefined +from sqlalchemy.orm import Mapped +from sqlmodel import Field, SQLModel +from sqlmodel.main import SQLModelMetaclass + +# Python 3.14+ PEP 649支持 +if sys.version_info >= (3, 14): + import annotationlib + + # 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题 + import sqlmodel.main + _original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type + + def _patched_get_sqlalchemy_type(field): + """ + 修复SQLModel的get_sqlalchemy_type函数,处理Python 3.14的类型问题。 + + 问题: + 1. ForwardRef对象(来自Relationship字段)会导致issubclass错误 + 2. typing._GenericAlias对象(如ClassVar[T])也会导致同样问题 + 3. list/dict等泛型类型在没有Field/Relationship时可能导致错误 + 4. Mapped类型在Python 3.14下可能出现在annotation中 + 5. Annotated类型可能包含sa_type metadata(如Array[T]) + 6. 自定义类型(如NumpyVector)有__sqlmodel_sa_type__属性 + 7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中 + + 解决: + - 优先检查field.metadata中的__get_pydantic_core_schema__(Pydantic已处理的情况) + - 检测__sqlmodel_sa_type__属性(NumpyVector等) + - 检测Relationship/ClassVar等返回None + - 对于Annotated类型,尝试提取sa_type metadata + - 其他情况调用原始函数 + """ + # 优先检查 field.metadata(Pydantic已处理Annotated类型的情况) + # 当使用 Array[T] 或 Annotated[T, metadata] 时,Pydantic会将metadata存储在这里 + metadata = getattr(field, 'metadata', None) + if metadata: + # metadata是一个列表,包含所有Annotated的元数据项 + for metadata_item in metadata: + # 检查metadata_item是否有__get_pydantic_core_schema__方法 + if hasattr(metadata_item, '__get_pydantic_core_schema__'): + try: + # 调用获取schema + schema = metadata_item.__get_pydantic_core_schema__(None, None) + # 检查schema的metadata中是否有sa_type + if isinstance(schema, dict) and 'metadata' in schema: + sa_type = schema['metadata'].get('sa_type') + if sa_type is not None: + return sa_type + except (TypeError, AttributeError, KeyError): + # Pydantic schema获取可能失败(类型不匹配、缺少属性等) + # 这是正常情况,继续检查下一个metadata项 + pass + + annotation = getattr(field, 'annotation', None) + if annotation is not None: + # 优先检查 __sqlmodel_sa_type__ 属性 + # 这处理 NumpyVector[dims, dtype] 等自定义类型 + if hasattr(annotation, '__sqlmodel_sa_type__'): + return annotation.__sqlmodel_sa_type__ + + # 检查自定义类型(如JSON100K)的 __get_pydantic_core_schema__ 方法 + # 这些类型在schema的metadata中定义sa_type + if hasattr(annotation, '__get_pydantic_core_schema__'): + try: + # 调用获取schema(传None作为handler,因为我们只需要metadata) + schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None) + # 检查schema的metadata中是否有sa_type + if isinstance(schema, dict) and 'metadata' in schema: + sa_type = schema['metadata'].get('sa_type') + if sa_type is not None: + return sa_type + except (TypeError, AttributeError, KeyError): + # Schema获取失败,继续其他检查 + pass + + anno_type_name = type(annotation).__name__ + + # ForwardRef: Relationship字段的annotation + if anno_type_name == 'ForwardRef': + return None + + # AnnotatedAlias: 检查是否有sa_type metadata(如Array[T]) + if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias': + from typing import get_origin, get_args + import typing + + # 尝试提取Annotated的metadata + if hasattr(typing, 'get_args'): + args = get_args(annotation) + # args[0]是实际类型,args[1:]是metadata + for metadata in args[1:]: + # 检查metadata是否有__get_pydantic_core_schema__方法 + if hasattr(metadata, '__get_pydantic_core_schema__'): + try: + # 调用获取schema + schema = metadata.__get_pydantic_core_schema__(None, None) + # 检查schema中是否有sa_type + if isinstance(schema, dict) and 'metadata' in schema: + sa_type = schema['metadata'].get('sa_type') + if sa_type is not None: + return sa_type + except (TypeError, AttributeError, KeyError): + # Annotated metadata的schema获取可能失败 + # 这是正常的类型检查过程,继续检查下一个metadata + pass + + # _GenericAlias或GenericAlias: typing泛型类型 + if anno_type_name in ('_GenericAlias', 'GenericAlias'): + from typing import get_origin + import typing + origin = get_origin(annotation) + + # ClassVar必须跳过 + if origin is typing.ClassVar: + return None + + # list/dict/tuple/set等内置泛型,如果字段没有明确的Field或Relationship,也跳过 + # 这通常意味着它是Relationship字段或类变量 + if origin in (list, dict, tuple, set): + # 检查field_info是否存在且有意义 + # Relationship字段会有特殊的field_info + field_info = getattr(field, 'field_info', None) + if field_info is None: + return None + + # Mapped: SQLAlchemy 2.0的Mapped类型,SQLModel不应该处理 + # 这可能是从父类继承的字段或Python 3.14注解处理的副作用 + # 检查类型名称和annotation的字符串表示 + if 'Mapped' in anno_type_name or 'Mapped' in str(annotation): + return None + + # 检查annotation是否是Mapped类或其实例 + try: + from sqlalchemy.orm import Mapped as SAMapped + # 检查origin(对于Mapped[T]这种泛型) + from typing import get_origin + if get_origin(annotation) is SAMapped: + return None + # 检查类型本身 + if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped): + return None + except (ImportError, TypeError): + # 如果SQLAlchemy没有Mapped或检查失败,继续 + pass + + # 其他情况正常处理 + return _original_get_sqlalchemy_type(field) + + sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type + + # 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题 + # 在Python 3.14 + SQLModel组合下,当子类(如SMSBaoProvider)继承父类(如VerificationCodeProvider)时, + # 父类的关系字段(如server_config)会在子类的model_fields中出现, + # 但其default值错误地设置为InstrumentedAttribute对象,而不是None + # 这导致实例化时尝试设置InstrumentedAttribute为字段值,触发SQLAlchemy内部错误 + import sqlmodel._compat as _compat + from sqlalchemy.orm import attributes as _sa_attributes + + _original_sqlmodel_table_construct = _compat.sqlmodel_table_construct + + def _patched_sqlmodel_table_construct(self_instance, values): + """ + 修复sqlmodel_table_construct,跳过InstrumentedAttribute默认值 + + 问题: + - 继承自polymorphic基类的表类(如FishAudioTTS, SMSBaoProvider) + - 其model_fields中的继承字段default值为InstrumentedAttribute + - 原函数尝试将InstrumentedAttribute设置为字段值 + - SQLAlchemy无法处理,抛出 '_sa_instance_state' 错误 + + 解决: + - 只设置用户提供的值和非InstrumentedAttribute默认值 + - InstrumentedAttribute默认值跳过(让SQLAlchemy自己处理) + """ + cls = type(self_instance) + + # 收集要设置的字段值 + fields_to_set = {} + + for name, field in cls.model_fields.items(): + # 如果用户提供了值,直接使用 + if name in values: + fields_to_set[name] = values[name] + continue + + # 否则检查默认值 + # 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值 + if isinstance(field.default, _sa_attributes.InstrumentedAttribute): + continue + + # 使用正常的默认值 + if field.default is not Undefined: + fields_to_set[name] = field.default + elif field.default_factory is not None: + fields_to_set[name] = field.get_default(call_default_factory=True) + + # 设置属性 - 只设置非InstrumentedAttribute值 + for key, value in fields_to_set.items(): + if not isinstance(value, _sa_attributes.InstrumentedAttribute): + setattr(self_instance, key, value) + + # 设置Pydantic内部属性 + object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys())) + if not cls.__pydantic_root_model__: + _extra = None + if cls.model_config.get('extra') == 'allow': + _extra = {} + for k, v in values.items(): + if k not in cls.model_fields: + _extra[k] = v + object.__setattr__(self_instance, '__pydantic_extra__', _extra) + + if cls.__pydantic_post_init__: + self_instance.model_post_init(None) + elif not cls.__pydantic_root_model__: + object.__setattr__(self_instance, '__pydantic_private__', None) + + # 设置关系 + for key in self_instance.__sqlmodel_relationships__: + value = values.get(key, Undefined) + if value is not Undefined: + setattr(self_instance, key, value) + + return self_instance + + _compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct +else: + annotationlib = None + + +def _extract_sa_type_from_annotation(annotation: Any) -> Any | None: + """ + 从类型注解中提取SQLAlchemy类型。 + + 支持以下形式: + 1. NumpyVector[256, np.float32] - 直接使用类型(有__sqlmodel_sa_type__属性) + 2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装 + 3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型 + + Args: + annotation: 字段的类型注解 + + Returns: + 提取到的SQLAlchemy类型,如果没有则返回None + """ + # 方法1:直接检查类型本身是否有__sqlmodel_sa_type__属性 + # 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况 + if hasattr(annotation, '__sqlmodel_sa_type__'): + return annotation.__sqlmodel_sa_type__ + + # 方法2:检查是否为Annotated类型 + if get_origin(annotation) is typing.Annotated: + # 获取元数据项(跳过第一个实际类型参数) + args = get_args(annotation) + if len(args) >= 2: + metadata_items = args[1:] # 第一个是实际类型,后面都是元数据 + + # 遍历元数据,查找包含sa_type的项 + for item in metadata_items: + # 检查元数据项是否有__sqlmodel_sa_type__属性 + if hasattr(item, '__sqlmodel_sa_type__'): + return item.__sqlmodel_sa_type__ + + # 检查是否有__get_pydantic_core_schema__方法 + if hasattr(item, '__get_pydantic_core_schema__'): + try: + # 调用该方法获取core schema + schema = item.__get_pydantic_core_schema__( + annotation, + lambda x: None # 虚拟handler + ) + # 检查schema的metadata中是否有sa_type + if isinstance(schema, dict) and 'metadata' in schema: + sa_type = schema['metadata'].get('sa_type') + if sa_type is not None: + return sa_type + except (TypeError, AttributeError, KeyError, ValueError): + # Pydantic core schema获取可能失败: + # - TypeError: 参数不匹配 + # - AttributeError: metadata不存在 + # - KeyError: schema结构不符合预期 + # - ValueError: 无效的类型定义 + # 这是正常的类型探测过程,继续检查下一个metadata项 + pass + + # 方法3:检查类型本身是否有__get_pydantic_core_schema__ + # (虽然NumpyVector已经在方法1处理,但这是通用的fallback) + if hasattr(annotation, '__get_pydantic_core_schema__'): + try: + schema = annotation.__get_pydantic_core_schema__( + annotation, + lambda x: None # 虚拟handler + ) + if isinstance(schema, dict) and 'metadata' in schema: + sa_type = schema['metadata'].get('sa_type') + if sa_type is not None: + return sa_type + except (TypeError, AttributeError, KeyError, ValueError): + # 类型本身的schema获取失败 + # 这是正常的fallback机制,annotation可能不支持此协议 + pass + + return None + + +def _resolve_annotations(attrs: dict[str, Any]) -> tuple[ + dict[str, Any], + dict[str, str], + Mapping[str, Any], + Mapping[str, Any], +]: + """ + Resolve annotations from a class namespace with Python 3.14 (PEP 649) support. + + This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated` + metadata and custom types remain accessible. Forward references that cannot be + evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the + whole resolution process. + """ + raw_annotations = attrs.get('__annotations__') or {} + try: + base_annotations = dict(raw_annotations) + except TypeError: + base_annotations = {} + + module_name = attrs.get('__module__') + module_globals: dict[str, Any] + if module_name and module_name in sys.modules: + module_globals = dict(sys.modules[module_name].__dict__) + else: + module_globals = {} + + module_globals.setdefault('__builtins__', __builtins__) + localns: dict[str, Any] = dict(attrs) + + try: + temp_cls = type('AnnotationProxy', (object,), dict(attrs)) + temp_cls.__module__ = module_name + extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {} + evaluated = get_type_hints( + temp_cls, + globalns=module_globals, + localns=localns, + **extras_kw, + ) + except (NameError, AttributeError, TypeError, RecursionError): + # get_type_hints可能失败的原因: + # - NameError: 前向引用无法解析(类型尚未定义) + # - AttributeError: 模块或类型不存在 + # - TypeError: 无效的类型注解 + # - RecursionError: 循环依赖的类型定义 + # 这是正常情况,回退到原始注解字符串 + evaluated = base_annotations + + return dict(evaluated), {}, module_globals, localns + + +def _evaluate_annotation_from_string( + field_name: str, + annotation_strings: dict[str, str], + current_type: Any, + globalns: Mapping[str, Any], + localns: Mapping[str, Any], +) -> Any: + """ + Attempt to re-evaluate the original annotation string for a field. + + This is used as a fallback when the resolved annotation lost its metadata + (e.g., Annotated wrappers) and we need to recover custom sa_type data. + """ + if not annotation_strings: + return current_type + + expr = annotation_strings.get(field_name) + if not expr or not isinstance(expr, str): + return current_type + + try: + return eval(expr, globalns, localns) + except (NameError, SyntaxError, AttributeError, TypeError): + # eval可能失败的原因: + # - NameError: 类型名称在namespace中不存在 + # - SyntaxError: 注解字符串有语法错误 + # - AttributeError: 访问不存在的模块属性 + # - TypeError: 无效的类型表达式 + # 这是正常的fallback机制,返回当前已解析的类型 + return current_type + + +class __DeclarativeMeta(SQLModelMetaclass): + """ + 一个智能的混合模式元类,它提供了灵活性和清晰度: + + 1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。 + 2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。 + 3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。 + 4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。 + """ + + _KNOWN_MAPPER_KEYS = { + "polymorphic_on", + "polymorphic_identity", + "polymorphic_abstract", + "version_id_col", + "concrete", + } + + def __new__(cls, name, bases, attrs, **kwargs): + # 1. 约定优于配置:自动设置 table=True + is_intended_as_table = any(getattr(b, '_is_table_mixin', False) for b in bases) + if is_intended_as_table and 'table' not in kwargs: + kwargs['table'] = True + + # 2. 智能合并 __mapper_args__ + collected_mapper_args = {} + + # 首先,处理明确的 mapper_args 字典 (优先级较低) + if 'mapper_args' in kwargs: + collected_mapper_args.update(kwargs.pop('mapper_args')) + + # 其次,处理便捷的关键字参数 (优先级更高) + for key in cls._KNOWN_MAPPER_KEYS: + if key in kwargs: + # .pop() 获取值并移除,避免传递给父类 + collected_mapper_args[key] = kwargs.pop(key) + + # 如果收集到了任何 mapper 参数,则更新到类的属性中 + if collected_mapper_args: + existing = attrs.get('__mapper_args__', {}).copy() + existing.update(collected_mapper_args) + attrs['__mapper_args__'] = existing + + # 3. 处理其他明确的参数 + if 'table_args' in kwargs: + attrs['__table_args__'] = kwargs.pop('table_args') + if 'table_name' in kwargs: + attrs['__tablename__'] = kwargs.pop('table_name') + if 'abstract' in kwargs: + attrs['__abstract__'] = kwargs.pop('abstract') + + # 4. 从Annotated元数据中提取sa_type并注入到Field + # 重要:必须在调用父类__new__之前处理,因为SQLModel会消费annotations + # + # Python 3.14兼容性问题: + # - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃(issubclass错误) + # - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段 + # - 虽然PEP 749建议不修改__annotations__,但这是修复SQLModel bug的必要措施 + # + # 获取annotations的策略: + # - Python 3.14+: 优先从__annotate__获取(如果存在) + # - fallback: 从__annotations__读取(如果存在) + # - 最终fallback: 空字典 + annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs) + + if annotations: + attrs['__annotations__'] = annotations + if annotationlib is not None: + # 在Python 3.14中禁用descriptor,转为普通dict + attrs['__annotate__'] = None + + for field_name, field_type in annotations.items(): + field_type = _evaluate_annotation_from_string( + field_name, + annotation_strings, + field_type, + eval_globals, + eval_locals, + ) + + # 跳过字符串或ForwardRef类型注解,让SQLModel自己处理 + if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef): + continue + + # 跳过特殊类型的字段 + origin = get_origin(field_type) + + # 跳过 ClassVar 字段 - 它们不是数据库字段 + if origin is typing.ClassVar: + continue + + # 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column + if origin is Mapped: + continue + + # 尝试从注解中提取sa_type + sa_type = _extract_sa_type_from_annotation(field_type) + + if sa_type is not None: + # 检查字段是否已有Field定义 + field_value = attrs.get(field_name, Undefined) + + if field_value is Undefined: + # 没有Field定义,创建一个新的Field并注入sa_type + attrs[field_name] = Field(sa_type=sa_type) + elif isinstance(field_value, FieldInfo): + # 已有Field定义,检查是否已设置sa_type + # 注意:只有在未设置时才注入,尊重显式配置 + # SQLModel使用Undefined作为"未设置"的标记 + if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined: + field_value.sa_type = sa_type + # 如果field_value是其他类型(如默认值),不处理 + # SQLModel会在后续处理中将其转换为Field + + # 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs + result = super().__new__(cls, name, bases, attrs, **kwargs) + + # 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__ + # SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__ + # 这导致子类丢失父类的关系定义,触发错误的 Column 创建 + # 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值 + if kwargs.get('table', False): + for base in bases: + if hasattr(base, '__sqlmodel_relationships__'): + for rel_name, rel_info in base.__sqlmodel_relationships__.items(): + # 只继承子类没有重新定义的关系 + if rel_name not in result.__sqlmodel_relationships__: + result.__sqlmodel_relationships__[rel_name] = rel_info + # 同时修复被错误创建的 Column - 恢复为父类的 relationship + if hasattr(base, rel_name): + base_attr = getattr(base, rel_name) + setattr(result, rel_name, base_attr) + + # 7. 检测:禁止子类重定义父类的 Relationship 字段 + # 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱, + # 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。 + for base in bases: + parent_relationships = getattr(base, '__sqlmodel_relationships__', {}) + for rel_name in parent_relationships: + # 检查当前类是否在 attrs 中重新定义了这个关系字段 + if rel_name in attrs: + raise TypeError( + f"类 {name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'。" + f"如需修改关系配置,请在父类中修改。" + ) + + # 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段 + # SQLModel 0.0.27 bug:子类会错误地继承父类的 Relationship 字段到 model_fields + # 这导致 Pydantic 尝试为 Relationship 字段生成 schema,因为类型是 + # Mapped[list['Character']] 这种前向引用,Pydantic 无法解析, + # 导致 __pydantic_complete__ = False + # + # 修复策略: + # - 检查类的 __sqlmodel_relationships__ 属性 + # - 从 model_fields 和 __pydantic_fields__ 中移除这些字段 + # - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与 + relationships = getattr(result, '__sqlmodel_relationships__', {}) + if relationships: + model_fields = getattr(result, 'model_fields', {}) + pydantic_fields = getattr(result, '__pydantic_fields__', {}) + + fields_removed = False + for rel_name in relationships: + if rel_name in model_fields: + del model_fields[rel_name] + fields_removed = True + if rel_name in pydantic_fields: + del pydantic_fields[rel_name] + fields_removed = True + + # 如果移除了字段,重新构建 Pydantic 模式 + # 注意:只在有字段被移除时才 rebuild,避免不必要的开销 + if fields_removed and hasattr(result, 'model_rebuild'): + result.model_rebuild(force=True) + + return result + + def __init__( + cls, + classname: str, + bases: tuple[type, ...], + dict_: dict[str, typing.Any], + **kw: typing.Any, + ) -> None: + """ + 重写 SQLModel 的 __init__ 以支持联表继承(Joined Table Inheritance) + + SQLModel 原始行为: + - 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__ + - 这阻止了子类创建自己的表 + + 修复逻辑: + - 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表) + - 强制调用 DeclarativeMeta.__init__ 来创建子表 + """ + from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass + + # 检查是否是表模型 + if not is_table_model_class(cls): + ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + return + + # 检查是否有基类是表模型 + base_is_table = any(is_table_model_class(base) for base in bases) + + if not base_is_table: + # 没有基类是表模型,走正常的 SQLModel 流程 + # 处理关系字段 + cls._setup_relationships() + DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw) + return + + # 关键:检测联表继承场景 + # 条件: + # 1. 当前类的 __tablename__ 与父类不同(表示需要新表) + # 2. 当前类有字段带有 foreign_key 指向父表 + current_tablename = getattr(cls, '__tablename__', None) + + # 查找父表信息 + parent_table = None + parent_tablename = None + for base in bases: + if is_table_model_class(base) and hasattr(base, '__tablename__'): + parent_tablename = base.__tablename__ + break + + # 检查是否有不同的 tablename + has_different_tablename = ( + current_tablename is not None + and parent_tablename is not None + and current_tablename != parent_tablename + ) + + # 检查是否有外键字段指向父表的主键 + # 注意:由于字段合并,我们需要检查直接基类的 model_fields + # 而不是当前类的合并后的 model_fields + has_fk_to_parent = False + + def _normalize_tablename(name: str) -> str: + """标准化表名以进行比较(移除下划线,转小写)""" + return name.replace('_', '').lower() + + def _fk_matches_parent(fk_str: str, parent_table: str) -> bool: + """检查 FK 字符串是否指向父表""" + if not fk_str or not parent_table: + return False + # FK 格式: "tablename.column" 或 "schema.tablename.column" + parts = fk_str.split('.') + if len(parts) >= 2: + fk_table = parts[-2] # 取倒数第二个作为表名 + # 标准化比较(处理下划线差异) + return _normalize_tablename(fk_table) == _normalize_tablename(parent_table) + return False + + if has_different_tablename and parent_tablename: + # 首先检查当前类的 model_fields + for field_name, field_info in cls.model_fields.items(): + fk = getattr(field_info, 'foreign_key', None) + if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename): + has_fk_to_parent = True + break + + # 如果没找到,检查直接基类的 model_fields(解决 mixin 字段被覆盖的问题) + if not has_fk_to_parent: + for base in bases: + if hasattr(base, 'model_fields'): + for field_name, field_info in base.model_fields.items(): + fk = getattr(field_info, 'foreign_key', None) + if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename): + has_fk_to_parent = True + break + if has_fk_to_parent: + break + + is_joined_inheritance = has_different_tablename and has_fk_to_parent + + if is_joined_inheritance: + # 联表继承:需要创建子表 + + # 修复外键字段:由于字段合并,外键信息可能丢失 + # 需要从基类的 mixin 中找回外键信息,并重建列 + from sqlalchemy import Column, ForeignKey, inspect as sa_inspect + from sqlalchemy.dialects.postgresql import UUID as SA_UUID + from sqlalchemy.exc import NoInspectionAvailable + from sqlalchemy.orm.attributes import InstrumentedAttribute + + # 联表继承:子表只应该有 id(FK 到父表)+ 子类特有的字段 + # 所有继承自祖先表的列都不应该在子表中重复创建 + + # 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复) + # 需要遍历整个 MRO,因为可能是多级继承(如 Tool -> Function -> GetWeatherFunction) + ancestor_column_names: set[str] = set() + for ancestor in cls.__mro__: + if ancestor is cls: + continue # 跳过当前类 + if is_table_model_class(ancestor): + try: + # 使用 inspect() 获取 mapper 的公开属性 + # 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998) + mapper = sa_inspect(ancestor) + for col in mapper.local_table.columns: + # 跳过 _polymorphic_name 列(鉴别器,由根父表管理) + if col.name.startswith('_polymorphic'): + continue + ancestor_column_names.add(col.name) + except NoInspectionAvailable: + continue + + # 找到子类自己定义的字段(不在父类中的) + child_own_fields: set[str] = set() + for field_name in cls.model_fields: + # 检查这个字段是否是在当前类直接定义的(不是继承的) + # 通过检查父类是否有这个字段来判断 + is_inherited = False + for base in bases: + if hasattr(base, 'model_fields') and field_name in base.model_fields: + is_inherited = True + break + if not is_inherited: + child_own_fields.add(field_name) + + # 从子类类属性中移除父表已有的列定义 + # 这样 SQLAlchemy 就不会在子表中创建这些列 + fk_field_name = None + for base in bases: + if hasattr(base, 'model_fields'): + for field_name, field_info in base.model_fields.items(): + fk = getattr(field_info, 'foreign_key', None) + pk = getattr(field_info, 'primary_key', False) + if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename): + fk_field_name = field_name + # 找到了外键字段,重建它 + # 创建一个新的 Column 对象包含外键约束 + new_col = Column( + field_name, + SA_UUID(as_uuid=True), + ForeignKey(fk), + primary_key=pk if pk else False + ) + setattr(cls, field_name, new_col) + break + else: + continue + break + + # 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段) + # 这防止 SQLAlchemy 在子表中创建重复列 + # 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute + for col_name in ancestor_column_names: + if col_name == fk_field_name: + continue # 保留 FK/PK 列(子表的主键,同时是父表的外键) + if col_name == 'id': + continue # id 会被 FK 字段覆盖 + if col_name in child_own_fields: + continue # 保留子类自己定义的字段 + + # 检查类属性是否是 Column 或 InstrumentedAttribute + if col_name in cls.__dict__: + attr = cls.__dict__[col_name] + # Column 对象或 InstrumentedAttribute 都需要删除 + if isinstance(attr, (Column, InstrumentedAttribute)): + try: + delattr(cls, col_name) + except AttributeError: + pass + + # 找到子类自己定义的关系(不在父类中的) + # 继承的关系会从父类自动获取,只需要设置子类新增的关系 + child_own_relationships: set[str] = set() + for rel_name in cls.__sqlmodel_relationships__: + is_inherited = False + for base in bases: + if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__: + is_inherited = True + break + if not is_inherited: + child_own_relationships.add(rel_name) + + # 只为子类自己定义的新关系调用关系设置 + if child_own_relationships: + cls._setup_relationships(only_these=child_own_relationships) + + # 强制调用 DeclarativeMeta.__init__ + DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw) + else: + # 非联表继承:单表继承或正常 Pydantic 模型 + ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + + def _setup_relationships(cls, only_these: set[str] | None = None) -> None: + """ + 设置 SQLAlchemy 关系字段(从 SQLModel 源码复制) + + Args: + only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类) + 如果为 None,设置所有关系(默认行为) + """ + from sqlalchemy.orm import relationship, Mapped + from sqlalchemy import inspect + from sqlmodel.main import get_relationship_to + from typing import get_origin + + for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): + # 如果指定了 only_these,只设置这些关系 + if only_these is not None and rel_name not in only_these: + continue + if rel_info.sa_relationship: + setattr(cls, rel_name, rel_info.sa_relationship) + continue + + raw_ann = cls.__annotations__[rel_name] + origin: typing.Any = get_origin(raw_ann) + if origin is Mapped: + ann = raw_ann.__args__[0] + else: + ann = raw_ann + cls.__annotations__[rel_name] = Mapped[ann] + + relationship_to = get_relationship_to( + name=rel_name, rel_info=rel_info, annotation=ann + ) + rel_kwargs: dict[str, typing.Any] = {} + if rel_info.back_populates: + rel_kwargs["back_populates"] = rel_info.back_populates + if rel_info.cascade_delete: + rel_kwargs["cascade"] = "all, delete-orphan" + if rel_info.passive_deletes: + rel_kwargs["passive_deletes"] = rel_info.passive_deletes + if rel_info.link_model: + ins = inspect(rel_info.link_model) + local_table = getattr(ins, "local_table") + if local_table is None: + raise RuntimeError( + f"Couldn't find secondary table for {rel_info.link_model}" + ) + rel_kwargs["secondary"] = local_table + + rel_args: list[typing.Any] = [] + if rel_info.sa_relationship_args: + rel_args.extend(rel_info.sa_relationship_args) + if rel_info.sa_relationship_kwargs: + rel_kwargs.update(rel_info.sa_relationship_kwargs) + + rel_value = relationship(relationship_to, *rel_args, **rel_kwargs) + setattr(cls, rel_name, rel_value) + + +class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta): + """此类必须和TableBase系列类搭配使用""" -class SQLModelBase(SQLModel): model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True) diff --git a/models/base/table_base.py b/models/base/table_base.py deleted file mode 100644 index 4df9873..0000000 --- a/models/base/table_base.py +++ /dev/null @@ -1,203 +0,0 @@ -import uuid -from datetime import datetime -from typing import Union, List, TypeVar, Type, Literal, override, Optional - -from fastapi import HTTPException -from sqlalchemy import DateTime, BinaryExpression, ClauseElement -from sqlalchemy.orm import selectinload -from sqlmodel import Field, select, Relationship -from sqlmodel.ext.asyncio.session import AsyncSession -from sqlalchemy.sql._typing import _OnClauseArgument -from sqlalchemy.ext.asyncio import AsyncAttrs - -from .sqlmodel_base import SQLModelBase - -T = TypeVar("T", bound="TableBase") -M = TypeVar("M", bound="SQLModel") - -now = lambda: datetime.now() -now_date = lambda: datetime.now().date() - -class TableBase(SQLModelBase, AsyncAttrs): - id: int | None = Field(default=None, primary_key=True) - - created_at: datetime = Field(default_factory=now) - updated_at: datetime = Field( - sa_type=DateTime, - sa_column_kwargs={"default": now, "onupdate": now}, - default_factory=now - ) - - @classmethod - async def add(cls: Type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | List[T]: - """ - 新增一条记录 - :param session: 数据库会话 - :param instances: - :param refresh: - :return: 新增的实例对象 - - usage: - item1 = Item(...) - item2 = Item(...) - - Item.add(session, [item1, item2]) - - item1_id = item1.id - """ - is_list = False - if isinstance(instances, list): - is_list = True - session.add_all(instances) - else: - session.add(instances) - - await session.commit() - - if refresh: - if is_list: - for instance in instances: - await session.refresh(instance) - else: - await session.refresh(instances) - - return instances - - async def save(self: T, session: AsyncSession, load: Optional[Relationship] = None) -> T: - session.add(self) - await session.commit() - - if load is not None: - cls = type(self) - return await cls.get(session, cls.id == self.id, load=load) - else: - await session.refresh(self) - return self - - async def update( - self: T, - session: AsyncSession, - other: M, - extra_data: dict = None, - exclude_unset: bool = True - ) -> T: - """ - 更新记录 - :param session: 数据库会话 - :param other: - :param extra_data: - :param exclude_unset: - :return: - """ - self.sqlmodel_update(other.model_dump(exclude_unset=exclude_unset), update=extra_data) - - session.add(self) - - await session.commit() - await session.refresh(self) - - return self - - @classmethod - async def delete(cls: Type[T], session: AsyncSession, instances: T | list[T]) -> None: - """ - 删除一些记录 - :param session: 数据库会话 - :param instances: - :return: None - - usage: - item1 = Item.get(...) - item2 = Item.get(...) - - Item.delete(session, [item1, item2]) - - """ - if isinstance(instances, list): - for instance in instances: - await session.delete(instance) - else: - await session.delete(instances) - - await session.commit() - - @classmethod - async def get( - cls: Type[T], - session: AsyncSession, - condition: BinaryExpression | ClauseElement | None, - *, - offset: int | None = None, - limit: int | None = None, - fetch_mode: Literal["one", "first", "all"] = "first", - join: Type[T] | tuple[Type[T], _OnClauseArgument] | None = None, - options: list | None = None, - load: Union[Relationship, None] = None, - order_by: list[ClauseElement] | None = None - ) -> T | List[T] | None: - """ - 异步获取模型实例 - - 参数: - session: 异步数据库会话 - condition: SQLAlchemy查询条件,如Model.id == 1 - offset: 结果偏移量 - limit: 结果数量限制 - options: 查询选项,如selectinload(Model.relation),异步访问关系属性必备,不然会报错 - fetch_mode: 获取模式 - "one"/"all"/"first" - join: 要联接的模型类 - - 返回: - 根据fetch_mode返回相应的查询结果 - """ - statement = select(cls) - - if condition is not None: - statement = statement.where(condition) - - if join is not None: - statement = statement.join(*join) - - if options: - statement = statement.options(*options) - - if load: - statement = statement.options(selectinload(load)) - - if order_by is not None: - statement = statement.order_by(*order_by) - - if offset: - statement = statement.offset(offset) - - if limit: - statement = statement.limit(limit) - - result = await session.exec(statement) - - if fetch_mode == "one": - return result.one() - elif fetch_mode == "first": - return result.first() - elif fetch_mode == "all": - return list(result.all()) - else: - raise ValueError(f"无效的 fetch_mode: {fetch_mode}") - - @classmethod - async def get_exist_one(cls: Type[T], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> T: - """此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None, - 而是会抛出fastapi 404 异常""" - instance = await cls.get(session, cls.id == id, load=load) - if not instance: - raise HTTPException(status_code=404, detail="Not found") - return instance - -class UUIDTableBase(TableBase): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - """override""" - - @classmethod - @override - async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Union[Relationship, None] = None) -> T: - return await super().get_exist_one(session, id, load) # type: ignore diff --git a/models/download.py b/models/download.py index 571f12c..4f08e3c 100644 --- a/models/download.py +++ b/models/download.py @@ -1,46 +1,176 @@ -from typing import Optional, TYPE_CHECKING +from enum import StrEnum +from typing import TYPE_CHECKING from uuid import UUID -from sqlmodel import Field, Relationship, UniqueConstraint +from sqlmodel import Field, Relationship, UniqueConstraint, Index + +from .base import SQLModelBase +from .mixin import UUIDTableBaseMixin, TableBaseMixin + + +class DownloadStatus(StrEnum): + """下载状态枚举""" + RUNNING = "running" + """进行中""" + COMPLETED = "completed" + """已完成""" + ERROR = "error" + """错误""" + + +class DownloadType(StrEnum): + """下载类型枚举""" + # [TODO] 补充具体下载类型 + pass -from .base import SQLModelBase, UUIDTableBase if TYPE_CHECKING: from .user import User from .task import Task from .node import Node + +# ==================== Aria2 信息模型 ==================== + +class DownloadAria2InfoBase(SQLModelBase): + """Aria2下载信息基础模型""" + + info_hash: str | None = Field(default=None, max_length=40) + """InfoHash(BT种子)""" + + piece_length: int = 0 + """分片大小""" + + num_pieces: int = 0 + """分片数量""" + + num_seeders: int = 0 + """做种人数""" + + connections: int = 0 + """连接数""" + + upload_speed: int = 0 + """上传速度(bytes/s)""" + + upload_length: int = 0 + """已上传大小(字节)""" + + error_code: str | None = None + """错误代码""" + + error_message: str | None = None + """错误信息""" + + +class DownloadAria2Info(DownloadAria2InfoBase, SQLModelBase, table=True): + """Aria2下载信息模型(与Download一对一关联)""" + + download_id: UUID = Field(foreign_key="download.id", primary_key=True) + """关联的下载任务UUID""" + + # 反向关系 + download: "Download" = Relationship(back_populates="aria2_info") + """关联的下载任务""" + + +class DownloadAria2File(SQLModelBase, TableBaseMixin): + """Aria2下载文件列表(与Download一对多关联)""" + + download_id: UUID = Field(foreign_key="download.id", index=True) + """关联的下载任务UUID""" + + file_index: int = Field(ge=1) + """文件索引(从1开始)""" + + path: str + """文件路径""" + + length: int = 0 + """文件大小(字节)""" + + completed_length: int = 0 + """已完成大小(字节)""" + + is_selected: bool = True + """是否选中下载""" + + # 反向关系 + download: "Download" = Relationship(back_populates="aria2_files") + """关联的下载任务""" + + +# ==================== 主模型 ==================== + class DownloadBase(SQLModelBase): pass -class Download(DownloadBase, UUIDTableBase, table=True): +class Download(DownloadBase, UUIDTableBaseMixin): """离线下载任务模型""" __table_args__ = ( UniqueConstraint("node_id", "g_id", name="uq_download_node_gid"), + Index("ix_download_status", "status"), + Index("ix_download_user_status", "user_id", "status"), ) - status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载状态: 0=进行中, 1=完成, 2=错误") - type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务类型") - source: str = Field(description="来源URL或标识") - total_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="总大小(字节)") - downloaded_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已下载大小(字节)") - g_id: str | None = Field(default=None, index=True, description="Aria2 GID") - speed: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载速度 (bytes/s)") - parent: str | None = Field(default=None, description="父任务标识") - attrs: str | None = Field(default=None, description="额外属性 (JSON格式)") - # attrs 示例: {"gid":"65c5faf38374cc63","status":"removed","totalLength":"0","completedLength":"0","uploadLength":"0","bitfield":"","downloadSpeed":"0","uploadSpeed":"0","infoHash":"ca159db2b1e78f6e95fd972be72251f967f639d4","numSeeders":"0","seeder":"","pieceLength":"16384","numPieces":"0","connections":"0","errorCode":"31","errorMessage":"","followedBy":null,"belongsTo":"","dir":"/data/ccaaDown/aria2/7a208304-9126-46d2-ba47-a6959f236a07","files":[{"index":"1","path":"[METADATA]zh-cn_windows_11_consumer_editions_version_21h2_updated_aug_2022_x64_dvd_a29983d5.iso","length":"0","completedLength":"0","selected":"true","uris":[]}],"bittorrent":{"announceList":[["udp://tracker.opentrackr.org:1337/announce"],["udp://9.rarbg.com:2810/announce"],["udp://tracker.openbittorrent.com:6969/announce"],["https://opentracker.i2p.rocks:443/announce"],["http://tracker.openbittorrent.com:80/announce"],["udp://open.stealth.si:80/announce"],["udp://tracker.torrent.eu.org:451/announce"],["udp://exodus.desync.com:6969/announce"],["udp://tracker.tiny-vps.com:6969/announce"],["udp://tracker.pomf.se:80/announce"],["udp://tracker.moeking.me:6969/announce"],["udp://tracker.dler.org:6969/announce"],["udp://open.demonii.com:1337/announce"],["udp://explodie.org:6969/announce"],["udp://chouchou.top:8080/announce"],["udp://bt.oiyo.tk:6969/announce"],["https://tracker.nanoha.org:443/announce"],["https://tracker.lilithraws.org:443/announce"],["http://tracker3.ctix.cn:8080/announce"],["http://tracker.nucozer-tracker.ml:2710/announce"]],"comment":"","creationDate":0,"mode":"","info":{"name":""}}} - error: str | None = Field(default=None, description="错误信息") - dst: str = Field(description="目标存储路径") - + status: DownloadStatus = Field(default=DownloadStatus.RUNNING, sa_column_kwargs={"server_default": "'running'"}) + """下载状态""" + + type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """任务类型 [TODO] 待定义枚举""" + + source: str + """来源URL或标识""" + + total_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """总大小(字节)""" + + downloaded_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """已下载大小(字节)""" + + g_id: str | None = Field(default=None, index=True) + """Aria2 GID""" + + speed: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """下载速度(bytes/s)""" + + parent: str | None = Field(default=None, max_length=255) + """父任务标识""" + + error: str | None = Field(default=None) + """错误信息""" + + dst: str + """目标存储路径""" + # 外键 - user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID") - task_id: int | None = Field(default=None, foreign_key="task.id", index=True, description="关联的任务ID") - node_id: int = Field(foreign_key="node.id", index=True, description="执行下载的节点ID") - + user_id: UUID = Field(foreign_key="user.id", index=True) + """所属用户UUID""" + + task_id: int | None = Field(default=None, foreign_key="task.id", index=True) + """关联的任务ID""" + + node_id: int = Field(foreign_key="node.id", index=True) + """执行下载的节点ID""" + # 关系 + aria2_info: DownloadAria2Info | None = Relationship( + back_populates="download", + sa_relationship_kwargs={"uselist": False}, + ) + """Aria2下载信息""" + + aria2_files: list[DownloadAria2File] = Relationship(back_populates="download") + """Aria2文件列表""" + user: "User" = Relationship(back_populates="downloads") - task: Optional["Task"] = Relationship(back_populates="downloads") + """所属用户""" + + task: "Task" = Relationship(back_populates="downloads") + """关联的任务""" + node: "Node" = Relationship(back_populates="downloads") + """执行下载的节点""" \ No newline at end of file diff --git a/models/group.py b/models/group.py index 28f9d4b..72e46c0 100644 --- a/models/group.py +++ b/models/group.py @@ -4,7 +4,8 @@ from uuid import UUID from sqlmodel import Field, Relationship, text -from .base import TableBase, SQLModelBase, UUIDTableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin, UUIDTableBaseMixin if TYPE_CHECKING: from .user import User @@ -75,7 +76,7 @@ class GroupResponse(GroupBase, GroupOptionsBase): from .policy import GroupPolicyLink -class GroupOptions(GroupOptionsBase, TableBase, table=True): +class GroupOptions(GroupOptionsBase, TableBaseMixin): """用户组选项模型""" group_id: UUID = Field(foreign_key="group.id", unique=True) @@ -100,7 +101,7 @@ class GroupOptions(GroupOptionsBase, TableBase, table=True): group: "Group" = Relationship(back_populates="options") -class Group(GroupBase, UUIDTableBase, table=True): +class Group(GroupBase, UUIDTableBaseMixin): """用户组模型""" name: str = Field(max_length=255, unique=True) @@ -134,14 +135,17 @@ class Group(GroupBase, UUIDTableBase, table=True): ) # 关系:一个组可以有多个用户 - user: list["User"] = Relationship( + users: list["User"] = Relationship( back_populates="group", sa_relationship_kwargs={"foreign_keys": "User.group_id"} ) - previous_user: list["User"] = Relationship( + """当前属于该组的用户列表""" + + previous_users: list["User"] = Relationship( back_populates="previous_group", sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"} ) + """之前属于该组的用户列表(用于过期后恢复)""" def to_response(self) -> "GroupResponse": """转换为响应 DTO""" diff --git a/models/mixin/README.md b/models/mixin/README.md new file mode 100644 index 0000000..de03841 --- /dev/null +++ b/models/mixin/README.md @@ -0,0 +1,543 @@ +# SQLModel Mixin Module + +This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs. + +## Module Overview + +The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide: + +- **CRUD Operations**: Async database operations (add, save, update, delete, get, count) +- **Polymorphic Inheritance**: Tools for joined table inheritance patterns +- **JWT Authentication**: Token generation and validation +- **Pagination & Sorting**: Standardized table view parameters +- **Response DTOs**: Consistent id/timestamp fields for API responses + +## Module Structure + +``` +sqlmodels/mixin/ +├── __init__.py # Module exports +├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin +├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest +├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.) +└── jwt/ # JWT authentication + ├── __init__.py + ├── key.py # JWTKey database model + ├── payload.py # JWTPayloadBase + ├── manager.py # JWTManager singleton + ├── auth.py # JWTAuthMixin + ├── exceptions.py # JWT-related exceptions + └── responses.py # TokenResponse DTO +``` + +## Dependency Hierarchy + +The module has a strict import order to avoid circular dependencies: + +1. **polymorphic.py** - Only depends on `SQLModelBase` +2. **table.py** - Depends on `polymorphic.py` +3. **jwt/** - May depend on both `polymorphic.py` and `table.py` +4. **info_response.py** - Only depends on `SQLModelBase` + +## Core Components + +### 1. TableBaseMixin + +Base mixin for database table models with integer primary keys. + +**Features:** +- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()` +- Automatic timestamp management (`created_at`, `updated_at`) +- Async relationship loading support (via `AsyncAttrs`) +- Pagination and sorting via `TableViewRequest` +- Polymorphic subclass loading support + +**Fields:** +- `id: int | None` - Integer primary key (auto-increment) +- `created_at: datetime` - Record creation timestamp +- `updated_at: datetime` - Record update timestamp (auto-updated) + +**Usage:** +```python +from sqlmodels.mixin import TableBaseMixin +from sqlmodels.base import SQLModelBase + +class User(SQLModelBase, TableBaseMixin, table=True): + name: str + email: str + """User email""" + +# CRUD operations +async def example(session: AsyncSession): + # Add + user = User(name="Alice", email="alice@example.com") + user = await user.save(session) + + # Get + user = await User.get(session, User.id == 1) + + # Update + update_data = UserUpdateRequest(name="Alice Smith") + user = await user.update(session, update_data) + + # Delete + await User.delete(session, user) + + # Count + count = await User.count(session, User.is_active == True) +``` + +**Important Notes:** +- `save()` and `update()` return refreshed instances - **always use the return value**: + ```python + # ✅ Correct + device = await device.save(session) + return device + + # ❌ Wrong - device is expired after commit + await device.save(session) + return device + ``` + +### 2. UUIDTableBaseMixin + +Extends `TableBaseMixin` with UUID primary keys instead of integers. + +**Differences from TableBaseMixin:** +- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`) +- `get_exist_one()` accepts `UUID` instead of `int` + +**Usage:** +```python +from sqlmodels.mixin import UUIDTableBaseMixin + +class Character(SQLModelBase, UUIDTableBaseMixin, table=True): + name: str + description: str | None = None + """Character description""" +``` + +**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions. + +### 3. TableViewRequest + +Standardized pagination and sorting parameters for LIST endpoints. + +**Fields:** +- `offset: int | None` - Skip first N records (default: 0) +- `limit: int | None` - Return max N records (default: 50, max: 100) +- `desc: bool | None` - Sort descending (default: True) +- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at") + +**Usage with TableBaseMixin.get():** +```python +from dependencies import TableViewRequestDep + +@router.get("/list") +async def list_characters( + session: SessionDep, + table_view: TableViewRequestDep +) -> list[Character]: + """List characters with pagination and sorting""" + return await Character.get( + session, + fetch_mode="all", + table_view=table_view # Automatically handles pagination and sorting + ) +``` + +**Manual usage:** +```python +table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at") +characters = await Character.get(session, fetch_mode="all", table_view=table_view) +``` + +**Backward Compatibility:** +The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code. + +### 4. PolymorphicBaseMixin + +Base mixin for joined table inheritance, automatically configuring polymorphic settings. + +**Automatic Configuration:** +- Defines `_polymorphic_name: str` field (indexed) +- Sets `polymorphic_on='_polymorphic_name'` +- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True` + +**Methods:** +- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`) +- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name +- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types + +**Usage:** +```python +from abc import ABC, abstractmethod +from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin + +class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC): + """Abstract base class for all tools""" + name: str + description: str + """Tool description""" + + @abstractmethod + async def execute(self, params: dict) -> dict: + """Execute the tool""" + pass +``` + +**Why Single Underscore Prefix?** +- SQLAlchemy maps single-underscore fields to database columns +- Pydantic treats them as private (excluded from serialization) +- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database) + +### 5. create_subclass_id_mixin() + +Factory function to create ID mixins for subclasses in joined table inheritance. + +**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field. + +**Signature:** +```python +def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]: + """ + Args: + parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function') + + Returns: + A mixin class containing id field (foreign key + primary key) + """ +``` + +**Usage:** +```python +from sqlmodels.mixin import create_subclass_id_mixin + +# Create mixin for ASR subclasses +ASRSubclassIdMixin = create_subclass_id_mixin('asr') + +class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True): + """FunASR implementation""" + pass +``` + +**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field. + +### 6. AutoPolymorphicIdentityMixin + +Automatically generates `polymorphic_identity` based on class name. + +**Naming Convention:** +- Format: `{parent_identity}.{classname_lowercase}` +- If no parent identity exists, uses `{classname_lowercase}` + +**Usage:** +```python +from sqlmodels.mixin import AutoPolymorphicIdentityMixin + +class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True): + """Base class for function-type tools""" + pass + # polymorphic_identity = 'function' + +class GetWeatherFunction(Function, table=True): + """Weather query function""" + pass + # polymorphic_identity = 'function.getweatherfunction' +``` + +**Manual Override:** +```python +class CustomTool( + Tool, + AutoPolymorphicIdentityMixin, + polymorphic_identity='custom_name', # Override auto-generated name + table=True +): + pass +``` + +### 7. JWTAuthMixin + +Provides JWT token generation and validation for entity classes (User, Client). + +**Methods:** +- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance +- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity + +**Requirements:** +Subclasses must define: +- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`) +- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value + +**Usage:** +```python +from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin + +class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True): + JWTPayload = UserJWTPayload # Define payload model + jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user + + email: str + is_admin: bool = False + is_active: bool = True + """User active status""" + +# Generate token +async def login(session: AsyncSession, user: User) -> str: + token = await user.issue_jwt(session) + return token + +# Validate token +async def verify(session: AsyncSession, token: str) -> User: + user = await User.from_jwt(session, token) + return user +``` + +### 8. Response DTO Mixins + +Mixins for standardized InfoResponse DTOs, defining id and timestamp fields. + +**Available Mixins:** +- `IntIdInfoMixin` - Integer ID field +- `UUIDIdInfoMixin` - UUID ID field +- `DatetimeInfoMixin` - `created_at` and `updated_at` fields +- `IntIdDatetimeInfoMixin` - Integer ID + timestamps +- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps + +**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned. + +**Usage:** +```python +from sqlmodels.mixin import UUIDIdDatetimeInfoMixin + +class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin): + """Character response DTO with id and timestamps""" + pass # Inherits id, created_at, updated_at from mixin +``` + +## Complete Joined Table Inheritance Example + +Here's a complete example demonstrating polymorphic inheritance: + +```python +from abc import ABC, abstractmethod +from sqlmodels.base import SQLModelBase +from sqlmodels.mixin import ( + UUIDTableBaseMixin, + PolymorphicBaseMixin, + create_subclass_id_mixin, + AutoPolymorphicIdentityMixin, +) + +# 1. Define Base class (fields only, no table) +class ASRBase(SQLModelBase): + name: str + """Configuration name""" + + base_url: str + """Service URL""" + +# 2. Define abstract parent class (with table) +class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC): + """Abstract base class for ASR configurations""" + # PolymorphicBaseMixin automatically provides: + # - _polymorphic_name field + # - polymorphic_on='_polymorphic_name' + # - polymorphic_abstract=True (when ABC with abstract methods) + + @abstractmethod + async def transcribe(self, pcm_data: bytes) -> str: + """Transcribe audio to text""" + pass + +# 3. Create ID Mixin for second-level subclasses +ASRSubclassIdMixin = create_subclass_id_mixin('asr') + +# 4. Create second-level abstract class (if needed) +class FunASR( + ASRSubclassIdMixin, + ASR, + AutoPolymorphicIdentityMixin, + polymorphic_abstract=True +): + """FunASR abstract base (may have multiple implementations)""" + pass + # polymorphic_identity = 'funasr' + +# 5. Create concrete implementation classes +class FunASRLocal(FunASR, table=True): + """FunASR local deployment""" + # polymorphic_identity = 'funasr.funasrlocal' + + async def transcribe(self, pcm_data: bytes) -> str: + # Implementation... + return "transcribed text" + +# 6. Get all concrete subclasses (for selectin_polymorphic) +concrete_asrs = ASR.get_concrete_subclasses() +# Returns: [FunASRLocal, ...] +``` + +## Import Guidelines + +**Standard Import:** +```python +from sqlmodels.mixin import ( + TableBaseMixin, + UUIDTableBaseMixin, + PolymorphicBaseMixin, + TableViewRequest, + create_subclass_id_mixin, + AutoPolymorphicIdentityMixin, + JWTAuthMixin, + UUIDIdDatetimeInfoMixin, + now, + now_date, +) +``` + +**Backward Compatibility:** +Some exports are also available from `sqlmodels.base` for backward compatibility: +```python +# Legacy import path (still works) +from sqlmodels.base import UUIDTableBase, TableViewRequest + +# Recommended new import path +from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest +``` + +## Best Practices + +### 1. Mixin Order Matters + +**Correct Order:** +```python +# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin +class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True): + pass +``` + +**Wrong Order:** +```python +# ❌ ID Mixin not first - won't override parent's id field +class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True): + pass +``` + +### 2. Always Use Return Values from save() and update() + +```python +# ✅ Correct - use returned instance +device = await device.save(session) +return device + +# ❌ Wrong - device is expired after commit +await device.save(session) +return device # AttributeError when accessing fields +``` + +### 3. Prefer table_view Over Manual Pagination + +```python +# ✅ Recommended - consistent across all endpoints +characters = await Character.get( + session, + fetch_mode="all", + table_view=table_view +) + +# ⚠️ Works but not recommended - manual parameter management +characters = await Character.get( + session, + fetch_mode="all", + offset=0, + limit=20, + order_by=[desc(Character.created_at)] +) +``` + +### 4. Polymorphic Loading for Many Subclasses + +```python +# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all' +tool_set = await ToolSet.get( + session, + ToolSet.id == tool_set_id, + load=ToolSet.tools, + load_polymorphic='all' # Two-phase query - only loads actual related subclasses +) + +# For fewer subclasses, specify the list explicitly +tool_set = await ToolSet.get( + session, + ToolSet.id == tool_set_id, + load=ToolSet.tools, + load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction] +) +``` + +### 5. Response DTOs Should Inherit Base Classes + +```python +# ✅ Correct - inherits from CharacterBase +class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin): + pass + +# ❌ Wrong - doesn't inherit from CharacterBase +class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin): + name: str # Duplicated field definition + description: str | None = None +``` + +**Reason:** Inheriting from Base classes ensures: +- Type checking via `isinstance(obj, XxxBase)` +- Consistency across related DTOs +- Future field additions automatically propagate + +### 6. Use Specific Types, Not Containers + +```python +# ✅ Correct - specific DTO for config updates +class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase): + weather_api_key: str | None = None + default_location: str | None = None + """Default location""" + +# ❌ Wrong - lose type safety +class ToolUpdateRequest(SQLModelBase): + config: dict[str, Any] # No field validation +``` + +## Type Variables + +```python +from sqlmodels.mixin import T, M + +T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods +M = TypeVar("M", bound="SQLModel") # For update() method +``` + +## Utility Functions + +```python +from sqlmodels.mixin import now, now_date + +# Lambda functions for default factories +now = lambda: datetime.now() +now_date = lambda: datetime.now().date() +``` + +## Related Modules + +- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports) +- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`) +- **sqlmodels.user** - User model with JWT authentication +- **sqlmodels.client** - Client model with JWT authentication +- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy + +## Additional Resources + +- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field +- `CLAUDE.md` - Project coding standards and design philosophy +- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance) diff --git a/models/mixin/__init__.py b/models/mixin/__init__.py new file mode 100644 index 0000000..1ad01e7 --- /dev/null +++ b/models/mixin/__init__.py @@ -0,0 +1,46 @@ +""" +SQLModel Mixin模块 + +提供各种Mixin类供SQLModel实体使用。 + +包含: +- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin) +- table: 表基类(TableBaseMixin, UUIDTableBaseMixin) +- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest) +- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入 +- info_response: InfoResponse DTO的id/时间戳Mixin + +导入顺序很重要,避免循环导入: +1. polymorphic(只依赖 SQLModelBase) +2. table(依赖 polymorphic) + +注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig, +而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。 +""" +# polymorphic 必须先导入 +from .polymorphic import ( + create_subclass_id_mixin, + AutoPolymorphicIdentityMixin, + PolymorphicBaseMixin, +) +# table 依赖 polymorphic +from .table import ( + TableBaseMixin, + UUIDTableBaseMixin, + TimeFilterRequest, + PaginationRequest, + TableViewRequest, + ListResponse, + T, + now, + now_date, +) +# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt) +# 需要时直接从 sqlmodels.mixin.jwt 导入 +from .info_response import ( + IntIdInfoMixin, + UUIDIdInfoMixin, + DatetimeInfoMixin, + IntIdDatetimeInfoMixin, + UUIDIdDatetimeInfoMixin, +) diff --git a/models/mixin/info_response.py b/models/mixin/info_response.py new file mode 100644 index 0000000..647b9a3 --- /dev/null +++ b/models/mixin/info_response.py @@ -0,0 +1,46 @@ +""" +InfoResponse DTO Mixin模块 + +提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。 + +设计说明: +- 这些Mixin用于**响应DTO**,不是数据库表 +- 从数据库返回时这些字段永远不为空,所以定义为必填字段 +- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成) +- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值" +""" +from datetime import datetime +from uuid import UUID + +from models.base import SQLModelBase + + +class IntIdInfoMixin(SQLModelBase): + """整数ID响应mixin - 用于InfoResponse DTO""" + id: int + """记录ID""" + + +class UUIDIdInfoMixin(SQLModelBase): + """UUID ID响应mixin - 用于InfoResponse DTO""" + id: UUID + """记录ID""" + + +class DatetimeInfoMixin(SQLModelBase): + """时间戳响应mixin - 用于InfoResponse DTO""" + created_at: datetime + """创建时间""" + + updated_at: datetime + """更新时间""" + + +class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin): + """整数ID + 时间戳响应mixin""" + pass + + +class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin): + """UUID ID + 时间戳响应mixin""" + pass diff --git a/models/mixin/polymorphic.py b/models/mixin/polymorphic.py new file mode 100644 index 0000000..2f7f9c6 --- /dev/null +++ b/models/mixin/polymorphic.py @@ -0,0 +1,456 @@ +""" +联表继承(Joined Table Inheritance)的通用工具 + +提供用于简化SQLModel多态表设计的辅助函数和Mixin。 + +Usage Example: + + from sqlmodels.base import SQLModelBase + from sqlmodels.mixin import UUIDTableBaseMixin + from sqlmodels.mixin.polymorphic import ( + PolymorphicBaseMixin, + create_subclass_id_mixin, + AutoPolymorphicIdentityMixin + ) + + # 1. 定义Base类(只有字段,无表) + class ASRBase(SQLModelBase): + name: str + \"\"\"配置名称\"\"\" + + base_url: str + \"\"\"服务地址\"\"\" + + # 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin + class ASR( + ASRBase, + UUIDTableBaseMixin, + PolymorphicBaseMixin, + ABC + ): + \"\"\"ASR配置的抽象基类\"\"\" + # PolymorphicBaseMixin 自动提供: + # - _polymorphic_name 字段 + # - polymorphic_on='_polymorphic_name' + # - polymorphic_abstract=True(当有抽象方法时) + + # 3. 为第二层子类创建ID Mixin + ASRSubclassIdMixin = create_subclass_id_mixin('asr') + + # 4. 创建第二层抽象类(如果需要) + class FunASR( + ASRSubclassIdMixin, + ASR, + AutoPolymorphicIdentityMixin, + polymorphic_abstract=True + ): + \"\"\"FunASR的抽象基类,可能有多个实现\"\"\" + pass + + # 5. 创建具体实现类 + class FunASRLocal(FunASR, table=True): + \"\"\"FunASR本地部署版本\"\"\" + # polymorphic_identity 会自动设置为 'asr.funasrlocal' + pass + + # 6. 获取所有具体子类(用于 selectin_polymorphic) + concrete_asrs = ASR.get_concrete_subclasses() + # 返回 [FunASRLocal, ...] +""" +import uuid +from abc import ABC +from uuid import UUID + +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined +from sqlalchemy import String, inspect +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlmodel import Field + +from models.base.sqlmodel_base import SQLModelBase + + +def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']: + """ + 动态创建SubclassIdMixin类 + + 在联表继承中,子类需要一个外键指向父表的主键。 + 此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。 + + Args: + parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function') + + Returns: + 一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4) + + Example: + >>> ASRSubclassIdMixin = create_subclass_id_mixin('asr') + >>> class FunASR(ASRSubclassIdMixin, ASR, table=True): + ... pass + + Note: + - 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id + - 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase) + - 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin) + """ + if not parent_table_name: + raise ValueError("parent_table_name 不能为空") + + # 转换为PascalCase作为类名 + class_name_parts = parent_table_name.split('_') + class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin' + + # 使用闭包捕获parent_table_name + _parent_table_name = parent_table_name + + # 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields + class SubclassIdMixin(SQLModelBase): + # 定义id字段 + id: UUID = Field( + default_factory=uuid.uuid4, + foreign_key=f'{_parent_table_name}.id', + primary_key=True, + ) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + """ + Pydantic v2 的子类初始化钩子,在模型完全构建后调用 + + 修复联表继承中子类字段的default_factory丢失问题。 + SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段, + 导致 INSERT 语句中出现 `table.column` 引用而非实际值。 + + 通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory, + 遵循单一真相原则(不硬编码 default_factory)。 + + 需要修复的字段: + - id: 主键(从父类获取 default_factory) + - created_at: 创建时间戳(从父类获取 default_factory) + - updated_at: 更新时间戳(从父类获取 default_factory) + """ + super().__pydantic_init_subclass__(**kwargs) + + if not hasattr(cls, 'model_fields'): + return + + def find_original_field_info(field_name: str) -> FieldInfo | None: + """从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)""" + for base in cls.__mro__[1:]: # 跳过自己 + if hasattr(base, 'model_fields') and field_name in base.model_fields: + field_info = base.model_fields[field_name] + # 跳过被 InstrumentedAttribute 污染的 + if not isinstance(field_info.default, InstrumentedAttribute): + return field_info + return None + + # 动态检测所有需要修复的字段 + # 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断: + # 1. default 是 InstrumentedAttribute(被 SQLAlchemy 污染) + # 2. 原始定义有 default_factory 或明确的 default 值 + # + # 覆盖场景: + # - UUID主键(UUIDTableBaseMixin):id 有 default_factory=uuid.uuid4,需要修复 + # - int主键(TableBaseMixin):id 用 default=None,不需要修复(数据库自增) + # - created_at/updated_at:有 default_factory=now,需要修复 + # - 外键字段(created_by_id等):有 default=None,需要修复 + # - 普通字段(name, temperature等):无 default_factory,不需要修复 + # + # MRO 查找保证: + # - 在多重继承场景下,MRO 顺序是确定性的 + # - find_original_field_info 会找到第一个未被污染且有该字段的父类 + for field_name, current_field in cls.model_fields.items(): + # 检查是否被污染(default 是 InstrumentedAttribute) + if not isinstance(current_field.default, InstrumentedAttribute): + continue # 未被污染,跳过 + + # 从父类查找原始定义 + original = find_original_field_info(field_name) + if original is None: + continue # 找不到原始定义,跳过 + + # 根据原始定义的 default/default_factory 来修复 + if original.default_factory: + # 有 default_factory(如 uuid.uuid4, now) + new_field = FieldInfo( + default_factory=original.default_factory, + annotation=current_field.annotation, + json_schema_extra=current_field.json_schema_extra, + ) + elif original.default is not PydanticUndefined: + # 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined + # PydanticUndefined 表示字段没有默认值(必填) + new_field = FieldInfo( + default=original.default, + annotation=current_field.annotation, + json_schema_extra=current_field.json_schema_extra, + ) + else: + continue # 既没有 default_factory 也没有有效的 default,跳过 + + # 复制SQLModel特有的属性 + if hasattr(current_field, 'foreign_key'): + new_field.foreign_key = current_field.foreign_key + if hasattr(current_field, 'primary_key'): + new_field.primary_key = current_field.primary_key + + cls.model_fields[field_name] = new_field + + # 设置类名和文档 + SubclassIdMixin.__name__ = class_name + SubclassIdMixin.__qualname__ = class_name + SubclassIdMixin.__doc__ = f""" + {parent_table_name}子类的ID Mixin + + 用于{parent_table_name}的子类,提供外键指向父表。 + 通过MRO确保此id字段覆盖继承的id字段。 + """ + + return SubclassIdMixin + + +class AutoPolymorphicIdentityMixin: + """ + 自动生成polymorphic_identity的Mixin + + 使用此Mixin的类会自动根据类名生成polymorphic_identity。 + 格式:{parent_polymorphic_identity}.{classname_lowercase} + + 如果没有父类的polymorphic_identity,则直接使用类名小写。 + + Example: + >>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True): + ... __polymorphic_name: str + ... + >>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True): + ... pass + ... # polymorphic_identity 会自动设置为 'function' + ... + >>> class CodeInterpreterFunction(Function, table=True): + ... pass + ... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction' + + Note: + - 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留 + - 此Mixin应该在继承列表中靠后的位置(在表基类之前) + """ + + def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs): + """ + 子类化钩子,自动生成polymorphic_identity + + Args: + polymorphic_identity: 如果手动指定,则使用指定的值 + **kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True) + """ + super().__init_subclass__(**kwargs) + + # 如果手动指定了polymorphic_identity,使用指定的值 + if polymorphic_identity is not None: + identity = polymorphic_identity + else: + # 自动生成polymorphic_identity + class_name = cls.__name__.lower() + + # 尝试从父类获取polymorphic_identity作为前缀 + parent_identity = None + for base in cls.__mro__[1:]: # 跳过自己 + if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict): + parent_identity = base.__mapper_args__.get('polymorphic_identity') + if parent_identity: + break + + # 构建identity + if parent_identity: + identity = f'{parent_identity}.{class_name}' + else: + identity = class_name + + # 设置到__mapper_args__ + if '__mapper_args__' not in cls.__dict__: + cls.__mapper_args__ = {} + + # 只在尚未设置polymorphic_identity时设置 + if 'polymorphic_identity' not in cls.__mapper_args__: + cls.__mapper_args__['polymorphic_identity'] = identity + + +class PolymorphicBaseMixin: + """ + 为联表继承链中的基类自动配置 polymorphic 设置的 Mixin + + 此 Mixin 自动设置以下内容: + - `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器 + - `_polymorphic_name: str`: 定义多态鉴别器字段(带索引) + - `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类 + + 使用场景: + 适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。 + + 用法示例: + ```python + from abc import ABC + from sqlmodels.mixin import UUIDTableBaseMixin + from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin + + # 定义基类 + class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC): + __tablename__ = 'mytool' + + # 不需要手动定义 _polymorphic_name + # 不需要手动设置 polymorphic_on + # 不需要手动设置 polymorphic_abstract + + # 定义子类 + class SpecificTool(MyTool): + __tablename__ = 'specifictool' + + # 会自动继承 polymorphic 配置 + ``` + + 自动行为: + 1. 定义 `_polymorphic_name: str` 字段(带索引) + 2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'` + 3. 自动检测抽象类: + - 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True + - 否则设置为 False + + 手动覆盖: + 可以在类定义时手动指定参数来覆盖自动行为: + ```python + class MyTool( + UUIDTableBaseMixin, + PolymorphicBaseMixin, + ABC, + polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name + polymorphic_abstract=False # 强制不设为抽象类 + ): + pass + ``` + + 注意事项: + - 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用 + - 适用于联表继承(joined table inheritance)场景 + - 子类会自动继承 _polymorphic_name 字段定义 + - 使用单下划线前缀是因为: + * SQLAlchemy 会映射单下划线字段为数据库列 + * Pydantic 将其视为私有属性,不参与序列化 + * 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列 + """ + + # 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段 + # + # 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column + # + # 为什么这样做: + # 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改 + # 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀) + # 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用 + # 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema) + # 5. 内部代码仍然可以正常访问和修改此字段 + # + # 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md + _polymorphic_name: Mapped[str] = mapped_column(String, index=True) + """ + 多态鉴别器字段,用于标识具体的子类类型 + + 注意:此字段使用单下划线前缀,表示内部使用。 + - ✅ 存储到数据库 + - ✅ 不出现在 API 序列化中 + - ✅ 防止外部直接修改 + """ + + def __init_subclass__( + cls, + polymorphic_on: str | None = None, + polymorphic_abstract: bool | None = None, + **kwargs + ): + """ + 在子类定义时自动配置 polymorphic 设置 + + Args: + polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。 + 设置为其他值可以使用不同的字段作为多态鉴别器。 + polymorphic_abstract: 是否为抽象类。 + - None: 自动检测(默认) + - True: 强制设为抽象类 + - False: 强制设为非抽象类 + **kwargs: 传递给父类的其他参数 + """ + super().__init_subclass__(**kwargs) + + # 初始化 __mapper_args__(如果还没有) + if '__mapper_args__' not in cls.__dict__: + cls.__mapper_args__ = {} + + # 设置 polymorphic_on(默认为 _polymorphic_name) + if 'polymorphic_on' not in cls.__mapper_args__: + cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name' + + # 自动检测或设置 polymorphic_abstract + if 'polymorphic_abstract' not in cls.__mapper_args__: + if polymorphic_abstract is None: + # 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类 + has_abc = ABC in cls.__mro__ + has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set())) + polymorphic_abstract = has_abc and has_abstract_methods + + cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract + + @classmethod + def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]: + """ + 递归获取当前类的所有具体(非抽象)子类 + + 用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。 + 可在任意多态基类上调用,返回该类的所有非抽象子类。 + + :return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类) + """ + result: list[type[PolymorphicBaseMixin]] = [] + for subclass in cls.__subclasses__(): + # 使用 inspect() 获取 mapper 的公开属性 + # 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811) + mapper = inspect(subclass) + if not mapper.polymorphic_abstract: + result.append(subclass) + # 无论是否抽象,都需要递归(抽象类可能有具体子类) + if hasattr(subclass, 'get_concrete_subclasses'): + result.extend(subclass.get_concrete_subclasses()) + return result + + @classmethod + def get_polymorphic_discriminator(cls) -> str: + """ + 获取多态鉴别字段名 + + 使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。 + + :return: 多态鉴别字段名(如 '_polymorphic_name') + :raises ValueError: 如果类未配置 polymorphic_on + """ + polymorphic_on = inspect(cls).polymorphic_on + if polymorphic_on is None: + raise ValueError( + f"{cls.__name__} 未配置 polymorphic_on," + f"请确保正确继承 PolymorphicBaseMixin" + ) + return polymorphic_on.key + + @classmethod + def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]: + """ + 获取 polymorphic_identity 到具体子类的映射 + + 包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。 + + :return: identity 到子类的映射字典 + """ + result: dict[str, type[PolymorphicBaseMixin]] = {} + for subclass in cls.get_concrete_subclasses(): + identity = inspect(subclass).polymorphic_identity + if identity: + result[identity] = subclass + return result diff --git a/models/mixin/table.py b/models/mixin/table.py new file mode 100644 index 0000000..c133ff5 --- /dev/null +++ b/models/mixin/table.py @@ -0,0 +1,927 @@ +""" +表基类 Mixin + +提供 TableBaseMixin、UUIDTableBaseMixin 和 TableViewRequest。 +这些类实际上是 Mixin,为 SQLModel 模型提供 CRUD 操作和时间戳字段。 + +依赖关系: + base/sqlmodel_base.py ← 最底层 + ↓ + mixin/polymorphic.py ← 定义 PolymorphicBaseMixin + ↓ + mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin + ↓ + base/__init__.py ← 从 mixin 重新导出(保持向后兼容) +""" +import uuid +from datetime import datetime +from typing import TypeVar, Literal, override, Any, ClassVar, Generic + +# TODO(ListResponse泛型问题): SQLModel泛型类型JSON Schema生成bug +# 已知问题: https://github.com/fastapi/sqlmodel/discussions/1002 +# 修复PR: https://github.com/fastapi/sqlmodel/pull/1275 (尚未合并) +# 现象: SQLModel + Generic[T] 的 __pydantic_generic_metadata__ = {origin: None, args: ()} +# 导致OpenAPI schema中泛型字段显示为{}而非正确的$ref +# 当前方案: ListResponse继承BaseModel而非SQLModel (Discussion #1002推荐的workaround) +# 未来: PR #1275合并后可改回继承SQLModelBase +from pydantic import BaseModel, ConfigDict +from fastapi import HTTPException +from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct +from sqlalchemy.orm import selectinload, Relationship, with_polymorphic +from sqlmodel import Field, select +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlalchemy.sql._typing import _OnClauseArgument +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel.main import RelationshipInfo + +from .polymorphic import PolymorphicBaseMixin +from models.base.sqlmodel_base import SQLModelBase + +# Type variables for generic type hints, improving code completion and analysis. +T = TypeVar("T", bound="TableBaseMixin") +M = TypeVar("M", bound="SQLModelBase") +ItemT = TypeVar("ItemT") + + +class ListResponse(BaseModel, Generic[ItemT]): + """ + 泛型分页响应 + + 用于所有LIST端点的标准化响应格式,包含记录总数和项目列表。 + 与 TableBaseMixin.get_with_count() 配合使用。 + + 使用示例: + ```python + @router.get("", response_model=ListResponse[CharacterInfoResponse]) + async def list_characters(...) -> ListResponse[Character]: + return await Character.get_with_count(session, table_view=table_view) + ``` + + Attributes: + count: 符合条件的记录总数(用于分页计算) + items: 当前页的记录列表 + + Note: + 继承BaseModel而非SQLModelBase,因为SQLModel的metaclass与Generic冲突。 + 详见文件顶部TODO注释。 + """ + # 与SQLModelBase保持一致的配置 + model_config = ConfigDict(use_attribute_docstrings=True) + + count: int + """符合条件的记录总数""" + + items: list[ItemT] + """当前页的记录列表""" + + +# Lambda functions to get the current time, used as default factories in model fields. +now = lambda: datetime.now() +now_date = lambda: datetime.now().date() + + +# ==================== 查询参数请求类 ==================== + +class TimeFilterRequest(SQLModelBase): + """ + 时间筛选请求参数 + + 用于 count() 等只需要时间筛选的场景。 + 纯数据类,只负责参数校验和携带,SQL子句构建由 TableBaseMixin 负责。 + + Raises: + ValueError: 时间范围无效 + """ + created_after_datetime: datetime | None = None + """创建时间起始筛选(created_at >= datetime),如果为None则不限制""" + + created_before_datetime: datetime | None = None + """创建时间结束筛选(created_at < datetime),如果为None则不限制""" + + updated_after_datetime: datetime | None = None + """更新时间起始筛选(updated_at >= datetime),如果为None则不限制""" + + updated_before_datetime: datetime | None = None + """更新时间结束筛选(updated_at < datetime),如果为None则不限制""" + + def model_post_init(self, __context: Any) -> None: + """ + 验证时间范围有效性 + + 验证规则: + 1. 同类型:after 必须小于 before + 2. 跨类型:created_after 不能大于 updated_before(记录不可能在创建前被更新) + """ + # 同类型矛盾验证 + if self.created_after_datetime and self.created_before_datetime: + if self.created_after_datetime >= self.created_before_datetime: + raise ValueError("created_after_datetime 必须小于 created_before_datetime") + if self.updated_after_datetime and self.updated_before_datetime: + if self.updated_after_datetime >= self.updated_before_datetime: + raise ValueError("updated_after_datetime 必须小于 updated_before_datetime") + + # 跨类型矛盾验证:created_after >= updated_before 意味着要求创建时间晚于或等于更新时间上界,逻辑矛盾 + if self.created_after_datetime and self.updated_before_datetime: + if self.created_after_datetime >= self.updated_before_datetime: + raise ValueError( + "created_after_datetime 不能大于或等于 updated_before_datetime" + "(记录的更新时间不可能早于或等于创建时间)" + ) + + +class PaginationRequest(SQLModelBase): + """ + 分页排序请求参数 + + 用于需要分页和排序的场景。 + 纯数据类,只负责携带参数,SQL子句构建由 TableBaseMixin 负责。 + """ + offset: int | None = Field(default=0, ge=0) + """偏移量(跳过前N条记录),必须为非负整数""" + + limit: int | None = Field(default=50, le=100) + """每页数量(返回最多N条记录),默认50,最大100""" + + desc: bool | None = True + """是否降序排序(True: 降序, False: 升序)""" + + order: Literal["created_at", "updated_at"] | None = "created_at" + """排序字段(created_at: 创建时间, updated_at: 更新时间)""" + + +class TableViewRequest(TimeFilterRequest, PaginationRequest): + """ + 表格视图请求参数(分页、排序和时间筛选) + + 组合继承 TimeFilterRequest 和 PaginationRequest,用于 get() 等需要完整查询参数的场景。 + 纯数据类,SQL子句构建由 TableBaseMixin 负责。 + + 使用示例: + ```python + # 在端点中使用依赖注入 + @router.get("/list") + async def list_items( + session: SessionDep, + table_view: TableViewRequestDep + ): + items = await Item.get( + session, + fetch_mode="all", + table_view=table_view + ) + return items + + # 直接使用 + table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at") + items = await Item.get(session, fetch_mode="all", table_view=table_view) + ``` + """ + pass + + +# ==================== TableBaseMixin ==================== + +class TableBaseMixin(AsyncAttrs): + """ + 一个异步 CRUD 操作的基础模型类 Mixin. + + 此类必须搭配SQLModelBase使用 + + 此类为所有继承它的 SQLModel 模型提供了通用的数据库操作方法, + 例如 add, save, update, delete, 和 get. 它还包括自动管理 + 的 `created_at` 和 `updated_at` 时间戳字段. + + Attributes: + id (int | None): 整数主键, 自动递增. + created_at (datetime): 记录创建时的时间戳, 自动设置. + updated_at (datetime): 记录每次更新时的时间戳, 自动更新. + """ + _is_table_mixin: ClassVar[bool] = True + """标记此类为表混入类的内部属性""" + + def __init_subclass__(cls, **kwargs): + """ + 接受并传递子类定义时的关键字参数 + + 这允许元类 __DeclarativeMeta 处理的参数(如 table_args) + 能够正确传递,而不会在 __init_subclass__ 阶段报错。 + """ + super().__init_subclass__(**kwargs) + + id: int | None = Field(default=None, primary_key=True) + + created_at: datetime = Field(default_factory=now) + updated_at: datetime = Field( + sa_type=DateTime, + sa_column_kwargs={'default': now, 'onupdate': now}, + default_factory=now + ) + + @classmethod + async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]: + """ + 向数据库中添加一个新的或多个新的记录. + + 这个类方法可以接受单个模型实例或一个实例列表,并将它们 + 一次性提交到数据库中。执行后,可以选择性地刷新这些实例以获取 + 数据库生成的值(例如,自动递增的 ID). + + Args: + session (AsyncSession): 用于数据库操作的异步会话对象. + instances (T | list[T]): 要添加的单个模型实例或模型实例列表. + refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True. + + Returns: + T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例. + + Usage: + item1 = Item(name="Apple") + item2 = Item(name="Banana") + + # 添加多个实例 + added_items = await Item.add(session, [item1, item2]) + + # 添加单个实例 + item3 = Item(name="Cherry") + added_item = await Item.add(session, item3) + """ + is_list = False + if isinstance(instances, list): + is_list = True + session.add_all(instances) + else: + session.add(instances) + + await session.commit() + + if refresh: + if is_list: + for instance in instances: + await session.refresh(instance) + else: + await session.refresh(instances) + + return instances + + async def save( + self: T, + session: AsyncSession, + load: RelationshipInfo | None = None, + refresh: bool = True + ) -> T: + """ + 保存(插入或更新)当前模型实例到数据库. + + 这是一个实例方法,它将当前对象添加到会话中并提交更改。 + 可以用于创建新记录或更新现有记录。还可以选择在保存后 + 预加载(eager load)一个关联关系. + + **重要**:调用此方法后,session中的所有对象都会过期(expired)。 + 如果需要继续使用该对象,必须使用返回值: + + ```python + # ✅ 正确:需要返回值时 + client = await client.save(session) + return client + + # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 + await client.save(session, refresh=False) + + # ❌ 错误:需要返回值但未使用 + await client.save(session) + return client # client 对象已过期 + ``` + + Args: + session (AsyncSession): 用于数据库操作的异步会话对象. + load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性. + 例如 `User.posts`. + refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值, + 设为 False 可节省一次数据库查询。默认为 True. + + Returns: + T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self. + """ + session.add(self) + await session.commit() + + if not refresh: + return self + + if load is not None: + cls = type(self) + await session.refresh(self) + # 如果指定了 load, 重新获取实例并加载关联关系 + return await cls.get(session, cls.id == self.id, load=load) + else: + await session.refresh(self) + return self + + async def update( + self: T, + session: AsyncSession, + other: M, + extra_data: dict[str, Any] | None = None, + exclude_unset: bool = True, + exclude: set[str] | None = None, + load: RelationshipInfo | None = None, + refresh: bool = True + ) -> T: + """ + 使用另一个模型实例或字典中的数据来更新当前实例. + + 此方法将 `other` 对象中的数据合并到当前实例中。默认情况下, + 它只会更新 `other` 中被显式设置的字段. + + **重要**:调用此方法后,session中的所有对象都会过期(expired)。 + 如果需要继续使用该对象,必须使用返回值: + + ```python + # ✅ 正确:需要返回值时 + client = await client.update(session, update_data) + return client + + # ✅ 正确:需要返回值且需要加载关系时 + user = await user.update(session, update_data, load=User.permission) + return user + + # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 + await client.update(session, update_data, refresh=False) + + # ❌ 错误:需要返回值但未使用 + await client.update(session, update_data) + return client # client 对象已过期 + ``` + + Args: + session (AsyncSession): 用于数据库操作的异步会话对象. + other (M): 一个 SQLModel 或 Pydantic 模型实例,其数据将用于更新当前实例. + extra_data (dict, optional): 一个额外的字典,用于更新当前实例的特定字段. + exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供) + 的字段将被忽略. 默认为 True. + exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}. + load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性. + 例如 `User.permission`. + refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值, + 设为 False 可节省一次数据库查询。默认为 True. + + Returns: + T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self. + """ + self.sqlmodel_update( + other.model_dump(exclude_unset=exclude_unset, exclude=exclude), + update=extra_data + ) + + session.add(self) + await session.commit() + + if not refresh: + return self + + if load is not None: + cls = type(self) + await session.refresh(self) + return await cls.get(session, cls.id == self.id, load=load) + else: + await session.refresh(self) + return self + + @classmethod + async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None: + """ + 从数据库中删除一个或多个记录. + + Args: + session (AsyncSession): 用于数据库操作的异步会话对象. + instances (T | list[T]): 要删除的单个模型实例或模型实例列表. + + Returns: + None + + Usage: + item_to_delete = await Item.get(session, Item.id == 1) + if item_to_delete: + await Item.delete(session, item_to_delete) + + items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all") + if items_to_delete: + await Item.delete(session, items_to_delete) + """ + if isinstance(instances, list): + for instance in instances: + await session.delete(instance) + else: + await session.delete(instances) + + await session.commit() + + @classmethod + def _build_time_filters( + cls: type[T], + created_before_datetime: datetime | None = None, + created_after_datetime: datetime | None = None, + updated_before_datetime: datetime | None = None, + updated_after_datetime: datetime | None = None, + ) -> list[BinaryExpression]: + """ + 构建时间筛选条件列表 + + Args: + created_before_datetime: 筛选 created_at < datetime 的记录 + created_after_datetime: 筛选 created_at >= datetime 的记录 + updated_before_datetime: 筛选 updated_at < datetime 的记录 + updated_after_datetime: 筛选 updated_at >= datetime 的记录 + + Returns: + BinaryExpression 条件列表 + """ + filters: list[BinaryExpression] = [] + if created_after_datetime is not None: + filters.append(cls.created_at >= created_after_datetime) + if created_before_datetime is not None: + filters.append(cls.created_at < created_before_datetime) + if updated_after_datetime is not None: + filters.append(cls.updated_at >= updated_after_datetime) + if updated_before_datetime is not None: + filters.append(cls.updated_at < updated_before_datetime) + return filters + + @classmethod + async def get( + cls: type[T], + session: AsyncSession, + condition: BinaryExpression | ClauseElement | None = None, + *, + offset: int | None = None, + limit: int | None = None, + fetch_mode: Literal["one", "first", "all"] = "first", + join: type[T] | tuple[type[T], _OnClauseArgument] | None = None, + options: list | None = None, + load: RelationshipInfo | None = None, + order_by: list[ClauseElement] | None = None, + filter: BinaryExpression | ClauseElement | None = None, + with_for_update: bool = False, + table_view: TableViewRequest | None = None, + load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, + created_before_datetime: datetime | None = None, + created_after_datetime: datetime | None = None, + updated_before_datetime: datetime | None = None, + updated_after_datetime: datetime | None = None, + ) -> T | list[T] | None: + """ + 根据指定的条件异步地从数据库中获取一个或多个模型实例. + + 这是一个功能强大的通用查询方法,支持过滤、排序、分页、连接查询和关联关系预加载. + + Args: + session (AsyncSession): 用于数据库操作的异步会话对象. + condition (BinaryExpression | ClauseElement | None): 主要的查询过滤条件, + 例如 `User.id == 1`。 + 当为 `None` 时,表示无条件查询(查询所有记录)。 + offset (int | None): 查询结果的起始偏移量, 用于分页. + limit (int | None): 返回记录的最大数量, 用于分页. + fetch_mode (Literal["one", "first", "all"]): + - "one": 获取唯一的一条记录. 如果找不到或找到多条,会引发异常. + - "first": 获取查询结果的第一条记录. 如果找不到,返回 `None`. + - "all": 获取所有匹配的记录,返回一个列表. + 默认为 "first". + join (type[T] | tuple[type[T], _OnClauseArgument] | None): + 要 JOIN 的模型类或一个包含模型类和 ON 子句的元组. + 例如 `User` 或 `(Profile, User.id == Profile.user_id)`. + options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据, + 例如 `[selectinload(User.posts)]`. + load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系. + 例如 `User.profile`. + order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表. + 例如 `[User.name.asc(), User.created_at.desc()]`. + filter (BinaryExpression | ClauseElement | None): 附加的过滤条件. + + with_for_update (bool): 如果为 True, 在查询中使用 `FOR UPDATE` 锁定选定的行. 默认为 False. + + table_view (TableViewRequest | None): TableViewRequest对象,如果提供则自动处理分页、排序和时间筛选。 + 会覆盖offset、limit、order_by及时间筛选参数。 + 这是推荐的分页排序方式,统一了所有LIST端点的参数格式。 + + load_polymorphic: 多态子类加载选项,需要与 load 参数配合使用。 + - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表 + - 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益) + - None(默认): 不使用多态加载 + + created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录 + created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录 + updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录 + updated_after_datetime (datetime | None): 筛选 updated_at >= datetime 的记录 + + Returns: + T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`. + + Raises: + ValueError: 如果提供了无效的 `fetch_mode` 值,或 load_polymorphic 未与 load 配合使用. + + Examples: + # 使用table_view参数(推荐) + users = await User.get(session, fetch_mode="all", table_view=table_view_args) + + # 传统方式(向后兼容) + users = await User.get(session, fetch_mode="all", offset=0, limit=20, order_by=[desc(User.created_at)]) + + # 使用多态加载(加载联表继承的子类数据) + tool_set = await ToolSet.get( + session, + ToolSet.id == tool_set_id, + load=ToolSet.tools, + load_polymorphic='all' # 只加载实际关联的子类 + ) + """ + # 参数验证:load_polymorphic 需要与 load 配合使用 + if load_polymorphic is not None and load is None: + raise ValueError( + "load_polymorphic 参数需要与 load 参数配合使用," + "请同时指定要加载的关系" + ) + + # 如果提供table_view,作为默认值使用(单独传入的参数优先级更高) + if table_view: + # 处理时间筛选(TimeFilterRequest 及其子类) + if isinstance(table_view, TimeFilterRequest): + if created_after_datetime is None and table_view.created_after_datetime is not None: + created_after_datetime = table_view.created_after_datetime + if created_before_datetime is None and table_view.created_before_datetime is not None: + created_before_datetime = table_view.created_before_datetime + if updated_after_datetime is None and table_view.updated_after_datetime is not None: + updated_after_datetime = table_view.updated_after_datetime + if updated_before_datetime is None and table_view.updated_before_datetime is not None: + updated_before_datetime = table_view.updated_before_datetime + # 处理分页排序(PaginationRequest 及其子类,包括 TableViewRequest) + if isinstance(table_view, PaginationRequest): + if offset is None: + offset = table_view.offset + if limit is None: + limit = table_view.limit + # 仅在未显式传入order_by时,从table_view构建排序子句 + if order_by is None: + order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at + order_by = [desc(order_column) if table_view.desc else asc(order_column)] + + # 对于多态基类,使用 with_polymorphic 预加载所有子类的列 + # 这避免了在响应序列化时的延迟加载问题(MissingGreenlet 错误) + if issubclass(cls, PolymorphicBaseMixin): + # '*' 表示加载所有子类 + polymorphic_cls = with_polymorphic(cls, '*') + statement = select(polymorphic_cls) + else: + statement = select(cls) + + if condition is not None: + statement = statement.where(condition) + + # 应用时间筛选 + for time_filter in cls._build_time_filters( + created_before_datetime, created_after_datetime, + updated_before_datetime, updated_after_datetime + ): + statement = statement.where(time_filter) + + if join is not None: + # 如果 join 是一个元组,解包它;否则直接使用 + if isinstance(join, tuple): + statement = statement.join(*join) + else: + statement = statement.join(join) + + + if options: + statement = statement.options(*options) + + if load: + # 处理多态加载 + if load_polymorphic is not None: + target_class = load.property.mapper.class_ + + # 检查目标类是否继承自 PolymorphicBaseMixin + if not issubclass(target_class, PolymorphicBaseMixin): + raise ValueError( + f"目标类 {target_class.__name__} 不是多态类," + f"请确保其继承自 PolymorphicBaseMixin" + ) + + if load_polymorphic == 'all': + # 两阶段查询:获取实际关联的多态类型 + subclasses_to_load = await cls._resolve_polymorphic_subclasses( + session, condition, load, target_class + ) + else: + subclasses_to_load = load_polymorphic + + if subclasses_to_load: + # 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项 + # 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading + statement = statement.options( + selectinload(load).selectin_polymorphic(subclasses_to_load) + ) + else: + statement = statement.options(selectinload(load)) + else: + statement = statement.options(selectinload(load)) + + if order_by is not None: + statement = statement.order_by(*order_by) + + if offset: + statement = statement.offset(offset) + + if limit: + statement = statement.limit(limit) + + if filter: + statement = statement.filter(filter) + + if with_for_update: + statement = statement.with_for_update() + + result = await session.exec(statement) + + if fetch_mode == "one": + return result.one() + elif fetch_mode == "first": + return result.first() + elif fetch_mode == "all": + return list(result.all()) + else: + raise ValueError(f"无效的 fetch_mode: {fetch_mode}") + + @classmethod + async def _resolve_polymorphic_subclasses( + cls: type[T], + session: AsyncSession, + condition: BinaryExpression | ClauseElement | None, + load: RelationshipInfo, + target_class: type[PolymorphicBaseMixin] + ) -> list[type[PolymorphicBaseMixin]]: + """ + 查询实际关联的多态子类类型 + + 通过查询多态鉴别字段确定实际存在的子类类型, + 避免加载所有可能的子类表(对于 > 10 个子类的场景有明显收益)。 + + :param session: 数据库会话 + :param condition: 主查询的条件 + :param load: 关系属性 + :param target_class: 多态基类 + :return: 实际关联的子类列表 + """ + # 获取多态鉴别字段(会抛出 ValueError 如果未配置) + discriminator = target_class.get_polymorphic_discriminator() + poly_name_col = getattr(target_class, discriminator) + + # 获取关系属性 + relationship_property = load.property + + # 构建查询获取实际的多态类型名称 + if relationship_property.secondary is not None: + # 多对多关系:通过中间表查询 + secondary = relationship_property.secondary + local_cols = list(relationship_property.local_columns) + + type_query = ( + select(distinct(poly_name_col)) + .select_from(target_class) + .join(secondary) + .where(secondary.c[local_cols[0].name].in_( + select(cls.id).where(condition) if condition is not None else select(cls.id) + )) + ) + else: + # 一对多关系:通过外键查询 + foreign_key_col = relationship_property.local_remote_pairs[0][1] + type_query = ( + select(distinct(poly_name_col)) + .where(foreign_key_col.in_( + select(cls.id).where(condition) if condition is not None else select(cls.id) + )) + ) + + type_result = await session.exec(type_query) + poly_names = list(type_result.all()) + + if not poly_names: + return [] + + # 映射到子类(包含所有层级的具体子类) + identity_map = target_class.get_identity_to_class_map() + return [identity_map[name] for name in poly_names if name in identity_map] + + @classmethod + async def count( + cls: type[T], + session: AsyncSession, + condition: BinaryExpression | ClauseElement | None = None, + *, + time_filter: TimeFilterRequest | None = None, + created_before_datetime: datetime | None = None, + created_after_datetime: datetime | None = None, + updated_before_datetime: datetime | None = None, + updated_after_datetime: datetime | None = None, + ) -> int: + """ + 根据条件统计记录数量(支持时间筛选) + + 使用数据库层面的 COUNT() 聚合函数,比 get() + len() 更高效。 + + Args: + session: 数据库会话 + condition: 查询条件,例如 `User.is_active == True` + time_filter: TimeFilterRequest 对象(优先级更高) + created_before_datetime: 筛选 created_at < datetime 的记录 + created_after_datetime: 筛选 created_at >= datetime 的记录 + updated_before_datetime: 筛选 updated_at < datetime 的记录 + updated_after_datetime: 筛选 updated_at >= datetime 的记录 + + Returns: + 符合条件的记录数量 + + Examples: + # 统计所有用户 + total = await User.count(session) + + # 统计激活的虚拟客户端 + count = await Client.count( + session, + (Client.user_id == user_id) & (Client.type != ClientTypeEnum.physical) & (Client.is_active == True) + ) + + # 使用 TimeFilterRequest 进行时间筛选 + count = await User.count(session, time_filter=time_filter_request) + + # 使用独立时间参数 + count = await User.count( + session, + created_after_datetime=datetime(2025, 1, 1), + created_before_datetime=datetime(2025, 2, 1), + ) + """ + # time_filter 的时间筛选优先级更高 + if isinstance(time_filter, TimeFilterRequest): + if time_filter.created_after_datetime is not None: + created_after_datetime = time_filter.created_after_datetime + if time_filter.created_before_datetime is not None: + created_before_datetime = time_filter.created_before_datetime + if time_filter.updated_after_datetime is not None: + updated_after_datetime = time_filter.updated_after_datetime + if time_filter.updated_before_datetime is not None: + updated_before_datetime = time_filter.updated_before_datetime + + statement = select(func.count()).select_from(cls) + + # 应用查询条件 + if condition is not None: + statement = statement.where(condition) + + # 应用时间筛选 + for time_condition in cls._build_time_filters( + created_before_datetime, created_after_datetime, + updated_before_datetime, updated_after_datetime + ): + statement = statement.where(time_condition) + + result = await session.scalar(statement) + return result or 0 + + @classmethod + async def get_with_count( + cls: type[T], + session: AsyncSession, + condition: BinaryExpression | ClauseElement | None = None, + *, + join: type[T] | tuple[type[T], _OnClauseArgument] | None = None, + options: list | None = None, + load: RelationshipInfo | None = None, + order_by: list[ClauseElement] | None = None, + filter: BinaryExpression | ClauseElement | None = None, + table_view: TableViewRequest | None = None, + load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, + ) -> 'ListResponse[T]': + """ + 获取分页列表及总数,直接返回 ListResponse + + 同时返回符合条件的记录列表和总数,用于分页场景。 + 与 get() 方法类似,但固定 fetch_mode="all" 并返回 ListResponse。 + + 注意:如果子类的 get() 方法支持额外参数(如 filter_params), + 子类应该覆盖此方法以确保 count 和 items 使用相同的过滤条件。 + + Args: + session: 数据库会话 + condition: 查询条件 + join: JOIN 的模型类或元组 + options: SQLAlchemy 查询选项 + load: selectinload 预加载关系 + order_by: 排序子句 + filter: 附加过滤条件 + table_view: 分页排序参数(推荐使用) + load_polymorphic: 多态子类加载选项 + + Returns: + ListResponse[T]: 包含 count 和 items 的分页响应 + + Examples: + ```python + @router.get("", response_model=ListResponse[CharacterInfoResponse]) + async def list_characters( + session: SessionDep, + table_view: TableViewRequestDep + ) -> ListResponse[Character]: + return await Character.get_with_count(session, table_view=table_view) + ``` + """ + # 提取时间筛选参数(用于 count) + time_filter: TimeFilterRequest | None = None + if table_view is not None: + time_filter = TimeFilterRequest( + created_after_datetime=table_view.created_after_datetime, + created_before_datetime=table_view.created_before_datetime, + updated_after_datetime=table_view.updated_after_datetime, + updated_before_datetime=table_view.updated_before_datetime, + ) + + # 获取总数(不带分页限制) + total_count = await cls.count(session, condition, time_filter=time_filter) + + # 获取分页数据 + items = await cls.get( + session, + condition, + fetch_mode="all", + join=join, + options=options, + load=load, + order_by=order_by, + filter=filter, + table_view=table_view, + load_polymorphic=load_polymorphic, + ) + + return ListResponse(count=total_count, items=items) + + @classmethod + async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T: + """ + 根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常. + + 这个方法是对 `get` 方法的封装,专门用于处理那种"记录必须存在"的业务场景。 + 如果记录未找到,它会直接引发 FastAPI 的 `HTTPException`, 而不是返回 `None`. + + Args: + session (AsyncSession): 用于数据库操作的异步会话对象. + id (int): 要查找的记录的主键 ID. + load (Relationship | None): 可选的,用于预加载的关联属性. + + Returns: + T: 找到的模型实例. + + Raises: + HTTPException: 如果 ID 对应的记录不存在,则抛出状态码为 404 的异常. + """ + instance = await cls.get(session, cls.id == id, load=load) + if not instance: + raise HTTPException(status_code=404, detail="Not found") + return instance + +class UUIDTableBaseMixin(TableBaseMixin): + """ + 一个使用 UUID 作为主键的异步 CRUD 操作基础模型类 Mixin. + + 此类继承自 `TableBaseMixin`, 将主键 `id` 的类型覆盖为 `uuid.UUID`, + 并为新记录自动生成 UUID. 它继承了 `TableBaseMixin` 的所有 CRUD 方法. + + Attributes: + id (uuid.UUID): UUID 类型的主键, 在创建时自动生成. + """ + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + """覆盖 `TableBaseMixin` 的 id 字段,使用 UUID 作为主键.""" + + @override + @classmethod + async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T: + """ + 根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常. + + 此方法覆盖了父类的同名方法,以确保 `id` 参数的类型注解为 `uuid.UUID`, + 从而提供更好的类型安全和代码提示. + + Args: + session (AsyncSession): 用于数据库操作的异步会话对象. + id (uuid.UUID): 要查找的记录的 UUID 主键. + load (Relationship | None): 可选的,用于预加载的关联属性. + + Returns: + T: 找到的模型实例. + + Raises: + HTTPException: 如果 UUID 对应的记录不存在,则抛出状态码为 404 的异常. + """ + # 类型检查器可能会警告这里的 `id` 类型不匹配超类方法, + # 但在运行时这是正确的,因为超类方法内部的比较 (cls.id == id) + # 会正确处理 UUID 类型。`type: ignore` 用于抑制此警告。 + return await super().get_exist_one(session, id, load) # type: ignore diff --git a/models/node.py b/models/node.py index b644e82..74615a3 100644 --- a/models/node.py +++ b/models/node.py @@ -1,23 +1,98 @@ - +from enum import StrEnum from typing import TYPE_CHECKING -from sqlmodel import Field, Relationship, text -from .base import TableBase + +from sqlmodel import Field, Relationship, text, Index + +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .download import Download -class Node(TableBase, table=True): + +class NodeStatus(StrEnum): + """节点状态枚举""" + ONLINE = "online" + """正常""" + OFFLINE = "offline" + """离线""" + + +class NodeType(StrEnum): + """节点类型枚举""" + MASTER = "master" + """主节点""" + SLAVE = "slave" + """从节点""" + + +class Aria2ConfigurationBase(SQLModelBase): + """Aria2配置基础模型""" + + rpc_url: str | None = Field(default=None, max_length=255) + """RPC地址""" + + rpc_secret: str | None = None + """RPC密钥""" + + temp_path: str | None = Field(default=None, max_length=255) + """临时下载路径""" + + max_concurrent: int = Field(default=5, ge=1, le=50) + """最大并发数""" + + timeout: int = Field(default=300, ge=1) + """请求超时时间(秒)""" + + +class Aria2Configuration(Aria2ConfigurationBase, TableBaseMixin): + """Aria2配置模型(与Node一对一关联)""" + + node_id: int = Field(foreign_key="node.id", unique=True, index=True) + """关联的节点ID""" + + # 反向关系 + node: "Node" = Relationship(back_populates="aria2_config") + """关联的节点""" + + +class Node(SQLModelBase, TableBaseMixin): """节点模型""" - - status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点状态: 0=正常, 1=离线") - name: str = Field(max_length=255, unique=True, description="节点名称") - type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点类型") - server: str = Field(max_length=255, description="节点地址(IP或域名)") - slave_key: str | None = Field(default=None, description="从机通讯密钥") - master_key: str | None = Field(default=None, description="主机通讯密钥") - aria2_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否启用Aria2") - aria2_options: str | None = Field(default=None, description="Aria2配置 (JSON格式)") - rank: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点排序权重") + + __table_args__ = ( + Index("ix_node_status", "status"), + ) + + status: NodeStatus = Field(default=NodeStatus.ONLINE, sa_column_kwargs={"server_default": "'online'"}) + """节点状态""" + + name: str = Field(max_length=255, unique=True) + """节点名称""" + + type: NodeType + """节点类型""" + + server: str = Field(max_length=255) + """节点地址(IP或域名)""" + + slave_key: str | None = Field(default=None, max_length=255) + """从机通讯密钥""" + + master_key: str | None = Field(default=None, max_length=255) + """主机通讯密钥""" + + aria2_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) + """是否启用Aria2""" + + rank: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """节点排序权重""" # 关系 - downloads: list["Download"] = Relationship(back_populates="node") \ No newline at end of file + aria2_config: Aria2Configuration | None = Relationship( + back_populates="node", + sa_relationship_kwargs={"uselist": False}, + ) + """Aria2配置""" + + downloads: list["Download"] = Relationship(back_populates="node") + """该节点的下载任务""" \ No newline at end of file diff --git a/models/object.py b/models/object.py index 6ced3d5..08ec6b5 100644 --- a/models/object.py +++ b/models/object.py @@ -1,12 +1,13 @@ from datetime import datetime -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal from uuid import UUID from enum import StrEnum from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index -from .base import SQLModelBase, UUIDTableBase +from .base import SQLModelBase +from .mixin import UUIDTableBaseMixin if TYPE_CHECKING: from .user import User @@ -19,6 +20,43 @@ class ObjectType(StrEnum): """对象类型枚举""" FILE = "file" FOLDER = "folder" + +class StorageType(StrEnum): + """存储类型枚举""" + LOCAL = "local" + QINIU = "qiniu" + TENCENT = "tencent" + ALIYUN = "aliyun" + ONEDRIVE = "onedrive" + GOOGLE_DRIVE = "google_drive" + DROPBOX = "dropbox" + WEBDAV = "webdav" + REMOTE = "remote" + + +class FileMetadataBase(SQLModelBase): + """文件元数据基础模型""" + + width: int | None = Field(default=None) + """图片宽度(像素)""" + + height: int | None = Field(default=None) + """图片高度(像素)""" + + duration: float | None = Field(default=None) + """音视频时长(秒)""" + + bitrate: int | None = Field(default=None) + """比特率(kbps)""" + + mime_type: str | None = Field(default=None, max_length=127) + """MIME类型""" + + checksum_md5: str | None = Field(default=None, max_length=32) + """MD5校验和""" + + checksum_sha256: str | None = Field(default=None, max_length=64) + """SHA256校验和""" # ==================== Base 模型 ==================== @@ -99,10 +137,10 @@ class PolicyResponse(SQLModelBase): name: str """策略名称""" - type: Literal["local", "qiniu", "tencent", "aliyun", "onedrive", "google_drive", "dropbox", "webdav", "remote"] + type: StorageType """存储类型""" - max_size: int = 0 + max_size: int = Field(ge=0, default=0) """单文件最大限制,单位字节,0表示不限制""" file_type: list[str] | None = None @@ -127,7 +165,18 @@ class DirectoryResponse(SQLModelBase): # ==================== 数据库模型 ==================== -class Object(ObjectBase, UUIDTableBase, table=True): +class FileMetadata(FileMetadataBase, UUIDTableBaseMixin): + """文件元数据模型(与Object一对一关联)""" + + object_id: UUID = Field(foreign_key="object.id", unique=True, index=True) + """关联的对象UUID""" + + # 反向关系 + object: "Object" = Relationship(back_populates="file_metadata") + """关联的对象""" + + +class Object(ObjectBase, UUIDTableBaseMixin): """ 统一对象模型 @@ -143,7 +192,7 @@ class Object(ObjectBase, UUIDTableBase, table=True): __table_args__ = ( # 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况) UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"), - # 名称不能包含斜杠 + # 名称不能包含斜杠 ([TODO] 还有特殊字符) CheckConstraint( "name NOT LIKE '%/%' AND name NOT LIKE '%\\%'", name="ck_object_name_no_slash", @@ -168,7 +217,7 @@ class Object(ObjectBase, UUIDTableBase, table=True): # ==================== 文件专属字段 ==================== - source_name: str | None = None + source_name: str | None = Field(default=None, max_length=255) """源文件名(仅文件有效)""" size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) @@ -177,10 +226,6 @@ class Object(ObjectBase, UUIDTableBase, table=True): upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True) """分块上传会话ID(仅文件有效)""" - # [TODO] 拆分 - file_metadata: str | None = None - """文件元数据 (JSON格式),仅文件有效""" - # ==================== 外键 ==================== parent_id: UUID | None = Field(default=None, foreign_key="object.id", index=True) @@ -201,7 +246,7 @@ class Object(ObjectBase, UUIDTableBase, table=True): """存储策略""" # 自引用关系 - parent: Optional["Object"] = Relationship( + parent: "Object" = Relationship( back_populates="children", sa_relationship_kwargs={"remote_side": "Object.id"}, ) @@ -211,6 +256,12 @@ class Object(ObjectBase, UUIDTableBase, table=True): """子对象(文件和子目录)""" # 仅文件有效的关系 + file_metadata: FileMetadata | None = Relationship( + back_populates="object", + sa_relationship_kwargs={"uselist": False}, + ) + """文件元数据(仅文件有效)""" + source_links: list["SourceLink"] = Relationship(back_populates="object") """源链接列表(仅文件有效)""" diff --git a/models/order.py b/models/order.py index e2dd003..7b88d3a 100644 --- a/models/order.py +++ b/models/order.py @@ -1,25 +1,58 @@ - +from enum import StrEnum from typing import TYPE_CHECKING from uuid import UUID from sqlmodel import Field, Relationship -from .base import TableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .user import User -class Order(TableBase, table=True): + +class OrderStatus(StrEnum): + """订单状态枚举""" + PENDING = "pending" + """待支付""" + COMPLETED = "completed" + """已完成""" + CANCELLED = "cancelled" + """已取消""" + + +class OrderType(StrEnum): + """订单类型枚举""" + # [TODO] 补充具体订单类型 + pass + + +class Order(SQLModelBase, TableBaseMixin): """订单模型""" - order_no: str = Field(max_length=255, unique=True, index=True, description="订单号,唯一") - type: int = Field(description="订单类型") - method: str | None = Field(default=None, max_length=255, description="支付方式") - product_id: int | None = Field(default=None, description="商品ID") - num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="购买数量") - name: str = Field(max_length=255, description="商品名称") - price: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单价格(分)") - status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单状态: 0=待支付, 1=已完成, 2=已取消") + order_no: str = Field(max_length=255, unique=True, index=True) + """订单号,唯一""" + + type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """订单类型 [TODO] 待定义枚举""" + + method: str | None = Field(default=None, max_length=255) + """支付方式""" + + product_id: int | None = Field(default=None) + """商品ID""" + + num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}) + """购买数量""" + + name: str = Field(max_length=255) + """商品名称""" + + price: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """订单价格(分)""" + + status: OrderStatus = Field(default=OrderStatus.PENDING, sa_column_kwargs={"server_default": "'pending'"}) + """订单状态""" # 外键 user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID") diff --git a/models/policy.py b/models/policy.py index 8f25e2c..084e1ae 100644 --- a/models/policy.py +++ b/models/policy.py @@ -4,7 +4,8 @@ from uuid import UUID from enum import StrEnum from sqlmodel import Field, Relationship, text -from .base import SQLModelBase, UUIDTableBase +from .base import SQLModelBase +from .mixin import UUIDTableBaseMixin if TYPE_CHECKING: from .object import Object @@ -29,16 +30,16 @@ class PolicyType(StrEnum): class PolicyOptionsBase(SQLModelBase): """存储策略选项的基础模型""" - token: str | None = None + token: str | None = Field(default=None) """访问令牌""" - file_type: str | None = None + file_type: str | None = Field(default=None) """允许的文件类型""" - mimetype: str | None = None + mimetype: str | None = Field(default=None, max_length=127) """MIME类型""" - od_redirect: str | None = None + od_redirect: str | None = Field(default=None, max_length=255) """OneDrive重定向地址""" chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"}) @@ -48,7 +49,7 @@ class PolicyOptionsBase(SQLModelBase): """是否使用S3路径风格""" -class PolicyOptions(PolicyOptionsBase, UUIDTableBase, table=True): +class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin): """存储策略选项模型(与Policy一对一关联)""" policy_id: UUID = Field(foreign_key="policy.id", unique=True) @@ -59,7 +60,7 @@ class PolicyOptions(PolicyOptionsBase, UUIDTableBase, table=True): """关联的策略""" -class Policy(UUIDTableBase, table=True): +class Policy(SQLModelBase, UUIDTableBaseMixin): """存储策略模型""" name: str = Field(max_length=255, unique=True) diff --git a/models/redeem.py b/models/redeem.py index d00bfab..574eec6 100644 --- a/models/redeem.py +++ b/models/redeem.py @@ -1,10 +1,22 @@ -from sqlmodel import Field, text -from .base import TableBase +from enum import StrEnum -class Redeem(TableBase, table=True): +from sqlmodel import Field, text + +from .base import SQLModelBase +from .mixin import TableBaseMixin + + +class RedeemType(StrEnum): + """兑换码类型枚举""" + # [TODO] 补充具体兑换码类型 + pass + + +class Redeem(SQLModelBase, TableBaseMixin): """兑换码模型""" - type: int = Field(description="兑换码类型") + type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """兑换码类型 [TODO] 待定义枚举""" product_id: int | None = Field(default=None, description="关联的商品/权益ID") num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等") code: str = Field(unique=True, index=True, description="兑换码,唯一") diff --git a/models/report.py b/models/report.py index 161ca4c..e39a450 100644 --- a/models/report.py +++ b/models/report.py @@ -1,15 +1,26 @@ - +from enum import StrEnum from typing import TYPE_CHECKING + from sqlmodel import Field, Relationship -from .base import TableBase + +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .share import Share -class Report(TableBase, table=True): + +class ReportReason(StrEnum): + """举报原因枚举""" + # [TODO] 补充具体举报原因 + pass + + +class Report(SQLModelBase, TableBaseMixin): """举报模型""" - reason: int = Field(description="举报原因代码") + reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """举报原因 [TODO] 待定义枚举""" description: str | None = Field(default=None, max_length=255, description="补充描述") # 外键 diff --git a/models/response.py b/models/response.py index 5782284..11bc235 100644 --- a/models/response.py +++ b/models/response.py @@ -1,15 +1,12 @@ """ 通用响应模型定义 """ - -from typing import Any import uuid from sqlmodel import Field from .base import SQLModelBase -# [TODO] 未来把这拆了,直接按需返回状态码 class ResponseBase(SQLModelBase): """通用响应模型""" diff --git a/models/setting.py b/models/setting.py index 531bf33..9d20858 100644 --- a/models/setting.py +++ b/models/setting.py @@ -1,9 +1,10 @@ from typing import Literal +from enum import StrEnum from sqlmodel import Field, UniqueConstraint -from .base import TableBase, SQLModelBase -from enum import StrEnum +from .base import SQLModelBase +from .mixin import TableBaseMixin # ==================== DTO 模型 ==================== @@ -72,7 +73,7 @@ class SettingsType(StrEnum): WOPI = "wopi" # 数据库模型 -class Setting(TableBase, table=True): +class Setting(SQLModelBase, TableBaseMixin): """设置模型""" __table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),) diff --git a/models/share.py b/models/share.py index 3505587..e235d72 100644 --- a/models/share.py +++ b/models/share.py @@ -5,7 +5,8 @@ from uuid import UUID from sqlmodel import Field, Relationship, text, UniqueConstraint, Index -from .base import TableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .user import User @@ -13,7 +14,7 @@ if TYPE_CHECKING: from .object import Object -class Share(TableBase, table=True): +class Share(SQLModelBase, TableBaseMixin): """分享模型""" __table_args__ = ( @@ -38,10 +39,10 @@ class Share(TableBase, table=True): downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) """下载次数""" - remain_downloads: int | None = None + remain_downloads: int | None = Field(default=None) """剩余下载次数 (NULL为不限制)""" - expires: datetime | None = None + expires: datetime | None = Field(default=None) """过期时间 (NULL为永不过期)""" preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}) diff --git a/models/source_link.py b/models/source_link.py index fa73ab2..fdc5a0a 100644 --- a/models/source_link.py +++ b/models/source_link.py @@ -4,13 +4,14 @@ from uuid import UUID from sqlmodel import Field, Relationship, Index -from .base import TableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .object import Object -class SourceLink(TableBase, table=True): +class SourceLink(SQLModelBase, TableBaseMixin): """链接模型""" __table_args__ = ( diff --git a/models/storage_pack.py b/models/storage_pack.py index 5ef89d6..019a10a 100644 --- a/models/storage_pack.py +++ b/models/storage_pack.py @@ -1,16 +1,17 @@ -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from datetime import datetime from uuid import UUID from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import TableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .user import User -class StoragePack(TableBase, table=True): +class StoragePack(SQLModelBase, TableBaseMixin): """容量包模型""" name: str = Field(max_length=255, description="容量包名称") diff --git a/models/tag.py b/models/tag.py index c977490..b2a0649 100644 --- a/models/tag.py +++ b/models/tag.py @@ -1,24 +1,41 @@ - -from typing import Optional, TYPE_CHECKING +from enum import StrEnum +from typing import TYPE_CHECKING from uuid import UUID from datetime import datetime from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime -from .base import TableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .user import User -class Tag(TableBase, table=True): + +class TagType(StrEnum): + """标签类型枚举""" + MANUAL = "manual" + """手动标签""" + AUTOMATIC = "automatic" + """自动标签""" + + +class Tag(SQLModelBase, TableBaseMixin): """标签模型""" __table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),) - name: str = Field(max_length=255, description="标签名称") - icon: str | None = Field(default=None, max_length=255, description="标签图标") - color: str | None = Field(default=None, max_length=255, description="标签颜色") - type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="标签类型: 0=手动, 1=自动") + name: str = Field(max_length=255) + """标签名称""" + + icon: str | None = Field(default=None, max_length=255) + """标签图标""" + + color: str | None = Field(default=None, max_length=255) + """标签颜色""" + + type: TagType = Field(default=TagType.MANUAL, sa_column_kwargs={"server_default": "'manual'"}) + """标签类型""" expression: str | None = Field(default=None, description="自动标签的匹配表达式") # 外键 diff --git a/models/task.py b/models/task.py index 9ce3873..7585ba7 100644 --- a/models/task.py +++ b/models/task.py @@ -1,32 +1,96 @@ - -from typing import Optional, TYPE_CHECKING +from enum import StrEnum +from typing import TYPE_CHECKING from uuid import UUID from datetime import datetime -from sqlmodel import Field, Relationship, CheckConstraint +from sqlmodel import Field, Relationship, CheckConstraint, Index -from .base import TableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .user import User from .download import Download -class Task(TableBase, table=True): + +class TaskStatus(StrEnum): + """任务状态枚举""" + QUEUED = "queued" + """排队中""" + RUNNING = "running" + """处理中""" + COMPLETED = "completed" + """已完成""" + ERROR = "error" + """错误""" + + +class TaskType(StrEnum): + """任务类型枚举""" + # [TODO] 补充具体任务类型 + pass + + +class TaskPropsBase(SQLModelBase): + """任务属性基础模型""" + + source_path: str | None = None + """源路径""" + + dest_path: str | None = None + """目标路径""" + + file_ids: str | None = None + """文件ID列表(逗号分隔)""" + + # [TODO] 根据业务需求补充更多字段 + + +class TaskProps(TaskPropsBase, TableBaseMixin): + """任务属性模型(与Task一对一关联)""" + + task_id: int = Field(foreign_key="task.id", primary_key=True) + """关联的任务ID""" + + # 反向关系 + task: "Task" = Relationship(back_populates="props") + """关联的任务""" + + +class Task(SQLModelBase, TableBaseMixin): """任务模型""" __table_args__ = ( CheckConstraint("progress BETWEEN 0 AND 100", name="ck_task_progress_range"), + Index("ix_task_status", "status"), + Index("ix_task_user_status", "user_id", "status"), ) - status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误") - type: int = Field(description="任务类型") - progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务进度 (0-100)") - error: str | None = Field(default=None, description="错误信息") - props: str | None = Field(default=None, description="任务属性 (JSON格式)") - + status: TaskStatus = Field(default=TaskStatus.QUEUED, sa_column_kwargs={"server_default": "'queued'"}) + """任务状态""" + + type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + """任务类型 [TODO] 待定义枚举""" + + progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0, le=100) + """任务进度(0-100)""" + + error: str | None = Field(default=None) + """错误信息""" + # 外键 - user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID") - + user_id: UUID = Field(foreign_key="user.id", index=True) + """所属用户UUID""" + # 关系 + props: TaskProps | None = Relationship( + back_populates="task", + sa_relationship_kwargs={"uselist": False}, + ) + """任务属性""" + user: "User" = Relationship(back_populates="tasks") - downloads: list["Download"] = Relationship(back_populates="task") \ No newline at end of file + """所属用户""" + + downloads: list["Download"] = Relationship(back_populates="task") + """关联的下载任务""" \ No newline at end of file diff --git a/models/user.py b/models/user.py index bc87432..b7db76a 100644 --- a/models/user.py +++ b/models/user.py @@ -1,11 +1,12 @@ from datetime import datetime from enum import StrEnum -from typing import Literal, Optional, TYPE_CHECKING +from typing import Literal, TYPE_CHECKING from uuid import UUID from sqlmodel import Field, Relationship -from .base import SQLModelBase, UUIDTableBase +from .base import SQLModelBase +from .mixin import UUIDTableBaseMixin if TYPE_CHECKING: from .group import Group @@ -19,15 +20,6 @@ if TYPE_CHECKING: from .user_authn import UserAuthn from .webdav import WebDAV -""" -Option 需求 -- 主题 跟随系统/浅色/深色 -- 颜色方案 参考 ThemeResponse -- 语言 -- 时区 -- 切换到不同存储策略是否提醒 -""" - class AvatarType(StrEnum): """头像类型枚举""" @@ -42,6 +34,13 @@ class ThemeType(StrEnum): DARK = "dark" SYSTEM = "system" +class UserStatus(StrEnum): + """用户状态枚举""" + + ACTIVE = "active" + ADMIN_BANNED = "admin_banned" + SYSTEM_BANNED = "system_banned" + # ==================== Base 模型 ==================== @@ -51,8 +50,8 @@ class UserBase(SQLModelBase): username: str """用户名""" - status: bool = True - """用户状态: True=正常, False=封禁""" + status: UserStatus = UserStatus.ACTIVE + """用户状态""" score: int = 0 """用户积分""" @@ -72,7 +71,7 @@ class LoginRequest(SQLModelBase): captcha: str | None = None """验证码""" - two_fa_code: str | None = None + two_fa_code: int | None = Field(min_length=6, max_length=6) """两步验证代码""" @@ -192,9 +191,6 @@ class UserSettingResponse(SQLModelBase): prefer_theme: str = "#5898d4" """用户首选主题""" - qq: str | None = None - """QQ号""" - themes: dict[str, str] = {} """用户主题配置""" @@ -216,7 +212,7 @@ UserSettingResponse.model_rebuild() # ==================== 数据库模型 ==================== -class User(UserBase, UUIDTableBase, table=True): +class User(UserBase, UUIDTableBaseMixin): """用户模型""" username: str = Field(max_length=50, unique=True, index=True) @@ -243,7 +239,7 @@ class User(UserBase, UUIDTableBase, table=True): score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0) """用户积分""" - group_expires: datetime | None = None + group_expires: datetime | None = Field(default=None) """当前用户组过期时间""" # Option 相关字段 @@ -266,13 +262,13 @@ class User(UserBase, UUIDTableBase, table=True): # 关系 group: "Group" = Relationship( - back_populates="user", + back_populates="users", sa_relationship_kwargs={ "foreign_keys": "User.group_id" } ) - previous_group: Optional["Group"] = Relationship( - back_populates="previous_user", + previous_group: "Group" = Relationship( + back_populates="previous_users", sa_relationship_kwargs={ "foreign_keys": "User.previous_group_id" } diff --git a/models/user_authn.py b/models/user_authn.py index cb0dfa3..3ce24be 100644 --- a/models/user_authn.py +++ b/models/user_authn.py @@ -4,7 +4,8 @@ from uuid import UUID from sqlalchemy import Column, Text from sqlmodel import Field, Relationship -from .base import TableBase, SQLModelBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .user import User @@ -24,7 +25,7 @@ class AuthnResponse(SQLModelBase): # ==================== 数据库模型 ==================== -class UserAuthn(TableBase, table=True): +class UserAuthn(SQLModelBase, TableBaseMixin): """用户 WebAuthn 凭证模型,与 User 为多对一关系""" credential_id: str = Field(max_length=255, unique=True, index=True) diff --git a/models/webdav.py b/models/webdav.py index 6b3eb06..e2f4d67 100644 --- a/models/webdav.py +++ b/models/webdav.py @@ -4,12 +4,13 @@ from uuid import UUID from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime -from .base import TableBase +from .base import SQLModelBase +from .mixin import TableBaseMixin if TYPE_CHECKING: from .user import User -class WebDAV(TableBase, table=True): +class WebDAV(SQLModelBase, TableBaseMixin): """WebDAV账户模型""" __table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),) diff --git a/utils/JWT/JWT.py b/utils/JWT/JWT.py index 97ec875..52543ea 100644 --- a/utils/JWT/JWT.py +++ b/utils/JWT/JWT.py @@ -6,7 +6,7 @@ from fastapi.security import OAuth2PasswordBearer oauth2_scheme = OAuth2PasswordBearer( scheme_name='获取 JWT Bearer 令牌', description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', - tokenUrl="/api/user/session", + tokenUrl="/api/v1/user/session", ) SECRET_KEY = ''