diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..0282a6b
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,37 @@
+.git/
+.gitignore
+.github/
+.idea/
+.vscode/
+.venv/
+.env
+.env.*
+.run/
+.claude/
+
+__pycache__/
+*.py[cod]
+
+tests/
+htmlcov/
+.pytest_cache/
+.coverage
+coverage.xml
+
+*.db
+*.sqlite
+*.sqlite3
+*.log
+logs/
+data/
+
+Dockerfile
+.dockerignore
+
+# Cython 编译产物
+*.c
+build/
+
+# 许可证私钥和工具脚本
+license_private.pem
+scripts/
diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
index e961ba8..10fc525 100644
--- a/.github/copilot-instructions.md
+++ b/.github/copilot-instructions.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
```python
import asyncio
from sqlmodel import Field
-from sqlmodels.base import UUIDTableBase, SQLModelBase
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
class CharacterBase(SQLModelBase):
name: str
"""角色名称"""
-class Character(CharacterBase, UUIDTableBase):
+class Character(CharacterBase, UUIDTableBaseMixin):
"""充血模型:包含数据和业务逻辑"""
# ==================== 运行时属性(在model_post_init初始化) ====================
@@ -570,11 +570,11 @@ async with character.init(session):
from abc import ABC, abstractmethod
from uuid import UUID
from sqlmodel import Field
-from sqlmodels.base import (
+from sqlmodel_ext import (
SQLModelBase,
- UUIDTableBase,
+ UUIDTableBaseMixin,
create_subclass_id_mixin,
- AutoPolymorphicIdentityMixin
+ AutoPolymorphicIdentityMixin,
)
# 1. 定义Base类(只有字段,无表)
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
# 2. 定义抽象父类(有表)
class ASR(
ASRBase,
- UUIDTableBase,
+ UUIDTableBaseMixin,
ABC,
polymorphic_on='__polymorphic_name',
polymorphic_abstract=True
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
# 3. 本地应用导入(从项目根目录的包开始)
from dependencies import SessionDep
from sqlmodels.user import User
-from sqlmodels.base import UUIDTableBase
+from sqlmodel_ext import UUIDTableBaseMixin
# 4. 相对导入(同包内的模块)
from .base import BaseClass
diff --git a/.gitignore b/.gitignore
index 1acc9c6..20979a4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -69,3 +69,13 @@ data/
# JB 的运行配置(换设备用不了)
.run/
.xml
+
+# 前端构建产物(Docker 构建时复制)
+statics/
+
+# Cython 编译产物
+*.c
+
+# 许可证密钥(保密)
+license_private.pem
+license.key
diff --git a/AGENTS.md b/AGENTS.md
index 0a3b586..c790885 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
```python
import asyncio
from sqlmodel import Field
-from sqlmodels.base import UUIDTableBase, SQLModelBase
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
class CharacterBase(SQLModelBase):
name: str
"""角色名称"""
-class Character(CharacterBase, UUIDTableBase):
+class Character(CharacterBase, UUIDTableBaseMixin):
"""充血模型:包含数据和业务逻辑"""
# ==================== 运行时属性(在model_post_init初始化) ====================
@@ -570,11 +570,11 @@ async with character.init(session):
from abc import ABC, abstractmethod
from uuid import UUID
from sqlmodel import Field
-from sqlmodels.base import (
+from sqlmodel_ext import (
SQLModelBase,
- UUIDTableBase,
+ UUIDTableBaseMixin,
create_subclass_id_mixin,
- AutoPolymorphicIdentityMixin
+ AutoPolymorphicIdentityMixin,
)
# 1. 定义Base类(只有字段,无表)
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
# 2. 定义抽象父类(有表)
class ASR(
ASRBase,
- UUIDTableBase,
+ UUIDTableBaseMixin,
ABC,
polymorphic_on='__polymorphic_name',
polymorphic_abstract=True
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
# 3. 本地应用导入(从项目根目录的包开始)
from dependencies import SessionDep
from sqlmodels.user import User
-from sqlmodels.base import UUIDTableBase
+from sqlmodel_ext import UUIDTableBaseMixin
# 4. 相对导入(同包内的模块)
from .base import BaseClass
diff --git a/CLAUDE.md b/CLAUDE.md
index 0a3b586..c790885 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
```python
import asyncio
from sqlmodel import Field
-from sqlmodels.base import UUIDTableBase, SQLModelBase
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
class CharacterBase(SQLModelBase):
name: str
"""角色名称"""
-class Character(CharacterBase, UUIDTableBase):
+class Character(CharacterBase, UUIDTableBaseMixin):
"""充血模型:包含数据和业务逻辑"""
# ==================== 运行时属性(在model_post_init初始化) ====================
@@ -570,11 +570,11 @@ async with character.init(session):
from abc import ABC, abstractmethod
from uuid import UUID
from sqlmodel import Field
-from sqlmodels.base import (
+from sqlmodel_ext import (
SQLModelBase,
- UUIDTableBase,
+ UUIDTableBaseMixin,
create_subclass_id_mixin,
- AutoPolymorphicIdentityMixin
+ AutoPolymorphicIdentityMixin,
)
# 1. 定义Base类(只有字段,无表)
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
# 2. 定义抽象父类(有表)
class ASR(
ASRBase,
- UUIDTableBase,
+ UUIDTableBaseMixin,
ABC,
polymorphic_on='__polymorphic_name',
polymorphic_abstract=True
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
# 3. 本地应用导入(从项目根目录的包开始)
from dependencies import SessionDep
from sqlmodels.user import User
-from sqlmodels.base import UUIDTableBase
+from sqlmodel_ext import UUIDTableBaseMixin
# 4. 相对导入(同包内的模块)
from .base import BaseClass
diff --git a/Dockerfile b/Dockerfile
index af6ff71..30a2e9e 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,13 +1,52 @@
-FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
+# ============================================================
+# 基础层:安装运行时依赖
+# ============================================================
+FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS base
WORKDIR /app
COPY pyproject.toml uv.lock ./
-
RUN uv sync --frozen --no-dev
COPY . .
-EXPOSE 5213
+# ============================================================
+# Community 版本:删除 ee/ 目录
+# ============================================================
+FROM base AS community
-CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
\ No newline at end of file
+RUN rm -rf ee/
+COPY statics/ /app/statics/
+
+EXPOSE 5213
+CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
+
+# ============================================================
+# Pro 编译层:Cython 编译 ee/ 模块
+# ============================================================
+FROM base AS pro-builder
+
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends gcc libc6-dev && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN uv sync --frozen --no-dev --extra build
+
+RUN uv run python setup_cython.py build_ext --inplace && \
+ uv run python setup_cython.py clean_artifacts
+
+# ============================================================
+# Pro 版本:包含编译后的 ee/ 模块(仅 __init__.py + .so)
+# ============================================================
+FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS pro
+
+WORKDIR /app
+
+COPY pyproject.toml uv.lock ./
+RUN uv sync --frozen --no-dev
+
+COPY --from=pro-builder /app/ /app/
+COPY statics/ /app/statics/
+
+EXPOSE 5213
+CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
diff --git a/docs/file-viewer-api.md b/docs/file-viewer-api.md
new file mode 100644
index 0000000..240816e
--- /dev/null
+++ b/docs/file-viewer-api.md
@@ -0,0 +1,594 @@
+# 文件预览应用选择器 — 前端适配文档
+
+## 概述
+
+文件预览系统类似 Android 的"使用什么应用打开"机制:用户点击文件时,前端根据扩展名查询可用查看器列表,展示选择弹窗,用户可选"仅此一次"或"始终使用"。
+
+### 应用类型
+
+| type | 说明 | 前端处理方式 |
+|------|------|-------------|
+| `builtin` | 前端内置组件 | 根据 `app_key` 路由到内置组件(如 `pdfjs`、`monaco`) |
+| `iframe` | iframe 内嵌 | 将 `iframe_url_template` 中的 `{file_url}` 替换为文件下载 URL,嵌入 iframe |
+| `wopi` | WOPI 协议 | 调用 `/file/{id}/wopi-session` 获取 `editor_url`,嵌入 iframe |
+
+### 内置 app_key 映射
+
+前端需要为以下 `app_key` 实现对应的内置预览组件:
+
+| app_key | 组件 | 说明 |
+|---------|------|------|
+| `pdfjs` | PDF.js 阅读器 | pdf |
+| `monaco` | Monaco Editor | txt, md, json, py, js, ts, html, css, ... |
+| `markdown` | Markdown 渲染器 | md, markdown, mdx |
+| `image_viewer` | 图片查看器 | jpg, png, gif, webp, svg, ... |
+| `video_player` | HTML5 Video | mp4, webm, ogg, mov, mkv, m3u8 |
+| `audio_player` | HTML5 Audio | mp3, wav, flac, aac, m4a, opus |
+
+> `office_viewer`(iframe)、`collabora`(wopi)、`onlyoffice`(wopi)默认禁用,需管理员在后台启用和配置。
+
+---
+
+## 文件下载 URL 与 iframe 预览
+
+### 现有下载流程(两步式)
+
+```
+步骤1: POST /api/v1/file/download/{file_id} → { access_token, expires_in }
+步骤2: GET /api/v1/file/download/{access_token} → 文件二进制流
+```
+
+- 步骤 1 需要 JWT 认证,返回一个下载令牌(有效期 1 小时)
+- 步骤 2 **不需要认证**,用令牌直接下载,**令牌为一次性**,下载后失效
+
+### 各类型查看器获取文件内容的方式
+
+| type | 获取文件方式 | 说明 |
+|------|-------------|------|
+| `builtin` | 前端自行获取 | 前端用 JS 调用下载接口拿到 Blob/ArrayBuffer,传给内置组件渲染 |
+| `iframe` | 需要公开可访问的 URL | 第三方服务(如 Office Online)会**从服务端拉取文件** |
+| `wopi` | WOPI 协议自动处理 | 编辑器通过 `/wopi/files/{id}/contents` 获取,前端只需嵌入 `editor_url` |
+
+### builtin 类型 — 前端自行获取
+
+内置组件(pdfjs、monaco 等)运行在前端,直接用 JS 获取文件内容即可:
+
+```typescript
+// 方式 A:用下载令牌拼 URL(适用于 PDF.js 等需要 URL 的组件)
+const { access_token } = await api.post(`/file/download/${fileId}`)
+const fileUrl = `${baseUrl}/api/v1/file/download/${access_token}`
+// 传给 PDF.js: pdfjsLib.getDocument(fileUrl)
+
+// 方式 B:用 fetch + Authorization 头获取 Blob(适用于需要 ArrayBuffer 的组件)
+const { access_token } = await api.post(`/file/download/${fileId}`)
+const blob = await fetch(`${baseUrl}/api/v1/file/download/${access_token}`).then(r => r.blob())
+// 传给 Monaco: monaco.editor.create(el, { value: await blob.text() })
+```
+
+### iframe 类型 — `{file_url}` 替换规则
+
+`iframe_url_template` 中的 `{file_url}` 需要替换为一个**外部可访问的文件直链**。
+
+**问题**:当前下载令牌是一次性的,而 Office Online 等服务可能多次请求该 URL。
+
+**当前可行方案**:
+
+```typescript
+// 1. 创建下载令牌
+const { access_token } = await api.post(`/file/download/${fileId}`)
+
+// 2. 拼出完整的文件 URL(必须是公网可达的地址)
+const fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
+
+// 3. 替换模板
+const iframeSrc = viewer.iframe_url_template.replace(
+ '{file_url}',
+ encodeURIComponent(fileUrl)
+)
+
+// 4. 嵌入 iframe
+//
+```
+
+> **已知限制**:下载令牌为一次性使用。如果第三方服务多次拉取文件(如 Office Online 可能重试),
+> 第二次请求会 404。后续版本将实现 `/file/get/{id}/{name}` 外链端点(多次可用),届时
+> iframe 应改用外链 URL。目前建议:
+>
+> 1. **优先使用 WOPI 类型**(Collabora/OnlyOffice),不存在此限制
+> 2. Office Online 预览在**文件较小**时通常只拉取一次,大多数场景可用
+> 3. 如需稳定方案,可等待外链端点实现后再启用 iframe 类型应用
+
+### wopi 类型 — 无需关心文件 URL
+
+WOPI 类型的查看器完全由后端处理文件传输,前端只需:
+
+```typescript
+// 1. 创建 WOPI 会话
+const session = await api.post(`/file/${fileId}/wopi-session`)
+
+// 2. 直接嵌入编辑器
+//
+```
+
+编辑器(Collabora/OnlyOffice)会通过 WOPI 协议自动从 `/wopi/files/{id}/contents` 获取文件内容,使用 `access_token` 认证,前端无需干预。
+
+---
+
+## 用户端 API
+
+### 1. 查询可用查看器
+
+用户点击文件时调用,获取该扩展名的可用查看器列表。
+
+```
+GET /api/v1/file/viewers?ext={extension}
+Authorization: Bearer {token}
+```
+
+**Query 参数**
+
+| 参数 | 类型 | 必填 | 说明 |
+|------|------|------|------|
+| ext | string | 是 | 文件扩展名,最长 20 字符。支持带点号(`.pdf`)、大写(`PDF`),后端会自动规范化 |
+
+**响应 200**
+
+```json
+{
+ "viewers": [
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "PDF 阅读器",
+ "app_key": "pdfjs",
+ "type": "builtin",
+ "icon": "file-pdf",
+ "description": "基于 pdf.js 的 PDF 在线阅读器",
+ "iframe_url_template": null,
+ "wopi_editor_url_template": null
+ }
+ ],
+ "default_viewer_id": null
+}
+```
+
+**字段说明**
+
+| 字段 | 类型 | 说明 |
+|------|------|------|
+| viewers | FileAppSummary[] | 可用查看器列表,已按优先级排序 |
+| default_viewer_id | string \| null | 用户设置的"始终使用"查看器 UUID,未设置则为 null |
+
+**FileAppSummary**
+
+| 字段 | 类型 | 说明 |
+|------|------|------|
+| id | UUID | 应用 UUID |
+| name | string | 应用显示名称 |
+| app_key | string | 应用唯一标识,前端路由用 |
+| type | `"builtin"` \| `"iframe"` \| `"wopi"` | 应用类型 |
+| icon | string \| null | 图标名称(可映射到 icon library) |
+| description | string \| null | 应用描述 |
+| iframe_url_template | string \| null | iframe 类型专用,URL 模板含 `{file_url}` 占位符 |
+| wopi_editor_url_template | string \| null | wopi 类型专用,编辑器 URL 模板 |
+
+---
+
+### 2. 设置默认查看器("始终使用")
+
+用户在选择弹窗中勾选"始终使用此应用"时调用。
+
+```
+PUT /api/v1/user/settings/file-viewers/default
+Authorization: Bearer {token}
+Content-Type: application/json
+```
+
+**请求体**
+
+```json
+{
+ "extension": "pdf",
+ "app_id": "550e8400-e29b-41d4-a716-446655440000"
+}
+```
+
+| 字段 | 类型 | 必填 | 说明 |
+|------|------|------|------|
+| extension | string | 是 | 文件扩展名(小写,无点号) |
+| app_id | UUID | 是 | 选择的查看器应用 UUID |
+
+**响应 200**
+
+```json
+{
+ "id": "660e8400-e29b-41d4-a716-446655440001",
+ "extension": "pdf",
+ "app": {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "PDF 阅读器",
+ "app_key": "pdfjs",
+ "type": "builtin",
+ "icon": "file-pdf",
+ "description": "基于 pdf.js 的 PDF 在线阅读器",
+ "iframe_url_template": null,
+ "wopi_editor_url_template": null
+ }
+}
+```
+
+**错误码**
+
+| 状态码 | 说明 |
+|--------|------|
+| 400 | 该应用不支持此扩展名 |
+| 404 | 应用不存在 |
+
+> 同一扩展名只允许一个默认值。重复 PUT 同一 extension 会更新(upsert),不会冲突。
+
+---
+
+### 3. 列出所有默认查看器设置
+
+用于用户设置页展示"已设为始终使用"的列表。
+
+```
+GET /api/v1/user/settings/file-viewers/defaults
+Authorization: Bearer {token}
+```
+
+**响应 200**
+
+```json
+[
+ {
+ "id": "660e8400-e29b-41d4-a716-446655440001",
+ "extension": "pdf",
+ "app": {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "PDF 阅读器",
+ "app_key": "pdfjs",
+ "type": "builtin",
+ "icon": "file-pdf",
+ "description": null,
+ "iframe_url_template": null,
+ "wopi_editor_url_template": null
+ }
+ }
+]
+```
+
+---
+
+### 4. 撤销默认查看器设置
+
+用户在设置页点击"取消始终使用"时调用。
+
+```
+DELETE /api/v1/user/settings/file-viewers/default/{id}
+Authorization: Bearer {token}
+```
+
+**响应** 204 No Content
+
+**错误码**
+
+| 状态码 | 说明 |
+|--------|------|
+| 404 | 记录不存在或不属于当前用户 |
+
+---
+
+### 5. 创建 WOPI 会话
+
+打开 WOPI 类型应用(如 Collabora、OnlyOffice)时调用。
+
+```
+POST /api/v1/file/{file_id}/wopi-session
+Authorization: Bearer {token}
+```
+
+**响应 200**
+
+```json
+{
+ "wopi_src": "http://localhost:8000/wopi/files/770e8400-e29b-41d4-a716-446655440002",
+ "access_token": "eyJhbGciOiJIUzI1NiIs...",
+ "access_token_ttl": 1739577600000,
+ "editor_url": "http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc=...&access_token=...&access_token_ttl=..."
+}
+```
+
+| 字段 | 类型 | 说明 |
+|------|------|------|
+| wopi_src | string | WOPI 源 URL(传给编辑器) |
+| access_token | string | WOPI 访问令牌 |
+| access_token_ttl | int | 令牌过期毫秒时间戳 |
+| editor_url | string | 完整的编辑器 URL,**直接嵌入 iframe 即可** |
+
+**错误码**
+
+| 状态码 | 说明 |
+|--------|------|
+| 400 | 文件无扩展名 / WOPI 应用未配置编辑器 URL |
+| 403 | 用户组无权限 |
+| 404 | 文件不存在 / 无可用 WOPI 查看器 |
+
+---
+
+## 前端交互流程
+
+### 打开文件预览
+
+```
+用户点击文件
+ │
+ ▼
+GET /file/viewers?ext={扩展名}
+ │
+ ├── viewers 为空 → 提示"暂无可用的预览方式"
+ │
+ ├── default_viewer_id 不为空 → 直接用对应 viewer 打开(跳过选择弹窗)
+ │
+ └── viewers.length == 1 → 直接用唯一 viewer 打开(可选策略)
+ │
+ └── viewers.length > 1 → 展示选择弹窗
+ │
+ ├── 用户选择 + 不勾选"始终使用" → 仅此一次打开
+ │
+ └── 用户选择 + 勾选"始终使用" → PUT /user/settings/file-viewers/default
+ │
+ └── 然后打开
+```
+
+### 根据 type 打开查看器
+
+```
+获取到 viewer 对象
+ │
+ ├── type == "builtin"
+ │ └── 根据 app_key 路由到内置组件
+ │ switch(app_key):
+ │ "pdfjs" →
+ │ "monaco" →
+ │ "markdown" →
+ │ "image_viewer" →
+ │ "video_player" →
+ │ "audio_player" →
+ │
+ │ 获取文件内容:
+ │ POST /file/download/{file_id} → { access_token }
+ │ fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
+ │ → 传 URL 或 fetch Blob 给内置组件
+ │
+ ├── type == "iframe"
+ │ └── 1. POST /file/download/{file_id} → { access_token }
+ │ 2. fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
+ │ 3. iframeSrc = viewer.iframe_url_template
+ │ .replace("{file_url}", encodeURIComponent(fileUrl))
+ │ 4.
+ │
+ └── type == "wopi"
+ └── 1. POST /file/{file_id}/wopi-session → { editor_url }
+ 2.
+ (编辑器自动通过 WOPI 协议获取文件,前端无需处理)
+```
+
+---
+
+## 管理员 API
+
+所有管理端点需要管理员身份(JWT 中 group.admin == true)。
+
+### 1. 列出所有文件应用
+
+```
+GET /api/v1/admin/file-app/list?page=1&page_size=20
+Authorization: Bearer {admin_token}
+```
+
+**响应 200**
+
+```json
+{
+ "apps": [
+ {
+ "id": "...",
+ "name": "PDF 阅读器",
+ "app_key": "pdfjs",
+ "type": "builtin",
+ "icon": "file-pdf",
+ "description": "...",
+ "is_enabled": true,
+ "is_restricted": false,
+ "iframe_url_template": null,
+ "wopi_discovery_url": null,
+ "wopi_editor_url_template": null,
+ "extensions": ["pdf"],
+ "allowed_group_ids": []
+ }
+ ],
+ "total": 9
+}
+```
+
+### 2. 创建文件应用
+
+```
+POST /api/v1/admin/file-app/
+Authorization: Bearer {admin_token}
+Content-Type: application/json
+```
+
+```json
+{
+ "name": "自定义查看器",
+ "app_key": "my_viewer",
+ "type": "iframe",
+ "description": "自定义 iframe 查看器",
+ "is_enabled": true,
+ "is_restricted": false,
+ "iframe_url_template": "https://example.com/view?url={file_url}",
+ "extensions": ["pdf", "docx"],
+ "allowed_group_ids": []
+}
+```
+
+**响应** 201 — 返回 FileAppResponse(同列表中的单项)
+
+**错误码**: 409 — app_key 已存在
+
+### 3. 获取应用详情
+
+```
+GET /api/v1/admin/file-app/{id}
+```
+
+**响应** 200 — FileAppResponse
+
+### 4. 更新应用
+
+```
+PATCH /api/v1/admin/file-app/{id}
+```
+
+只传需要更新的字段:
+
+```json
+{
+ "name": "新名称",
+ "is_enabled": false
+}
+```
+
+**响应** 200 — FileAppResponse
+
+### 5. 删除应用
+
+```
+DELETE /api/v1/admin/file-app/{id}
+```
+
+**响应** 204 No Content(级联删除扩展名关联、用户偏好、用户组关联)
+
+### 6. 全量替换扩展名列表
+
+```
+PUT /api/v1/admin/file-app/{id}/extensions
+```
+
+```json
+{
+ "extensions": ["doc", "docx", "odt"]
+}
+```
+
+**响应** 200 — FileAppResponse
+
+### 7. 全量替换允许的用户组
+
+```
+PUT /api/v1/admin/file-app/{id}/groups
+```
+
+```json
+{
+ "group_ids": ["uuid-1", "uuid-2"]
+}
+```
+
+**响应** 200 — FileAppResponse
+
+> `is_restricted` 为 `true` 时,只有 `allowed_group_ids` 中的用户组成员能看到此应用。`is_restricted` 为 `false` 时所有用户可见,`allowed_group_ids` 不生效。
+
+---
+
+## TypeScript 类型参考
+
+```typescript
+type FileAppType = 'builtin' | 'iframe' | 'wopi'
+
+interface FileAppSummary {
+ id: string
+ name: string
+ app_key: string
+ type: FileAppType
+ icon: string | null
+ description: string | null
+ iframe_url_template: string | null
+ wopi_editor_url_template: string | null
+}
+
+interface FileViewersResponse {
+ viewers: FileAppSummary[]
+ default_viewer_id: string | null
+}
+
+interface SetDefaultViewerRequest {
+ extension: string
+ app_id: string
+}
+
+interface UserFileAppDefaultResponse {
+ id: string
+ extension: string
+ app: FileAppSummary
+}
+
+interface WopiSessionResponse {
+ wopi_src: string
+ access_token: string
+ access_token_ttl: number
+ editor_url: string
+}
+
+// ========== 管理员类型 ==========
+
+interface FileAppResponse {
+ id: string
+ name: string
+ app_key: string
+ type: FileAppType
+ icon: string | null
+ description: string | null
+ is_enabled: boolean
+ is_restricted: boolean
+ iframe_url_template: string | null
+ wopi_discovery_url: string | null
+ wopi_editor_url_template: string | null
+ extensions: string[]
+ allowed_group_ids: string[]
+}
+
+interface FileAppListResponse {
+ apps: FileAppResponse[]
+ total: number
+}
+
+interface FileAppCreateRequest {
+ name: string
+ app_key: string
+ type: FileAppType
+ icon?: string
+ description?: string
+ is_enabled?: boolean // default: true
+ is_restricted?: boolean // default: false
+ iframe_url_template?: string
+ wopi_discovery_url?: string
+ wopi_editor_url_template?: string
+ extensions?: string[] // default: []
+ allowed_group_ids?: string[] // default: []
+}
+
+interface FileAppUpdateRequest {
+ name?: string
+ app_key?: string
+ type?: FileAppType
+ icon?: string
+ description?: string
+ is_enabled?: boolean
+ is_restricted?: boolean
+ iframe_url_template?: string
+ wopi_discovery_url?: string
+ wopi_editor_url_template?: string
+}
+```
diff --git a/docs/text-editor-api.md b/docs/text-editor-api.md
new file mode 100644
index 0000000..45a974b
--- /dev/null
+++ b/docs/text-editor-api.md
@@ -0,0 +1,242 @@
+# 文本文件在线编辑 — 前端适配文档
+
+## 概述
+
+Monaco Editor 打开文本文件时,通过 GET 获取内容和哈希作为编辑基线;保存时用 jsdiff 计算 unified diff,仅发送差异部分,后端验证无并发冲突后应用 patch。
+
+```
+打开文件: GET /api/v1/file/content/{file_id} → { content, hash, size }
+保存文件: PATCH /api/v1/file/content/{file_id} ← { patch, base_hash }
+ → { new_hash, new_size }
+```
+
+---
+
+## 约定
+
+| 项目 | 约定 |
+|------|------|
+| 编码 | 全程 UTF-8 |
+| 换行符 | 后端 GET 时统一规范化为 `\n`,前端无需处理 `\r\n` |
+| hash 算法 | SHA-256,hex 编码(64 字符),基于 UTF-8 bytes 计算 |
+| diff 格式 | jsdiff `createPatch()` 输出的标准 unified diff |
+| 空 diff | 前端自行判断,内容未变时不发请求 |
+
+---
+
+## GET /api/v1/file/content/{file_id}
+
+获取文本文件内容。
+
+### 请求
+
+```
+GET /api/v1/file/content/{file_id}
+Authorization: Bearer
+```
+
+### 响应 200
+
+```json
+{
+ "content": "line1\nline2\nline3\n",
+ "hash": "a1b2c3d4...(64字符 SHA-256 hex)",
+ "size": 18
+}
+```
+
+| 字段 | 类型 | 说明 |
+|------|------|------|
+| `content` | string | 文件文本内容,换行符已规范化为 `\n` |
+| `hash` | string | 基于规范化内容 UTF-8 bytes 的 SHA-256 hex |
+| `size` | number | 规范化后的字节大小 |
+
+### 错误
+
+| 状态码 | 说明 |
+|--------|------|
+| 400 | 文件不是有效的 UTF-8 文本(二进制文件) |
+| 401 | 未认证 |
+| 404 | 文件不存在 |
+
+---
+
+## PATCH /api/v1/file/content/{file_id}
+
+增量保存文本文件。
+
+### 请求
+
+```
+PATCH /api/v1/file/content/{file_id}
+Authorization: Bearer
+Content-Type: application/json
+```
+
+```json
+{
+ "patch": "--- a\n+++ b\n@@ -1,3 +1,3 @@\n line1\n-line2\n+LINE2\n line3\n",
+ "base_hash": "a1b2c3d4...(GET 返回的 hash)"
+}
+```
+
+| 字段 | 类型 | 说明 |
+|------|------|------|
+| `patch` | string | jsdiff `createPatch()` 生成的 unified diff |
+| `base_hash` | string | 编辑前 GET 返回的 `hash` 值 |
+
+### 响应 200
+
+```json
+{
+ "new_hash": "e5f6a7b8...(64字符)",
+ "new_size": 18
+}
+```
+
+保存成功后,前端应将 `new_hash` 作为新的 `base_hash`,用于下次保存。
+
+### 错误
+
+| 状态码 | 说明 | 前端处理 |
+|--------|------|----------|
+| 401 | 未认证 | — |
+| 404 | 文件不存在 | — |
+| 409 | `base_hash` 不匹配(并发冲突) | 提示用户刷新,重新加载内容 |
+| 422 | patch 格式无效或应用失败 | 回退到全量保存或提示用户 |
+
+---
+
+## 前端实现参考
+
+### 依赖
+
+```bash
+npm install jsdiff
+```
+
+### 计算 hash
+
+```typescript
+async function sha256(text: string): Promise {
+ const bytes = new TextEncoder().encode(text);
+ const hashBuffer = await crypto.subtle.digest("SHA-256", bytes);
+ const hashArray = Array.from(new Uint8Array(hashBuffer));
+ return hashArray.map(b => b.toString(16).padStart(2, "0")).join("");
+}
+```
+
+### 打开文件
+
+```typescript
+interface TextContent {
+ content: string;
+ hash: string;
+ size: number;
+}
+
+async function openFile(fileId: string): Promise {
+ const resp = await fetch(`/api/v1/file/content/${fileId}`, {
+ headers: { Authorization: `Bearer ${token}` },
+ });
+
+ if (!resp.ok) {
+ if (resp.status === 400) throw new Error("该文件不是文本文件");
+ throw new Error("获取文件内容失败");
+ }
+
+ return resp.json();
+}
+```
+
+### 保存文件
+
+```typescript
+import { createPatch } from "diff";
+
+interface PatchResult {
+ new_hash: string;
+ new_size: number;
+}
+
+async function saveFile(
+ fileId: string,
+ originalContent: string,
+ currentContent: string,
+ baseHash: string,
+): Promise {
+ // 内容未变,不发请求
+ if (originalContent === currentContent) return null;
+
+ const patch = createPatch("file", originalContent, currentContent);
+
+ const resp = await fetch(`/api/v1/file/content/${fileId}`, {
+ method: "PATCH",
+ headers: {
+ Authorization: `Bearer ${token}`,
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify({ patch, base_hash: baseHash }),
+ });
+
+ if (resp.status === 409) {
+ // 并发冲突,需要用户决策
+ throw new Error("CONFLICT");
+ }
+
+ if (!resp.ok) throw new Error("保存失败");
+
+ return resp.json();
+}
+```
+
+### 完整编辑流程
+
+```typescript
+// 1. 打开
+const file = await openFile(fileId);
+let baseContent = file.content;
+let baseHash = file.hash;
+
+// 2. 用户在 Monaco 中编辑...
+editor.setValue(baseContent);
+
+// 3. 保存(Ctrl+S)
+const currentContent = editor.getValue();
+const result = await saveFile(fileId, baseContent, currentContent, baseHash);
+
+if (result) {
+ // 更新基线
+ baseContent = currentContent;
+ baseHash = result.new_hash;
+}
+```
+
+### 冲突处理建议
+
+当 PATCH 返回 409 时,说明文件已被其他会话修改:
+
+```typescript
+try {
+ await saveFile(fileId, baseContent, currentContent, baseHash);
+} catch (e) {
+ if (e.message === "CONFLICT") {
+ // 方案 A:提示用户,提供"覆盖"和"放弃"选项
+ // 方案 B:重新 GET 最新内容,展示 diff 让用户合并
+ const latest = await openFile(fileId);
+ // 展示合并 UI...
+ }
+}
+```
+
+---
+
+## hash 一致性验证
+
+前端可以在 GET 后本地验证 hash,确保传输无误:
+
+```typescript
+const file = await openFile(fileId);
+const localHash = await sha256(file.content);
+console.assert(localHash === file.hash, "hash 不一致,内容可能损坏");
+```
diff --git a/license_public.pem b/license_public.pem
new file mode 100644
index 0000000..81e7cdf
--- /dev/null
+++ b/license_public.pem
@@ -0,0 +1,14 @@
+-----BEGIN PUBLIC KEY-----
+MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
+oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
+eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
+VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
+vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
+TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
+l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
+7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
+EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
+CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
+enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
+R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
+-----END PUBLIC KEY-----
diff --git a/main.py b/main.py
index 8a72ca5..ef9941c 100644
--- a/main.py
+++ b/main.py
@@ -1,16 +1,40 @@
+from pathlib import Path
from typing import NoReturn
from fastapi import FastAPI, Request
+from loguru import logger as l
-from utils.conf import appmeta
-from utils.http.http_exceptions import raise_internal_error
-from utils.lifespan import lifespan
+from routers import router
+from service.redis import RedisManager
from sqlmodels.database_connection import DatabaseManager
from sqlmodels.migration import migration
from utils import JWT
-from routers import router
-from service.redis import RedisManager
-from loguru import logger as l
+from utils.conf import appmeta
+from utils.http.http_exceptions import raise_internal_error
+from utils.lifespan import lifespan
+
+# 尝试加载企业版功能
+try:
+ from ee import init_ee
+ from ee.license import LicenseError
+
+ async def _init_ee_and_routes() -> None:
+ try:
+ await init_ee()
+ except LicenseError as exc:
+ l.critical(f"许可证验证失败: {exc}")
+ raise SystemExit(1) from exc
+
+ from ee.routers import ee_router
+ from routers.api.v1 import router as v1_router
+ v1_router.include_router(ee_router)
+
+ lifespan.add_startup(_init_ee_and_routes)
+except ImportError:
+ l.info("以 Community 版本运行")
+
+STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
+"""前端静态文件目录(由 Docker 构建时复制)"""
async def _init_db() -> None:
"""初始化数据库连接引擎"""
@@ -64,6 +88,31 @@ async def handle_unexpected_exceptions(
# 挂载路由
app.include_router(router)
+# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
+if STATICS_DIR.is_dir():
+ from starlette.staticfiles import StaticFiles
+ from fastapi.responses import FileResponse
+
+ _assets_dir: Path = STATICS_DIR / "assets"
+ if _assets_dir.is_dir():
+ app.mount("/assets", StaticFiles(directory=_assets_dir), name="assets")
+
+ @app.get("/{path:path}")
+ async def spa_fallback(path: str) -> FileResponse:
+ """
+ SPA fallback 路由
+
+ 优先级:API 路由 > /assets 静态挂载 > 此 catch-all 路由。
+ 若请求路径对应 statics/ 下的真实文件则直接返回,否则返回 index.html。
+ """
+ file_path: Path = (STATICS_DIR / path).resolve()
+ # 防止路径穿越
+ if file_path.is_relative_to(STATICS_DIR) and path and file_path.is_file():
+ return FileResponse(file_path)
+ return FileResponse(STATICS_DIR / "index.html")
+
+ l.info(f"前端静态文件已挂载: {STATICS_DIR}")
+
# 防止直接运行 main.py
if __name__ == "__main__":
l.error("请用 fastapi ['dev', 'run'] 命令启动服务")
diff --git a/middleware/dependencies.py b/middleware/dependencies.py
index 80b217c..ea0b9ec 100644
--- a/middleware/dependencies.py
+++ b/middleware/dependencies.py
@@ -17,7 +17,7 @@ from fastapi import Depends, Form, Query
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.database_connection import DatabaseManager
-from sqlmodels.mixin import TimeFilterRequest, TableViewRequest
+from sqlmodel_ext import TimeFilterRequest, TableViewRequest
from sqlmodels.user import UserFilterParams, UserStatus
diff --git a/pyproject.toml b/pyproject.toml
index 9689430..1a26fa4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,6 +12,7 @@ dependencies = [
"asyncpg>=0.31.0",
"cachetools>=6.2.4",
"captcha>=0.7.1",
+ "cryptography>=46.0.3",
"fastapi[standard]>=0.122.0",
"httpx>=0.27.0",
"itsdangerous>=2.2.0",
@@ -28,8 +29,16 @@ dependencies = [
"redis[hiredis]>=7.1.0",
"sqlalchemy>=2.0.44",
"sqlmodel>=0.0.27",
+ "sqlmodel-ext[pgvector]>=0.1.1",
"uvicorn>=0.38.0",
"webauthn>=2.7.0",
+ "whatthepatch>=1.0.6",
+]
+
+[project.optional-dependencies]
+build = [
+ "cython>=3.0.11",
+ "setuptools>=75.0.0",
]
[tool.pytest.ini_options]
diff --git a/routers/__init__.py b/routers/__init__.py
index 0cdeb5e..8095487 100644
--- a/routers/__init__.py
+++ b/routers/__init__.py
@@ -1,6 +1,8 @@
from fastapi import APIRouter
from .api import router as api_router
+from .wopi import wopi_router
router = APIRouter()
-router.include_router(api_router)
\ No newline at end of file
+router.include_router(api_router)
+router.include_router(wopi_router)
\ No newline at end of file
diff --git a/routers/api/v1/admin/__init__.py b/routers/api/v1/admin/__init__.py
index 9d4e946..d0043e4 100644
--- a/routers/api/v1/admin/__init__.py
+++ b/routers/api/v1/admin/__init__.py
@@ -9,7 +9,7 @@ from sqlmodels import (
User, ResponseBase,
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
)
-from sqlmodels.base import SQLModelBase
+from sqlmodel_ext import SQLModelBase
from sqlmodels.setting import (
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
)
@@ -17,6 +17,7 @@ from sqlmodels.setting import SettingsType
from utils import http_exceptions
from utils.conf import appmeta
from .file import admin_file_router
+from .file_app import admin_file_app_router
from .group import admin_group_router
from .policy import admin_policy_router
from .share import admin_share_router
@@ -44,6 +45,7 @@ admin_router = APIRouter(
admin_router.include_router(admin_group_router)
admin_router.include_router(admin_user_router)
admin_router.include_router(admin_file_router)
+admin_router.include_router(admin_file_app_router)
admin_router.include_router(admin_policy_router)
admin_router.include_router(admin_share_router)
admin_router.include_router(admin_task_router)
diff --git a/routers/api/v1/admin/file_app/__init__.py b/routers/api/v1/admin/file_app/__init__.py
new file mode 100644
index 0000000..d319b47
--- /dev/null
+++ b/routers/api/v1/admin/file_app/__init__.py
@@ -0,0 +1,348 @@
+"""
+管理员文件应用管理端点
+
+提供文件查看器应用的 CRUD、扩展名管理和用户组权限管理。
+"""
+from uuid import UUID
+
+from fastapi import APIRouter, Depends, status
+from loguru import logger as l
+from sqlalchemy import select
+
+from middleware.auth import admin_required
+from middleware.dependencies import SessionDep, TableViewRequestDep
+from sqlmodels import (
+ FileApp,
+ FileAppCreateRequest,
+ FileAppExtension,
+ FileAppGroupLink,
+ FileAppListResponse,
+ FileAppResponse,
+ FileAppUpdateRequest,
+ ExtensionUpdateRequest,
+ GroupAccessUpdateRequest,
+)
+from utils import http_exceptions
+
+admin_file_app_router = APIRouter(
+ prefix="/file-app",
+ tags=["admin", "file_app"],
+ dependencies=[Depends(admin_required)],
+)
+
+
+@admin_file_app_router.get(
+ path='/list',
+ summary='列出所有文件应用',
+)
+async def list_file_apps(
+ session: SessionDep,
+ table_view: TableViewRequestDep,
+) -> FileAppListResponse:
+ """
+ 列出所有文件应用端点(分页)
+
+ 认证:管理员权限
+ """
+ result = await FileApp.get_with_count(
+ session,
+ table_view=table_view,
+ )
+
+ apps: list[FileAppResponse] = []
+ for app in result.items:
+ extensions = await FileAppExtension.get(
+ session,
+ FileAppExtension.app_id == app.id,
+ fetch_mode="all",
+ )
+ group_links_result = await session.exec(
+ select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
+ )
+ group_links: list[FileAppGroupLink] = list(group_links_result.all())
+ apps.append(FileAppResponse.from_app(app, extensions, group_links))
+
+ return FileAppListResponse(apps=apps, total=result.count)
+
+
+@admin_file_app_router.post(
+ path='/',
+ summary='创建文件应用',
+ status_code=status.HTTP_201_CREATED,
+)
+async def create_file_app(
+ session: SessionDep,
+ request: FileAppCreateRequest,
+) -> FileAppResponse:
+ """
+ 创建文件应用端点
+
+ 认证:管理员权限
+
+ 错误处理:
+ - 409: app_key 已存在
+ """
+ # 检查 app_key 唯一
+ existing = await FileApp.get(session, FileApp.app_key == request.app_key)
+ if existing:
+ http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
+
+ # 创建应用
+ app = FileApp(
+ name=request.name,
+ app_key=request.app_key,
+ type=request.type,
+ icon=request.icon,
+ description=request.description,
+ is_enabled=request.is_enabled,
+ is_restricted=request.is_restricted,
+ iframe_url_template=request.iframe_url_template,
+ wopi_discovery_url=request.wopi_discovery_url,
+ wopi_editor_url_template=request.wopi_editor_url_template,
+ )
+ app = await app.save(session)
+ app_id = app.id
+
+ # 创建扩展名关联
+ extensions: list[FileAppExtension] = []
+ for i, ext in enumerate(request.extensions):
+ normalized = ext.lower().strip().lstrip('.')
+ ext_record = FileAppExtension(
+ app_id=app_id,
+ extension=normalized,
+ priority=i,
+ )
+ ext_record = await ext_record.save(session)
+ extensions.append(ext_record)
+
+ # 创建用户组关联
+ group_links: list[FileAppGroupLink] = []
+ for group_id in request.allowed_group_ids:
+ link = FileAppGroupLink(app_id=app_id, group_id=group_id)
+ session.add(link)
+ group_links.append(link)
+ if group_links:
+ await session.commit()
+
+ l.info(f"创建文件应用: {app.name} ({app.app_key})")
+
+ return FileAppResponse.from_app(app, extensions, group_links)
+
+
+@admin_file_app_router.get(
+ path='/{app_id}',
+ summary='获取文件应用详情',
+)
+async def get_file_app(
+ session: SessionDep,
+ app_id: UUID,
+) -> FileAppResponse:
+ """
+ 获取文件应用详情端点
+
+ 认证:管理员权限
+
+ 错误处理:
+ - 404: 应用不存在
+ """
+ app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
+ if not app:
+ http_exceptions.raise_not_found("应用不存在")
+
+ extensions = await FileAppExtension.get(
+ session,
+ FileAppExtension.app_id == app.id,
+ fetch_mode="all",
+ )
+ group_links_result = await session.exec(
+ select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
+ )
+ group_links: list[FileAppGroupLink] = list(group_links_result.all())
+
+ return FileAppResponse.from_app(app, extensions, group_links)
+
+
+@admin_file_app_router.patch(
+ path='/{app_id}',
+ summary='更新文件应用',
+)
+async def update_file_app(
+ session: SessionDep,
+ app_id: UUID,
+ request: FileAppUpdateRequest,
+) -> FileAppResponse:
+ """
+ 更新文件应用端点
+
+ 认证:管理员权限
+
+ 错误处理:
+ - 404: 应用不存在
+ - 409: 新 app_key 已被其他应用使用
+ """
+ app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
+ if not app:
+ http_exceptions.raise_not_found("应用不存在")
+
+ # 检查 app_key 唯一性
+ if request.app_key is not None and request.app_key != app.app_key:
+ existing = await FileApp.get(session, FileApp.app_key == request.app_key)
+ if existing:
+ http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
+
+ # 更新非 None 字段
+ update_data = request.model_dump(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(app, key, value)
+
+ app = await app.save(session)
+
+ extensions = await FileAppExtension.get(
+ session,
+ FileAppExtension.app_id == app.id,
+ fetch_mode="all",
+ )
+ group_links_result = await session.exec(
+ select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
+ )
+ group_links: list[FileAppGroupLink] = list(group_links_result.all())
+
+ l.info(f"更新文件应用: {app.name} ({app.app_key})")
+
+ return FileAppResponse.from_app(app, extensions, group_links)
+
+
+@admin_file_app_router.delete(
+ path='/{app_id}',
+ summary='删除文件应用',
+ status_code=status.HTTP_204_NO_CONTENT,
+)
+async def delete_file_app(
+ session: SessionDep,
+ app_id: UUID,
+) -> None:
+ """
+ 删除文件应用端点(级联删除扩展名、用户偏好和用户组关联)
+
+ 认证:管理员权限
+
+ 错误处理:
+ - 404: 应用不存在
+ """
+ app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
+ if not app:
+ http_exceptions.raise_not_found("应用不存在")
+
+ app_name = app.app_key
+ await FileApp.delete(session, app)
+ l.info(f"删除文件应用: {app_name}")
+
+
+@admin_file_app_router.put(
+ path='/{app_id}/extensions',
+ summary='全量替换扩展名列表',
+)
+async def update_extensions(
+ session: SessionDep,
+ app_id: UUID,
+ request: ExtensionUpdateRequest,
+) -> FileAppResponse:
+ """
+ 全量替换扩展名列表端点
+
+ 先删除旧的扩展名关联,再创建新的。
+
+ 认证:管理员权限
+
+ 错误处理:
+ - 404: 应用不存在
+ """
+ app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
+ if not app:
+ http_exceptions.raise_not_found("应用不存在")
+
+ # 删除旧的扩展名
+ old_extensions: list[FileAppExtension] = await FileAppExtension.get(
+ session,
+ FileAppExtension.app_id == app_id,
+ fetch_mode="all",
+ )
+ for old_ext in old_extensions:
+ await FileAppExtension.delete(session, old_ext, commit=False)
+
+ # 创建新的扩展名
+ new_extensions: list[FileAppExtension] = []
+ for i, ext in enumerate(request.extensions):
+ normalized = ext.lower().strip().lstrip('.')
+ ext_record = FileAppExtension(
+ app_id=app_id,
+ extension=normalized,
+ priority=i,
+ )
+ session.add(ext_record)
+ new_extensions.append(ext_record)
+
+ await session.commit()
+ # refresh 新创建的记录
+ for ext_record in new_extensions:
+ await session.refresh(ext_record)
+
+ group_links_result = await session.exec(
+ select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
+ )
+ group_links: list[FileAppGroupLink] = list(group_links_result.all())
+
+ l.info(f"更新文件应用 {app.app_key} 的扩展名: {request.extensions}")
+
+ return FileAppResponse.from_app(app, new_extensions, group_links)
+
+
+@admin_file_app_router.put(
+ path='/{app_id}/groups',
+ summary='全量替换允许的用户组',
+)
+async def update_group_access(
+ session: SessionDep,
+ app_id: UUID,
+ request: GroupAccessUpdateRequest,
+) -> FileAppResponse:
+ """
+ 全量替换允许的用户组端点
+
+ 先删除旧的关联,再创建新的。
+
+ 认证:管理员权限
+
+ 错误处理:
+ - 404: 应用不存在
+ """
+ app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
+ if not app:
+ http_exceptions.raise_not_found("应用不存在")
+
+ # 删除旧的用户组关联
+ old_links_result = await session.exec(
+ select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
+ )
+ old_links: list[FileAppGroupLink] = list(old_links_result.all())
+ for old_link in old_links:
+ await session.delete(old_link)
+
+ # 创建新的用户组关联
+ new_links: list[FileAppGroupLink] = []
+ for group_id in request.group_ids:
+ link = FileAppGroupLink(app_id=app_id, group_id=group_id)
+ session.add(link)
+ new_links.append(link)
+
+ await session.commit()
+
+ extensions = await FileAppExtension.get(
+ session,
+ FileAppExtension.app_id == app_id,
+ fetch_mode="all",
+ )
+
+ l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
+
+ return FileAppResponse.from_app(app, extensions, new_links)
diff --git a/routers/api/v1/admin/policy/__init__.py b/routers/api/v1/admin/policy/__init__.py
index f95678d..81772f1 100644
--- a/routers/api/v1/admin/policy/__init__.py
+++ b/routers/api/v1/admin/policy/__init__.py
@@ -9,7 +9,7 @@ from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
ListResponse, Object, )
-from sqlmodels.base import SQLModelBase
+from sqlmodel_ext import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService
admin_policy_router = APIRouter(
diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py
index 23077c7..e675d85 100644
--- a/routers/api/v1/file/__init__.py
+++ b/routers/api/v1/file/__init__.py
@@ -8,47 +8,92 @@
- /file/upload - 上传相关操作
- /file/download - 下载相关操作
"""
+import hashlib
from datetime import datetime, timedelta
from typing import Annotated
from uuid import UUID
+import whatthepatch
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse
from loguru import logger as l
+from sqlmodel_ext import SQLModelBase
+from whatthepatch.exceptions import HunkApplyException
from middleware.auth import auth_required, verify_download_token
from middleware.dependencies import SessionDep
from sqlmodels import (
CreateFileRequest,
CreateUploadSessionRequest,
+ FileApp,
+ FileAppExtension,
+ FileAppGroupLink,
+ FileAppType,
Object,
ObjectType,
PhysicalFile,
Policy,
PolicyType,
ResponseBase,
+ Setting,
+ SettingsType,
UploadChunkResponse,
UploadSession,
UploadSessionResponse,
User,
+ WopiSessionResponse,
)
from service.storage import LocalStorageService, adjust_user_storage
-from service.redis.token_store import TokenStore
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
+from utils.JWT.wopi_token import create_wopi_token
from utils import http_exceptions
+from .viewers import viewers_router
# DTO
class DownloadTokenModel(ResponseBase):
"""下载Token响应模型"""
-
+
access_token: str
"""JWT 令牌"""
-
+
expires_in: int
"""过期时间(秒)"""
+
+class TextContentResponse(ResponseBase):
+ """文本文件内容响应"""
+
+ content: str
+ """文件文本内容(UTF-8)"""
+
+ hash: str
+ """SHA-256 hex"""
+
+ size: int
+ """文件字节大小"""
+
+
+class PatchContentRequest(SQLModelBase):
+ """增量保存请求"""
+
+ patch: str
+ """unified diff 文本"""
+
+ base_hash: str
+ """原始内容的 SHA-256 hex(64字符)"""
+
+
+class PatchContentResponse(ResponseBase):
+ """增量保存响应"""
+
+ new_hash: str
+ """新内容的 SHA-256 hex"""
+
+ new_size: int
+ """新文件字节大小"""
+
# ==================== 主路由 ====================
router = APIRouter(prefix="/file", tags=["file"])
@@ -410,7 +455,7 @@ async def create_download_token_endpoint(
@_download_router.get(
path='/{token}',
summary='下载文件',
- description='使用下载令牌下载文件(一次性令牌,仅可使用一次)。',
+ description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
)
async def download_file(
session: SessionDep,
@@ -420,19 +465,14 @@ async def download_file(
下载文件端点
验证 JWT 令牌后返回文件内容。
- 令牌为一次性使用,下载后即失效。
+ 令牌在有效期内可重复使用(支持浏览器 range 请求等场景)。
"""
# 验证令牌
result = verify_download_token(token)
if not result:
raise HTTPException(status_code=401, detail="下载令牌无效或已过期")
- jti, file_id, owner_id = result
-
- # 检查并标记令牌已使用(原子操作)
- is_first_use = await TokenStore.mark_used(jti, DOWNLOAD_TOKEN_TTL)
- if not is_first_use:
- raise HTTPException(status_code=404)
+ _, file_id, owner_id = result
# 获取文件对象(排除已删除的)
file_obj = await Object.get(
@@ -478,6 +518,7 @@ async def download_file(
router.include_router(_upload_router)
router.include_router(_download_router)
+router.include_router(viewers_router)
# ==================== 创建空白文件 ====================
@@ -574,6 +615,113 @@ async def create_empty_file(
})
+# ==================== WOPI 会话 ====================
+
+@router.post(
+ path='/{file_id}/wopi-session',
+ summary='创建 WOPI 会话',
+ description='为 WOPI 类型的查看器创建编辑会话,返回编辑器 URL 和访问令牌。',
+)
+async def create_wopi_session(
+ session: SessionDep,
+ user: Annotated[User, Depends(auth_required)],
+ file_id: UUID,
+) -> WopiSessionResponse:
+ """
+ 创建 WOPI 会话端点
+
+ 流程:
+ 1. 验证文件存在且属于当前用户
+ 2. 查找文件扩展名对应的 WOPI 类型应用
+ 3. 检查用户组权限
+ 4. 生成 WOPI access token
+ 5. 构建 editor URL
+
+ 认证:JWT token 必填
+
+ 错误处理:
+ - 404: 文件不存在 / 无可用 WOPI 应用
+ - 403: 用户组无权限
+ """
+ # 验证文件
+ file_obj: Object | None = await Object.get(
+ session,
+ Object.id == file_id,
+ )
+ if not file_obj or file_obj.owner_id != user.id:
+ http_exceptions.raise_not_found("文件不存在")
+
+ if not file_obj.is_file:
+ http_exceptions.raise_bad_request("对象不是文件")
+
+ # 获取文件扩展名
+ name_parts = file_obj.name.rsplit('.', 1)
+ if len(name_parts) < 2:
+ http_exceptions.raise_bad_request("文件无扩展名,无法使用 WOPI 查看器")
+ ext = name_parts[1].lower()
+
+ # 查找 WOPI 类型的应用
+ from sqlalchemy import and_, select
+ ext_records: list[FileAppExtension] = await FileAppExtension.get(
+ session,
+ FileAppExtension.extension == ext,
+ fetch_mode="all",
+ load=FileAppExtension.app,
+ )
+
+ wopi_app: FileApp | None = None
+ for ext_record in ext_records:
+ app = ext_record.app
+ if app.type == FileAppType.WOPI and app.is_enabled:
+ # 检查用户组权限(FileAppGroupLink 是纯关联表,使用 session 查询)
+ if app.is_restricted:
+ stmt = select(FileAppGroupLink).where(
+ and_(
+ FileAppGroupLink.app_id == app.id,
+ FileAppGroupLink.group_id == user.group_id,
+ )
+ )
+ result = await session.exec(stmt)
+ if not result.first():
+ continue
+ wopi_app = app
+ break
+
+ if not wopi_app:
+ http_exceptions.raise_not_found("无可用的 WOPI 查看器")
+
+ if not wopi_app.wopi_editor_url_template:
+ http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板")
+
+ # 获取站点 URL
+ site_url_setting: Setting | None = await Setting.get(
+ session,
+ (Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
+ )
+ site_url = site_url_setting.value if site_url_setting else "http://localhost"
+
+ # 生成 WOPI token
+ can_write = file_obj.owner_id == user.id
+ token, access_token_ttl = create_wopi_token(file_id, user.id, can_write)
+
+ # 构建 wopi_src
+ wopi_src = f"{site_url}/wopi/files/{file_id}"
+
+ # 构建 editor URL
+ editor_url = wopi_app.wopi_editor_url_template.format(
+ wopi_src=wopi_src,
+ access_token=token,
+ access_token_ttl=access_token_ttl,
+ )
+
+ return WopiSessionResponse(
+ wopi_src=wopi_src,
+ access_token=token,
+ access_token_ttl=access_token_ttl,
+ editor_url=editor_url,
+ )
+
+
# ==================== 文件外链(保留原有端点结构) ====================
@router.get(
@@ -612,36 +760,171 @@ async def file_update(id: str) -> ResponseBase:
@router.get(
- path='/preview/{id}',
- summary='预览文件',
- description='获取文件预览。',
- dependencies=[Depends(auth_required)]
-)
-async def file_preview(id: str) -> ResponseBase:
- """预览文件"""
- raise HTTPException(status_code=501, detail="预览功能暂未实现")
-
-
-@router.get(
- path='/content/{id}',
+ path='/content/{file_id}',
summary='获取文本文件内容',
- description='获取文本文件内容。',
- dependencies=[Depends(auth_required)]
+ description='获取文本文件的 UTF-8 内容和 SHA-256 哈希值。',
)
-async def file_content(id: str) -> ResponseBase:
- """获取文本文件内容"""
- raise HTTPException(status_code=501, detail="文本内容功能暂未实现")
+async def file_content(
+ session: SessionDep,
+ user: Annotated[User, Depends(auth_required)],
+ file_id: UUID,
+) -> TextContentResponse:
+ """
+ 获取文本文件内容端点
+
+ 返回文件的 UTF-8 文本内容和基于规范化内容的 SHA-256 哈希值。
+ 换行符统一规范化为 ``\\n``。
+
+ 认证:JWT token 必填
+
+ 错误处理:
+ - 400: 文件不是有效的 UTF-8 文本
+ - 404: 文件不存在
+ """
+ file_obj = await Object.get(
+ session,
+ (Object.id == file_id) & (Object.deleted_at == None)
+ )
+ if not file_obj or file_obj.owner_id != user.id:
+ http_exceptions.raise_not_found("文件不存在")
+
+ if not file_obj.is_file:
+ http_exceptions.raise_bad_request("对象不是文件")
+
+ physical_file = await file_obj.awaitable_attrs.physical_file
+ if not physical_file or not physical_file.storage_path:
+ http_exceptions.raise_internal_error("文件存储路径丢失")
+
+ policy = await Policy.get(session, Policy.id == file_obj.policy_id)
+ if not policy:
+ http_exceptions.raise_internal_error("存储策略不存在")
+
+ if policy.type != PolicyType.LOCAL:
+ http_exceptions.raise_not_implemented("S3 存储暂未实现")
+
+ storage_service = LocalStorageService(policy)
+ raw_bytes = await storage_service.read_file(physical_file.storage_path)
+
+ try:
+ content = raw_bytes.decode('utf-8')
+ except UnicodeDecodeError:
+ http_exceptions.raise_bad_request("文件不是有效的 UTF-8 文本")
+
+ # 换行符规范化
+ content = content.replace('\r\n', '\n').replace('\r', '\n')
+ normalized_bytes = content.encode('utf-8')
+ hash_hex = hashlib.sha256(normalized_bytes).hexdigest()
+
+ return TextContentResponse(
+ content=content,
+ hash=hash_hex,
+ size=len(normalized_bytes),
+ )
-@router.get(
- path='/doc/{id}',
- summary='获取Office文档预览地址',
- description='获取Office文档在线预览地址。',
- dependencies=[Depends(auth_required)]
+@router.patch(
+ path='/content/{file_id}',
+ summary='增量保存文本文件',
+ description='使用 unified diff 增量更新文本文件内容。',
)
-async def file_doc(id: str) -> ResponseBase:
- """获取Office文档预览地址"""
- raise HTTPException(status_code=501, detail="Office预览功能暂未实现")
+async def patch_file_content(
+ session: SessionDep,
+ user: Annotated[User, Depends(auth_required)],
+ file_id: UUID,
+ request: PatchContentRequest,
+) -> PatchContentResponse:
+ """
+ 增量保存文本文件端点
+
+ 接收 unified diff 和 base_hash,验证无并发冲突后应用 patch。
+
+ 认证:JWT token 必填
+
+ 错误处理:
+ - 400: 文件不是有效的 UTF-8 文本
+ - 404: 文件不存在
+ - 409: base_hash 不匹配(并发冲突)
+ - 422: 无效的 patch 格式或 patch 应用失败
+ """
+ file_obj = await Object.get(
+ session,
+ (Object.id == file_id) & (Object.deleted_at == None)
+ )
+ if not file_obj or file_obj.owner_id != user.id:
+ http_exceptions.raise_not_found("文件不存在")
+
+ if not file_obj.is_file:
+ http_exceptions.raise_bad_request("对象不是文件")
+
+ if file_obj.is_banned:
+ http_exceptions.raise_banned()
+
+ physical_file = await file_obj.awaitable_attrs.physical_file
+ if not physical_file or not physical_file.storage_path:
+ http_exceptions.raise_internal_error("文件存储路径丢失")
+
+ storage_path = physical_file.storage_path
+
+ policy = await Policy.get(session, Policy.id == file_obj.policy_id)
+ if not policy:
+ http_exceptions.raise_internal_error("存储策略不存在")
+
+ if policy.type != PolicyType.LOCAL:
+ http_exceptions.raise_not_implemented("S3 存储暂未实现")
+
+ storage_service = LocalStorageService(policy)
+ raw_bytes = await storage_service.read_file(storage_path)
+
+ # 解码 + 规范化
+ original_text = raw_bytes.decode('utf-8')
+ original_text = original_text.replace('\r\n', '\n').replace('\r', '\n')
+ normalized_bytes = original_text.encode('utf-8')
+
+ # 冲突检测(hash 基于规范化后的内容,与 GET 端点一致)
+ current_hash = hashlib.sha256(normalized_bytes).hexdigest()
+ if current_hash != request.base_hash:
+ http_exceptions.raise_conflict("文件内容已被修改,请刷新后重试")
+
+ # 解析并应用 patch
+ diffs = list(whatthepatch.parse_patch(request.patch))
+ if not diffs:
+ http_exceptions.raise_unprocessable_entity("无效的 patch 格式")
+
+ try:
+ result = whatthepatch.apply_diff(diffs[0], original_text)
+ except HunkApplyException:
+ http_exceptions.raise_unprocessable_entity("Patch 应用失败,差异内容与当前文件不匹配")
+
+ new_text = '\n'.join(result)
+
+ # 保持尾部换行符一致
+ if original_text.endswith('\n') and not new_text.endswith('\n'):
+ new_text += '\n'
+
+ new_bytes = new_text.encode('utf-8')
+
+ # 写入文件
+ await storage_service.write_file(storage_path, new_bytes)
+
+ # 更新数据库
+ owner_id = file_obj.owner_id
+ old_size = file_obj.size
+ new_size = len(new_bytes)
+ size_diff = new_size - old_size
+
+ file_obj.size = new_size
+ file_obj = await file_obj.save(session, commit=False)
+ physical_file.size = new_size
+ physical_file = await physical_file.save(session, commit=False)
+ if size_diff != 0:
+ await adjust_user_storage(session, owner_id, size_diff, commit=False)
+ await session.commit()
+
+ new_hash = hashlib.sha256(new_bytes).hexdigest()
+
+ l.info(f"文本文件增量保存: file_id={file_id}, size={old_size}->{new_size}")
+
+ return PatchContentResponse(new_hash=new_hash, new_size=new_size)
@router.get(
diff --git a/routers/api/v1/file/viewers/__init__.py b/routers/api/v1/file/viewers/__init__.py
new file mode 100644
index 0000000..194e666
--- /dev/null
+++ b/routers/api/v1/file/viewers/__init__.py
@@ -0,0 +1,106 @@
+"""
+文件查看器查询端点
+
+提供按文件扩展名查询可用查看器的功能,包含用户组访问控制过滤。
+"""
+from typing import Annotated
+from uuid import UUID
+
+from fastapi import APIRouter, Depends, Query
+from sqlalchemy import and_
+
+from sqlalchemy import select
+
+from middleware.auth import auth_required
+from middleware.dependencies import SessionDep
+from sqlmodels import (
+ FileApp,
+ FileAppExtension,
+ FileAppGroupLink,
+ FileAppSummary,
+ FileViewersResponse,
+ User,
+ UserFileAppDefault,
+)
+
+viewers_router = APIRouter(prefix="/viewers", tags=["file", "viewers"])
+
+
+@viewers_router.get(
+ path='',
+ summary='查询可用文件查看器',
+ description='根据文件扩展名查询可用的查看器应用列表。',
+)
+async def get_viewers(
+ session: SessionDep,
+ user: Annotated[User, Depends(auth_required)],
+ ext: Annotated[str, Query(max_length=20, description="文件扩展名")],
+) -> FileViewersResponse:
+ """
+ 查询可用文件查看器端点
+
+ 流程:
+ 1. 规范化扩展名(小写,去点号)
+ 2. 查询匹配的已启用应用
+ 3. 按用户组权限过滤
+ 4. 按 priority 排序
+ 5. 查询用户默认偏好
+
+ 认证:JWT token 必填
+
+ 错误处理:
+ - 401: 未授权
+ """
+ # 规范化扩展名
+ normalized_ext = ext.lower().strip().lstrip('.')
+
+ # 查询匹配扩展名的应用(已启用的)
+ ext_records: list[FileAppExtension] = await FileAppExtension.get(
+ session,
+ and_(
+ FileAppExtension.extension == normalized_ext,
+ ),
+ fetch_mode="all",
+ load=FileAppExtension.app,
+ )
+
+ # 过滤和收集可用应用
+ user_group_id = user.group_id
+ viewers: list[tuple[FileAppSummary, int]] = []
+
+ for ext_record in ext_records:
+ app: FileApp = ext_record.app
+ if not app.is_enabled:
+ continue
+
+ if app.is_restricted:
+ # 检查用户组权限(FileAppGroupLink 是纯关联表,使用 session 查询)
+ stmt = select(FileAppGroupLink).where(
+ and_(
+ FileAppGroupLink.app_id == app.id,
+ FileAppGroupLink.group_id == user_group_id,
+ )
+ )
+ result = await session.exec(stmt)
+ group_link = result.first()
+ if not group_link:
+ continue
+
+ viewers.append((app.to_summary(), ext_record.priority))
+
+ # 按 priority 排序
+ viewers.sort(key=lambda x: x[1])
+
+ # 查询用户默认偏好
+ user_default: UserFileAppDefault | None = await UserFileAppDefault.get(
+ session,
+ and_(
+ UserFileAppDefault.user_id == user.id,
+ UserFileAppDefault.extension == normalized_ext,
+ ),
+ )
+
+ return FileViewersResponse(
+ viewers=[v[0] for v in viewers],
+ default_viewer_id=user_default.app_id if user_default else None,
+ )
diff --git a/routers/api/v1/share/__init__.py b/routers/api/v1/share/__init__.py
index 02d0190..f6bf5e5 100644
--- a/routers/api/v1/share/__init__.py
+++ b/routers/api/v1/share/__init__.py
@@ -14,7 +14,7 @@ from sqlmodels.share import (
ShareDetailResponse, ShareOwnerInfo, ShareObjectItem,
)
from sqlmodels.object import Object, ObjectType
-from sqlmodels.mixin import ListResponse, TableViewRequest
+from sqlmodel_ext import ListResponse, TableViewRequest
from utils import http_exceptions
from utils.password.pwd import Password, PasswordStatus
diff --git a/routers/api/v1/user/settings/__init__.py b/routers/api/v1/user/settings/__init__.py
index 1883463..e4ea7da 100644
--- a/routers/api/v1/user/settings/__init__.py
+++ b/routers/api/v1/user/settings/__init__.py
@@ -17,12 +17,14 @@ from sqlmodels.color import ThemeColorsBase
from sqlmodels.user_authn import UserAuthn
from utils import JWT, Password, http_exceptions
from utils.password.pwd import PasswordStatus, TwoFactorResponse, TwoFactorVerifyRequest
+from .file_viewers import file_viewers_router
user_settings_router = APIRouter(
prefix='/settings',
tags=["user", "user_settings"],
dependencies=[Depends(auth_required)],
)
+user_settings_router.include_router(file_viewers_router)
@user_settings_router.get(
diff --git a/routers/api/v1/user/settings/file_viewers/__init__.py b/routers/api/v1/user/settings/file_viewers/__init__.py
new file mode 100644
index 0000000..031fa2f
--- /dev/null
+++ b/routers/api/v1/user/settings/file_viewers/__init__.py
@@ -0,0 +1,150 @@
+"""
+用户文件查看器偏好设置端点
+
+提供用户"始终使用"默认查看器的增删查功能。
+"""
+from typing import Annotated
+from uuid import UUID
+
+from fastapi import APIRouter, Depends, status
+from sqlalchemy import and_
+
+from middleware.auth import auth_required
+from middleware.dependencies import SessionDep
+from sqlmodels import (
+ FileApp,
+ FileAppExtension,
+ SetDefaultViewerRequest,
+ User,
+ UserFileAppDefault,
+ UserFileAppDefaultResponse,
+)
+from utils import http_exceptions
+
+file_viewers_router = APIRouter(
+ prefix='/file-viewers',
+ tags=["user", "user_settings", "file_viewers"],
+ dependencies=[Depends(auth_required)],
+)
+
+
+@file_viewers_router.put(
+ path='/default',
+ summary='设置默认查看器',
+ description='为指定扩展名设置"始终使用"的查看器。',
+)
+async def set_default_viewer(
+ session: SessionDep,
+ user: Annotated[User, Depends(auth_required)],
+ request: SetDefaultViewerRequest,
+) -> UserFileAppDefaultResponse:
+ """
+ 设置默认查看器端点
+
+ 如果用户已有该扩展名的默认设置,则更新;否则创建新记录。
+
+ 认证:JWT token 必填
+
+ 错误处理:
+ - 404: 应用不存在
+ - 400: 应用不支持该扩展名
+ """
+ # 规范化扩展名
+ normalized_ext = request.extension.lower().strip().lstrip('.')
+
+ # 验证应用存在
+ app: FileApp | None = await FileApp.get(session, FileApp.id == request.app_id)
+ if not app:
+ http_exceptions.raise_not_found("应用不存在")
+
+ # 验证应用支持该扩展名
+ ext_record: FileAppExtension | None = await FileAppExtension.get(
+ session,
+ and_(
+ FileAppExtension.app_id == app.id,
+ FileAppExtension.extension == normalized_ext,
+ ),
+ )
+ if not ext_record:
+ http_exceptions.raise_bad_request("该应用不支持此扩展名")
+
+ # 查找已有记录
+ existing: UserFileAppDefault | None = await UserFileAppDefault.get(
+ session,
+ and_(
+ UserFileAppDefault.user_id == user.id,
+ UserFileAppDefault.extension == normalized_ext,
+ ),
+ )
+
+ if existing:
+ existing.app_id = request.app_id
+ existing = await existing.save(session)
+ # 重新加载 app 关系
+ await session.refresh(existing, attribute_names=["app"])
+ return existing.to_response()
+ else:
+ new_default = UserFileAppDefault(
+ user_id=user.id,
+ extension=normalized_ext,
+ app_id=request.app_id,
+ )
+ new_default = await new_default.save(session)
+ # 重新加载 app 关系
+ await session.refresh(new_default, attribute_names=["app"])
+ return new_default.to_response()
+
+
+@file_viewers_router.get(
+ path='/defaults',
+ summary='列出所有默认查看器设置',
+ description='获取当前用户所有"始终使用"的查看器偏好。',
+)
+async def list_default_viewers(
+ session: SessionDep,
+ user: Annotated[User, Depends(auth_required)],
+) -> list[UserFileAppDefaultResponse]:
+ """
+ 列出所有默认查看器设置端点
+
+ 认证:JWT token 必填
+ """
+ defaults: list[UserFileAppDefault] = await UserFileAppDefault.get(
+ session,
+ UserFileAppDefault.user_id == user.id,
+ fetch_mode="all",
+ load=UserFileAppDefault.app,
+ )
+ return [d.to_response() for d in defaults]
+
+
+@file_viewers_router.delete(
+ path='/default/{default_id}',
+ summary='撤销默认查看器设置',
+ description='删除指定的"始终使用"偏好。',
+ status_code=status.HTTP_204_NO_CONTENT,
+)
+async def delete_default_viewer(
+ session: SessionDep,
+ user: Annotated[User, Depends(auth_required)],
+ default_id: UUID,
+) -> None:
+ """
+ 撤销默认查看器设置端点
+
+ 认证:JWT token 必填
+
+ 错误处理:
+ - 404: 记录不存在或不属于当前用户
+ """
+ existing: UserFileAppDefault | None = await UserFileAppDefault.get(
+ session,
+ and_(
+ UserFileAppDefault.id == default_id,
+ UserFileAppDefault.user_id == user.id,
+ ),
+ )
+ if not existing:
+ http_exceptions.raise_not_found("默认设置不存在")
+
+ await UserFileAppDefault.delete(session, existing)
diff --git a/routers/wopi/__init__.py b/routers/wopi/__init__.py
new file mode 100644
index 0000000..53bc2c1
--- /dev/null
+++ b/routers/wopi/__init__.py
@@ -0,0 +1,11 @@
+"""
+WOPI(Web Application Open Platform Interface)路由
+
+挂载在根级别 /wopi(非 /api/v1 下),因为 WOPI 客户端要求标准路径。
+"""
+from fastapi import APIRouter
+
+from .files import wopi_files_router
+
+wopi_router = APIRouter(prefix="/wopi", tags=["wopi"])
+wopi_router.include_router(wopi_files_router)
diff --git a/routers/wopi/files/__init__.py b/routers/wopi/files/__init__.py
new file mode 100644
index 0000000..b3720ad
--- /dev/null
+++ b/routers/wopi/files/__init__.py
@@ -0,0 +1,203 @@
+"""
+WOPI 文件操作端点
+
+实现 WOPI 协议的核心文件操作接口:
+- CheckFileInfo: 获取文件元数据
+- GetFile: 下载文件内容
+- PutFile: 上传/更新文件内容
+"""
+from uuid import UUID
+
+from fastapi import APIRouter, Query, Request, Response
+from fastapi.responses import JSONResponse
+from loguru import logger as l
+
+from middleware.dependencies import SessionDep
+from sqlmodels import Object, PhysicalFile, Policy, PolicyType, User, WopiFileInfo
+from service.storage import LocalStorageService
+from utils import http_exceptions
+from utils.JWT.wopi_token import verify_wopi_token
+
+wopi_files_router = APIRouter(prefix="/files", tags=["wopi"])
+
+
+@wopi_files_router.get(
+ path='/{file_id}',
+ summary='WOPI CheckFileInfo',
+ description='返回文件的元数据信息。',
+)
+async def check_file_info(
+ session: SessionDep,
+ file_id: UUID,
+ access_token: str = Query(...),
+) -> JSONResponse:
+ """
+ WOPI CheckFileInfo 端点
+
+ 认证:WOPI access_token(query 参数)
+
+ 返回 WOPI 规范的 PascalCase JSON。
+ """
+ # 验证令牌
+ payload = verify_wopi_token(access_token)
+ if not payload or payload.file_id != file_id:
+ http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
+
+ # 获取文件
+ file_obj: Object | None = await Object.get(
+ session,
+ Object.id == file_id,
+ )
+ if not file_obj or not file_obj.is_file:
+ http_exceptions.raise_not_found("文件不存在")
+
+ # 获取用户信息
+ user: User | None = await User.get(session, User.id == payload.user_id)
+ user_name = user.nickname or user.email or str(payload.user_id) if user else str(payload.user_id)
+
+ # 构建响应
+ info = WopiFileInfo(
+ base_file_name=file_obj.name,
+ size=file_obj.size or 0,
+ owner_id=str(file_obj.owner_id),
+ user_id=str(payload.user_id),
+ user_friendly_name=user_name,
+ version=file_obj.updated_at.isoformat() if file_obj.updated_at else "",
+ user_can_write=payload.can_write,
+ read_only=not payload.can_write,
+ supports_update=payload.can_write,
+ )
+
+ return JSONResponse(content=info.to_wopi_dict())
+
+
+@wopi_files_router.get(
+ path='/{file_id}/contents',
+ summary='WOPI GetFile',
+ description='返回文件的二进制内容。',
+)
+async def get_file(
+ session: SessionDep,
+ file_id: UUID,
+ access_token: str = Query(...),
+) -> Response:
+ """
+ WOPI GetFile 端点
+
+ 认证:WOPI access_token(query 参数)
+
+ 返回文件的原始二进制内容。
+ """
+ # 验证令牌
+ payload = verify_wopi_token(access_token)
+ if not payload or payload.file_id != file_id:
+ http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
+
+ # 获取文件
+ file_obj: Object | None = await Object.get(session, Object.id == file_id)
+ if not file_obj or not file_obj.is_file:
+ http_exceptions.raise_not_found("文件不存在")
+
+ # 获取物理文件
+ physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
+ if not physical_file or not physical_file.storage_path:
+ http_exceptions.raise_internal_error("文件存储路径丢失")
+
+ # 获取策略
+ policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
+ if not policy:
+ http_exceptions.raise_internal_error("存储策略不存在")
+
+ if policy.type == PolicyType.LOCAL:
+ storage_service = LocalStorageService(policy)
+ if not await storage_service.file_exists(physical_file.storage_path):
+ http_exceptions.raise_not_found("物理文件不存在")
+
+ import aiofiles
+ async with aiofiles.open(physical_file.storage_path, 'rb') as f:
+ content = await f.read()
+
+ return Response(
+ content=content,
+ media_type="application/octet-stream",
+ headers={"X-WOPI-ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
+ )
+ else:
+ http_exceptions.raise_not_implemented("S3 存储暂未实现")
+
+
+@wopi_files_router.post(
+ path='/{file_id}/contents',
+ summary='WOPI PutFile',
+ description='更新文件内容。',
+)
+async def put_file(
+ session: SessionDep,
+ request: Request,
+ file_id: UUID,
+ access_token: str = Query(...),
+) -> JSONResponse:
+ """
+ WOPI PutFile 端点
+
+ 认证:WOPI access_token(query 参数,需要写权限)
+
+ 接收请求体中的文件二进制内容并覆盖存储。
+ """
+ # 验证令牌
+ payload = verify_wopi_token(access_token)
+ if not payload or payload.file_id != file_id:
+ http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
+
+ if not payload.can_write:
+ http_exceptions.raise_forbidden("没有写入权限")
+
+ # 获取文件
+ file_obj: Object | None = await Object.get(session, Object.id == file_id)
+ if not file_obj or not file_obj.is_file:
+ http_exceptions.raise_not_found("文件不存在")
+
+ # 获取物理文件
+ physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
+ if not physical_file or not physical_file.storage_path:
+ http_exceptions.raise_internal_error("文件存储路径丢失")
+
+ # 获取策略
+ policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
+ if not policy:
+ http_exceptions.raise_internal_error("存储策略不存在")
+
+ # 读取请求体
+ content = await request.body()
+
+ if policy.type == PolicyType.LOCAL:
+ import aiofiles
+ async with aiofiles.open(physical_file.storage_path, 'wb') as f:
+ await f.write(content)
+
+ # 更新文件大小
+ new_size = len(content)
+ old_size = file_obj.size or 0
+ file_obj.size = new_size
+ file_obj = await file_obj.save(session, commit=False)
+
+ # 更新物理文件大小
+ physical_file.size = new_size
+ await physical_file.save(session, commit=False)
+
+ # 更新用户存储配额
+ size_diff = new_size - old_size
+ if size_diff != 0:
+ from service.storage import adjust_user_storage
+ await adjust_user_storage(session, file_obj.owner_id, size_diff, commit=False)
+
+ await session.commit()
+
+ l.info(f"WOPI PutFile: file_id={file_id}, new_size={new_size}")
+
+ return JSONResponse(
+ content={"ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
+ status_code=200,
+ )
+ else:
+ http_exceptions.raise_not_implemented("S3 存储暂未实现")
diff --git a/service/storage/naming_rule.py b/service/storage/naming_rule.py
index dd7d873..8e19325 100644
--- a/service/storage/naming_rule.py
+++ b/service/storage/naming_rule.py
@@ -23,7 +23,7 @@ import string
from datetime import datetime
from uuid import UUID, uuid4
-from sqlmodels.base import SQLModelBase
+from sqlmodel_ext import SQLModelBase
class NamingContext(SQLModelBase):
diff --git a/setup_cython.py b/setup_cython.py
new file mode 100644
index 0000000..bcfa5a9
--- /dev/null
+++ b/setup_cython.py
@@ -0,0 +1,91 @@
+"""
+Cython 编译脚本 — 将 ee/ 下的纯逻辑文件编译为 .so
+
+用法:
+ uv run --extra build python setup_cython.py build_ext --inplace
+
+编译规则:
+ - 跳过 __init__.py(Python 包发现需要)
+ - 只编译 .py 文件(纯函数 / 服务逻辑)
+
+编译后清理(Pro Docker 构建用):
+ uv run --extra build python setup_cython.py clean_artifacts
+"""
+import shutil
+import sys
+from pathlib import Path
+
+EE_DIR = Path("ee")
+
+# 跳过 __init__.py —— 包发现需要原始 .py
+SKIP_NAMES = {"__init__.py"}
+
+
+def _collect_modules() -> list[str]:
+ """收集 ee/ 下需要编译的 .py 文件路径(点分模块名)。"""
+ modules: list[str] = []
+ for py_file in EE_DIR.rglob("*.py"):
+ if py_file.name in SKIP_NAMES:
+ continue
+ # ee/license.py → ee.license
+ module = str(py_file.with_suffix("")).replace("\\", "/").replace("/", ".")
+ modules.append(module)
+ return modules
+
+
+def clean_artifacts() -> None:
+ """删除已编译的 .py 源码、.c 中间文件和 build/ 目录。"""
+ for py_file in EE_DIR.rglob("*.py"):
+ if py_file.name in SKIP_NAMES:
+ continue
+ # 只删除有对应 .so / .pyd 的源文件
+ parent = py_file.parent
+ stem = py_file.stem
+ has_compiled = (
+ any(parent.glob(f"{stem}*.so")) or
+ any(parent.glob(f"{stem}*.pyd"))
+ )
+ if has_compiled:
+ py_file.unlink()
+ print(f"已删除源码: {py_file}")
+
+ # 删除 .c 中间文件
+ for c_file in EE_DIR.rglob("*.c"):
+ c_file.unlink()
+ print(f"已删除中间文件: {c_file}")
+
+ # 删除 build/ 目录
+ build_dir = Path("build")
+ if build_dir.exists():
+ shutil.rmtree(build_dir)
+ print(f"已删除: {build_dir}/")
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1 and sys.argv[1] == "clean_artifacts":
+ clean_artifacts()
+ sys.exit(0)
+
+ # 动态导入(仅在编译时需要)
+ from Cython.Build import cythonize
+ from setuptools import Extension, setup
+
+ modules = _collect_modules()
+ if not modules:
+ print("未找到需要编译的模块")
+ sys.exit(0)
+
+ print(f"即将编译以下模块: {modules}")
+
+ extensions = [
+ Extension(mod, [mod.replace(".", "/") + ".py"])
+ for mod in modules
+ ]
+
+ setup(
+ name="disknext-ee",
+ ext_modules=cythonize(
+ extensions,
+ compiler_directives={'language_level': "3"},
+ ),
+ )
diff --git a/sqlmodels/__init__.py b/sqlmodels/__init__.py
index eb0a34e..e87d2a7 100644
--- a/sqlmodels/__init__.py
+++ b/sqlmodels/__init__.py
@@ -115,6 +115,14 @@ from .storage_pack import StoragePack
from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
from .webdav import WebDAV
+from .file_app import (
+ FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
+ # DTO
+ FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
+ FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
+ ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
+)
+from .wopi import WopiFileInfo, WopiAccessTokenPayload
from .database_connection import DatabaseManager
@@ -131,5 +139,5 @@ from .model_base import (
AdminSummaryResponse,
)
-# mixin 中的通用分页模型
-from .mixin import ListResponse
+# 通用分页模型
+from sqlmodel_ext import ListResponse
diff --git a/sqlmodels/auth_identity.py b/sqlmodels/auth_identity.py
index 5649f43..bab8741 100644
--- a/sqlmodels/auth_identity.py
+++ b/sqlmodels/auth_identity.py
@@ -10,8 +10,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
-from .base import SQLModelBase
-from .mixin import UUIDTableBaseMixin
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/base/README.md b/sqlmodels/base/README.md
deleted file mode 100644
index 9ff36ce..0000000
--- a/sqlmodels/base/README.md
+++ /dev/null
@@ -1,657 +0,0 @@
-# SQLModels Base Module
-
-This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
-
-**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
-
-## Table of Contents
-
-- [Overview](#overview)
-- [Migration Notice](#migration-notice)
-- [Python 3.14 Compatibility](#python-314-compatibility)
-- [Core Component](#core-component)
- - [SQLModelBase](#sqlmodelbase)
-- [Metaclass Features](#metaclass-features)
- - [Automatic sa_type Injection](#automatic-sa_type-injection)
- - [Table Configuration](#table-configuration)
- - [Polymorphic Support](#polymorphic-support)
-- [Custom Types Integration](#custom-types-integration)
-- [Best Practices](#best-practices)
-- [Troubleshooting](#troubleshooting)
-
-## Overview
-
-The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
-
-- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
-- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
-- **Flexible configuration** through class parameters and automatic docstring support
-- **Type-safe annotations** with automatic validation
-
-All models in this project should directly or indirectly inherit from `SQLModelBase`.
-
----
-
-## Migration Notice
-
-As of the recent refactoring, the following components have been moved:
-
-| Component | Old Location | New Location |
-|-----------|-------------|--------------|
-| `TableBase` → `TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
-| `UUIDTableBase` → `UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
-| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
-| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
-| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
-| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
-| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
-
-**Update your imports**:
-
-```python
-# ❌ Old (deprecated)
-from sqlmodels.base import TableBase, UUIDTableBase
-
-# ✅ New (correct)
-from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
-```
-
-For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
-
----
-
-## Python 3.14 Compatibility
-
-### Overview
-
-This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
-
-**Key Changes in Python 3.14**:
-- Annotations are no longer evaluated at class definition time
-- Type hints are stored as deferred code objects
-- `__annotate__` function generates annotations on demand
-- Forward references become `ForwardRef` objects
-
-### Implementation Strategy
-
-We use **`typing.get_type_hints()`** as the universal annotations resolver:
-
-```python
-def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
- # Create temporary proxy class
- temp_cls = type('AnnotationProxy', (object,), dict(attrs))
-
- # Use get_type_hints with include_extras=True
- evaluated = get_type_hints(
- temp_cls,
- globalns=module_globals,
- localns=localns,
- include_extras=True # Preserve Annotated metadata
- )
-
- return dict(evaluated), {}, module_globals, localns
-```
-
-**Why `get_type_hints()`?**
-- ✅ Works across Python 3.10-3.14+
-- ✅ Handles PEP 649 automatically
-- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
-- ✅ Resolves forward references
-- ✅ Recommended by Python documentation
-
-### SQLModel Compatibility Patch
-
-**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
-
-**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
-
-```python
-if sys.version_info >= (3, 14):
- def _patched_get_sqlalchemy_type(field):
- annotation = getattr(field, 'annotation', None)
- if annotation is not None:
- # Priority 1: Check __sqlmodel_sa_type__ attribute
- # Handles NumpyVector[dims, dtype] and similar custom types
- if hasattr(annotation, '__sqlmodel_sa_type__'):
- return annotation.__sqlmodel_sa_type__
-
- # Priority 2: Check Annotated metadata
- if get_origin(annotation) is Annotated:
- for metadata in get_args(annotation)[1:]:
- if hasattr(metadata, '__sqlmodel_sa_type__'):
- return metadata.__sqlmodel_sa_type__
-
- # ... handle ForwardRef, ClassVar, etc.
-
- return _original_get_sqlalchemy_type(field)
-```
-
-### Supported Patterns
-
-#### Pattern 1: Direct Custom Type Usage
-```python
-from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class SpeakerInfo(UUIDTableBaseMixin, table=True):
- embedding: NumpyVector[256, np.float32]
- """Voice embedding - sa_type automatically extracted"""
-```
-
-#### Pattern 2: Annotated Wrapper
-```python
-from typing import Annotated
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
-
-class SpeakerInfo(UUIDTableBaseMixin, table=True):
- embedding: EmbeddingVector
-```
-
-#### Pattern 3: Array Type
-```python
-from sqlmodels.sqlmodel_types.dialects.postgresql import Array
-from sqlmodels.mixin import TableBaseMixin
-
-class ServerConfig(TableBaseMixin, table=True):
- protocols: Array[ProtocolEnum]
- """Allowed protocols - sa_type from Array handler"""
-```
-
-### Migration from Python 3.13
-
-**No code changes required!** The implementation is transparent:
-
-- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
-- Custom types already use `__sqlmodel_sa_type__` attribute
-- Monkey-patch only activates for Python 3.14+
-
----
-
-## Core Component
-
-### SQLModelBase
-
-`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
-
-**Key Features**:
-- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
-- Automatic `validate_by_name` configuration
-- Custom metaclass for sa_type injection and polymorphic setup
-- Integration with Pydantic v2
-- Python 3.14 PEP 649 compatibility
-
-**Usage**:
-
-```python
-from sqlmodels.base import SQLModelBase
-
-class UserBase(SQLModelBase):
- name: str
- """User's display name"""
-
- email: str
- """User's email address"""
-```
-
-**Important Notes**:
-- Use **docstrings** for field descriptions, not `Field(description=...)`
-- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
-- This class should be used for non-table models (DTOs, request/response models)
-
-**For table models**, use mixins from `sqlmodels.mixin`:
-- `TableBaseMixin` - Integer primary key with timestamps
-- `UUIDTableBaseMixin` - UUID primary key with timestamps
-
-See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
-
----
-
-## Metaclass Features
-
-### Automatic sa_type Injection
-
-The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
-
-**Before** (verbose):
-```python
-from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class SpeakerInfo(UUIDTableBaseMixin, table=True):
- embedding: np.ndarray = Field(
- sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
- )
-```
-
-**After** (clean):
-```python
-from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class SpeakerInfo(UUIDTableBaseMixin, table=True):
- embedding: NumpyVector[256, np.float32]
- """Speaker voice embedding"""
-```
-
-**How It Works**:
-
-The metaclass uses a three-tier detection strategy:
-
-1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
- ```python
- if hasattr(annotation, '__sqlmodel_sa_type__'):
- return annotation.__sqlmodel_sa_type__
- ```
-
-2. **Annotated metadata** (Priority 2)
- ```python
- # For Annotated[np.ndarray, NumpyVector[256, np.float32]]
- if get_origin(annotation) is typing.Annotated:
- for item in metadata_items:
- if hasattr(item, '__sqlmodel_sa_type__'):
- return item.__sqlmodel_sa_type__
- ```
-
-3. **Pydantic Core Schema metadata** (Priority 3)
- ```python
- schema = annotation.__get_pydantic_core_schema__(...)
- if schema['metadata'].get('sa_type'):
- return schema['metadata']['sa_type']
- ```
-
-After extracting `sa_type`, the metaclass:
-- Creates `Field(sa_type=sa_type)` if no Field is defined
-- Injects `sa_type` into existing Field if not already set
-- Respects explicit `Field(sa_type=...)` (no override)
-
-**Supported Patterns**:
-
-```python
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-# Pattern 1: Direct usage (recommended)
-class Model(UUIDTableBaseMixin, table=True):
- embedding: NumpyVector[256, np.float32]
-
-# Pattern 2: With Field constraints
-class Model(UUIDTableBaseMixin, table=True):
- embedding: NumpyVector[256, np.float32] = Field(nullable=False)
-
-# Pattern 3: Annotated wrapper
-EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
-
-class Model(UUIDTableBaseMixin, table=True):
- embedding: EmbeddingVector
-
-# Pattern 4: Explicit sa_type (override)
-class Model(UUIDTableBaseMixin, table=True):
- embedding: NumpyVector[256, np.float32] = Field(
- sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
- )
-```
-
-### Table Configuration
-
-The metaclass provides smart defaults and flexible configuration:
-
-**Automatic `table=True`**:
-```python
-# Classes inheriting from TableBaseMixin automatically get table=True
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class MyModel(UUIDTableBaseMixin): # table=True is automatic
- pass
-```
-
-**Convenient mapper arguments**:
-```python
-# Instead of verbose __mapper_args__
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class MyModel(
- UUIDTableBaseMixin,
- polymorphic_on='_polymorphic_name',
- polymorphic_abstract=True
-):
- pass
-
-# Equivalent to:
-class MyModel(UUIDTableBaseMixin):
- __mapper_args__ = {
- 'polymorphic_on': '_polymorphic_name',
- 'polymorphic_abstract': True
- }
-```
-
-**Smart merging**:
-```python
-# Dictionary and keyword arguments are merged
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class MyModel(
- UUIDTableBaseMixin,
- mapper_args={'version_id_col': 'version'},
- polymorphic_on='type' # Merged into __mapper_args__
-):
- pass
-```
-
-### Polymorphic Support
-
-The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
-
-**Supported parameters**:
-- `polymorphic_on`: Discriminator column name
-- `polymorphic_identity`: Identity value for this class
-- `polymorphic_abstract`: Whether this is an abstract base
-- `table_args`: SQLAlchemy table arguments
-- `table_name`: Override table name (becomes `__tablename__`)
-
-**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
-
----
-
-## Custom Types Integration
-
-### Using NumpyVector
-
-The `NumpyVector` type demonstrates automatic sa_type injection:
-
-```python
-from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
-from sqlmodels.mixin import UUIDTableBaseMixin
-import numpy as np
-
-class SpeakerInfo(UUIDTableBaseMixin, table=True):
- embedding: NumpyVector[256, np.float32]
- """Speaker voice embedding - sa_type automatically injected"""
-```
-
-**How NumpyVector works**:
-
-```python
-# NumpyVector[dims, dtype] returns a class with:
-class _NumpyVectorType:
- __sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
-
- @classmethod
- def __get_pydantic_core_schema__(cls, source_type, handler):
- return handler.generate_schema(np.ndarray)
-```
-
-This dual approach ensures:
-1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
-2. Pydantic can validate as `np.ndarray`
-
-### Creating Custom SQLAlchemy Types
-
-To create types that work with automatic injection, provide one of:
-
-**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
-
-```python
-from sqlalchemy import TypeDecorator, String
-
-class UpperCaseString(TypeDecorator):
- impl = String
-
- def process_bind_param(self, value, dialect):
- return value.upper() if value else value
-
-class UpperCaseType:
- __sqlmodel_sa_type__ = UpperCaseString()
-
- @classmethod
- def __get_pydantic_core_schema__(cls, source_type, handler):
- return core_schema.str_schema()
-
-# Usage
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class MyModel(UUIDTableBaseMixin, table=True):
- code: UpperCaseType # Automatically uses UpperCaseString()
-```
-
-**Option 2: Pydantic metadata with sa_type**:
-
-```python
-def __get_pydantic_core_schema__(self, source_type, handler):
- return core_schema.json_or_python_schema(
- json_schema=core_schema.str_schema(),
- python_schema=core_schema.str_schema(),
- metadata={'sa_type': UpperCaseString()}
- )
-```
-
-**Option 3: Using Annotated**:
-
-```python
-from typing import Annotated
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-UpperCase = Annotated[str, UpperCaseType()]
-
-class MyModel(UUIDTableBaseMixin, table=True):
- code: UpperCase
-```
-
----
-
-## Best Practices
-
-### 1. Inherit from correct base classes
-
-```python
-from sqlmodels.base import SQLModelBase
-from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
-
-# ✅ For non-table models (DTOs, requests, responses)
-class UserBase(SQLModelBase):
- name: str
-
-# ✅ For table models with UUID primary key
-class User(UserBase, UUIDTableBaseMixin, table=True):
- email: str
-
-# ✅ For table models with custom primary key
-class LegacyUser(TableBaseMixin, table=True):
- id: int = Field(primary_key=True)
- username: str
-```
-
-### 2. Use docstrings for field descriptions
-
-```python
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-# ✅ Recommended
-class User(UUIDTableBaseMixin, table=True):
- name: str
- """User's display name"""
-
-# ❌ Avoid
-class User(UUIDTableBaseMixin, table=True):
- name: str = Field(description="User's display name")
-```
-
-**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
-
-### 3. Leverage automatic sa_type injection
-
-```python
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-# ✅ Clean and recommended
-class SpeakerInfo(UUIDTableBaseMixin, table=True):
- embedding: NumpyVector[256, np.float32]
- """Voice embedding"""
-
-# ❌ Verbose and unnecessary
-class SpeakerInfo(UUIDTableBaseMixin, table=True):
- embedding: np.ndarray = Field(
- sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
- )
-```
-
-### 4. Follow polymorphic naming conventions
-
-See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
-
-### 5. Separate Base, Parent, and Implementation classes
-
-```python
-from abc import ABC, abstractmethod
-from sqlmodels.base import SQLModelBase
-from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
-
-# ✅ Recommended structure
-class ASRBase(SQLModelBase):
- """Pure data fields, no table"""
- name: str
- base_url: str
-
-class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
- """Abstract parent with table"""
- @abstractmethod
- async def transcribe(self, audio: bytes) -> str:
- pass
-
-class WhisperASR(ASR, table=True):
- """Concrete implementation"""
- model_size: str
-
- async def transcribe(self, audio: bytes) -> str:
- # Implementation
- pass
-```
-
-**Why?**
-- Base class can be reused for DTOs
-- Parent class defines the polymorphic hierarchy
-- Implementation classes are clean and focused
-
----
-
-## Troubleshooting
-
-### Issue: ValueError: X has no matching SQLAlchemy type
-
-**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
-
-```python
-# ✅ Provide __sqlmodel_sa_type__
-class MyType:
- __sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
-```
-
-### Issue: Can't generate DDL for NullType()
-
-**Symptoms**: Error during table creation saying a column has `NullType`.
-
-**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
-
-**Solution**:
-1. Ensure your type has `__sqlmodel_sa_type__` class attribute
-2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
-3. Verify type annotation is correct (not a string forward reference)
-
-```python
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-# ✅ Correct
-class Model(UUIDTableBaseMixin, table=True):
- data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
-
-# ❌ Wrong (string annotation)
-class Model(UUIDTableBaseMixin, table=True):
- data: 'NumpyVector[256, np.float32]' # sa_type lost
-```
-
-### Issue: Polymorphic identity conflicts
-
-**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
-
-**Solution**:
-1. Check that each concrete class has a unique identity
-2. Use `AutoPolymorphicIdentityMixin` for automatic naming
-3. Manually specify identity if needed:
- ```python
- class MyClass(Parent, polymorphic_identity='unique.name', table=True):
- pass
- ```
-
-### Issue: Python 3.14 annotation errors
-
-**Symptoms**: Errors related to `__annotations__` or type resolution.
-
-**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
-1. Check for manual `__annotations__` manipulation (avoid it)
-2. Ensure all types are properly imported
-3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
-
-### Issue: Polymorphic and CRUD-related errors
-
-For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
-
----
-
-## Implementation Details
-
-For developers modifying this module:
-
-**Core files**:
-- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
-- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
-- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
-
-**Key functions in this module**:
-
-1. **`_resolve_annotations(attrs: dict[str, Any])`**
- - Uses `typing.get_type_hints()` for Python 3.14 compatibility
- - Returns tuple: `(annotations, annotation_strings, globalns, localns)`
- - Preserves `Annotated` metadata with `include_extras=True`
-
-2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
- - Extracts SQLAlchemy type from type annotations
- - Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
- - Called by metaclass during class creation
-
-3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
- - Global monkey-patch for SQLModel
- - Checks `__sqlmodel_sa_type__` before falling back to original logic
- - Handles custom types like `NumpyVector` and `Array`
-
-4. **`__DeclarativeMeta.__new__()`**
- - Processes class definition parameters
- - Injects `sa_type` into field definitions
- - Sets up `__mapper_args__`, `__table_args__`, etc.
- - Handles Python 3.14 annotations via `get_type_hints()`
-
-**Metaclass processing order**:
-1. Check if class should be a table (`_has_table_mixin`)
-2. Collect `__mapper_args__` from kwargs and explicit dict
-3. Process `table_args`, `table_name`, `abstract` parameters
-4. Resolve annotations using `get_type_hints()`
-5. For each field, try to extract `sa_type` and inject into Field
-6. Call parent metaclass with cleaned kwargs
-
-For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
-
----
-
-## See Also
-
-**Project Documentation**:
-- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
-- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
-- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
-
-**External References**:
-- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
-- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
-- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
-- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
-- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
-- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)
diff --git a/sqlmodels/base/__init__.py b/sqlmodels/base/__init__.py
deleted file mode 100644
index 91e3cb6..0000000
--- a/sqlmodels/base/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""
-SQLModel 基础模块
-
-包含:
-- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
-
-注意:
- TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
- 为了避免循环导入,此处不再重新导出它们
- 请直接从 sqlmodels.mixin 导入这些类
-"""
-from .sqlmodel_base import SQLModelBase
diff --git a/sqlmodels/base/sqlmodel_base.py b/sqlmodels/base/sqlmodel_base.py
deleted file mode 100644
index e07b90c..0000000
--- a/sqlmodels/base/sqlmodel_base.py
+++ /dev/null
@@ -1,846 +0,0 @@
-import sys
-import typing
-from typing import Any, Mapping, get_args, get_origin, get_type_hints
-
-from pydantic import ConfigDict
-from pydantic.fields import FieldInfo
-from pydantic_core import PydanticUndefined as Undefined
-from sqlalchemy.orm import Mapped
-from sqlmodel import Field, SQLModel
-from sqlmodel.main import SQLModelMetaclass
-
-# Python 3.14+ PEP 649支持
-if sys.version_info >= (3, 14):
- import annotationlib
-
- # 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
- import sqlmodel.main
- _original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
-
- def _patched_get_sqlalchemy_type(field):
- """
- 修复SQLModel的get_sqlalchemy_type函数,处理Python 3.14的类型问题。
-
- 问题:
- 1. ForwardRef对象(来自Relationship字段)会导致issubclass错误
- 2. typing._GenericAlias对象(如ClassVar[T])也会导致同样问题
- 3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
- 4. Mapped类型在Python 3.14下可能出现在annotation中
- 5. Annotated类型可能包含sa_type metadata(如Array[T])
- 6. 自定义类型(如NumpyVector)有__sqlmodel_sa_type__属性
- 7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
-
- 解决:
- - 优先检查field.metadata中的__get_pydantic_core_schema__(Pydantic已处理的情况)
- - 检测__sqlmodel_sa_type__属性(NumpyVector等)
- - 检测Relationship/ClassVar等返回None
- - 对于Annotated类型,尝试提取sa_type metadata
- - 其他情况调用原始函数
- """
- # 优先检查 field.metadata(Pydantic已处理Annotated类型的情况)
- # 当使用 Array[T] 或 Annotated[T, metadata] 时,Pydantic会将metadata存储在这里
- metadata = getattr(field, 'metadata', None)
- if metadata:
- # metadata是一个列表,包含所有Annotated的元数据项
- for metadata_item in metadata:
- # 检查metadata_item是否有__get_pydantic_core_schema__方法
- if hasattr(metadata_item, '__get_pydantic_core_schema__'):
- try:
- # 调用获取schema
- schema = metadata_item.__get_pydantic_core_schema__(None, None)
- # 检查schema的metadata中是否有sa_type
- if isinstance(schema, dict) and 'metadata' in schema:
- sa_type = schema['metadata'].get('sa_type')
- if sa_type is not None:
- return sa_type
- except (TypeError, AttributeError, KeyError):
- # Pydantic schema获取可能失败(类型不匹配、缺少属性等)
- # 这是正常情况,继续检查下一个metadata项
- pass
-
- annotation = getattr(field, 'annotation', None)
- if annotation is not None:
- # 优先检查 __sqlmodel_sa_type__ 属性
- # 这处理 NumpyVector[dims, dtype] 等自定义类型
- if hasattr(annotation, '__sqlmodel_sa_type__'):
- return annotation.__sqlmodel_sa_type__
-
- # 检查自定义类型(如JSON100K)的 __get_pydantic_core_schema__ 方法
- # 这些类型在schema的metadata中定义sa_type
- if hasattr(annotation, '__get_pydantic_core_schema__'):
- try:
- # 调用获取schema(传None作为handler,因为我们只需要metadata)
- schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
- # 检查schema的metadata中是否有sa_type
- if isinstance(schema, dict) and 'metadata' in schema:
- sa_type = schema['metadata'].get('sa_type')
- if sa_type is not None:
- return sa_type
- except (TypeError, AttributeError, KeyError):
- # Schema获取失败,继续其他检查
- pass
-
- anno_type_name = type(annotation).__name__
-
- # ForwardRef: Relationship字段的annotation
- if anno_type_name == 'ForwardRef':
- return None
-
- # AnnotatedAlias: 检查是否有sa_type metadata(如Array[T])
- if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
- from typing import get_origin, get_args
- import typing
-
- # 尝试提取Annotated的metadata
- if hasattr(typing, 'get_args'):
- args = get_args(annotation)
- # args[0]是实际类型,args[1:]是metadata
- for metadata in args[1:]:
- # 检查metadata是否有__get_pydantic_core_schema__方法
- if hasattr(metadata, '__get_pydantic_core_schema__'):
- try:
- # 调用获取schema
- schema = metadata.__get_pydantic_core_schema__(None, None)
- # 检查schema中是否有sa_type
- if isinstance(schema, dict) and 'metadata' in schema:
- sa_type = schema['metadata'].get('sa_type')
- if sa_type is not None:
- return sa_type
- except (TypeError, AttributeError, KeyError):
- # Annotated metadata的schema获取可能失败
- # 这是正常的类型检查过程,继续检查下一个metadata
- pass
-
- # _GenericAlias或GenericAlias: typing泛型类型
- if anno_type_name in ('_GenericAlias', 'GenericAlias'):
- from typing import get_origin
- import typing
- origin = get_origin(annotation)
-
- # ClassVar必须跳过
- if origin is typing.ClassVar:
- return None
-
- # list/dict/tuple/set等内置泛型,如果字段没有明确的Field或Relationship,也跳过
- # 这通常意味着它是Relationship字段或类变量
- if origin in (list, dict, tuple, set):
- # 检查field_info是否存在且有意义
- # Relationship字段会有特殊的field_info
- field_info = getattr(field, 'field_info', None)
- if field_info is None:
- return None
-
- # Mapped: SQLAlchemy 2.0的Mapped类型,SQLModel不应该处理
- # 这可能是从父类继承的字段或Python 3.14注解处理的副作用
- # 检查类型名称和annotation的字符串表示
- if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
- return None
-
- # 检查annotation是否是Mapped类或其实例
- try:
- from sqlalchemy.orm import Mapped as SAMapped
- # 检查origin(对于Mapped[T]这种泛型)
- from typing import get_origin
- if get_origin(annotation) is SAMapped:
- return None
- # 检查类型本身
- if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
- return None
- except (ImportError, TypeError):
- # 如果SQLAlchemy没有Mapped或检查失败,继续
- pass
-
- # 其他情况正常处理
- return _original_get_sqlalchemy_type(field)
-
- sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
-
- # 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
- # 在Python 3.14 + SQLModel组合下,当子类(如SMSBaoProvider)继承父类(如VerificationCodeProvider)时,
- # 父类的关系字段(如server_config)会在子类的model_fields中出现,
- # 但其default值错误地设置为InstrumentedAttribute对象,而不是None
- # 这导致实例化时尝试设置InstrumentedAttribute为字段值,触发SQLAlchemy内部错误
- import sqlmodel._compat as _compat
- from sqlalchemy.orm import attributes as _sa_attributes
-
- _original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
-
- def _patched_sqlmodel_table_construct(self_instance, values):
- """
- 修复sqlmodel_table_construct,跳过InstrumentedAttribute默认值
-
- 问题:
- - 继承自polymorphic基类的表类(如FishAudioTTS, SMSBaoProvider)
- - 其model_fields中的继承字段default值为InstrumentedAttribute
- - 原函数尝试将InstrumentedAttribute设置为字段值
- - SQLAlchemy无法处理,抛出 '_sa_instance_state' 错误
-
- 解决:
- - 只设置用户提供的值和非InstrumentedAttribute默认值
- - InstrumentedAttribute默认值跳过(让SQLAlchemy自己处理)
- """
- cls = type(self_instance)
-
- # 收集要设置的字段值
- fields_to_set = {}
-
- for name, field in cls.model_fields.items():
- # 如果用户提供了值,直接使用
- if name in values:
- fields_to_set[name] = values[name]
- continue
-
- # 否则检查默认值
- # 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
- if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
- continue
-
- # 使用正常的默认值
- if field.default is not Undefined:
- fields_to_set[name] = field.default
- elif field.default_factory is not None:
- fields_to_set[name] = field.get_default(call_default_factory=True)
-
- # 设置属性 - 只设置非InstrumentedAttribute值
- for key, value in fields_to_set.items():
- if not isinstance(value, _sa_attributes.InstrumentedAttribute):
- setattr(self_instance, key, value)
-
- # 设置Pydantic内部属性
- object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
- if not cls.__pydantic_root_model__:
- _extra = None
- if cls.model_config.get('extra') == 'allow':
- _extra = {}
- for k, v in values.items():
- if k not in cls.model_fields:
- _extra[k] = v
- object.__setattr__(self_instance, '__pydantic_extra__', _extra)
-
- if cls.__pydantic_post_init__:
- self_instance.model_post_init(None)
- elif not cls.__pydantic_root_model__:
- object.__setattr__(self_instance, '__pydantic_private__', None)
-
- # 设置关系
- for key in self_instance.__sqlmodel_relationships__:
- value = values.get(key, Undefined)
- if value is not Undefined:
- setattr(self_instance, key, value)
-
- return self_instance
-
- _compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
-else:
- annotationlib = None
-
-
-def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
- """
- 从类型注解中提取SQLAlchemy类型。
-
- 支持以下形式:
- 1. NumpyVector[256, np.float32] - 直接使用类型(有__sqlmodel_sa_type__属性)
- 2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
- 3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
-
- Args:
- annotation: 字段的类型注解
-
- Returns:
- 提取到的SQLAlchemy类型,如果没有则返回None
- """
- # 方法1:直接检查类型本身是否有__sqlmodel_sa_type__属性
- # 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
- if hasattr(annotation, '__sqlmodel_sa_type__'):
- return annotation.__sqlmodel_sa_type__
-
- # 方法2:检查是否为Annotated类型
- if get_origin(annotation) is typing.Annotated:
- # 获取元数据项(跳过第一个实际类型参数)
- args = get_args(annotation)
- if len(args) >= 2:
- metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
-
- # 遍历元数据,查找包含sa_type的项
- for item in metadata_items:
- # 检查元数据项是否有__sqlmodel_sa_type__属性
- if hasattr(item, '__sqlmodel_sa_type__'):
- return item.__sqlmodel_sa_type__
-
- # 检查是否有__get_pydantic_core_schema__方法
- if hasattr(item, '__get_pydantic_core_schema__'):
- try:
- # 调用该方法获取core schema
- schema = item.__get_pydantic_core_schema__(
- annotation,
- lambda x: None # 虚拟handler
- )
- # 检查schema的metadata中是否有sa_type
- if isinstance(schema, dict) and 'metadata' in schema:
- sa_type = schema['metadata'].get('sa_type')
- if sa_type is not None:
- return sa_type
- except (TypeError, AttributeError, KeyError, ValueError):
- # Pydantic core schema获取可能失败:
- # - TypeError: 参数不匹配
- # - AttributeError: metadata不存在
- # - KeyError: schema结构不符合预期
- # - ValueError: 无效的类型定义
- # 这是正常的类型探测过程,继续检查下一个metadata项
- pass
-
- # 方法3:检查类型本身是否有__get_pydantic_core_schema__
- # (虽然NumpyVector已经在方法1处理,但这是通用的fallback)
- if hasattr(annotation, '__get_pydantic_core_schema__'):
- try:
- schema = annotation.__get_pydantic_core_schema__(
- annotation,
- lambda x: None # 虚拟handler
- )
- if isinstance(schema, dict) and 'metadata' in schema:
- sa_type = schema['metadata'].get('sa_type')
- if sa_type is not None:
- return sa_type
- except (TypeError, AttributeError, KeyError, ValueError):
- # 类型本身的schema获取失败
- # 这是正常的fallback机制,annotation可能不支持此协议
- pass
-
- return None
-
-
-def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
- dict[str, Any],
- dict[str, str],
- Mapping[str, Any],
- Mapping[str, Any],
-]:
- """
- Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
-
- This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
- metadata and custom types remain accessible. Forward references that cannot be
- evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
- whole resolution process.
- """
- raw_annotations = attrs.get('__annotations__') or {}
- try:
- base_annotations = dict(raw_annotations)
- except TypeError:
- base_annotations = {}
-
- module_name = attrs.get('__module__')
- module_globals: dict[str, Any]
- if module_name and module_name in sys.modules:
- module_globals = dict(sys.modules[module_name].__dict__)
- else:
- module_globals = {}
-
- module_globals.setdefault('__builtins__', __builtins__)
- localns: dict[str, Any] = dict(attrs)
-
- try:
- temp_cls = type('AnnotationProxy', (object,), dict(attrs))
- temp_cls.__module__ = module_name
- extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
- evaluated = get_type_hints(
- temp_cls,
- globalns=module_globals,
- localns=localns,
- **extras_kw,
- )
- except (NameError, AttributeError, TypeError, RecursionError):
- # get_type_hints可能失败的原因:
- # - NameError: 前向引用无法解析(类型尚未定义)
- # - AttributeError: 模块或类型不存在
- # - TypeError: 无效的类型注解
- # - RecursionError: 循环依赖的类型定义
- # 这是正常情况,回退到原始注解字符串
- evaluated = base_annotations
-
- return dict(evaluated), {}, module_globals, localns
-
-
-def _evaluate_annotation_from_string(
- field_name: str,
- annotation_strings: dict[str, str],
- current_type: Any,
- globalns: Mapping[str, Any],
- localns: Mapping[str, Any],
-) -> Any:
- """
- Attempt to re-evaluate the original annotation string for a field.
-
- This is used as a fallback when the resolved annotation lost its metadata
- (e.g., Annotated wrappers) and we need to recover custom sa_type data.
- """
- if not annotation_strings:
- return current_type
-
- expr = annotation_strings.get(field_name)
- if not expr or not isinstance(expr, str):
- return current_type
-
- try:
- return eval(expr, globalns, localns)
- except (NameError, SyntaxError, AttributeError, TypeError):
- # eval可能失败的原因:
- # - NameError: 类型名称在namespace中不存在
- # - SyntaxError: 注解字符串有语法错误
- # - AttributeError: 访问不存在的模块属性
- # - TypeError: 无效的类型表达式
- # 这是正常的fallback机制,返回当前已解析的类型
- return current_type
-
-
-class __DeclarativeMeta(SQLModelMetaclass):
- """
- 一个智能的混合模式元类,它提供了灵活性和清晰度:
-
- 1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
- 2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
- 3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
- 4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
- """
-
- _KNOWN_MAPPER_KEYS = {
- "polymorphic_on",
- "polymorphic_identity",
- "polymorphic_abstract",
- "version_id_col",
- "concrete",
- }
-
- def __new__(cls, name, bases, attrs, **kwargs):
- # 1. 约定优于配置:自动设置 table=True
- is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
- if is_intended_as_table and 'table' not in kwargs:
- kwargs['table'] = True
-
- # 2. 智能合并 __mapper_args__
- collected_mapper_args = {}
-
- # 首先,处理明确的 mapper_args 字典 (优先级较低)
- if 'mapper_args' in kwargs:
- collected_mapper_args.update(kwargs.pop('mapper_args'))
-
- # 其次,处理便捷的关键字参数 (优先级更高)
- for key in cls._KNOWN_MAPPER_KEYS:
- if key in kwargs:
- # .pop() 获取值并移除,避免传递给父类
- collected_mapper_args[key] = kwargs.pop(key)
-
- # 如果收集到了任何 mapper 参数,则更新到类的属性中
- if collected_mapper_args:
- existing = attrs.get('__mapper_args__', {}).copy()
- existing.update(collected_mapper_args)
- attrs['__mapper_args__'] = existing
-
- # 3. 处理其他明确的参数
- if 'table_args' in kwargs:
- attrs['__table_args__'] = kwargs.pop('table_args')
- if 'table_name' in kwargs:
- attrs['__tablename__'] = kwargs.pop('table_name')
- if 'abstract' in kwargs:
- attrs['__abstract__'] = kwargs.pop('abstract')
-
- # 4. 从Annotated元数据中提取sa_type并注入到Field
- # 重要:必须在调用父类__new__之前处理,因为SQLModel会消费annotations
- #
- # Python 3.14兼容性问题:
- # - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃(issubclass错误)
- # - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
- # - 虽然PEP 749建议不修改__annotations__,但这是修复SQLModel bug的必要措施
- #
- # 获取annotations的策略:
- # - Python 3.14+: 优先从__annotate__获取(如果存在)
- # - fallback: 从__annotations__读取(如果存在)
- # - 最终fallback: 空字典
- annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
-
- if annotations:
- attrs['__annotations__'] = annotations
- if annotationlib is not None:
- # 在Python 3.14中禁用descriptor,转为普通dict
- attrs['__annotate__'] = None
-
- for field_name, field_type in annotations.items():
- field_type = _evaluate_annotation_from_string(
- field_name,
- annotation_strings,
- field_type,
- eval_globals,
- eval_locals,
- )
-
- # 跳过字符串或ForwardRef类型注解,让SQLModel自己处理
- if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
- continue
-
- # 跳过特殊类型的字段
- origin = get_origin(field_type)
-
- # 跳过 ClassVar 字段 - 它们不是数据库字段
- if origin is typing.ClassVar:
- continue
-
- # 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
- if origin is Mapped:
- continue
-
- # 尝试从注解中提取sa_type
- sa_type = _extract_sa_type_from_annotation(field_type)
-
- if sa_type is not None:
- # 检查字段是否已有Field定义
- field_value = attrs.get(field_name, Undefined)
-
- if field_value is Undefined:
- # 没有Field定义,创建一个新的Field并注入sa_type
- attrs[field_name] = Field(sa_type=sa_type)
- elif isinstance(field_value, FieldInfo):
- # 已有Field定义,检查是否已设置sa_type
- # 注意:只有在未设置时才注入,尊重显式配置
- # SQLModel使用Undefined作为"未设置"的标记
- if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
- field_value.sa_type = sa_type
- # 如果field_value是其他类型(如默认值),不处理
- # SQLModel会在后续处理中将其转换为Field
-
- # 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
- result = super().__new__(cls, name, bases, attrs, **kwargs)
-
- # 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
- # SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
- # 这导致子类丢失父类的关系定义,触发错误的 Column 创建
- # 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
- if kwargs.get('table', False):
- for base in bases:
- if hasattr(base, '__sqlmodel_relationships__'):
- for rel_name, rel_info in base.__sqlmodel_relationships__.items():
- # 只继承子类没有重新定义的关系
- if rel_name not in result.__sqlmodel_relationships__:
- result.__sqlmodel_relationships__[rel_name] = rel_info
- # 同时修复被错误创建的 Column - 恢复为父类的 relationship
- if hasattr(base, rel_name):
- base_attr = getattr(base, rel_name)
- setattr(result, rel_name, base_attr)
-
- # 7. 检测:禁止子类重定义父类的 Relationship 字段
- # 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
- # 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
- for base in bases:
- parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
- for rel_name in parent_relationships:
- # 检查当前类是否在 attrs 中重新定义了这个关系字段
- if rel_name in attrs:
- raise TypeError(
- f"类 {name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'。"
- f"如需修改关系配置,请在父类中修改。"
- )
-
- # 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
- # SQLModel 0.0.27 bug:子类会错误地继承父类的 Relationship 字段到 model_fields
- # 这导致 Pydantic 尝试为 Relationship 字段生成 schema,因为类型是
- # Mapped[list['Character']] 这种前向引用,Pydantic 无法解析,
- # 导致 __pydantic_complete__ = False
- #
- # 修复策略:
- # - 检查类的 __sqlmodel_relationships__ 属性
- # - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
- # - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
- relationships = getattr(result, '__sqlmodel_relationships__', {})
- if relationships:
- model_fields = getattr(result, 'model_fields', {})
- pydantic_fields = getattr(result, '__pydantic_fields__', {})
-
- fields_removed = False
- for rel_name in relationships:
- if rel_name in model_fields:
- del model_fields[rel_name]
- fields_removed = True
- if rel_name in pydantic_fields:
- del pydantic_fields[rel_name]
- fields_removed = True
-
- # 如果移除了字段,重新构建 Pydantic 模式
- # 注意:只在有字段被移除时才 rebuild,避免不必要的开销
- if fields_removed and hasattr(result, 'model_rebuild'):
- result.model_rebuild(force=True)
-
- return result
-
- def __init__(
- cls,
- classname: str,
- bases: tuple[type, ...],
- dict_: dict[str, typing.Any],
- **kw: typing.Any,
- ) -> None:
- """
- 重写 SQLModel 的 __init__ 以支持联表继承(Joined Table Inheritance)
-
- SQLModel 原始行为:
- - 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
- - 这阻止了子类创建自己的表
-
- 修复逻辑:
- - 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
- - 强制调用 DeclarativeMeta.__init__ 来创建子表
- """
- from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
-
- # 检查是否是表模型
- if not is_table_model_class(cls):
- ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
- return
-
- # 检查是否有基类是表模型
- base_is_table = any(is_table_model_class(base) for base in bases)
-
- if not base_is_table:
- # 没有基类是表模型,走正常的 SQLModel 流程
- # 处理关系字段
- cls._setup_relationships()
- DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
- return
-
- # 关键:检测联表继承场景
- # 条件:
- # 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
- # 2. 当前类有字段带有 foreign_key 指向父表
- current_tablename = getattr(cls, '__tablename__', None)
-
- # 查找父表信息
- parent_table = None
- parent_tablename = None
- for base in bases:
- if is_table_model_class(base) and hasattr(base, '__tablename__'):
- parent_tablename = base.__tablename__
- break
-
- # 检查是否有不同的 tablename
- has_different_tablename = (
- current_tablename is not None
- and parent_tablename is not None
- and current_tablename != parent_tablename
- )
-
- # 检查是否有外键字段指向父表的主键
- # 注意:由于字段合并,我们需要检查直接基类的 model_fields
- # 而不是当前类的合并后的 model_fields
- has_fk_to_parent = False
-
- def _normalize_tablename(name: str) -> str:
- """标准化表名以进行比较(移除下划线,转小写)"""
- return name.replace('_', '').lower()
-
- def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
- """检查 FK 字符串是否指向父表"""
- if not fk_str or not parent_table:
- return False
- # FK 格式: "tablename.column" 或 "schema.tablename.column"
- parts = fk_str.split('.')
- if len(parts) >= 2:
- fk_table = parts[-2] # 取倒数第二个作为表名
- # 标准化比较(处理下划线差异)
- return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
- return False
-
- if has_different_tablename and parent_tablename:
- # 首先检查当前类的 model_fields
- for field_name, field_info in cls.model_fields.items():
- fk = getattr(field_info, 'foreign_key', None)
- if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
- has_fk_to_parent = True
- break
-
- # 如果没找到,检查直接基类的 model_fields(解决 mixin 字段被覆盖的问题)
- if not has_fk_to_parent:
- for base in bases:
- if hasattr(base, 'model_fields'):
- for field_name, field_info in base.model_fields.items():
- fk = getattr(field_info, 'foreign_key', None)
- if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
- has_fk_to_parent = True
- break
- if has_fk_to_parent:
- break
-
- is_joined_inheritance = has_different_tablename and has_fk_to_parent
-
- if is_joined_inheritance:
- # 联表继承:需要创建子表
-
- # 修复外键字段:由于字段合并,外键信息可能丢失
- # 需要从基类的 mixin 中找回外键信息,并重建列
- from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
- from sqlalchemy.dialects.postgresql import UUID as SA_UUID
- from sqlalchemy.exc import NoInspectionAvailable
- from sqlalchemy.orm.attributes import InstrumentedAttribute
-
- # 联表继承:子表只应该有 id(FK 到父表)+ 子类特有的字段
- # 所有继承自祖先表的列都不应该在子表中重复创建
-
- # 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
- # 需要遍历整个 MRO,因为可能是多级继承(如 Tool -> Function -> GetWeatherFunction)
- ancestor_column_names: set[str] = set()
- for ancestor in cls.__mro__:
- if ancestor is cls:
- continue # 跳过当前类
- if is_table_model_class(ancestor):
- try:
- # 使用 inspect() 获取 mapper 的公开属性
- # 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
- mapper = sa_inspect(ancestor)
- for col in mapper.local_table.columns:
- # 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
- if col.name.startswith('_polymorphic'):
- continue
- ancestor_column_names.add(col.name)
- except NoInspectionAvailable:
- continue
-
- # 找到子类自己定义的字段(不在父类中的)
- child_own_fields: set[str] = set()
- for field_name in cls.model_fields:
- # 检查这个字段是否是在当前类直接定义的(不是继承的)
- # 通过检查父类是否有这个字段来判断
- is_inherited = False
- for base in bases:
- if hasattr(base, 'model_fields') and field_name in base.model_fields:
- is_inherited = True
- break
- if not is_inherited:
- child_own_fields.add(field_name)
-
- # 从子类类属性中移除父表已有的列定义
- # 这样 SQLAlchemy 就不会在子表中创建这些列
- fk_field_name = None
- for base in bases:
- if hasattr(base, 'model_fields'):
- for field_name, field_info in base.model_fields.items():
- fk = getattr(field_info, 'foreign_key', None)
- pk = getattr(field_info, 'primary_key', False)
- if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
- fk_field_name = field_name
- # 找到了外键字段,重建它
- # 创建一个新的 Column 对象包含外键约束
- new_col = Column(
- field_name,
- SA_UUID(as_uuid=True),
- ForeignKey(fk),
- primary_key=pk if pk else False
- )
- setattr(cls, field_name, new_col)
- break
- else:
- continue
- break
-
- # 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
- # 这防止 SQLAlchemy 在子表中创建重复列
- # 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
- for col_name in ancestor_column_names:
- if col_name == fk_field_name:
- continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
- if col_name == 'id':
- continue # id 会被 FK 字段覆盖
- if col_name in child_own_fields:
- continue # 保留子类自己定义的字段
-
- # 检查类属性是否是 Column 或 InstrumentedAttribute
- if col_name in cls.__dict__:
- attr = cls.__dict__[col_name]
- # Column 对象或 InstrumentedAttribute 都需要删除
- if isinstance(attr, (Column, InstrumentedAttribute)):
- try:
- delattr(cls, col_name)
- except AttributeError:
- pass
-
- # 找到子类自己定义的关系(不在父类中的)
- # 继承的关系会从父类自动获取,只需要设置子类新增的关系
- child_own_relationships: set[str] = set()
- for rel_name in cls.__sqlmodel_relationships__:
- is_inherited = False
- for base in bases:
- if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
- is_inherited = True
- break
- if not is_inherited:
- child_own_relationships.add(rel_name)
-
- # 只为子类自己定义的新关系调用关系设置
- if child_own_relationships:
- cls._setup_relationships(only_these=child_own_relationships)
-
- # 强制调用 DeclarativeMeta.__init__
- DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
- else:
- # 非联表继承:单表继承或正常 Pydantic 模型
- ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
-
- def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
- """
- 设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
-
- Args:
- only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
- 如果为 None,设置所有关系(默认行为)
- """
- from sqlalchemy.orm import relationship, Mapped
- from sqlalchemy import inspect
- from sqlmodel.main import get_relationship_to
- from typing import get_origin
-
- for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
- # 如果指定了 only_these,只设置这些关系
- if only_these is not None and rel_name not in only_these:
- continue
- if rel_info.sa_relationship:
- setattr(cls, rel_name, rel_info.sa_relationship)
- continue
-
- raw_ann = cls.__annotations__[rel_name]
- origin: typing.Any = get_origin(raw_ann)
- if origin is Mapped:
- ann = raw_ann.__args__[0]
- else:
- ann = raw_ann
- cls.__annotations__[rel_name] = Mapped[ann]
-
- relationship_to = get_relationship_to(
- name=rel_name, rel_info=rel_info, annotation=ann
- )
- rel_kwargs: dict[str, typing.Any] = {}
- if rel_info.back_populates:
- rel_kwargs["back_populates"] = rel_info.back_populates
- if rel_info.cascade_delete:
- rel_kwargs["cascade"] = "all, delete-orphan"
- if rel_info.passive_deletes:
- rel_kwargs["passive_deletes"] = rel_info.passive_deletes
- if rel_info.link_model:
- ins = inspect(rel_info.link_model)
- local_table = getattr(ins, "local_table")
- if local_table is None:
- raise RuntimeError(
- f"Couldn't find secondary table for {rel_info.link_model}"
- )
- rel_kwargs["secondary"] = local_table
-
- rel_args: list[typing.Any] = []
- if rel_info.sa_relationship_args:
- rel_args.extend(rel_info.sa_relationship_args)
- if rel_info.sa_relationship_kwargs:
- rel_kwargs.update(rel_info.sa_relationship_kwargs)
-
- rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
- setattr(cls, rel_name, rel_value)
-
-
-class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
- """此类必须和TableBase系列类搭配使用"""
-
- model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)
diff --git a/sqlmodels/color.py b/sqlmodels/color.py
index bc88dec..be56884 100644
--- a/sqlmodels/color.py
+++ b/sqlmodels/color.py
@@ -1,6 +1,6 @@
from enum import StrEnum
-from .base import SQLModelBase
+from sqlmodel_ext import SQLModelBase
class ChromaticColor(StrEnum):
diff --git a/sqlmodels/database.py b/sqlmodels/database.py
deleted file mode 100644
index cef7602..0000000
--- a/sqlmodels/database.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from sqlmodel import SQLModel
-from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
-from sqlmodel.ext.asyncio.session import AsyncSession
-from utils.conf import appmeta
-from sqlalchemy.orm import sessionmaker
-from typing import AsyncGenerator
-
-ASYNC_DATABASE_URL = appmeta.database_url
-
-engine: AsyncEngine = create_async_engine(
- ASYNC_DATABASE_URL,
- echo=appmeta.debug,
- connect_args={
- "check_same_thread": False
- } if ASYNC_DATABASE_URL.startswith("sqlite") else {},
- future=True,
- # pool_size=POOL_SIZE,
- # max_overflow=64,
-)
-
-_async_session_factory = sessionmaker(engine, class_=AsyncSession)
-
-async def get_session() -> AsyncGenerator[AsyncSession, None]:
- async with _async_session_factory() as session:
- yield session
-
-async def init_db(
- url: str = ASYNC_DATABASE_URL
-):
- """创建数据库结构"""
- async with engine.begin() as conn:
- await conn.run_sync(SQLModel.metadata.create_all)
-
\ No newline at end of file
diff --git a/sqlmodels/download.py b/sqlmodels/download.py
index 1440ad4..671a663 100644
--- a/sqlmodels/download.py
+++ b/sqlmodels/download.py
@@ -4,8 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, Index
-from .base import SQLModelBase
-from .mixin import UUIDTableBaseMixin, TableBaseMixin
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/file_app.py b/sqlmodels/file_app.py
new file mode 100644
index 0000000..f31daf0
--- /dev/null
+++ b/sqlmodels/file_app.py
@@ -0,0 +1,409 @@
+"""
+文件查看器应用模块
+
+提供文件预览应用选择器系统的数据模型和 DTO。
+类似 Android 的"使用什么应用打开"机制:
+- 管理员注册应用(内置/iframe/WOPI)
+- 用户按扩展名查询可用查看器
+- 用户可设置"始终使用"偏好
+- 支持用户组级别的访问控制
+
+架构:
+ FileApp (应用注册表)
+ ├── FileAppExtension (扩展名关联)
+ ├── FileAppGroupLink (用户组访问控制)
+ └── UserFileAppDefault (用户默认偏好)
+"""
+from enum import StrEnum
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+from sqlmodel import Field, Relationship, UniqueConstraint
+
+from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
+
+if TYPE_CHECKING:
+ from .group import Group
+
+
+# ==================== 枚举 ====================
+
+class FileAppType(StrEnum):
+ """文件应用类型"""
+
+ BUILTIN = "builtin"
+ """前端内置查看器(如 pdf.js, Monaco)"""
+
+ IFRAME = "iframe"
+ """iframe 内嵌第三方服务"""
+
+ WOPI = "wopi"
+ """WOPI 协议(OnlyOffice / Collabora)"""
+
+
+# ==================== Link 表 ====================
+
+class FileAppGroupLink(SQLModelBase, table=True):
+ """应用-用户组访问控制关联表"""
+
+ app_id: UUID = Field(foreign_key="fileapp.id", primary_key=True, ondelete="CASCADE")
+ """关联的应用UUID"""
+
+ group_id: UUID = Field(foreign_key="group.id", primary_key=True, ondelete="CASCADE")
+ """关联的用户组UUID"""
+
+
+# ==================== DTO 模型 ====================
+
+class FileAppSummary(SQLModelBase):
+ """查看器列表项 DTO,用于选择器弹窗"""
+
+ id: UUID
+ """应用UUID"""
+
+ name: str
+ """应用名称"""
+
+ app_key: str
+ """应用唯一标识"""
+
+ type: FileAppType
+ """应用类型"""
+
+ icon: str | None = None
+ """图标名称/URL"""
+
+ description: str | None = None
+ """应用描述"""
+
+ iframe_url_template: str | None = None
+ """iframe URL 模板"""
+
+ wopi_editor_url_template: str | None = None
+ """WOPI 编辑器 URL 模板"""
+
+
+class FileViewersResponse(SQLModelBase):
+ """查看器查询响应 DTO"""
+
+ viewers: list[FileAppSummary] = []
+ """可用查看器列表(已按 priority 排序)"""
+
+ default_viewer_id: UUID | None = None
+ """用户默认查看器UUID(如果已设置"始终使用")"""
+
+
+class SetDefaultViewerRequest(SQLModelBase):
+ """设置默认查看器请求 DTO"""
+
+ extension: str = Field(max_length=20)
+ """文件扩展名(小写,无点号)"""
+
+ app_id: UUID
+ """应用UUID"""
+
+
+class UserFileAppDefaultResponse(SQLModelBase):
+ """用户默认查看器响应 DTO"""
+
+ id: UUID
+ """记录UUID"""
+
+ extension: str
+ """扩展名"""
+
+ app: FileAppSummary
+ """关联的应用摘要"""
+
+
+class FileAppCreateRequest(SQLModelBase):
+ """管理员创建应用请求 DTO"""
+
+ name: str = Field(max_length=100)
+ """应用名称"""
+
+ app_key: str = Field(max_length=50)
+ """应用唯一标识"""
+
+ type: FileAppType
+ """应用类型"""
+
+ icon: str | None = Field(default=None, max_length=255)
+ """图标名称/URL"""
+
+ description: str | None = Field(default=None, max_length=500)
+ """应用描述"""
+
+ is_enabled: bool = True
+ """是否启用"""
+
+ is_restricted: bool = False
+ """是否限制用户组访问"""
+
+ iframe_url_template: str | None = Field(default=None, max_length=1024)
+ """iframe URL 模板"""
+
+ wopi_discovery_url: str | None = Field(default=None, max_length=512)
+ """WOPI 发现端点 URL"""
+
+ wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
+ """WOPI 编辑器 URL 模板"""
+
+ extensions: list[str] = []
+ """关联的扩展名列表"""
+
+ allowed_group_ids: list[UUID] = []
+ """允许访问的用户组UUID列表"""
+
+
+class FileAppUpdateRequest(SQLModelBase):
+ """管理员更新应用请求 DTO(所有字段可选)"""
+
+ name: str | None = Field(default=None, max_length=100)
+ """应用名称"""
+
+ app_key: str | None = Field(default=None, max_length=50)
+ """应用唯一标识"""
+
+ type: FileAppType | None = None
+ """应用类型"""
+
+ icon: str | None = Field(default=None, max_length=255)
+ """图标名称/URL"""
+
+ description: str | None = Field(default=None, max_length=500)
+ """应用描述"""
+
+ is_enabled: bool | None = None
+ """是否启用"""
+
+ is_restricted: bool | None = None
+ """是否限制用户组访问"""
+
+ iframe_url_template: str | None = Field(default=None, max_length=1024)
+ """iframe URL 模板"""
+
+ wopi_discovery_url: str | None = Field(default=None, max_length=512)
+ """WOPI 发现端点 URL"""
+
+ wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
+ """WOPI 编辑器 URL 模板"""
+
+
+class FileAppResponse(SQLModelBase):
+ """管理员应用详情响应 DTO"""
+
+ id: UUID
+ """应用UUID"""
+
+ name: str
+ """应用名称"""
+
+ app_key: str
+ """应用唯一标识"""
+
+ type: FileAppType
+ """应用类型"""
+
+ icon: str | None = None
+ """图标名称/URL"""
+
+ description: str | None = None
+ """应用描述"""
+
+ is_enabled: bool = True
+ """是否启用"""
+
+ is_restricted: bool = False
+ """是否限制用户组访问"""
+
+ iframe_url_template: str | None = None
+ """iframe URL 模板"""
+
+ wopi_discovery_url: str | None = None
+ """WOPI 发现端点 URL"""
+
+ wopi_editor_url_template: str | None = None
+ """WOPI 编辑器 URL 模板"""
+
+ extensions: list[str] = []
+ """关联的扩展名列表"""
+
+ allowed_group_ids: list[UUID] = []
+ """允许访问的用户组UUID列表"""
+
+ @classmethod
+ def from_app(
+ cls,
+ app: "FileApp",
+ extensions: list["FileAppExtension"],
+ group_links: list[FileAppGroupLink],
+ ) -> "FileAppResponse":
+ """从 ORM 对象构建 DTO"""
+ return cls(
+ id=app.id,
+ name=app.name,
+ app_key=app.app_key,
+ type=app.type,
+ icon=app.icon,
+ description=app.description,
+ is_enabled=app.is_enabled,
+ is_restricted=app.is_restricted,
+ iframe_url_template=app.iframe_url_template,
+ wopi_discovery_url=app.wopi_discovery_url,
+ wopi_editor_url_template=app.wopi_editor_url_template,
+ extensions=[ext.extension for ext in extensions],
+ allowed_group_ids=[link.group_id for link in group_links],
+ )
+
+
+class FileAppListResponse(SQLModelBase):
+ """管理员应用列表响应 DTO"""
+
+ apps: list[FileAppResponse] = []
+ """应用列表"""
+
+ total: int = 0
+ """总数"""
+
+
+class ExtensionUpdateRequest(SQLModelBase):
+ """扩展名全量替换请求 DTO"""
+
+ extensions: list[str]
+ """扩展名列表(小写,无点号)"""
+
+
+class GroupAccessUpdateRequest(SQLModelBase):
+ """用户组权限全量替换请求 DTO"""
+
+ group_ids: list[UUID]
+ """允许访问的用户组UUID列表"""
+
+
+class WopiSessionResponse(SQLModelBase):
+ """WOPI 会话响应 DTO"""
+
+ wopi_src: str
+ """WOPI 源 URL"""
+
+ access_token: str
+ """WOPI 访问令牌"""
+
+ access_token_ttl: int
+ """令牌过期时间戳(毫秒,WOPI 规范要求)"""
+
+ editor_url: str
+ """完整的编辑器 URL"""
+
+
+# ==================== 数据库模型 ====================
+
+class FileApp(SQLModelBase, UUIDTableBaseMixin):
+ """文件查看器应用注册表"""
+
+ name: str = Field(max_length=100)
+ """应用名称"""
+
+ app_key: str = Field(max_length=50, unique=True, index=True)
+ """应用唯一标识,前端路由用"""
+
+ type: FileAppType
+ """应用类型"""
+
+ icon: str | None = Field(default=None, max_length=255)
+ """图标名称/URL"""
+
+ description: str | None = Field(default=None, max_length=500)
+ """应用描述"""
+
+ is_enabled: bool = True
+ """是否启用"""
+
+ is_restricted: bool = False
+ """是否限制用户组访问"""
+
+ iframe_url_template: str | None = Field(default=None, max_length=1024)
+ """iframe URL 模板,支持 {file_url} 占位符"""
+
+ wopi_discovery_url: str | None = Field(default=None, max_length=512)
+ """WOPI 客户端发现端点 URL"""
+
+ wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
+ """WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
+
+ # 关系
+ extensions: list["FileAppExtension"] = Relationship(
+ back_populates="app",
+ sa_relationship_kwargs={"cascade": "all, delete-orphan"}
+ )
+
+ user_defaults: list["UserFileAppDefault"] = Relationship(
+ back_populates="app",
+ sa_relationship_kwargs={"cascade": "all, delete-orphan"}
+ )
+
+ allowed_groups: list["Group"] = Relationship(
+ link_model=FileAppGroupLink,
+ )
+
+ def to_summary(self) -> FileAppSummary:
+ """转换为摘要 DTO"""
+ return FileAppSummary(
+ id=self.id,
+ name=self.name,
+ app_key=self.app_key,
+ type=self.type,
+ icon=self.icon,
+ description=self.description,
+ iframe_url_template=self.iframe_url_template,
+ wopi_editor_url_template=self.wopi_editor_url_template,
+ )
+
+
+class FileAppExtension(SQLModelBase, TableBaseMixin):
+ """扩展名关联表"""
+
+ __table_args__ = (
+ UniqueConstraint("app_id", "extension", name="uq_fileappextension_app_extension"),
+ )
+
+ app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
+ """关联的应用UUID"""
+
+ extension: str = Field(max_length=20, index=True)
+ """扩展名(小写,无点号)"""
+
+ priority: int = Field(default=0, ge=0)
+ """排序优先级(越小越优先)"""
+
+ # 关系
+ app: FileApp = Relationship(back_populates="extensions")
+
+
+class UserFileAppDefault(SQLModelBase, UUIDTableBaseMixin):
+ """用户"始终使用"偏好"""
+
+ __table_args__ = (
+ UniqueConstraint("user_id", "extension", name="uq_userfileappdefault_user_extension"),
+ )
+
+ user_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
+ """用户UUID"""
+
+ extension: str = Field(max_length=20)
+ """扩展名(小写,无点号)"""
+
+ app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
+ """关联的应用UUID"""
+
+ # 关系
+ app: FileApp = Relationship(back_populates="user_defaults")
+
+ def to_response(self) -> UserFileAppDefaultResponse:
+ """转换为响应 DTO(需预加载 app 关系)"""
+ return UserFileAppDefaultResponse(
+ id=self.id,
+ extension=self.extension,
+ app=self.app.to_summary(),
+ )
diff --git a/sqlmodels/group.py b/sqlmodels/group.py
index 3e92e8a..8bea70d 100644
--- a/sqlmodels/group.py
+++ b/sqlmodels/group.py
@@ -4,8 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, text
-from .base import SQLModelBase
-from .mixin import TableBaseMixin, UUIDTableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/migration.py b/sqlmodels/migration.py
index 7715265..59aa43c 100644
--- a/sqlmodels/migration.py
+++ b/sqlmodels/migration.py
@@ -17,6 +17,7 @@ async def migration() -> None:
await init_default_group()
await init_default_user()
await init_default_theme_presets()
+ await init_default_file_apps()
log.info('数据库初始化结束')
@@ -372,3 +373,146 @@ async def init_default_theme_presets() -> None:
)
await default_preset.save(session)
log.info('已创建默认主题预设')
+
+
+# ==================== 默认文件查看器应用种子数据 ====================
+
+_DEFAULT_FILE_APPS: list[dict] = [
+ # 内置应用(type=builtin,默认启用)
+ {
+ "name": "PDF 阅读器",
+ "app_key": "pdfjs",
+ "type": "builtin",
+ "icon": "file-pdf",
+ "description": "基于 pdf.js 的 PDF 在线阅读器",
+ "is_enabled": True,
+ "extensions": ["pdf"],
+ },
+ {
+ "name": "代码编辑器",
+ "app_key": "monaco",
+ "type": "builtin",
+ "icon": "code",
+ "description": "基于 Monaco Editor 的代码编辑器",
+ "is_enabled": True,
+ "extensions": [
+ "txt", "md", "json", "xml", "yaml", "yml",
+ "py", "js", "ts", "jsx", "tsx",
+ "html", "css", "scss", "less",
+ "sh", "bash", "zsh",
+ "c", "cpp", "h", "hpp",
+ "java", "kt", "go", "rs", "rb",
+ "sql", "graphql",
+ "toml", "ini", "cfg", "conf",
+ "env", "gitignore", "dockerfile",
+ "vue", "svelte",
+ ],
+ },
+ {
+ "name": "Markdown 预览",
+ "app_key": "markdown",
+ "type": "builtin",
+ "icon": "markdown",
+ "description": "Markdown 实时预览",
+ "is_enabled": True,
+ "extensions": ["md", "markdown", "mdx"],
+ },
+ {
+ "name": "图片查看器",
+ "app_key": "image_viewer",
+ "type": "builtin",
+ "icon": "image",
+ "description": "图片在线查看器",
+ "is_enabled": True,
+ "extensions": ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg", "ico", "avif"],
+ },
+ {
+ "name": "视频播放器",
+ "app_key": "video_player",
+ "type": "builtin",
+ "icon": "video",
+ "description": "HTML5 视频播放器",
+ "is_enabled": True,
+ "extensions": ["mp4", "webm", "ogg", "mov", "mkv", "m3u8"],
+ },
+ {
+ "name": "音频播放器",
+ "app_key": "audio_player",
+ "type": "builtin",
+ "icon": "audio",
+ "description": "HTML5 音频播放器",
+ "is_enabled": True,
+ "extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"],
+ },
+ # iframe 应用(默认禁用)
+ {
+ "name": "Office 在线预览",
+ "app_key": "office_viewer",
+ "type": "iframe",
+ "icon": "file-word",
+ "description": "使用 Microsoft Office Online 预览文档",
+ "is_enabled": False,
+ "iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
+ "extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
+ },
+ # WOPI 应用(默认禁用)
+ {
+ "name": "Collabora Online",
+ "app_key": "collabora",
+ "type": "wopi",
+ "icon": "file-text",
+ "description": "Collabora Online 文档编辑器(需自行部署)",
+ "is_enabled": False,
+ "extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx", "odt", "ods", "odp"],
+ },
+ {
+ "name": "OnlyOffice",
+ "app_key": "onlyoffice",
+ "type": "wopi",
+ "icon": "file-text",
+ "description": "OnlyOffice 文档编辑器(需自行部署)",
+ "is_enabled": False,
+ "extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
+ },
+]
+
+
+async def init_default_file_apps() -> None:
+ """初始化默认文件查看器应用"""
+ from .file_app import FileApp, FileAppExtension, FileAppType
+ from .database_connection import DatabaseManager
+
+ log.info('初始化文件查看器应用...')
+
+ async for session in DatabaseManager.get_session():
+ # 已存在应用则跳过
+ existing_count = await FileApp.count(session)
+ if existing_count > 0:
+ return
+
+ for app_data in _DEFAULT_FILE_APPS:
+ extensions = app_data.pop("extensions")
+
+ app = FileApp(
+ name=app_data["name"],
+ app_key=app_data["app_key"],
+ type=FileAppType(app_data["type"]),
+ icon=app_data.get("icon"),
+ description=app_data.get("description"),
+ is_enabled=app_data.get("is_enabled", True),
+ iframe_url_template=app_data.get("iframe_url_template"),
+ wopi_discovery_url=app_data.get("wopi_discovery_url"),
+ wopi_editor_url_template=app_data.get("wopi_editor_url_template"),
+ )
+ app = await app.save(session)
+ app_id = app.id
+
+ for i, ext in enumerate(extensions):
+ ext_record = FileAppExtension(
+ app_id=app_id,
+ extension=ext.lower(),
+ priority=i,
+ )
+ await ext_record.save(session)
+
+ log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用')
diff --git a/sqlmodels/mixin/README.md b/sqlmodels/mixin/README.md
deleted file mode 100644
index de03841..0000000
--- a/sqlmodels/mixin/README.md
+++ /dev/null
@@ -1,543 +0,0 @@
-# SQLModel Mixin Module
-
-This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
-
-## Module Overview
-
-The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
-
-- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
-- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
-- **JWT Authentication**: Token generation and validation
-- **Pagination & Sorting**: Standardized table view parameters
-- **Response DTOs**: Consistent id/timestamp fields for API responses
-
-## Module Structure
-
-```
-sqlmodels/mixin/
-├── __init__.py # Module exports
-├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
-├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
-├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
-└── jwt/ # JWT authentication
- ├── __init__.py
- ├── key.py # JWTKey database model
- ├── payload.py # JWTPayloadBase
- ├── manager.py # JWTManager singleton
- ├── auth.py # JWTAuthMixin
- ├── exceptions.py # JWT-related exceptions
- └── responses.py # TokenResponse DTO
-```
-
-## Dependency Hierarchy
-
-The module has a strict import order to avoid circular dependencies:
-
-1. **polymorphic.py** - Only depends on `SQLModelBase`
-2. **table.py** - Depends on `polymorphic.py`
-3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
-4. **info_response.py** - Only depends on `SQLModelBase`
-
-## Core Components
-
-### 1. TableBaseMixin
-
-Base mixin for database table models with integer primary keys.
-
-**Features:**
-- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
-- Automatic timestamp management (`created_at`, `updated_at`)
-- Async relationship loading support (via `AsyncAttrs`)
-- Pagination and sorting via `TableViewRequest`
-- Polymorphic subclass loading support
-
-**Fields:**
-- `id: int | None` - Integer primary key (auto-increment)
-- `created_at: datetime` - Record creation timestamp
-- `updated_at: datetime` - Record update timestamp (auto-updated)
-
-**Usage:**
-```python
-from sqlmodels.mixin import TableBaseMixin
-from sqlmodels.base import SQLModelBase
-
-class User(SQLModelBase, TableBaseMixin, table=True):
- name: str
- email: str
- """User email"""
-
-# CRUD operations
-async def example(session: AsyncSession):
- # Add
- user = User(name="Alice", email="alice@example.com")
- user = await user.save(session)
-
- # Get
- user = await User.get(session, User.id == 1)
-
- # Update
- update_data = UserUpdateRequest(name="Alice Smith")
- user = await user.update(session, update_data)
-
- # Delete
- await User.delete(session, user)
-
- # Count
- count = await User.count(session, User.is_active == True)
-```
-
-**Important Notes:**
-- `save()` and `update()` return refreshed instances - **always use the return value**:
- ```python
- # ✅ Correct
- device = await device.save(session)
- return device
-
- # ❌ Wrong - device is expired after commit
- await device.save(session)
- return device
- ```
-
-### 2. UUIDTableBaseMixin
-
-Extends `TableBaseMixin` with UUID primary keys instead of integers.
-
-**Differences from TableBaseMixin:**
-- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
-- `get_exist_one()` accepts `UUID` instead of `int`
-
-**Usage:**
-```python
-from sqlmodels.mixin import UUIDTableBaseMixin
-
-class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
- name: str
- description: str | None = None
- """Character description"""
-```
-
-**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
-
-### 3. TableViewRequest
-
-Standardized pagination and sorting parameters for LIST endpoints.
-
-**Fields:**
-- `offset: int | None` - Skip first N records (default: 0)
-- `limit: int | None` - Return max N records (default: 50, max: 100)
-- `desc: bool | None` - Sort descending (default: True)
-- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
-
-**Usage with TableBaseMixin.get():**
-```python
-from dependencies import TableViewRequestDep
-
-@router.get("/list")
-async def list_characters(
- session: SessionDep,
- table_view: TableViewRequestDep
-) -> list[Character]:
- """List characters with pagination and sorting"""
- return await Character.get(
- session,
- fetch_mode="all",
- table_view=table_view # Automatically handles pagination and sorting
- )
-```
-
-**Manual usage:**
-```python
-table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
-characters = await Character.get(session, fetch_mode="all", table_view=table_view)
-```
-
-**Backward Compatibility:**
-The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
-
-### 4. PolymorphicBaseMixin
-
-Base mixin for joined table inheritance, automatically configuring polymorphic settings.
-
-**Automatic Configuration:**
-- Defines `_polymorphic_name: str` field (indexed)
-- Sets `polymorphic_on='_polymorphic_name'`
-- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
-
-**Methods:**
-- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
-- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
-- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
-
-**Usage:**
-```python
-from abc import ABC, abstractmethod
-from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
-
-class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
- """Abstract base class for all tools"""
- name: str
- description: str
- """Tool description"""
-
- @abstractmethod
- async def execute(self, params: dict) -> dict:
- """Execute the tool"""
- pass
-```
-
-**Why Single Underscore Prefix?**
-- SQLAlchemy maps single-underscore fields to database columns
-- Pydantic treats them as private (excluded from serialization)
-- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
-
-### 5. create_subclass_id_mixin()
-
-Factory function to create ID mixins for subclasses in joined table inheritance.
-
-**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
-
-**Signature:**
-```python
-def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
- """
- Args:
- parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
-
- Returns:
- A mixin class containing id field (foreign key + primary key)
- """
-```
-
-**Usage:**
-```python
-from sqlmodels.mixin import create_subclass_id_mixin
-
-# Create mixin for ASR subclasses
-ASRSubclassIdMixin = create_subclass_id_mixin('asr')
-
-class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
- """FunASR implementation"""
- pass
-```
-
-**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
-
-### 6. AutoPolymorphicIdentityMixin
-
-Automatically generates `polymorphic_identity` based on class name.
-
-**Naming Convention:**
-- Format: `{parent_identity}.{classname_lowercase}`
-- If no parent identity exists, uses `{classname_lowercase}`
-
-**Usage:**
-```python
-from sqlmodels.mixin import AutoPolymorphicIdentityMixin
-
-class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
- """Base class for function-type tools"""
- pass
- # polymorphic_identity = 'function'
-
-class GetWeatherFunction(Function, table=True):
- """Weather query function"""
- pass
- # polymorphic_identity = 'function.getweatherfunction'
-```
-
-**Manual Override:**
-```python
-class CustomTool(
- Tool,
- AutoPolymorphicIdentityMixin,
- polymorphic_identity='custom_name', # Override auto-generated name
- table=True
-):
- pass
-```
-
-### 7. JWTAuthMixin
-
-Provides JWT token generation and validation for entity classes (User, Client).
-
-**Methods:**
-- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
-- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
-
-**Requirements:**
-Subclasses must define:
-- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
-- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
-
-**Usage:**
-```python
-from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
-
-class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
- JWTPayload = UserJWTPayload # Define payload model
- jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
-
- email: str
- is_admin: bool = False
- is_active: bool = True
- """User active status"""
-
-# Generate token
-async def login(session: AsyncSession, user: User) -> str:
- token = await user.issue_jwt(session)
- return token
-
-# Validate token
-async def verify(session: AsyncSession, token: str) -> User:
- user = await User.from_jwt(session, token)
- return user
-```
-
-### 8. Response DTO Mixins
-
-Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
-
-**Available Mixins:**
-- `IntIdInfoMixin` - Integer ID field
-- `UUIDIdInfoMixin` - UUID ID field
-- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
-- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
-- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
-
-**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
-
-**Usage:**
-```python
-from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
-
-class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
- """Character response DTO with id and timestamps"""
- pass # Inherits id, created_at, updated_at from mixin
-```
-
-## Complete Joined Table Inheritance Example
-
-Here's a complete example demonstrating polymorphic inheritance:
-
-```python
-from abc import ABC, abstractmethod
-from sqlmodels.base import SQLModelBase
-from sqlmodels.mixin import (
- UUIDTableBaseMixin,
- PolymorphicBaseMixin,
- create_subclass_id_mixin,
- AutoPolymorphicIdentityMixin,
-)
-
-# 1. Define Base class (fields only, no table)
-class ASRBase(SQLModelBase):
- name: str
- """Configuration name"""
-
- base_url: str
- """Service URL"""
-
-# 2. Define abstract parent class (with table)
-class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
- """Abstract base class for ASR configurations"""
- # PolymorphicBaseMixin automatically provides:
- # - _polymorphic_name field
- # - polymorphic_on='_polymorphic_name'
- # - polymorphic_abstract=True (when ABC with abstract methods)
-
- @abstractmethod
- async def transcribe(self, pcm_data: bytes) -> str:
- """Transcribe audio to text"""
- pass
-
-# 3. Create ID Mixin for second-level subclasses
-ASRSubclassIdMixin = create_subclass_id_mixin('asr')
-
-# 4. Create second-level abstract class (if needed)
-class FunASR(
- ASRSubclassIdMixin,
- ASR,
- AutoPolymorphicIdentityMixin,
- polymorphic_abstract=True
-):
- """FunASR abstract base (may have multiple implementations)"""
- pass
- # polymorphic_identity = 'funasr'
-
-# 5. Create concrete implementation classes
-class FunASRLocal(FunASR, table=True):
- """FunASR local deployment"""
- # polymorphic_identity = 'funasr.funasrlocal'
-
- async def transcribe(self, pcm_data: bytes) -> str:
- # Implementation...
- return "transcribed text"
-
-# 6. Get all concrete subclasses (for selectin_polymorphic)
-concrete_asrs = ASR.get_concrete_subclasses()
-# Returns: [FunASRLocal, ...]
-```
-
-## Import Guidelines
-
-**Standard Import:**
-```python
-from sqlmodels.mixin import (
- TableBaseMixin,
- UUIDTableBaseMixin,
- PolymorphicBaseMixin,
- TableViewRequest,
- create_subclass_id_mixin,
- AutoPolymorphicIdentityMixin,
- JWTAuthMixin,
- UUIDIdDatetimeInfoMixin,
- now,
- now_date,
-)
-```
-
-**Backward Compatibility:**
-Some exports are also available from `sqlmodels.base` for backward compatibility:
-```python
-# Legacy import path (still works)
-from sqlmodels.base import UUIDTableBase, TableViewRequest
-
-# Recommended new import path
-from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
-```
-
-## Best Practices
-
-### 1. Mixin Order Matters
-
-**Correct Order:**
-```python
-# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
-class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
- pass
-```
-
-**Wrong Order:**
-```python
-# ❌ ID Mixin not first - won't override parent's id field
-class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
- pass
-```
-
-### 2. Always Use Return Values from save() and update()
-
-```python
-# ✅ Correct - use returned instance
-device = await device.save(session)
-return device
-
-# ❌ Wrong - device is expired after commit
-await device.save(session)
-return device # AttributeError when accessing fields
-```
-
-### 3. Prefer table_view Over Manual Pagination
-
-```python
-# ✅ Recommended - consistent across all endpoints
-characters = await Character.get(
- session,
- fetch_mode="all",
- table_view=table_view
-)
-
-# ⚠️ Works but not recommended - manual parameter management
-characters = await Character.get(
- session,
- fetch_mode="all",
- offset=0,
- limit=20,
- order_by=[desc(Character.created_at)]
-)
-```
-
-### 4. Polymorphic Loading for Many Subclasses
-
-```python
-# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
-tool_set = await ToolSet.get(
- session,
- ToolSet.id == tool_set_id,
- load=ToolSet.tools,
- load_polymorphic='all' # Two-phase query - only loads actual related subclasses
-)
-
-# For fewer subclasses, specify the list explicitly
-tool_set = await ToolSet.get(
- session,
- ToolSet.id == tool_set_id,
- load=ToolSet.tools,
- load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
-)
-```
-
-### 5. Response DTOs Should Inherit Base Classes
-
-```python
-# ✅ Correct - inherits from CharacterBase
-class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
- pass
-
-# ❌ Wrong - doesn't inherit from CharacterBase
-class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
- name: str # Duplicated field definition
- description: str | None = None
-```
-
-**Reason:** Inheriting from Base classes ensures:
-- Type checking via `isinstance(obj, XxxBase)`
-- Consistency across related DTOs
-- Future field additions automatically propagate
-
-### 6. Use Specific Types, Not Containers
-
-```python
-# ✅ Correct - specific DTO for config updates
-class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
- weather_api_key: str | None = None
- default_location: str | None = None
- """Default location"""
-
-# ❌ Wrong - lose type safety
-class ToolUpdateRequest(SQLModelBase):
- config: dict[str, Any] # No field validation
-```
-
-## Type Variables
-
-```python
-from sqlmodels.mixin import T, M
-
-T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
-M = TypeVar("M", bound="SQLModel") # For update() method
-```
-
-## Utility Functions
-
-```python
-from sqlmodels.mixin import now, now_date
-
-# Lambda functions for default factories
-now = lambda: datetime.now()
-now_date = lambda: datetime.now().date()
-```
-
-## Related Modules
-
-- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
-- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
-- **sqlmodels.user** - User model with JWT authentication
-- **sqlmodels.client** - Client model with JWT authentication
-- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
-
-## Additional Resources
-
-- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
-- `CLAUDE.md` - Project coding standards and design philosophy
-- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
diff --git a/sqlmodels/mixin/__init__.py b/sqlmodels/mixin/__init__.py
deleted file mode 100644
index 832828a..0000000
--- a/sqlmodels/mixin/__init__.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""
-SQLModel Mixin模块
-
-提供各种Mixin类供SQLModel实体使用。
-
-包含:
-- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
-- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError)
-- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
-- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
-- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations)
-- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
-- info_response: InfoResponse DTO的id/时间戳Mixin
-
-导入顺序很重要,避免循环导入:
-1. polymorphic(只依赖 SQLModelBase)
-2. optimistic_lock(只依赖 SQLAlchemy)
-3. table(依赖 polymorphic 和 optimistic_lock)
-4. relation_preload(只依赖 SQLModelBase)
-
-注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
-而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
-"""
-# polymorphic 必须先导入
-from .polymorphic import (
- AutoPolymorphicIdentityMixin,
- PolymorphicBaseMixin,
- create_subclass_id_mixin,
- register_sti_column_properties_for_all_subclasses,
- register_sti_columns_for_all_subclasses,
-)
-# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前
-from .optimistic_lock import (
- OptimisticLockError,
- OptimisticLockMixin,
-)
-# table 依赖 polymorphic 和 optimistic_lock
-from .table import (
- ListResponse,
- PaginationRequest,
- T,
- TableBaseMixin,
- TableViewRequest,
- TimeFilterRequest,
- UUIDTableBaseMixin,
- now,
- now_date,
-)
-# relation_preload 只依赖 SQLModelBase
-from .relation_preload import (
- RelationPreloadMixin,
- requires_relations,
-)
-# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
-# 需要时直接从 sqlmodels.mixin.jwt 导入
-from .info_response import (
- DatetimeInfoMixin,
- IntIdDatetimeInfoMixin,
- IntIdInfoMixin,
- UUIDIdDatetimeInfoMixin,
- UUIDIdInfoMixin,
-)
diff --git a/sqlmodels/mixin/info_response.py b/sqlmodels/mixin/info_response.py
deleted file mode 100644
index f1e053e..0000000
--- a/sqlmodels/mixin/info_response.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""
-InfoResponse DTO Mixin模块
-
-提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。
-
-设计说明:
-- 这些Mixin用于**响应DTO**,不是数据库表
-- 从数据库返回时这些字段永远不为空,所以定义为必填字段
-- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成)
-- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
-"""
-from datetime import datetime
-from uuid import UUID
-
-from sqlmodels.base import SQLModelBase
-
-
-class IntIdInfoMixin(SQLModelBase):
- """整数ID响应mixin - 用于InfoResponse DTO"""
- id: int
- """记录ID"""
-
-
-class UUIDIdInfoMixin(SQLModelBase):
- """UUID ID响应mixin - 用于InfoResponse DTO"""
- id: UUID
- """记录ID"""
-
-
-class DatetimeInfoMixin(SQLModelBase):
- """时间戳响应mixin - 用于InfoResponse DTO"""
- created_at: datetime
- """创建时间"""
-
- updated_at: datetime
- """更新时间"""
-
-
-class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
- """整数ID + 时间戳响应mixin"""
- pass
-
-
-class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
- """UUID ID + 时间戳响应mixin"""
- pass
diff --git a/sqlmodels/mixin/optimistic_lock.py b/sqlmodels/mixin/optimistic_lock.py
deleted file mode 100644
index c9b7da5..0000000
--- a/sqlmodels/mixin/optimistic_lock.py
+++ /dev/null
@@ -1,90 +0,0 @@
-"""
-乐观锁 Mixin
-
-提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
-
-乐观锁适用场景:
-- 涉及"状态转换"的表(如:待支付 -> 已支付)
-- 涉及"数值变动"的表(如:余额、库存)
-
-不适用场景:
-- 日志表、纯插入表、低价值统计表
-- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
-
-使用示例:
- class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
- status: OrderStatusEnum
- amount: Decimal
-
- # save/update 时自动检查版本号
- # 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
- try:
- order = await order.save(session)
- except OptimisticLockError as e:
- # 处理冲突:重新查询并重试,或报错给用户
- l.warning(f"乐观锁冲突: {e}")
-"""
-from typing import ClassVar
-
-from sqlalchemy.orm.exc import StaleDataError
-
-
-class OptimisticLockError(Exception):
- """
- 乐观锁冲突异常
-
- 当 save/update 操作检测到版本号不匹配时抛出。
- 这意味着在读取和写入之间,其他事务已经修改了该记录。
-
- Attributes:
- model_class: 发生冲突的模型类名
- record_id: 记录 ID(如果可用)
- expected_version: 期望的版本号(如果可用)
- original_error: 原始的 StaleDataError
- """
-
- def __init__(
- self,
- message: str,
- model_class: str | None = None,
- record_id: str | None = None,
- expected_version: int | None = None,
- original_error: StaleDataError | None = None,
- ):
- super().__init__(message)
- self.model_class = model_class
- self.record_id = record_id
- self.expected_version = expected_version
- self.original_error = original_error
-
-
-class OptimisticLockMixin:
- """
- 乐观锁 Mixin
-
- 使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
- 每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
- session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。
-
- 原理:
- 1. 每条记录有一个 version 字段,初始值为 0
- 2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似:
- UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
- 3. 如果 WHERE 条件不匹配(version 已被其他事务修改),
- UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError
-
- 继承顺序:
- OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
- class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
- ...
-
- 配套重试:
- 如果加了乐观锁,业务层需要处理 OptimisticLockError:
- - 报错给用户:"数据已被修改,请刷新后重试"
- - 自动重试:重新查询最新数据并再次尝试
- """
- _has_optimistic_lock: ClassVar[bool] = True
- """标记此类启用了乐观锁"""
-
- version: int = 0
- """乐观锁版本号,每次更新自动递增"""
diff --git a/sqlmodels/mixin/polymorphic.py b/sqlmodels/mixin/polymorphic.py
deleted file mode 100644
index ba67275..0000000
--- a/sqlmodels/mixin/polymorphic.py
+++ /dev/null
@@ -1,710 +0,0 @@
-"""
-联表继承(Joined Table Inheritance)的通用工具
-
-提供用于简化SQLModel多态表设计的辅助函数和Mixin。
-
-Usage Example:
-
- from sqlmodels.base import SQLModelBase
- from sqlmodels.mixin import UUIDTableBaseMixin
- from sqlmodels.mixin.polymorphic import (
- PolymorphicBaseMixin,
- create_subclass_id_mixin,
- AutoPolymorphicIdentityMixin
- )
-
- # 1. 定义Base类(只有字段,无表)
- class ASRBase(SQLModelBase):
- name: str
- \"\"\"配置名称\"\"\"
-
- base_url: str
- \"\"\"服务地址\"\"\"
-
- # 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
- class ASR(
- ASRBase,
- UUIDTableBaseMixin,
- PolymorphicBaseMixin,
- ABC
- ):
- \"\"\"ASR配置的抽象基类\"\"\"
- # PolymorphicBaseMixin 自动提供:
- # - _polymorphic_name 字段
- # - polymorphic_on='_polymorphic_name'
- # - polymorphic_abstract=True(当有抽象方法时)
-
- # 3. 为第二层子类创建ID Mixin
- ASRSubclassIdMixin = create_subclass_id_mixin('asr')
-
- # 4. 创建第二层抽象类(如果需要)
- class FunASR(
- ASRSubclassIdMixin,
- ASR,
- AutoPolymorphicIdentityMixin,
- polymorphic_abstract=True
- ):
- \"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
- pass
-
- # 5. 创建具体实现类
- class FunASRLocal(FunASR, table=True):
- \"\"\"FunASR本地部署版本\"\"\"
- # polymorphic_identity 会自动设置为 'asr.funasrlocal'
- pass
-
- # 6. 获取所有具体子类(用于 selectin_polymorphic)
- concrete_asrs = ASR.get_concrete_subclasses()
- # 返回 [FunASRLocal, ...]
-"""
-import uuid
-from abc import ABC
-from uuid import UUID
-
-from loguru import logger as l
-from pydantic.fields import FieldInfo
-from pydantic_core import PydanticUndefined
-from sqlalchemy import Column, String, inspect
-from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
-from sqlalchemy.orm.attributes import InstrumentedAttribute
-from sqlmodel import Field
-from sqlmodel.main import get_column_from_field
-
-from sqlmodels.base.sqlmodel_base import SQLModelBase
-
-# 用于延迟注册 STI 子类列的队列
-# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
-_sti_subclasses_to_register: list[type] = []
-
-
-def register_sti_columns_for_all_subclasses() -> None:
- """
- 为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
-
- 此函数应在 configure_mappers() 之前调用。
- 将 STI 子类的字段添加到父表的 metadata 中。
- 同时修复被 Column 对象污染的 model_fields。
- """
- for cls in _sti_subclasses_to_register:
- try:
- cls._register_sti_columns()
- except Exception as e:
- l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
-
- # 修复被 Column 对象污染的 model_fields
- # 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
- try:
- _fix_polluted_model_fields(cls)
- except Exception as e:
- l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
-
-
-def register_sti_column_properties_for_all_subclasses() -> None:
- """
- 为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
-
- 此函数应在 configure_mappers() 之后调用。
- 将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
- """
- for cls in _sti_subclasses_to_register:
- try:
- cls._register_sti_column_properties()
- except Exception as e:
- l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
-
- # 清空队列
- _sti_subclasses_to_register.clear()
-
-
-def _fix_polluted_model_fields(cls: type) -> None:
- """
- 修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
-
- 当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
- 或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
- 时错误地使用这些 SQLAlchemy 对象作为默认值。
-
- 此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
-
- :param cls: 要修复的类
- """
- if not hasattr(cls, 'model_fields'):
- return
-
- def find_original_field_info(field_name: str) -> FieldInfo | None:
- """从 MRO 中查找字段的原始定义(未被污染的)"""
- for base in cls.__mro__[1:]: # 跳过自己
- if hasattr(base, 'model_fields') and field_name in base.model_fields:
- field_info = base.model_fields[field_name]
- # 跳过被 InstrumentedAttribute 或 Column 污染的
- if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
- return field_info
- return None
-
- for field_name, current_field in cls.model_fields.items():
- # 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
- # Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
- # 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
- if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
- continue # 未被污染,跳过
-
- # 从父类查找原始定义
- original = find_original_field_info(field_name)
- if original is None:
- continue # 找不到原始定义,跳过
-
- # 根据原始定义的 default/default_factory 来修复
- if original.default_factory:
- # 有 default_factory(如 uuid.uuid4, now)
- new_field = FieldInfo(
- default_factory=original.default_factory,
- annotation=current_field.annotation,
- json_schema_extra=current_field.json_schema_extra,
- )
- elif original.default is not PydanticUndefined:
- # 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
- # PydanticUndefined 表示字段没有默认值(必填)
- new_field = FieldInfo(
- default=original.default,
- annotation=current_field.annotation,
- json_schema_extra=current_field.json_schema_extra,
- )
- else:
- continue # 既没有 default_factory 也没有有效的 default,跳过
-
- # 复制 SQLModel 特有的属性
- if hasattr(current_field, 'foreign_key'):
- new_field.foreign_key = current_field.foreign_key
- if hasattr(current_field, 'primary_key'):
- new_field.primary_key = current_field.primary_key
-
- cls.model_fields[field_name] = new_field
-
-
-def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
- """
- 动态创建SubclassIdMixin类
-
- 在联表继承中,子类需要一个外键指向父表的主键。
- 此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
-
- Args:
- parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
-
- Returns:
- 一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
-
- Example:
- >>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
- >>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
- ... pass
-
- Note:
- - 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
- - 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
- - 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
- """
- if not parent_table_name:
- raise ValueError("parent_table_name 不能为空")
-
- # 转换为PascalCase作为类名
- class_name_parts = parent_table_name.split('_')
- class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
-
- # 使用闭包捕获parent_table_name
- _parent_table_name = parent_table_name
-
- # 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
- class SubclassIdMixin(SQLModelBase):
- # 定义id字段
- id: UUID = Field(
- default_factory=uuid.uuid4,
- foreign_key=f'{_parent_table_name}.id',
- primary_key=True,
- )
-
- @classmethod
- def __pydantic_init_subclass__(cls, **kwargs):
- """
- Pydantic v2 的子类初始化钩子,在模型完全构建后调用
-
- 修复联表继承中子类字段的 default_factory 丢失问题。
- SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
- 导致 INSERT 语句中出现 `table.column` 引用而非实际值。
- """
- super().__pydantic_init_subclass__(**kwargs)
- _fix_polluted_model_fields(cls)
-
- # 设置类名和文档
- SubclassIdMixin.__name__ = class_name
- SubclassIdMixin.__qualname__ = class_name
- SubclassIdMixin.__doc__ = f"""
- {parent_table_name}子类的ID Mixin
-
- 用于{parent_table_name}的子类,提供外键指向父表。
- 通过MRO确保此id字段覆盖继承的id字段。
- """
-
- return SubclassIdMixin
-
-
-class AutoPolymorphicIdentityMixin:
- """
- 自动生成polymorphic_identity的Mixin,并支持STI子类列注册
-
- 使用此Mixin的类会自动根据类名生成polymorphic_identity。
- 格式:{parent_polymorphic_identity}.{classname_lowercase}
-
- 如果没有父类的polymorphic_identity,则直接使用类名小写。
-
- **重要:数据库迁移注意事项**
-
- 编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
-
- 例如,对于以下继承链::
-
- LLM (polymorphic_on='_polymorphic_name')
- └── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
- └── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
-
- 迁移脚本中设置 _polymorphic_name 时::
-
- # ❌ 错误:缺少父类前缀
- UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
-
- # ✅ 正确:包含完整的继承链前缀
- UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
-
- **STI(单表继承)支持**:
- 当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
- 添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
- 注册到父表的问题。
-
- Example (JTI):
- >>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
- ... __polymorphic_name: str
- ...
- >>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
- ... pass
- ... # polymorphic_identity 会自动设置为 'function'
- ...
- >>> class CodeInterpreterFunction(Function, table=True):
- ... pass
- ... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
-
- Example (STI):
- >>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
- ... user_id: UUID
- ...
- >>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
- ... upload_deadline: datetime | None = None # 自动添加到 userfile 表
- ... # polymorphic_identity 会自动设置为 'pendingfile'
-
- Note:
- - 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
- - 此Mixin应该在继承列表中靠后的位置(在表基类之前)
- - STI模式下,新字段会在类定义时自动添加到父表的metadata中
- """
-
- def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
- """
- 子类化钩子,自动生成polymorphic_identity并处理STI列注册
-
- Args:
- polymorphic_identity: 如果手动指定,则使用指定的值
- **kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
- """
- super().__init_subclass__(**kwargs)
-
- # 如果手动指定了polymorphic_identity,使用指定的值
- if polymorphic_identity is not None:
- identity = polymorphic_identity
- else:
- # 自动生成polymorphic_identity
- class_name = cls.__name__.lower()
-
- # 尝试从父类获取polymorphic_identity作为前缀
- parent_identity = None
- for base in cls.__mro__[1:]: # 跳过自己
- if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
- parent_identity = base.__mapper_args__.get('polymorphic_identity')
- if parent_identity:
- break
-
- # 构建identity
- if parent_identity:
- identity = f'{parent_identity}.{class_name}'
- else:
- identity = class_name
-
- # 设置到__mapper_args__
- if '__mapper_args__' not in cls.__dict__:
- cls.__mapper_args__ = {}
-
- # 只在尚未设置polymorphic_identity时设置
- if 'polymorphic_identity' not in cls.__mapper_args__:
- cls.__mapper_args__['polymorphic_identity'] = identity
-
- # 注册 STI 子类列的延迟执行
- # 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
- # 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
- _sti_subclasses_to_register.append(cls)
-
- @classmethod
- def __pydantic_init_subclass__(cls, **kwargs):
- """
- Pydantic v2 的子类初始化钩子,在模型完全构建后调用
-
- 修复 STI 继承中子类字段被 Column 对象污染的问题。
- 当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
- SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
- """
- super().__pydantic_init_subclass__(**kwargs)
- _fix_polluted_model_fields(cls)
-
- @classmethod
- def _register_sti_columns(cls) -> None:
- """
- 将STI子类的新字段注册到父表的列定义中
-
- 检测当前类是否是STI子类(与父类共用同一张表),
- 如果是,则将子类定义的新字段添加到父表的metadata中。
-
- JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
- """
- # 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
- parent_table = None
- parent_fields: set[str] = set()
-
- for base in cls.__mro__[1:]:
- if hasattr(base, '__table__') and base.__table__ is not None:
- parent_table = base.__table__
- # 收集父类的所有字段名
- if hasattr(base, 'model_fields'):
- parent_fields.update(base.model_fields.keys())
- break
-
- if parent_table is None:
- return # 没有找到父表,可能是根类
-
- # JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
- # JTI 类有自己独立的表,不需要将列注册到父表
- if hasattr(cls, '__table__') and cls.__table__ is not None:
- if cls.__table__.name != parent_table.name:
- return # JTI,跳过 STI 列注册
-
- # 获取当前类的新字段(不在父类中的字段)
- if not hasattr(cls, 'model_fields'):
- return
-
- existing_columns = {col.name for col in parent_table.columns}
-
- for field_name, field_info in cls.model_fields.items():
- # 跳过从父类继承的字段
- if field_name in parent_fields:
- continue
-
- # 跳过私有字段和ClassVar
- if field_name.startswith('_'):
- continue
-
- # 跳过已存在的列
- if field_name in existing_columns:
- continue
-
- # 使用 SQLModel 的内置 API 创建列
- try:
- column = get_column_from_field(field_info)
- column.name = field_name
- column.key = field_name
- # STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
- # Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
- column.nullable = True
-
- # 将列添加到父表
- parent_table.append_column(column)
- except Exception as e:
- l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
-
- @classmethod
- def _register_sti_column_properties(cls) -> None:
- """
- 将 STI 子类的列作为 ColumnProperty 添加到 mapper
-
- 此方法在 configure_mappers() 之后调用,将已添加到表中的列
- 注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
-
- **重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
- 这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
- 避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
-
- JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
- """
- # 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
- parent_table = None
- parent_class = None
- for base in cls.__mro__[1:]:
- if hasattr(base, '__table__') and base.__table__ is not None:
- parent_table = base.__table__
- parent_class = base
- break
-
- if parent_table is None:
- return # 没有找到父表,可能是根类
-
- # JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
- # JTI 类有自己独立的表,不需要将列属性注册到 mapper
- if hasattr(cls, '__table__') and cls.__table__ is not None:
- if cls.__table__.name != parent_table.name:
- return # JTI,跳过 STI 列属性注册
-
- # 获取子类和父类的 mapper
- child_mapper = inspect(cls).mapper
- parent_mapper = inspect(parent_class).mapper
- local_table = child_mapper.local_table
-
- # 查找父类的所有字段名
- parent_fields: set[str] = set()
- if hasattr(parent_class, 'model_fields'):
- parent_fields.update(parent_class.model_fields.keys())
-
- if not hasattr(cls, 'model_fields'):
- return
-
- # 获取两个 mapper 已有的列属性
- child_existing_props = {p.key for p in child_mapper.column_attrs}
- parent_existing_props = {p.key for p in parent_mapper.column_attrs}
-
- for field_name in cls.model_fields:
- # 跳过从父类继承的字段
- if field_name in parent_fields:
- continue
-
- # 跳过私有字段
- if field_name.startswith('_'):
- continue
-
- # 检查表中是否有这个列
- if field_name not in local_table.columns:
- continue
-
- column = local_table.columns[field_name]
-
- # 添加到子类的 mapper(如果尚不存在)
- if field_name not in child_existing_props:
- try:
- prop = ColumnProperty(column)
- child_mapper.add_property(field_name, prop)
- except Exception as e:
- l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
-
- # 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
- if field_name not in parent_existing_props:
- try:
- prop = ColumnProperty(column)
- parent_mapper.add_property(field_name, prop)
- except Exception as e:
- l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
-
-
-class PolymorphicBaseMixin:
- """
- 为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
-
- 此 Mixin 自动设置以下内容:
- - `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- - `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- - `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
-
- 使用场景:
- 适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
-
- 用法示例:
- ```python
- from abc import ABC
- from sqlmodels.mixin import UUIDTableBaseMixin
- from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
-
- # 定义基类
- class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
- __tablename__ = 'mytool'
-
- # 不需要手动定义 _polymorphic_name
- # 不需要手动设置 polymorphic_on
- # 不需要手动设置 polymorphic_abstract
-
- # 定义子类
- class SpecificTool(MyTool):
- __tablename__ = 'specifictool'
-
- # 会自动继承 polymorphic 配置
- ```
-
- 自动行为:
- 1. 定义 `_polymorphic_name: str` 字段(带索引)
- 2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
- 3. 自动检测抽象类:
- - 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- - 否则设置为 False
-
- 手动覆盖:
- 可以在类定义时手动指定参数来覆盖自动行为:
- ```python
- class MyTool(
- UUIDTableBaseMixin,
- PolymorphicBaseMixin,
- ABC,
- polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
- polymorphic_abstract=False # 强制不设为抽象类
- ):
- pass
- ```
-
- 注意事项:
- - 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- - 适用于联表继承(joined table inheritance)场景
- - 子类会自动继承 _polymorphic_name 字段定义
- - 使用单下划线前缀是因为:
- * SQLAlchemy 会映射单下划线字段为数据库列
- * Pydantic 将其视为私有属性,不参与序列化
- * 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
- """
-
- # 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
- #
- # 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
- #
- # 为什么这样做:
- # 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
- # 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
- # 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
- # 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
- # 5. 内部代码仍然可以正常访问和修改此字段
- #
- # 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
- _polymorphic_name: Mapped[str] = mapped_column(String, index=True)
- """
- 多态鉴别器字段,用于标识具体的子类类型
-
- 注意:此字段使用单下划线前缀,表示内部使用。
- - ✅ 存储到数据库
- - ✅ 不出现在 API 序列化中
- - ✅ 防止外部直接修改
- """
-
- def __init_subclass__(
- cls,
- polymorphic_on: str | None = None,
- polymorphic_abstract: bool | None = None,
- **kwargs
- ):
- """
- 在子类定义时自动配置 polymorphic 设置
-
- Args:
- polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
- 设置为其他值可以使用不同的字段作为多态鉴别器。
- polymorphic_abstract: 是否为抽象类。
- - None: 自动检测(默认)
- - True: 强制设为抽象类
- - False: 强制设为非抽象类
- **kwargs: 传递给父类的其他参数
- """
- super().__init_subclass__(**kwargs)
-
- # 初始化 __mapper_args__(如果还没有)
- if '__mapper_args__' not in cls.__dict__:
- cls.__mapper_args__ = {}
-
- # 设置 polymorphic_on(默认为 _polymorphic_name)
- if 'polymorphic_on' not in cls.__mapper_args__:
- cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
-
- # 自动检测或设置 polymorphic_abstract
- if 'polymorphic_abstract' not in cls.__mapper_args__:
- if polymorphic_abstract is None:
- # 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
- has_abc = ABC in cls.__mro__
- has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
- polymorphic_abstract = has_abc and has_abstract_methods
-
- cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
-
- @classmethod
- def _is_joined_table_inheritance(cls) -> bool:
- """
- 检测当前类是否使用联表继承(Joined Table Inheritance)
-
- 通过检查子类是否有独立的表来判断:
- - JTI: 子类有独立的 local_table(与父类不同)
- - STI: 子类与父类共用同一个 local_table
-
- :return: True 表示 JTI,False 表示 STI 或无子类
- """
- mapper = inspect(cls)
- base_table_name = mapper.local_table.name
-
- # 检查所有直接子类
- for subclass in cls.__subclasses__():
- sub_mapper = inspect(subclass)
- # 如果任何子类有不同的表名,说明是 JTI
- if sub_mapper.local_table.name != base_table_name:
- return True
-
- return False
-
- @classmethod
- def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
- """
- 递归获取当前类的所有具体(非抽象)子类
-
- 用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
- 可在任意多态基类上调用,返回该类的所有非抽象子类。
-
- :return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
- """
- result: list[type[PolymorphicBaseMixin]] = []
- for subclass in cls.__subclasses__():
- # 使用 inspect() 获取 mapper 的公开属性
- # 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
- mapper = inspect(subclass)
- if not mapper.polymorphic_abstract:
- result.append(subclass)
- # 无论是否抽象,都需要递归(抽象类可能有具体子类)
- if hasattr(subclass, 'get_concrete_subclasses'):
- result.extend(subclass.get_concrete_subclasses())
- return result
-
- @classmethod
- def get_polymorphic_discriminator(cls) -> str:
- """
- 获取多态鉴别字段名
-
- 使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
-
- :return: 多态鉴别字段名(如 '_polymorphic_name')
- :raises ValueError: 如果类未配置 polymorphic_on
- """
- polymorphic_on = inspect(cls).polymorphic_on
- if polymorphic_on is None:
- raise ValueError(
- f"{cls.__name__} 未配置 polymorphic_on,"
- f"请确保正确继承 PolymorphicBaseMixin"
- )
- return polymorphic_on.key
-
- @classmethod
- def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
- """
- 获取 polymorphic_identity 到具体子类的映射
-
- 包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
-
- :return: identity 到子类的映射字典
- """
- result: dict[str, type[PolymorphicBaseMixin]] = {}
- for subclass in cls.get_concrete_subclasses():
- identity = inspect(subclass).polymorphic_identity
- if identity:
- result[identity] = subclass
- return result
diff --git a/sqlmodels/mixin/relation_preload.py b/sqlmodels/mixin/relation_preload.py
deleted file mode 100644
index 624018f..0000000
--- a/sqlmodels/mixin/relation_preload.py
+++ /dev/null
@@ -1,470 +0,0 @@
-"""
-关系预加载 Mixin
-
-提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
-
-设计原则:
-- 按需加载:只加载被调用方法需要的关系
-- 增量加载:已加载的关系不重复加载
-- 查询最优:相同关系只查询一次,不同关系增量查询
-- 零侵入:调用方无需任何改动
-- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
-
-使用方式:
- from sqlmodels.mixin import RelationPreloadMixin, requires_relations
-
- class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
- kling_video_generator: KlingO1Generator = Relationship(...)
-
- @requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
- async def cost(self, params, context, session) -> ToolCost:
- # 自动加载,可以安全访问
- price = self.kling_video_generator.kling_o1.pro_price_per_second
- ...
-
- # 调用方 - 无需任何改动
- await tool.cost(params, context, session) # 自动加载 cost 需要的关系
- await tool._call(...) # 关系相同则跳过,否则增量加载
-
-支持 AsyncGenerator:
- @requires_relations('twitter_api')
- async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
- yield ToolResponse(...) # 装饰器正确处理 async generator
-"""
-import inspect as python_inspect
-from functools import wraps
-from typing import Callable, TypeVar, ParamSpec, Any
-
-from loguru import logger as l
-from sqlalchemy import inspect as sa_inspect
-from sqlmodel.ext.asyncio.session import AsyncSession
-from sqlmodel.main import RelationshipInfo
-
-P = ParamSpec('P')
-R = TypeVar('R')
-
-
-def _extract_session(
- func: Callable,
- args: tuple[Any, ...],
- kwargs: dict[str, Any],
-) -> AsyncSession | None:
- """
- 从方法参数中提取 AsyncSession
-
- 按以下顺序查找:
- 1. kwargs 中名为 'session' 的参数
- 2. 根据函数签名定位 'session' 参数的位置,从 args 提取
- 3. kwargs 中类型为 AsyncSession 的参数
- """
- # 1. 优先从 kwargs 查找
- if 'session' in kwargs:
- return kwargs['session']
-
- # 2. 从函数签名定位位置参数
- try:
- sig = python_inspect.signature(func)
- param_names = list(sig.parameters.keys())
-
- if 'session' in param_names:
- # 计算位置(减去 self)
- idx = param_names.index('session') - 1
- if 0 <= idx < len(args):
- return args[idx]
- except (ValueError, TypeError):
- pass
-
- # 3. 遍历 kwargs 找 AsyncSession 类型
- for value in kwargs.values():
- if isinstance(value, AsyncSession):
- return value
-
- return None
-
-
-def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
- """
- 检查对象的关系是否已加载(独立函数版本)
-
- Args:
- obj: 要检查的对象
- rel_name: 关系属性名
-
- Returns:
- True 如果关系已加载,False 如果未加载或已过期
- """
- try:
- state = sa_inspect(obj)
- return rel_name not in state.unloaded
- except Exception:
- return False
-
-
-def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
- """
- 在类中查找指向目标类的关系属性名
-
- Args:
- from_class: 源类
- to_class: 目标类
-
- Returns:
- 关系属性名,如果找不到则返回 None
-
- Example:
- _find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
- # 返回 'kling_video_generator'
- """
- for attr_name in dir(from_class):
- try:
- attr = getattr(from_class, attr_name, None)
- if attr is None:
- continue
- # 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
- # parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
- if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
- target_class = attr.property.mapper.class_
- if target_class == to_class:
- return attr_name
- except AttributeError:
- continue
- return None
-
-
-def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
- """
- 装饰器:声明方法需要的关系,自动按需增量加载
-
- 参数格式:
- - 字符串:本类属性名,如 'kling_video_generator'
- - RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
-
- 行为:
- - 方法调用时自动检查关系是否已加载
- - 未加载的关系会被增量加载(单次查询)
- - 已加载的关系直接跳过
-
- 支持:
- - 普通 async 方法:`async def cost(...) -> ToolCost`
- - AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
-
- Example:
- @requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
- async def cost(self, params, context, session) -> ToolCost:
- # self.kling_video_generator.kling_o1 已自动加载
- ...
-
- @requires_relations('twitter_api')
- async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
- yield ToolResponse(...) # AsyncGenerator 正确处理
-
- 验证:
- - 字符串格式的关系名在类创建时(__init_subclass__)验证
- - 拼写错误会在导入时抛出 AttributeError
- """
- def decorator(func: Callable[P, R]) -> Callable[P, R]:
- # 检测是否是 async generator 函数
- is_async_gen = python_inspect.isasyncgenfunction(func)
-
- if is_async_gen:
- # AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
- @wraps(func)
- async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
- session = _extract_session(func, args, kwargs)
- if session is not None:
- await self._ensure_relations_loaded(session, relations)
- # 委托给原始 async generator,逐个 yield 值
- async for item in func(self, *args, **kwargs):
- yield item # type: ignore
- else:
- # 普通 async 函数:await 并返回结果
- @wraps(func)
- async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
- session = _extract_session(func, args, kwargs)
- if session is not None:
- await self._ensure_relations_loaded(session, relations)
- return await func(self, *args, **kwargs)
-
- # 保存关系声明供验证和内省使用
- wrapper._required_relations = relations # type: ignore
- return wrapper
-
- return decorator
-
-
-class RelationPreloadMixin:
- """
- 关系预加载 Mixin
-
- 提供按需增量加载能力,确保 SQL 查询数理论最优。
-
- 特性:
- - 按需加载:只加载被调用方法需要的关系
- - 增量加载:已加载的关系不重复加载
- - 原地更新:直接修改 self,无需替换实例
- - 导入时验证:字符串关系名在类创建时验证
- - Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
- """
-
- def __init_subclass__(cls, **kwargs) -> None:
- """类创建时验证所有 @requires_relations 声明"""
- super().__init_subclass__(**kwargs)
-
- # 收集类及其父类的所有注解(包含普通字段)
- all_annotations: set[str] = set()
- for klass in cls.__mro__:
- if hasattr(klass, '__annotations__'):
- all_annotations.update(klass.__annotations__.keys())
-
- # 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
- sqlmodel_relationships: set[str] = set()
- for klass in cls.__mro__:
- if hasattr(klass, '__sqlmodel_relationships__'):
- sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
-
- # 合并所有可用的属性名
- all_available_names = all_annotations | sqlmodel_relationships
-
- for method_name in dir(cls):
- if method_name.startswith('__'):
- continue
-
- try:
- method = getattr(cls, method_name, None)
- except AttributeError:
- continue
-
- if method is None or not hasattr(method, '_required_relations'):
- continue
-
- # 验证字符串格式的关系名
- for spec in method._required_relations:
- if isinstance(spec, str):
- # 检查注解、Relationship 或已有属性
- if spec not in all_available_names and not hasattr(cls, spec):
- raise AttributeError(
- f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
- f"但 {cls.__name__} 没有此属性"
- )
-
- def _is_relation_loaded(self, rel_name: str) -> bool:
- """
- 检查关系是否真正已加载(基于 SQLAlchemy inspect)
-
- 使用 SQLAlchemy 的 inspect 检测真实加载状态,
- 自动处理 commit 导致的 expire 问题。
-
- Args:
- rel_name: 关系属性名
-
- Returns:
- True 如果关系已加载,False 如果未加载或已过期
- """
- try:
- state = sa_inspect(self)
- # unloaded 包含未加载的关系属性名
- return rel_name not in state.unloaded
- except Exception:
- # 对象可能未被 SQLAlchemy 管理
- return False
-
- async def _ensure_relations_loaded(
- self,
- session: AsyncSession,
- relations: tuple[str | RelationshipInfo, ...],
- ) -> None:
- """
- 确保指定关系已加载,只加载未加载的部分
-
- 基于 SQLAlchemy inspect 检测真实状态,自动处理:
- - 首次访问的关系
- - commit 后 expire 的关系
- - 嵌套关系(如 KlingO1Generator.kling_o1)
-
- Args:
- session: 数据库会话
- relations: 需要的关系规格
- """
- # 找出真正未加载的关系(基于 SQLAlchemy inspect)
- to_load: list[str | RelationshipInfo] = []
- # 区分直接关系和嵌套关系的 key
- direct_keys: set[str] = set() # 本类的直接关系属性名
- nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
-
- for rel in relations:
- if isinstance(rel, str):
- # 直接关系:检查本类的关系是否已加载
- if not self._is_relation_loaded(rel):
- to_load.append(rel)
- direct_keys.add(rel)
- else:
- # 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
- # 1. 查找指向父类的关系属性
- parent_class = rel.parent.class_
- parent_attr = _find_relation_to_class(self.__class__, parent_class)
-
- if parent_attr is None:
- # 找不到路径,可能是配置错误,但仍尝试加载
- l.warning(
- f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
- f"无法检查 {rel.key} 是否已加载"
- )
- to_load.append(rel)
- continue
-
- # 2. 检查父对象是否已加载
- if not self._is_relation_loaded(parent_attr):
- # 父对象未加载,需要同时加载父对象和嵌套关系
- if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
- to_load.append(parent_attr)
- nested_parent_keys.add(parent_attr)
- to_load.append(rel)
- else:
- # 3. 父对象已加载,检查嵌套关系是否已加载
- parent_obj = getattr(self, parent_attr)
- if not _is_obj_relation_loaded(parent_obj, rel.key):
- # 嵌套关系未加载:需要同时传递父关系和嵌套关系
- # 因为 _build_load_chains 需要完整的链来构建 selectinload
- if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
- to_load.append(parent_attr)
- nested_parent_keys.add(parent_attr)
- to_load.append(rel)
-
- if not to_load:
- return # 全部已加载,跳过
-
- # 构建 load 参数
- load_options = self._specs_to_load_options(to_load)
- if not load_options:
- return
-
- # 安全地获取主键值(避免触发懒加载)
- state = sa_inspect(self)
- pk_tuple = state.key[1] if state.key else None
- if pk_tuple is None:
- l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
- return
- # 主键是元组,取第一个值(假设单列主键)
- pk_value = pk_tuple[0]
-
- # 单次查询加载缺失的关系
- fresh = await self.__class__.get(
- session,
- self.__class__.id == pk_value,
- load=load_options,
- )
-
- if fresh is None:
- l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
- return
-
- # 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
- all_direct_keys = direct_keys | nested_parent_keys
- for key in all_direct_keys:
- value = getattr(fresh, key, None)
- object.__setattr__(self, key, value)
-
- def _specs_to_load_options(
- self,
- specs: list[str | RelationshipInfo],
- ) -> list[RelationshipInfo]:
- """
- 将关系规格转换为 load 参数
-
- - 字符串 → cls.{name}
- - RelationshipInfo → 直接使用
- """
- result: list[RelationshipInfo] = []
-
- for spec in specs:
- if isinstance(spec, str):
- rel = getattr(self.__class__, spec, None)
- if rel is not None:
- result.append(rel)
- else:
- l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
- else:
- result.append(spec)
-
- return result
-
- # ==================== 可选的手动预加载 API ====================
-
- @classmethod
- def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
- """
- 获取指定方法声明的关系(用于外部预加载场景)
-
- Args:
- method_name: 方法名
-
- Returns:
- RelationshipInfo 列表
- """
- method = getattr(cls, method_name, None)
- if method is None or not hasattr(method, '_required_relations'):
- return []
-
- result: list[RelationshipInfo] = []
- for spec in method._required_relations:
- if isinstance(spec, str):
- rel = getattr(cls, spec, None)
- if rel:
- result.append(rel)
- else:
- result.append(spec)
-
- return result
-
- @classmethod
- def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
- """
- 获取多个方法的关系并去重(用于批量预加载场景)
-
- Args:
- method_names: 方法名列表
-
- Returns:
- 去重后的 RelationshipInfo 列表
- """
- seen: set[str] = set()
- result: list[RelationshipInfo] = []
-
- for method_name in method_names:
- for rel in cls.get_relations_for_method(method_name):
- key = rel.key
- if key not in seen:
- seen.add(key)
- result.append(rel)
-
- return result
-
- async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
- """
- 手动预加载指定方法的关系(可选优化 API)
-
- 当需要确保在调用方法前完成所有加载时使用。
- 通常情况下不需要调用此方法,装饰器会自动处理。
-
- Args:
- session: 数据库会话
- method_names: 方法名列表
-
- Returns:
- self(支持链式调用)
-
- Example:
- # 可选:显式预加载(通常不需要)
- tool = await tool.preload_for(session, 'cost', '_call')
- """
- all_relations: list[str | RelationshipInfo] = []
-
- for method_name in method_names:
- method = getattr(self.__class__, method_name, None)
- if method and hasattr(method, '_required_relations'):
- all_relations.extend(method._required_relations)
-
- if all_relations:
- await self._ensure_relations_loaded(session, tuple(all_relations))
-
- return self
diff --git a/sqlmodels/mixin/table.py b/sqlmodels/mixin/table.py
deleted file mode 100644
index cd6c830..0000000
--- a/sqlmodels/mixin/table.py
+++ /dev/null
@@ -1,1247 +0,0 @@
-"""
-表基类 Mixin
-
-提供 TableBaseMixin、UUIDTableBaseMixin 和 TableViewRequest。
-这些类实际上是 Mixin,为 SQLModel 模型提供 CRUD 操作和时间戳字段。
-
-依赖关系:
- base/sqlmodel_base.py ← 最底层
- ↓
- mixin/polymorphic.py ← 定义 PolymorphicBaseMixin
- ↓
- mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin
- ↓
- base/__init__.py ← 从 mixin 重新导出(保持向后兼容)
-
-维护须知:
- 增删功能时必须更新 __version__ 字段(遵循语义化版本)
-
-版本历史:
- 0.1.0 - delete() 方法支持条件删除(condition 参数)
-"""
-__version__ = "0.1.0"
-import uuid
-from datetime import datetime
-from typing import TypeVar, Literal, override, Any, ClassVar, Generic
-
-# TODO(ListResponse泛型问题): SQLModel泛型类型JSON Schema生成bug
-# 已知问题: https://github.com/fastapi/sqlmodel/discussions/1002
-# 修复PR: https://github.com/fastapi/sqlmodel/pull/1275 (尚未合并)
-# 现象: SQLModel + Generic[T] 的 __pydantic_generic_metadata__ = {origin: None, args: ()}
-# 导致OpenAPI schema中泛型字段显示为{}而非正确的$ref
-# 当前方案: ListResponse继承BaseModel而非SQLModel (Discussion #1002推荐的workaround)
-# 未来: PR #1275合并后可改回继承SQLModelBase
-from pydantic import BaseModel, ConfigDict
-from fastapi import HTTPException
-from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct, delete as sql_delete, inspect
-from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
-from sqlalchemy.orm.exc import StaleDataError
-from sqlmodel import Field, select
-
-from .optimistic_lock import OptimisticLockError
-from sqlmodel.ext.asyncio.session import AsyncSession
-from sqlalchemy.sql._typing import _OnClauseArgument
-from sqlalchemy.ext.asyncio import AsyncAttrs
-from sqlmodel.main import RelationshipInfo
-
-from .polymorphic import PolymorphicBaseMixin
-from sqlmodels.base.sqlmodel_base import SQLModelBase
-
-# Type variables for generic type hints, improving code completion and analysis.
-T = TypeVar("T", bound="TableBaseMixin")
-M = TypeVar("M", bound="SQLModelBase")
-ItemT = TypeVar("ItemT")
-
-
-class ListResponse(BaseModel, Generic[ItemT]):
- """
- 泛型分页响应
-
- 用于所有LIST端点的标准化响应格式,包含记录总数和项目列表。
- 与 TableBaseMixin.get_with_count() 配合使用。
-
- 使用示例:
- ```python
- @router.get("", response_model=ListResponse[CharacterInfoResponse])
- async def list_characters(...) -> ListResponse[Character]:
- return await Character.get_with_count(session, table_view=table_view)
- ```
-
- Attributes:
- count: 符合条件的记录总数(用于分页计算)
- items: 当前页的记录列表
-
- Note:
- 继承BaseModel而非SQLModelBase,因为SQLModel的metaclass与Generic冲突。
- 详见文件顶部TODO注释。
- """
- # 与SQLModelBase保持一致的配置
- model_config = ConfigDict(use_attribute_docstrings=True)
-
- count: int
- """符合条件的记录总数"""
-
- items: list[ItemT]
- """当前页的记录列表"""
-
-
-# Lambda functions to get the current time, used as default factories in model fields.
-now = lambda: datetime.now()
-now_date = lambda: datetime.now().date()
-
-
-# ==================== 查询参数请求类 ====================
-
-class TimeFilterRequest(SQLModelBase):
- """
- 时间筛选请求参数
-
- 用于 count() 等只需要时间筛选的场景。
- 纯数据类,只负责参数校验和携带,SQL子句构建由 TableBaseMixin 负责。
-
- Raises:
- ValueError: 时间范围无效
- """
- created_after_datetime: datetime | None = None
- """创建时间起始筛选(created_at >= datetime),如果为None则不限制"""
-
- created_before_datetime: datetime | None = None
- """创建时间结束筛选(created_at < datetime),如果为None则不限制"""
-
- updated_after_datetime: datetime | None = None
- """更新时间起始筛选(updated_at >= datetime),如果为None则不限制"""
-
- updated_before_datetime: datetime | None = None
- """更新时间结束筛选(updated_at < datetime),如果为None则不限制"""
-
- def model_post_init(self, __context: Any) -> None:
- """
- 验证时间范围有效性
-
- 验证规则:
- 1. 同类型:after 必须小于 before
- 2. 跨类型:created_after 不能大于 updated_before(记录不可能在创建前被更新)
- """
- # 同类型矛盾验证
- if self.created_after_datetime and self.created_before_datetime:
- if self.created_after_datetime >= self.created_before_datetime:
- raise ValueError("created_after_datetime 必须小于 created_before_datetime")
- if self.updated_after_datetime and self.updated_before_datetime:
- if self.updated_after_datetime >= self.updated_before_datetime:
- raise ValueError("updated_after_datetime 必须小于 updated_before_datetime")
-
- # 跨类型矛盾验证:created_after >= updated_before 意味着要求创建时间晚于或等于更新时间上界,逻辑矛盾
- if self.created_after_datetime and self.updated_before_datetime:
- if self.created_after_datetime >= self.updated_before_datetime:
- raise ValueError(
- "created_after_datetime 不能大于或等于 updated_before_datetime"
- "(记录的更新时间不可能早于或等于创建时间)"
- )
-
-
-class PaginationRequest(SQLModelBase):
- """
- 分页排序请求参数
-
- 用于需要分页和排序的场景。
- 纯数据类,只负责携带参数,SQL子句构建由 TableBaseMixin 负责。
- """
- offset: int | None = Field(default=0, ge=0)
- """偏移量(跳过前N条记录),必须为非负整数"""
-
- limit: int | None = Field(default=50, le=100)
- """每页数量(返回最多N条记录),默认50,最大100"""
-
- desc: bool | None = True
- """是否降序排序(True: 降序, False: 升序)"""
-
- order: Literal["created_at", "updated_at"] | None = "created_at"
- """排序字段(created_at: 创建时间, updated_at: 更新时间)"""
-
-
-class TableViewRequest(TimeFilterRequest, PaginationRequest):
- """
- 表格视图请求参数(分页、排序和时间筛选)
-
- 组合继承 TimeFilterRequest 和 PaginationRequest,用于 get() 等需要完整查询参数的场景。
- 纯数据类,SQL子句构建由 TableBaseMixin 负责。
-
- 使用示例:
- ```python
- # 在端点中使用依赖注入
- @router.get("/list")
- async def list_items(
- session: SessionDep,
- table_view: TableViewRequestDep
- ):
- items = await Item.get(
- session,
- fetch_mode="all",
- table_view=table_view
- )
- return items
-
- # 直接使用
- table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
- items = await Item.get(session, fetch_mode="all", table_view=table_view)
- ```
- """
- pass
-
-
-# ==================== TableBaseMixin ====================
-
-class TableBaseMixin(AsyncAttrs):
- """
- 一个异步 CRUD 操作的基础模型类 Mixin.
-
- 此类必须搭配SQLModelBase使用
-
- 此类为所有继承它的 SQLModel 模型提供了通用的数据库操作方法,
- 例如 add, save, update, delete, 和 get. 它还包括自动管理
- 的 `created_at` 和 `updated_at` 时间戳字段.
-
- Attributes:
- id (int | None): 整数主键, 自动递增.
- created_at (datetime): 记录创建时的时间戳, 自动设置.
- updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
- """
- _has_table_mixin: ClassVar[bool] = True
- """标记此类继承了表混入类的内部属性"""
-
- def __init_subclass__(cls, **kwargs):
- """
- 接受并传递子类定义时的关键字参数
-
- 这允许元类 __DeclarativeMeta 处理的参数(如 table_args)
- 能够正确传递,而不会在 __init_subclass__ 阶段报错。
- """
- super().__init_subclass__(**kwargs)
-
- id: int | None = Field(default=None, primary_key=True)
-
- created_at: datetime = Field(default_factory=now)
- updated_at: datetime = Field(
- sa_type=DateTime,
- sa_column_kwargs={'default': now, 'onupdate': now},
- default_factory=now
- )
-
- @classmethod
- async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
- """
- 向数据库中添加一个新的或多个新的记录.
-
- 这个类方法可以接受单个模型实例或一个实例列表,并将它们
- 一次性提交到数据库中。执行后,可以选择性地刷新这些实例以获取
- 数据库生成的值(例如,自动递增的 ID).
-
- Args:
- session (AsyncSession): 用于数据库操作的异步会话对象.
- instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
- refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
-
- Returns:
- T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
-
- Usage:
- item1 = Item(name="Apple")
- item2 = Item(name="Banana")
-
- # 添加多个实例
- added_items = await Item.add(session, [item1, item2])
-
- # 添加单个实例
- item3 = Item(name="Cherry")
- added_item = await Item.add(session, item3)
- """
- is_list = False
- if isinstance(instances, list):
- is_list = True
- session.add_all(instances)
- else:
- session.add(instances)
-
- await session.commit()
-
- if refresh:
- if is_list:
- for instance in instances:
- await session.refresh(instance)
- else:
- await session.refresh(instances)
-
- return instances
-
- async def save(
- self: T,
- session: AsyncSession,
- load: RelationshipInfo | list[RelationshipInfo] | None = None,
- refresh: bool = True,
- commit: bool = True,
- jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
- optimistic_retry_count: int = 0,
- ) -> T:
- """
- 保存(插入或更新)当前模型实例到数据库.
-
- 这是一个实例方法,它将当前对象添加到会话中并提交更改。
- 可以用于创建新记录或更新现有记录。还可以选择在保存后
- 预加载(eager load)一个关联关系.
-
- **重要**:调用此方法后,session中的所有对象都会过期(expired)。
- 如果需要继续使用该对象,必须使用返回值:
-
- ```python
- # ✅ 正确:需要返回值时
- client = await client.save(session)
- return client
-
- # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
- await client.save(session, refresh=False)
-
- # ✅ 正确:批量操作时延迟提交
- for item in items:
- item = await item.save(session, commit=False)
- await session.commit()
-
- # ✅ 正确:保存后需要访问多态关系时
- tool_set = await tool_set.save(session, load=ToolSet.tools, jti_subclasses='all')
- return tool_set # tools 关系已正确加载子类数据
-
- # ✅ 正确:启用乐观锁自动重试
- order = await order.save(session, optimistic_retry_count=3)
-
- # ❌ 错误:需要返回值但未使用
- await client.save(session)
- return client # client 对象已过期
- ```
-
- Args:
- session (AsyncSession): 用于数据库操作的异步会话对象.
- load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
- 例如 `User.posts`.
- refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
- 设为 False 可节省一次数据库查询。默认为 True.
- commit (bool): 是否在保存后提交事务。如果为 False,只会 flush 获取 ID
- 但不提交,适用于批量操作场景。默认为 True.
- jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
- - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- - 'all': 两阶段查询,只加载实际关联的子类
- - None(默认): 不使用多态加载
- optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。
- 重试时会重新查询最新数据,将当前修改合并后再次保存。
-
- Returns:
- T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
-
- Raises:
- OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
- """
- cls = type(self)
- instance = self
- retries_remaining = optimistic_retry_count
- current_data: dict[str, Any] | None = None # 延迟计算,仅在需要重试时
-
- while True:
- session.add(instance)
- try:
- if commit:
- await session.commit()
- else:
- await session.flush()
- break # 成功,退出循环
- except StaleDataError as e:
- await session.rollback()
- if retries_remaining <= 0:
- raise OptimisticLockError(
- message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
- model_class=cls.__name__,
- record_id=str(getattr(instance, 'id', None)),
- expected_version=getattr(instance, 'version', None),
- original_error=e,
- ) from e
-
- # 失败后重试:重新查询最新数据并合并修改
- retries_remaining -= 1
- if current_data is None:
- current_data = self.model_dump(exclude={'id', 'version', 'created_at', 'updated_at'})
-
- fresh = await cls.get(session, cls.id == self.id)
- if fresh is None:
- raise OptimisticLockError(
- message=f"{cls.__name__} 重试失败:记录已被删除",
- model_class=cls.__name__,
- record_id=str(getattr(self, 'id', None)),
- original_error=e,
- ) from e
-
- for key, value in current_data.items():
- if hasattr(fresh, key):
- setattr(fresh, key, value)
- instance = fresh
-
- if not refresh:
- return instance
-
- if load is not None:
- await session.refresh(instance)
- return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
- else:
- await session.refresh(instance)
- return instance
-
- async def update(
- self: T,
- session: AsyncSession,
- other: M,
- extra_data: dict[str, Any] | None = None,
- exclude_unset: bool = True,
- exclude: set[str] | None = None,
- load: RelationshipInfo | list[RelationshipInfo] | None = None,
- refresh: bool = True,
- commit: bool = True,
- jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
- optimistic_retry_count: int = 0,
- ) -> T:
- """
- 使用另一个模型实例或字典中的数据来更新当前实例.
-
- 此方法将 `other` 对象中的数据合并到当前实例中。默认情况下,
- 它只会更新 `other` 中被显式设置的字段.
-
- **重要**:调用此方法后,session中的所有对象都会过期(expired)。
- 如果需要继续使用该对象,必须使用返回值:
-
- ```python
- # ✅ 正确:需要返回值时
- client = await client.update(session, update_data)
- return client
-
- # ✅ 正确:需要返回值且需要加载关系时
- user = await user.update(session, update_data, load=User.permission)
- return user
-
- # ✅ 正确:更新后需要访问多态关系时
- tool_set = await tool_set.update(session, data, load=ToolSet.tools, jti_subclasses='all')
- return tool_set # tools 关系已正确加载子类数据
-
- # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
- await client.update(session, update_data, refresh=False)
-
- # ✅ 正确:批量操作时延迟提交
- for item in items:
- item = await item.update(session, data, commit=False)
- await session.commit()
-
- # ✅ 正确:启用乐观锁自动重试
- order = await order.update(session, update_data, optimistic_retry_count=3)
-
- # ❌ 错误:需要返回值但未使用
- await client.update(session, update_data)
- return client # client 对象已过期
- ```
-
- Args:
- session (AsyncSession): 用于数据库操作的异步会话对象.
- other (M): 一个 SQLModel 或 Pydantic 模型实例,其数据将用于更新当前实例.
- extra_data (dict, optional): 一个额外的字典,用于更新当前实例的特定字段.
- exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
- 的字段将被忽略. 默认为 True.
- exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
- load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
- 例如 `User.permission`.
- refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
- 设为 False 可节省一次数据库查询。默认为 True.
- commit (bool): 是否在更新后提交事务。如果为 False,只会 flush
- 但不提交,适用于批量操作场景。默认为 True.
- jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
- - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- - 'all': 两阶段查询,只加载实际关联的子类
- - None(默认): 不使用多态加载
- optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。
- 重试时会重新查询最新数据,将 other 的更新重新应用后再次保存。
-
- Returns:
- T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
-
- Raises:
- OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
- """
- cls = type(self)
- update_data = other.model_dump(exclude_unset=exclude_unset, exclude=exclude)
- instance = self
- retries_remaining = optimistic_retry_count
-
- while True:
- instance.sqlmodel_update(update_data, update=extra_data)
- session.add(instance)
-
- try:
- if commit:
- await session.commit()
- else:
- await session.flush()
- break # 成功,退出循环
- except StaleDataError as e:
- await session.rollback()
- if retries_remaining <= 0:
- raise OptimisticLockError(
- message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
- model_class=cls.__name__,
- record_id=str(getattr(instance, 'id', None)),
- expected_version=getattr(instance, 'version', None),
- original_error=e,
- ) from e
-
- # 失败后重试:重新查询最新数据并重新应用更新
- retries_remaining -= 1
- fresh = await cls.get(session, cls.id == self.id)
- if fresh is None:
- raise OptimisticLockError(
- message=f"{cls.__name__} 重试失败:记录已被删除",
- model_class=cls.__name__,
- record_id=str(getattr(self, 'id', None)),
- original_error=e,
- ) from e
- instance = fresh
-
- if not refresh:
- return instance
-
- if load is not None:
- await session.refresh(instance)
- return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
- else:
- await session.refresh(instance)
- return instance
-
- @classmethod
- async def delete(
- cls: type[T],
- session: AsyncSession,
- instances: T | list[T] | None = None,
- *,
- condition: BinaryExpression | ClauseElement | None = None,
- commit: bool = True,
- ) -> int:
- """
- 从数据库中删除记录,支持实例删除和条件删除两种模式。
-
- Args:
- session: 用于数据库操作的异步会话对象
- instances: 要删除的单个模型实例或模型实例列表(实例删除模式)
- condition: WHERE 条件表达式(条件删除模式,直接执行 SQL DELETE)
- commit: 是否在删除后提交事务。默认为 True
-
- Returns:
- 删除的记录数(条件删除模式返回实际删除数,实例删除模式返回实例数)
-
- Raises:
- ValueError: 同时提供 instances 和 condition,或两者都未提供
-
- Usage:
- # 实例删除模式
- item = await Item.get(session, Item.id == 1)
- if item:
- await Item.delete(session, item)
-
- items = await Item.get(session, Item.name.in_(["A", "B"]), fetch_mode="all")
- if items:
- await Item.delete(session, items)
-
- # 条件删除模式(高效批量删除,不加载实例到内存)
- deleted_count = await Item.delete(
- session,
- condition=(Item.user_id == user_id) & (Item.status == "expired"),
- )
- """
- if instances is not None and condition is not None:
- raise ValueError("不能同时提供 instances 和 condition 参数")
- if instances is None and condition is None:
- raise ValueError("必须提供 instances 或 condition 参数之一")
-
- deleted_count = 0
-
- if condition is not None:
- # 条件删除模式:直接执行 SQL DELETE
- stmt = sql_delete(cls).where(condition)
- result = await session.execute(stmt)
- deleted_count = result.rowcount
- else:
- # 实例删除模式
- if isinstance(instances, list):
- for instance in instances:
- await session.delete(instance)
- deleted_count = len(instances)
- else:
- await session.delete(instances)
- deleted_count = 1
-
- if commit:
- await session.commit()
-
- return deleted_count
-
- @classmethod
- def _build_time_filters(
- cls: type[T],
- created_before_datetime: datetime | None = None,
- created_after_datetime: datetime | None = None,
- updated_before_datetime: datetime | None = None,
- updated_after_datetime: datetime | None = None,
- ) -> list[BinaryExpression]:
- """
- 构建时间筛选条件列表
-
- Args:
- created_before_datetime: 筛选 created_at < datetime 的记录
- created_after_datetime: 筛选 created_at >= datetime 的记录
- updated_before_datetime: 筛选 updated_at < datetime 的记录
- updated_after_datetime: 筛选 updated_at >= datetime 的记录
-
- Returns:
- BinaryExpression 条件列表
- """
- filters: list[BinaryExpression] = []
- if created_after_datetime is not None:
- filters.append(cls.created_at >= created_after_datetime)
- if created_before_datetime is not None:
- filters.append(cls.created_at < created_before_datetime)
- if updated_after_datetime is not None:
- filters.append(cls.updated_at >= updated_after_datetime)
- if updated_before_datetime is not None:
- filters.append(cls.updated_at < updated_before_datetime)
- return filters
-
- @classmethod
- async def get(
- cls: type[T],
- session: AsyncSession,
- condition: BinaryExpression | ClauseElement | None = None,
- *,
- offset: int | None = None,
- limit: int | None = None,
- fetch_mode: Literal["one", "first", "all"] = "first",
- join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
- options: list | None = None,
- load: RelationshipInfo | list[RelationshipInfo] | None = None,
- order_by: list[ClauseElement] | None = None,
- filter: BinaryExpression | ClauseElement | None = None,
- with_for_update: bool = False,
- table_view: TableViewRequest | None = None,
- jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
- populate_existing: bool = False,
- created_before_datetime: datetime | None = None,
- created_after_datetime: datetime | None = None,
- updated_before_datetime: datetime | None = None,
- updated_after_datetime: datetime | None = None,
- ) -> T | list[T] | None:
- """
- 根据指定的条件异步地从数据库中获取一个或多个模型实例.
-
- 这是一个功能强大的通用查询方法,支持过滤、排序、分页、连接查询和关联关系预加载.
-
- Args:
- session (AsyncSession): 用于数据库操作的异步会话对象.
- condition (BinaryExpression | ClauseElement | None): 主要的查询过滤条件,
- 例如 `User.id == 1`。
- 当为 `None` 时,表示无条件查询(查询所有记录)。
- offset (int | None): 查询结果的起始偏移量, 用于分页.
- limit (int | None): 返回记录的最大数量, 用于分页.
- fetch_mode (Literal["one", "first", "all"]):
- - "one": 获取唯一的一条记录. 如果找不到或找到多条,会引发异常.
- - "first": 获取查询结果的第一条记录. 如果找不到,返回 `None`.
- - "all": 获取所有匹配的记录,返回一个列表.
- 默认为 "first".
- join (type[T] | tuple[type[T], _OnClauseArgument] | None):
- 要 JOIN 的模型类或一个包含模型类和 ON 子句的元组.
- 例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
- options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
- 例如 `[selectinload(User.posts)]`.
- load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
- 可以是单个关系或关系列表。支持嵌套关系预加载:
- 当传入多个关系时,会自动检测依赖关系并构建链式 selectinload。
- 例如 `[NodeGroupNode.element_links, NodeGroupElementLink.node]`
- 会自动构建 `selectinload(element_links).selectinload(node)`。
- order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
- 例如 `[User.name.asc(), User.created_at.desc()]`.
- filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
-
- with_for_update (bool): 如果为 True, 在查询中使用 `FOR UPDATE` 锁定选定的行. 默认为 False.
-
- table_view (TableViewRequest | None): TableViewRequest对象,如果提供则自动处理分页、排序和时间筛选。
- 会覆盖offset、limit、order_by及时间筛选参数。
- 这是推荐的分页排序方式,统一了所有LIST端点的参数格式。
-
- jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
- - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- - 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益)
- - None(默认): 不使用多态加载
-
- populate_existing (bool): 如果为 True,强制用数据库数据覆盖 session 中已存在的对象(identity map)。
- 用于批量刷新对象,避免循环调用 session.refresh() 导致的 N 次查询。
- 注意:只刷新标量字段,不影响运行时属性(_开头的属性)。
- 对于 STI(单表继承)对象,推荐按子类分组查询以包含子类字段。默认为 False。
-
- created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
- created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
- updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
- updated_after_datetime (datetime | None): 筛选 updated_at >= datetime 的记录
-
- Returns:
- T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
-
- Raises:
- ValueError: 如果提供了无效的 `fetch_mode` 值,或 jti_subclasses 未与 load 配合使用.
-
- Examples:
- # 使用table_view参数(推荐)
- users = await User.get(session, fetch_mode="all", table_view=table_view_args)
-
- # 传统方式(向后兼容)
- users = await User.get(session, fetch_mode="all", offset=0, limit=20, order_by=[desc(User.created_at)])
-
- # 使用多态加载(加载联表继承的子类数据)
- tool_set = await ToolSet.get(
- session,
- ToolSet.id == tool_set_id,
- load=ToolSet.tools,
- jti_subclasses='all' # 只加载实际关联的子类
- )
- """
- # 参数验证:jti_subclasses 需要与 load 配合使用
- if jti_subclasses is not None and load is None:
- raise ValueError(
- "jti_subclasses 参数需要与 load 参数配合使用,"
- "请同时指定要加载的关系"
- )
-
- # 如果提供table_view,作为默认值使用(单独传入的参数优先级更高)
- if table_view:
- # 处理时间筛选(TimeFilterRequest 及其子类)
- if isinstance(table_view, TimeFilterRequest):
- if created_after_datetime is None and table_view.created_after_datetime is not None:
- created_after_datetime = table_view.created_after_datetime
- if created_before_datetime is None and table_view.created_before_datetime is not None:
- created_before_datetime = table_view.created_before_datetime
- if updated_after_datetime is None and table_view.updated_after_datetime is not None:
- updated_after_datetime = table_view.updated_after_datetime
- if updated_before_datetime is None and table_view.updated_before_datetime is not None:
- updated_before_datetime = table_view.updated_before_datetime
- # 处理分页排序(PaginationRequest 及其子类,包括 TableViewRequest)
- if isinstance(table_view, PaginationRequest):
- if offset is None:
- offset = table_view.offset
- if limit is None:
- limit = table_view.limit
- # 仅在未显式传入order_by时,从table_view构建排序子句
- if order_by is None:
- order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at
- order_by = [desc(order_column) if table_view.desc else asc(order_column)]
-
- # 对于多态基类,使用 with_polymorphic 预加载所有子类的列
- # 这避免了在响应序列化时的延迟加载问题(MissingGreenlet 错误)
- polymorphic_cls = None # 保存多态实体,用于子类关系预加载
- is_polymorphic = issubclass(cls, PolymorphicBaseMixin)
- is_jti = is_polymorphic and cls._is_joined_table_inheritance()
- is_sti = is_polymorphic and not cls._is_joined_table_inheritance()
-
- # JTI 模式:总是使用 with_polymorphic(避免 N+1 查询)
- # STI 模式:不使用 with_polymorphic(批量刷新时请按子类分组查询)
- if is_jti:
- # '*' 表示加载所有子类
- polymorphic_cls = with_polymorphic(cls, '*')
- statement = select(polymorphic_cls)
- else:
- statement = select(cls)
-
- # 对于 STI(单表继承)子类,自动添加多态过滤条件
- # SQLAlchemy/SQLModel 在 STI 模式下不会自动添加 WHERE discriminator = 'identity' 过滤
- # 这是已知行为,参考:
- # - https://github.com/sqlalchemy/sqlalchemy/issues/5018 (bulk operations 不自动添加多态过滤)
- # - https://github.com/fastapi/sqlmodel/issues/488 (SQLModel STI 支持不完整)
- # 社区最佳实践是显式添加多态过滤条件
- if issubclass(cls, PolymorphicBaseMixin) and not cls._is_joined_table_inheritance():
- mapper = inspect(cls)
- # 检查是否有 polymorphic_identity 且不是抽象类
- if mapper.polymorphic_identity is not None and not mapper.polymorphic_abstract:
- poly_on = mapper.polymorphic_on
- if poly_on is not None:
- statement = statement.where(poly_on == mapper.polymorphic_identity)
-
- if condition is not None:
- statement = statement.where(condition)
-
- # 应用时间筛选
- for time_filter in cls._build_time_filters(
- created_before_datetime, created_after_datetime,
- updated_before_datetime, updated_after_datetime
- ):
- statement = statement.where(time_filter)
-
- if join is not None:
- # 如果 join 是一个元组,解包它;否则直接使用
- if isinstance(join, tuple):
- statement = statement.join(*join)
- else:
- statement = statement.join(join)
-
-
- if options:
- statement = statement.options(*options)
-
- if load:
- # 标准化为列表
- load_list = load if isinstance(load, list) else [load]
-
- # 构建链式 selectinload(支持嵌套关系预加载)
- # 例如:load=[NodeGroupNode.element_links, NodeGroupElementLink.node]
- # 会构建:selectinload(element_links).selectinload(node)
- load_chains = cls._build_load_chains(load_list)
-
- # 处理多态加载(仅支持单链且只有一个关系)
- if jti_subclasses is not None:
- if len(load_chains) > 1 or len(load_chains[0]) > 1:
- raise ValueError(
- "jti_subclasses 仅支持单个关系(无嵌套链),请不要传入多个关系"
- )
- single_load = load_chains[0][0]
- target_class = single_load.property.mapper.class_
-
- # 检查目标类是否继承自 PolymorphicBaseMixin
- if not issubclass(target_class, PolymorphicBaseMixin):
- raise ValueError(
- f"目标类 {target_class.__name__} 不是多态类,"
- f"请确保其继承自 PolymorphicBaseMixin"
- )
-
- if jti_subclasses == 'all':
- # 两阶段查询:获取实际关联的多态类型
- subclasses_to_load = await cls._resolve_polymorphic_subclasses(
- session, condition, single_load, target_class
- )
- else:
- subclasses_to_load = jti_subclasses
-
- if subclasses_to_load:
- # 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项
- # 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
- statement = statement.options(
- selectinload(single_load).selectin_polymorphic(subclasses_to_load)
- )
- else:
- statement = statement.options(selectinload(single_load))
- else:
- # 为每条链构建链式 selectinload
- for chain in load_chains:
- # 获取第一个关系并检查是否需要通过多态实体访问
- first_rel = chain[0]
- first_rel_parent = first_rel.property.parent.class_
-
- # 如果关系的 parent_class 是当前类的子类(不是 cls 本身),
- # 且当前是多态查询,则需要通过 polymorphic_cls.SubclassName 访问
- if (
- polymorphic_cls is not None
- and first_rel_parent is not cls
- and issubclass(first_rel_parent, cls)
- ):
- # 通过多态实体访问子类的关系属性
- # 例如:polymorphic_cls.NodeGroupNode.element_links
- subclass_alias = getattr(polymorphic_cls, first_rel_parent.__name__)
- rel_name = first_rel.key
- first_rel_via_poly = getattr(subclass_alias, rel_name)
- loader = selectinload(first_rel_via_poly)
- else:
- loader = selectinload(first_rel)
-
- for rel in chain[1:]:
- loader = loader.selectinload(rel)
- statement = statement.options(loader)
-
- if order_by is not None:
- statement = statement.order_by(*order_by)
-
- if offset:
- statement = statement.offset(offset)
-
- if limit:
- statement = statement.limit(limit)
-
- if filter:
- statement = statement.filter(filter)
-
- if with_for_update:
- # 对于联表继承的多态模型,使用 FOR UPDATE OF <主表> 来避免 PostgreSQL 的限制
- # PostgreSQL 不支持在 LEFT OUTER JOIN 的可空侧使用 FOR UPDATE
- if issubclass(cls, PolymorphicBaseMixin):
- statement = statement.with_for_update(of=cls)
- else:
- statement = statement.with_for_update()
-
- if populate_existing:
- # 强制用数据库数据覆盖 identity map 中的对象
- # 用于批量刷新,避免循环 refresh() 的 N 次查询
- statement = statement.execution_options(populate_existing=True)
-
- result = await session.exec(statement)
-
- if fetch_mode == "one":
- return result.one()
- elif fetch_mode == "first":
- return result.first()
- elif fetch_mode == "all":
- return list(result.all())
- else:
- raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
-
- @staticmethod
- def _build_load_chains(load_list: list[RelationshipInfo]) -> list[list[RelationshipInfo]]:
- """
- 将关系列表构建为链式加载结构
-
- 自动检测关系之间的依赖关系,构建嵌套预加载链。
- 例如:[NodeGroupNode.element_links, NodeGroupElementLink.node]
- 会构建:[[element_links, node]](一条链)
-
- 算法:
- 1. 获取每个关系的 parent class 和 target class
- 2. 如果关系 B 的 parent class 等于关系 A 的 target class,则 B 链在 A 后面
- 3. 独立的关系各自成为一条链
-
- Args:
- load_list: 关系属性列表
-
- Returns:
- 链式关系列表,每条链是一个关系列表
- """
- if not load_list:
- return []
-
- # 构建关系信息:{关系: (parent_class, target_class)}
- rel_info: dict[RelationshipInfo, tuple[type, type]] = {}
- for rel in load_list:
- parent_class = rel.property.parent.class_
- target_class = rel.property.mapper.class_
- rel_info[rel] = (parent_class, target_class)
-
- # 构建依赖图:{关系: 其前置关系}
- predecessors: dict[RelationshipInfo, RelationshipInfo | None] = {rel: None for rel in load_list}
- for rel_b in load_list:
- parent_b, _ = rel_info[rel_b]
- for rel_a in load_list:
- if rel_a is rel_b:
- continue
- _, target_a = rel_info[rel_a]
- # 如果 B 的 parent 精确等于 A 的 target,则 B 链在 A 后面
- # 使用精确匹配避免继承关系导致的误判(如 NodeGroupNode 是 CanvasNode 子类)
- if parent_b is target_a:
- predecessors[rel_b] = rel_a
- break
-
- # 找出所有链的起点(没有前置关系的)
- roots = [rel for rel, pred in predecessors.items() if pred is None]
-
- # 构建链
- chains: list[list[RelationshipInfo]] = []
- used: set[RelationshipInfo] = set()
-
- for root in roots:
- chain = [root]
- used.add(root)
- # 找后续节点
- current = root
- while True:
- # 找以 current 的 target 为 parent 的关系
- _, current_target = rel_info[current]
- next_rel = None
- for rel, (parent, _) in rel_info.items():
- if rel not in used and parent is current_target:
- next_rel = rel
- break
- if next_rel is None:
- break
- chain.append(next_rel)
- used.add(next_rel)
- current = next_rel
- chains.append(chain)
-
- return chains
-
- @classmethod
- async def _resolve_polymorphic_subclasses(
- cls: type[T],
- session: AsyncSession,
- condition: BinaryExpression | ClauseElement | None,
- load: RelationshipInfo,
- target_class: type[PolymorphicBaseMixin]
- ) -> list[type[PolymorphicBaseMixin]]:
- """
- 查询实际关联的多态子类类型
-
- 通过查询多态鉴别字段确定实际存在的子类类型,
- 避免加载所有可能的子类表(对于 > 10 个子类的场景有明显收益)。
-
- :param session: 数据库会话
- :param condition: 主查询的条件
- :param load: 关系属性
- :param target_class: 多态基类
- :return: 实际关联的子类列表
- """
- # 获取多态鉴别字段(会抛出 ValueError 如果未配置)
- discriminator = target_class.get_polymorphic_discriminator()
- poly_name_col = getattr(target_class, discriminator)
-
- # 获取关系属性
- relationship_property = load.property
-
- # 构建查询获取实际的多态类型名称
- if relationship_property.secondary is not None:
- # 多对多关系:通过中间表查询
- secondary = relationship_property.secondary
- local_cols = list(relationship_property.local_columns)
-
- type_query = (
- select(distinct(poly_name_col))
- .select_from(target_class)
- .join(secondary)
- .where(secondary.c[local_cols[0].name].in_(
- select(cls.id).where(condition) if condition is not None else select(cls.id)
- ))
- )
- else:
- # 多对一/一对多关系:通过外键查询
- # local_remote_pairs[0] = (local_fk_col, remote_pk_col)
- # 对于多对一:local 是当前类的外键,remote 是目标类的主键
- local_fk_col = relationship_property.local_remote_pairs[0][0]
- remote_pk_col = relationship_property.local_remote_pairs[0][1]
- type_query = (
- select(distinct(poly_name_col))
- .where(remote_pk_col.in_(
- select(local_fk_col).where(condition) if condition is not None else select(local_fk_col)
- ))
- )
-
- type_result = await session.exec(type_query)
- poly_names = list(type_result.all())
-
- if not poly_names:
- return []
-
- # 映射到子类(包含所有层级的具体子类)
- identity_map = target_class.get_identity_to_class_map()
- return [identity_map[name] for name in poly_names if name in identity_map]
-
- @classmethod
- async def count(
- cls: type[T],
- session: AsyncSession,
- condition: BinaryExpression | ClauseElement | None = None,
- *,
- time_filter: TimeFilterRequest | None = None,
- created_before_datetime: datetime | None = None,
- created_after_datetime: datetime | None = None,
- updated_before_datetime: datetime | None = None,
- updated_after_datetime: datetime | None = None,
- ) -> int:
- """
- 根据条件统计记录数量(支持时间筛选)
-
- 使用数据库层面的 COUNT() 聚合函数,比 get() + len() 更高效。
-
- Args:
- session: 数据库会话
- condition: 查询条件,例如 `User.is_active == True`
- time_filter: TimeFilterRequest 对象(优先级更高)
- created_before_datetime: 筛选 created_at < datetime 的记录
- created_after_datetime: 筛选 created_at >= datetime 的记录
- updated_before_datetime: 筛选 updated_at < datetime 的记录
- updated_after_datetime: 筛选 updated_at >= datetime 的记录
-
- Returns:
- 符合条件的记录数量
-
- Examples:
- # 统计所有用户
- total = await User.count(session)
-
- # 统计激活的虚拟客户端
- count = await Client.count(
- session,
- (Client.user_id == user_id) & (Client.type != ClientTypeEnum.physical) & (Client.is_active == True)
- )
-
- # 使用 TimeFilterRequest 进行时间筛选
- count = await User.count(session, time_filter=time_filter_request)
-
- # 使用独立时间参数
- count = await User.count(
- session,
- created_after_datetime=datetime(2025, 1, 1),
- created_before_datetime=datetime(2025, 2, 1),
- )
- """
- # time_filter 的时间筛选优先级更高
- if isinstance(time_filter, TimeFilterRequest):
- if time_filter.created_after_datetime is not None:
- created_after_datetime = time_filter.created_after_datetime
- if time_filter.created_before_datetime is not None:
- created_before_datetime = time_filter.created_before_datetime
- if time_filter.updated_after_datetime is not None:
- updated_after_datetime = time_filter.updated_after_datetime
- if time_filter.updated_before_datetime is not None:
- updated_before_datetime = time_filter.updated_before_datetime
-
- statement = select(func.count()).select_from(cls)
-
- # 应用查询条件
- if condition is not None:
- statement = statement.where(condition)
-
- # 应用时间筛选
- for time_condition in cls._build_time_filters(
- created_before_datetime, created_after_datetime,
- updated_before_datetime, updated_after_datetime
- ):
- statement = statement.where(time_condition)
-
- result = await session.scalar(statement)
- return result or 0
-
- @classmethod
- async def get_with_count(
- cls: type[T],
- session: AsyncSession,
- condition: BinaryExpression | ClauseElement | None = None,
- *,
- join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
- options: list | None = None,
- load: RelationshipInfo | list[RelationshipInfo] | None = None,
- order_by: list[ClauseElement] | None = None,
- filter: BinaryExpression | ClauseElement | None = None,
- table_view: TableViewRequest | None = None,
- jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
- ) -> 'ListResponse[T]':
- """
- 获取分页列表及总数,直接返回 ListResponse
-
- 同时返回符合条件的记录列表和总数,用于分页场景。
- 与 get() 方法类似,但固定 fetch_mode="all" 并返回 ListResponse。
-
- 注意:如果子类的 get() 方法支持额外参数(如 filter_params),
- 子类应该覆盖此方法以确保 count 和 items 使用相同的过滤条件。
-
- Args:
- session: 数据库会话
- condition: 查询条件
- join: JOIN 的模型类或元组
- options: SQLAlchemy 查询选项
- load: selectinload 预加载关系
- order_by: 排序子句
- filter: 附加过滤条件
- table_view: 分页排序参数(推荐使用)
- jti_subclasses: 多态子类加载选项
-
- Returns:
- ListResponse[T]: 包含 count 和 items 的分页响应
-
- Examples:
- ```python
- @router.get("", response_model=ListResponse[CharacterInfoResponse])
- async def list_characters(
- session: SessionDep,
- table_view: TableViewRequestDep
- ) -> ListResponse[Character]:
- return await Character.get_with_count(session, table_view=table_view)
- ```
- """
- # 提取时间筛选参数(用于 count)
- time_filter: TimeFilterRequest | None = None
- if table_view is not None:
- time_filter = TimeFilterRequest(
- created_after_datetime=table_view.created_after_datetime,
- created_before_datetime=table_view.created_before_datetime,
- updated_after_datetime=table_view.updated_after_datetime,
- updated_before_datetime=table_view.updated_before_datetime,
- )
-
- # 获取总数(不带分页限制)
- total_count = await cls.count(session, condition, time_filter=time_filter)
-
- # 获取分页数据
- items = await cls.get(
- session,
- condition,
- fetch_mode="all",
- join=join,
- options=options,
- load=load,
- order_by=order_by,
- filter=filter,
- table_view=table_view,
- jti_subclasses=jti_subclasses,
- )
-
- return ListResponse(count=total_count, items=items)
-
- @classmethod
- async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
- """
- 根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
-
- 这个方法是对 `get` 方法的封装,专门用于处理那种"记录必须存在"的业务场景。
- 如果记录未找到,它会直接引发 FastAPI 的 `HTTPException`, 而不是返回 `None`.
-
- Args:
- session (AsyncSession): 用于数据库操作的异步会话对象.
- id (int): 要查找的记录的主键 ID.
- load (Relationship | None): 可选的,用于预加载的关联属性.
-
- Returns:
- T: 找到的模型实例.
-
- Raises:
- HTTPException: 如果 ID 对应的记录不存在,则抛出状态码为 404 的异常.
- """
- instance = await cls.get(session, cls.id == id, load=load)
- if not instance:
- raise HTTPException(status_code=404, detail="Not found")
- return instance
-
-class UUIDTableBaseMixin(TableBaseMixin):
- """
- 一个使用 UUID 作为主键的异步 CRUD 操作基础模型类 Mixin.
-
- 此类继承自 `TableBaseMixin`, 将主键 `id` 的类型覆盖为 `uuid.UUID`,
- 并为新记录自动生成 UUID. 它继承了 `TableBaseMixin` 的所有 CRUD 方法.
-
- Attributes:
- id (uuid.UUID): UUID 类型的主键, 在创建时自动生成.
- """
- id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
- """覆盖 `TableBaseMixin` 的 id 字段,使用 UUID 作为主键."""
-
- @override
- @classmethod
- async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
- """
- 根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
-
- 此方法覆盖了父类的同名方法,以确保 `id` 参数的类型注解为 `uuid.UUID`,
- 从而提供更好的类型安全和代码提示.
-
- Args:
- session (AsyncSession): 用于数据库操作的异步会话对象.
- id (uuid.UUID): 要查找的记录的 UUID 主键.
- load (Relationship | None): 可选的,用于预加载的关联属性.
-
- Returns:
- T: 找到的模型实例.
-
- Raises:
- HTTPException: 如果 UUID 对应的记录不存在,则抛出状态码为 404 的异常.
- """
- # 类型检查器可能会警告这里的 `id` 类型不匹配超类方法,
- # 但在运行时这是正确的,因为超类方法内部的比较 (cls.id == id)
- # 会正确处理 UUID 类型。`type: ignore` 用于抑制此警告。
- return await super().get_exist_one(session, id, load) # type: ignore
diff --git a/sqlmodels/model_base.py b/sqlmodels/model_base.py
index 0851f8e..3230ad6 100644
--- a/sqlmodels/model_base.py
+++ b/sqlmodels/model_base.py
@@ -4,7 +4,7 @@ from enum import StrEnum
from sqlmodel import Field
-from .base import SQLModelBase
+from sqlmodel_ext import SQLModelBase
class ResponseBase(SQLModelBase):
diff --git a/sqlmodels/node.py b/sqlmodels/node.py
index 26679f5..96fa787 100644
--- a/sqlmodels/node.py
+++ b/sqlmodels/node.py
@@ -3,8 +3,7 @@ from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Index
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .download import Download
diff --git a/sqlmodels/object.py b/sqlmodels/object.py
index 5752596..ada4f8a 100644
--- a/sqlmodels/object.py
+++ b/sqlmodels/object.py
@@ -7,8 +7,7 @@ from enum import StrEnum
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, CheckConstraint, Index, text
-from .base import SQLModelBase
-from .mixin import UUIDTableBaseMixin
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/order.py b/sqlmodels/order.py
index 4af80e5..d93ff3e 100644
--- a/sqlmodels/order.py
+++ b/sqlmodels/order.py
@@ -4,8 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/physical_file.py b/sqlmodels/physical_file.py
index 187039b..455d2bd 100644
--- a/sqlmodels/physical_file.py
+++ b/sqlmodels/physical_file.py
@@ -15,8 +15,7 @@ from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, Index
-from .base import SQLModelBase
-from .mixin import UUIDTableBaseMixin
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
diff --git a/sqlmodels/policy.py b/sqlmodels/policy.py
index 9893d4c..c65953f 100644
--- a/sqlmodels/policy.py
+++ b/sqlmodels/policy.py
@@ -4,8 +4,7 @@ from uuid import UUID
from enum import StrEnum
from sqlmodel import Field, Relationship, text
-from .base import SQLModelBase
-from .mixin import UUIDTableBaseMixin
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
diff --git a/sqlmodels/redeem.py b/sqlmodels/redeem.py
index 574eec6..e672b6c 100644
--- a/sqlmodels/redeem.py
+++ b/sqlmodels/redeem.py
@@ -2,8 +2,7 @@ from enum import StrEnum
from sqlmodel import Field, text
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
class RedeemType(StrEnum):
diff --git a/sqlmodels/report.py b/sqlmodels/report.py
index c928bdf..885d720 100644
--- a/sqlmodels/report.py
+++ b/sqlmodels/report.py
@@ -4,8 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .share import Share
diff --git a/sqlmodels/setting.py b/sqlmodels/setting.py
index b4375c9..035f8da 100644
--- a/sqlmodels/setting.py
+++ b/sqlmodels/setting.py
@@ -2,9 +2,9 @@ from enum import StrEnum
from sqlmodel import UniqueConstraint
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
+
from .auth_identity import AuthProviderType
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
from .user import UserResponse
class CaptchaType(StrEnum):
diff --git a/sqlmodels/share.py b/sqlmodels/share.py
index 8e757be..7045c67 100644
--- a/sqlmodels/share.py
+++ b/sqlmodels/share.py
@@ -5,9 +5,9 @@ from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
-from .base import SQLModelBase
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
+
from .model_base import ResponseBase
-from .mixin import UUIDTableBaseMixin
from .object import ObjectType
if TYPE_CHECKING:
diff --git a/sqlmodels/source_link.py b/sqlmodels/source_link.py
index 879b9f7..e121ff9 100644
--- a/sqlmodels/source_link.py
+++ b/sqlmodels/source_link.py
@@ -4,8 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, Index
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .object import Object
diff --git a/sqlmodels/storage_pack.py b/sqlmodels/storage_pack.py
index 01edd6f..c21830c 100644
--- a/sqlmodels/storage_pack.py
+++ b/sqlmodels/storage_pack.py
@@ -5,8 +5,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, Column, func, DateTime
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/tag.py b/sqlmodels/tag.py
index 2b1b792..b83e485 100644
--- a/sqlmodels/tag.py
+++ b/sqlmodels/tag.py
@@ -5,8 +5,7 @@ from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/task.py b/sqlmodels/task.py
index c1cf261..980c3f8 100644
--- a/sqlmodels/task.py
+++ b/sqlmodels/task.py
@@ -5,8 +5,7 @@ from datetime import datetime
from sqlmodel import Field, Relationship, CheckConstraint, Index
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .download import Download
diff --git a/sqlmodels/theme_preset.py b/sqlmodels/theme_preset.py
index bfec1fc..9d8475e 100644
--- a/sqlmodels/theme_preset.py
+++ b/sqlmodels/theme_preset.py
@@ -3,9 +3,9 @@ from uuid import UUID
from sqlmodel import Field
-from .base import SQLModelBase
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
+
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
-from .mixin import UUIDTableBaseMixin
class ThemePresetBase(SQLModelBase):
diff --git a/sqlmodels/uri.py b/sqlmodels/uri.py
index 3d4075c..75d3e9a 100644
--- a/sqlmodels/uri.py
+++ b/sqlmodels/uri.py
@@ -2,7 +2,7 @@
from enum import StrEnum
from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote
-from .base import SQLModelBase
+from sqlmodel_ext import SQLModelBase
class FileSystemNamespace(StrEnum):
diff --git a/sqlmodels/user.py b/sqlmodels/user.py
index 77fdb86..198f7f2 100644
--- a/sqlmodels/user.py
+++ b/sqlmodels/user.py
@@ -9,11 +9,11 @@ from sqlmodel import Field, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
+from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse
+
from .auth_identity import AuthProviderType
-from .base import SQLModelBase
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
from .model_base import ResponseBase
-from .mixin import UUIDTableBaseMixin, TableViewRequest, ListResponse
T = TypeVar("T", bound="User")
diff --git a/sqlmodels/user_authn.py b/sqlmodels/user_authn.py
index eac498e..0997193 100644
--- a/sqlmodels/user_authn.py
+++ b/sqlmodels/user_authn.py
@@ -5,8 +5,7 @@ from uuid import UUID
from sqlalchemy import Column, Text
from sqlmodel import Field, Relationship
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/webdav.py b/sqlmodels/webdav.py
index 8d7f9ce..b73bce2 100644
--- a/sqlmodels/webdav.py
+++ b/sqlmodels/webdav.py
@@ -4,8 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
-from .base import SQLModelBase
-from .mixin import TableBaseMixin
+from sqlmodel_ext import SQLModelBase, TableBaseMixin
if TYPE_CHECKING:
from .user import User
diff --git a/sqlmodels/wopi.py b/sqlmodels/wopi.py
new file mode 100644
index 0000000..5fceabf
--- /dev/null
+++ b/sqlmodels/wopi.py
@@ -0,0 +1,83 @@
+"""
+WOPI(Web Application Open Platform Interface)协议模型
+
+提供 WOPI CheckFileInfo 响应模型和 WOPI 访问令牌 Payload 定义。
+"""
+from uuid import UUID
+
+from sqlmodel_ext import SQLModelBase
+
+
+class WopiFileInfo(SQLModelBase):
+ """
+ WOPI CheckFileInfo 响应模型。
+
+ 字段命名遵循 WOPI 规范(PascalCase),通过 alias 映射。
+ 参考: https://learn.microsoft.com/en-us/microsoft-365/cloud-storage-partner-program/rest/files/checkfileinfo
+ """
+
+ base_file_name: str
+ """文件名(含扩展名)"""
+
+ size: int
+ """文件大小(字节)"""
+
+ owner_id: str
+ """文件所有者标识"""
+
+ user_id: str
+ """当前用户标识"""
+
+ user_friendly_name: str
+ """用户显示名"""
+
+ version: str
+ """文件版本标识(使用 updated_at)"""
+
+ sha256: str = ""
+ """文件 SHA256 哈希(如果可用)"""
+
+ user_can_write: bool = False
+ """用户是否可写"""
+
+ user_can_not_write_relative: bool = True
+ """是否禁止创建关联文件"""
+
+ read_only: bool = True
+ """文件是否只读"""
+
+ supports_locks: bool = False
+ """是否支持锁(v1 不实现)"""
+
+ supports_update: bool = True
+ """是否支持更新"""
+
+ def to_wopi_dict(self) -> dict[str, str | int | bool]:
+ """转换为 WOPI 规范的 PascalCase 字典"""
+ return {
+ "BaseFileName": self.base_file_name,
+ "Size": self.size,
+ "OwnerId": self.owner_id,
+ "UserId": self.user_id,
+ "UserFriendlyName": self.user_friendly_name,
+ "Version": self.version,
+ "SHA256": self.sha256,
+ "UserCanWrite": self.user_can_write,
+ "UserCanNotWriteRelative": self.user_can_not_write_relative,
+ "ReadOnly": self.read_only,
+ "SupportsLocks": self.supports_locks,
+ "SupportsUpdate": self.supports_update,
+ }
+
+
+class WopiAccessTokenPayload(SQLModelBase):
+ """WOPI 访问令牌内部 Payload"""
+
+ file_id: UUID
+ """文件UUID"""
+
+ user_id: UUID
+ """用户UUID"""
+
+ can_write: bool = False
+ """是否可写"""
diff --git a/tests/integration/api/test_admin_file_app.py b/tests/integration/api/test_admin_file_app.py
new file mode 100644
index 0000000..3c3cd0c
--- /dev/null
+++ b/tests/integration/api/test_admin_file_app.py
@@ -0,0 +1,253 @@
+"""
+管理员文件应用管理集成测试
+
+测试管理员 CRUD、扩展名更新、用户组权限更新和权限校验。
+"""
+from uuid import UUID, uuid4
+
+import pytest
+import pytest_asyncio
+from httpx import AsyncClient
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from sqlmodels.file_app import FileApp, FileAppExtension, FileAppType
+from sqlmodels.group import Group
+from sqlmodels.user import User
+
+
+# ==================== Fixtures ====================
+
+@pytest_asyncio.fixture
+async def setup_admin_app(
+ initialized_db: AsyncSession,
+) -> dict[str, UUID]:
+ """创建测试用管理员文件应用"""
+ app = FileApp(
+ name="管理员测试应用",
+ app_key="admin_test_app",
+ type=FileAppType.BUILTIN,
+ is_enabled=True,
+ )
+ app = await app.save(initialized_db)
+
+ ext = FileAppExtension(app_id=app.id, extension="test", priority=0)
+ await ext.save(initialized_db)
+
+ return {"app_id": app.id}
+
+
+# ==================== Admin CRUD ====================
+
+class TestAdminFileAppCRUD:
+ """管理员文件应用 CRUD 测试"""
+
+ @pytest.mark.asyncio
+ async def test_create_file_app(
+ self,
+ async_client: AsyncClient,
+ admin_headers: dict[str, str],
+ ) -> None:
+ """管理员创建文件应用"""
+ response = await async_client.post(
+ "/api/v1/admin/file-app/",
+ headers=admin_headers,
+ json={
+ "name": "新建应用",
+ "app_key": "new_app",
+ "type": "builtin",
+ "description": "测试新建",
+ "extensions": ["pdf", "txt"],
+ "allowed_group_ids": [],
+ },
+ )
+ assert response.status_code == 201
+ data = response.json()
+ assert data["name"] == "新建应用"
+ assert data["app_key"] == "new_app"
+ assert "pdf" in data["extensions"]
+ assert "txt" in data["extensions"]
+
+ @pytest.mark.asyncio
+ async def test_create_duplicate_app_key(
+ self,
+ async_client: AsyncClient,
+ admin_headers: dict[str, str],
+ setup_admin_app: dict[str, UUID],
+ ) -> None:
+ """创建重复 app_key 返回 409"""
+ response = await async_client.post(
+ "/api/v1/admin/file-app/",
+ headers=admin_headers,
+ json={
+ "name": "重复应用",
+ "app_key": "admin_test_app",
+ "type": "builtin",
+ },
+ )
+ assert response.status_code == 409
+
+ @pytest.mark.asyncio
+ async def test_list_file_apps(
+ self,
+ async_client: AsyncClient,
+ admin_headers: dict[str, str],
+ setup_admin_app: dict[str, UUID],
+ ) -> None:
+ """管理员列出文件应用"""
+ response = await async_client.get(
+ "/api/v1/admin/file-app/list",
+ headers=admin_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert "apps" in data
+ assert data["total"] >= 1
+
+ @pytest.mark.asyncio
+ async def test_get_file_app_detail(
+ self,
+ async_client: AsyncClient,
+ admin_headers: dict[str, str],
+ setup_admin_app: dict[str, UUID],
+ ) -> None:
+ """管理员获取应用详情"""
+ app_id = setup_admin_app["app_id"]
+ response = await async_client.get(
+ f"/api/v1/admin/file-app/{app_id}",
+ headers=admin_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["app_key"] == "admin_test_app"
+ assert "test" in data["extensions"]
+
+ @pytest.mark.asyncio
+ async def test_get_nonexistent_app(
+ self,
+ async_client: AsyncClient,
+ admin_headers: dict[str, str],
+ ) -> None:
+ """获取不存在的应用返回 404"""
+ response = await async_client.get(
+ f"/api/v1/admin/file-app/{uuid4()}",
+ headers=admin_headers,
+ )
+ assert response.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_update_file_app(
+ self,
+ async_client: AsyncClient,
+ admin_headers: dict[str, str],
+ setup_admin_app: dict[str, UUID],
+ ) -> None:
+ """管理员更新应用"""
+ app_id = setup_admin_app["app_id"]
+ response = await async_client.patch(
+ f"/api/v1/admin/file-app/{app_id}",
+ headers=admin_headers,
+ json={
+ "name": "更新后的名称",
+ "is_enabled": False,
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["name"] == "更新后的名称"
+ assert data["is_enabled"] is False
+
+ @pytest.mark.asyncio
+ async def test_delete_file_app(
+ self,
+ async_client: AsyncClient,
+ initialized_db: AsyncSession,
+ admin_headers: dict[str, str],
+ ) -> None:
+ """管理员删除应用"""
+ # 先创建一个应用
+ app = FileApp(
+ name="待删除应用", app_key="to_delete_admin", type=FileAppType.BUILTIN
+ )
+ app = await app.save(initialized_db)
+ app_id = app.id
+
+ response = await async_client.delete(
+ f"/api/v1/admin/file-app/{app_id}",
+ headers=admin_headers,
+ )
+ assert response.status_code == 204
+
+ # 确认已删除
+ found = await FileApp.get(initialized_db, FileApp.id == app_id)
+ assert found is None
+
+
+# ==================== Extensions Management ====================
+
+class TestAdminExtensionManagement:
+ """管理员扩展名管理测试"""
+
+ @pytest.mark.asyncio
+ async def test_update_extensions(
+ self,
+ async_client: AsyncClient,
+ admin_headers: dict[str, str],
+ setup_admin_app: dict[str, UUID],
+ ) -> None:
+ """全量替换扩展名列表"""
+ app_id = setup_admin_app["app_id"]
+ response = await async_client.put(
+ f"/api/v1/admin/file-app/{app_id}/extensions",
+ headers=admin_headers,
+ json={"extensions": ["doc", "docx", "odt"]},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert sorted(data["extensions"]) == ["doc", "docx", "odt"]
+
+
+# ==================== Group Access Management ====================
+
+class TestAdminGroupAccessManagement:
+ """管理员用户组权限管理测试"""
+
+ @pytest.mark.asyncio
+ async def test_update_group_access(
+ self,
+ async_client: AsyncClient,
+ initialized_db: AsyncSession,
+ admin_headers: dict[str, str],
+ setup_admin_app: dict[str, UUID],
+ ) -> None:
+ """全量替换用户组权限"""
+ app_id = setup_admin_app["app_id"]
+ admin_user = await User.get(initialized_db, User.email == "admin@disknext.local")
+ group_id = admin_user.group_id
+
+ response = await async_client.put(
+ f"/api/v1/admin/file-app/{app_id}/groups",
+ headers=admin_headers,
+ json={"group_ids": [str(group_id)]},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert str(group_id) in data["allowed_group_ids"]
+
+
+# ==================== Permission Tests ====================
+
+class TestAdminPermission:
+ """权限校验测试"""
+
+ @pytest.mark.asyncio
+ async def test_non_admin_forbidden(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ ) -> None:
+ """普通用户访问管理端点返回 403"""
+ response = await async_client.get(
+ "/api/v1/admin/file-app/list",
+ headers=auth_headers,
+ )
+ assert response.status_code == 403
diff --git a/tests/integration/api/test_file_content.py b/tests/integration/api/test_file_content.py
new file mode 100644
index 0000000..1dc9da6
--- /dev/null
+++ b/tests/integration/api/test_file_content.py
@@ -0,0 +1,466 @@
+"""
+文本文件内容 GET/PATCH 集成测试
+
+测试 GET /file/content/{file_id} 和 PATCH /file/content/{file_id} 端点。
+"""
+import hashlib
+from pathlib import Path
+from uuid import uuid4
+
+import pytest
+import pytest_asyncio
+from httpx import AsyncClient
+from sqlalchemy import event
+from sqlalchemy.engine import Engine
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from sqlmodels import Object, ObjectType, PhysicalFile, Policy, User
+
+
+@pytest.fixture(autouse=True)
+def _register_sqlite_greatest():
+ """注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法"""
+
+ def _on_connect(dbapi_connection, connection_record):
+ if hasattr(dbapi_connection, 'create_function'):
+ dbapi_connection.create_function("greatest", 2, max)
+
+ event.listen(Engine, "connect", _on_connect)
+ yield
+ event.remove(Engine, "connect", _on_connect)
+
+
+# ==================== Fixtures ====================
+
+@pytest_asyncio.fixture
+async def local_policy(
+ initialized_db: AsyncSession,
+ tmp_path: Path,
+) -> Policy:
+ """创建指向临时目录的本地存储策略"""
+ from sqlmodels import PolicyType
+
+ policy = Policy(
+ id=uuid4(),
+ name="测试本地存储",
+ type=PolicyType.LOCAL,
+ server=str(tmp_path),
+ )
+ initialized_db.add(policy)
+ await initialized_db.commit()
+ await initialized_db.refresh(policy)
+ return policy
+
+
+@pytest_asyncio.fixture
+async def text_file(
+ initialized_db: AsyncSession,
+ tmp_path: Path,
+ local_policy: Policy,
+) -> dict[str, str | int]:
+ """创建包含 UTF-8 文本内容的测试文件"""
+ user = await User.get(initialized_db, User.email == "testuser@test.local")
+ root = await Object.get_root(initialized_db, user.id)
+
+ content = "line1\nline2\nline3\n"
+ content_bytes = content.encode('utf-8')
+ content_hash = hashlib.sha256(content_bytes).hexdigest()
+
+ file_path = tmp_path / "test.txt"
+ file_path.write_bytes(content_bytes)
+
+ physical_file = PhysicalFile(
+ id=uuid4(),
+ storage_path=str(file_path),
+ size=len(content_bytes),
+ policy_id=local_policy.id,
+ reference_count=1,
+ )
+ initialized_db.add(physical_file)
+
+ file_obj = Object(
+ id=uuid4(),
+ name="test.txt",
+ type=ObjectType.FILE,
+ size=len(content_bytes),
+ physical_file_id=physical_file.id,
+ parent_id=root.id,
+ owner_id=user.id,
+ policy_id=local_policy.id,
+ )
+ initialized_db.add(file_obj)
+ await initialized_db.commit()
+
+ return {
+ "id": str(file_obj.id),
+ "content": content,
+ "hash": content_hash,
+ "size": len(content_bytes),
+ "path": str(file_path),
+ }
+
+
+@pytest_asyncio.fixture
+async def binary_file(
+ initialized_db: AsyncSession,
+ tmp_path: Path,
+ local_policy: Policy,
+) -> dict[str, str | int]:
+ """创建非 UTF-8 的二进制测试文件"""
+ user = await User.get(initialized_db, User.email == "testuser@test.local")
+ root = await Object.get_root(initialized_db, user.id)
+
+ # 包含无效 UTF-8 字节序列
+ content_bytes = b'\x80\x81\x82\xff\xfe\xfd'
+
+ file_path = tmp_path / "binary.dat"
+ file_path.write_bytes(content_bytes)
+
+ physical_file = PhysicalFile(
+ id=uuid4(),
+ storage_path=str(file_path),
+ size=len(content_bytes),
+ policy_id=local_policy.id,
+ reference_count=1,
+ )
+ initialized_db.add(physical_file)
+
+ file_obj = Object(
+ id=uuid4(),
+ name="binary.dat",
+ type=ObjectType.FILE,
+ size=len(content_bytes),
+ physical_file_id=physical_file.id,
+ parent_id=root.id,
+ owner_id=user.id,
+ policy_id=local_policy.id,
+ )
+ initialized_db.add(file_obj)
+ await initialized_db.commit()
+
+ return {
+ "id": str(file_obj.id),
+ "path": str(file_path),
+ }
+
+
+# ==================== GET /file/content/{file_id} ====================
+
+class TestGetFileContent:
+ """GET /file/content/{file_id} 端点测试"""
+
+ @pytest.mark.asyncio
+ async def test_get_content_success(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ text_file: dict[str, str | int],
+ ) -> None:
+ """成功获取文本文件内容和哈希"""
+ response = await async_client.get(
+ f"/api/v1/file/content/{text_file['id']}",
+ headers=auth_headers,
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["content"] == text_file["content"]
+ assert data["hash"] == text_file["hash"]
+ assert data["size"] == text_file["size"]
+
+ @pytest.mark.asyncio
+ async def test_get_content_non_utf8_returns_400(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ binary_file: dict[str, str | int],
+ ) -> None:
+ """非 UTF-8 文件返回 400"""
+ response = await async_client.get(
+ f"/api/v1/file/content/{binary_file['id']}",
+ headers=auth_headers,
+ )
+
+ assert response.status_code == 400
+ assert "UTF-8" in response.json()["detail"]
+
+ @pytest.mark.asyncio
+ async def test_get_content_not_found(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ ) -> None:
+ """文件不存在返回 404"""
+ fake_id = uuid4()
+ response = await async_client.get(
+ f"/api/v1/file/content/{fake_id}",
+ headers=auth_headers,
+ )
+
+ assert response.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_get_content_unauthenticated(
+ self,
+ async_client: AsyncClient,
+ text_file: dict[str, str | int],
+ ) -> None:
+ """未认证返回 401"""
+ response = await async_client.get(
+ f"/api/v1/file/content/{text_file['id']}",
+ )
+
+ assert response.status_code == 401
+
+ @pytest.mark.asyncio
+ async def test_get_content_normalizes_crlf(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ initialized_db: AsyncSession,
+ tmp_path: Path,
+ local_policy: Policy,
+ ) -> None:
+ """CRLF 换行符被规范化为 LF"""
+ user = await User.get(initialized_db, User.email == "testuser@test.local")
+ root = await Object.get_root(initialized_db, user.id)
+
+ crlf_content = b"line1\r\nline2\r\n"
+ file_path = tmp_path / "crlf.txt"
+ file_path.write_bytes(crlf_content)
+
+ physical_file = PhysicalFile(
+ id=uuid4(),
+ storage_path=str(file_path),
+ size=len(crlf_content),
+ policy_id=local_policy.id,
+ reference_count=1,
+ )
+ initialized_db.add(physical_file)
+
+ file_obj = Object(
+ id=uuid4(),
+ name="crlf.txt",
+ type=ObjectType.FILE,
+ size=len(crlf_content),
+ physical_file_id=physical_file.id,
+ parent_id=root.id,
+ owner_id=user.id,
+ policy_id=local_policy.id,
+ )
+ initialized_db.add(file_obj)
+ await initialized_db.commit()
+
+ response = await async_client.get(
+ f"/api/v1/file/content/{file_obj.id}",
+ headers=auth_headers,
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ # 内容应该被规范化为 LF
+ assert data["content"] == "line1\nline2\n"
+ # 哈希基于规范化后的内容
+ expected_hash = hashlib.sha256("line1\nline2\n".encode('utf-8')).hexdigest()
+ assert data["hash"] == expected_hash
+
+
+# ==================== PATCH /file/content/{file_id} ====================
+
+class TestPatchFileContent:
+ """PATCH /file/content/{file_id} 端点测试"""
+
+ @pytest.mark.asyncio
+ async def test_patch_content_success(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ text_file: dict[str, str | int],
+ ) -> None:
+ """正常增量保存"""
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,3 +1,3 @@\n"
+ " line1\n"
+ "-line2\n"
+ "+LINE2_MODIFIED\n"
+ " line3\n"
+ )
+
+ response = await async_client.patch(
+ f"/api/v1/file/content/{text_file['id']}",
+ headers=auth_headers,
+ json={
+ "patch": patch_text,
+ "base_hash": text_file["hash"],
+ },
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert "new_hash" in data
+ assert "new_size" in data
+ assert data["new_hash"] != text_file["hash"]
+
+ # 验证文件实际被修改
+ file_path = Path(text_file["path"])
+ new_content = file_path.read_text(encoding='utf-8')
+ assert "LINE2_MODIFIED" in new_content
+ assert "line2" not in new_content
+
+ @pytest.mark.asyncio
+ async def test_patch_content_hash_mismatch_returns_409(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ text_file: dict[str, str | int],
+ ) -> None:
+ """base_hash 不匹配返回 409"""
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,3 +1,3 @@\n"
+ " line1\n"
+ "-line2\n"
+ "+changed\n"
+ " line3\n"
+ )
+
+ response = await async_client.patch(
+ f"/api/v1/file/content/{text_file['id']}",
+ headers=auth_headers,
+ json={
+ "patch": patch_text,
+ "base_hash": "0" * 64, # 错误的哈希
+ },
+ )
+
+ assert response.status_code == 409
+
+ @pytest.mark.asyncio
+ async def test_patch_content_invalid_patch_returns_422(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ text_file: dict[str, str | int],
+ ) -> None:
+ """无效的 patch 格式返回 422"""
+ response = await async_client.patch(
+ f"/api/v1/file/content/{text_file['id']}",
+ headers=auth_headers,
+ json={
+ "patch": "this is not a valid patch",
+ "base_hash": text_file["hash"],
+ },
+ )
+
+ assert response.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_patch_content_context_mismatch_returns_422(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ text_file: dict[str, str | int],
+ ) -> None:
+ """patch 上下文行不匹配返回 422"""
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,3 +1,3 @@\n"
+ " WRONG_CONTEXT_LINE\n"
+ "-line2\n"
+ "+replaced\n"
+ " line3\n"
+ )
+
+ response = await async_client.patch(
+ f"/api/v1/file/content/{text_file['id']}",
+ headers=auth_headers,
+ json={
+ "patch": patch_text,
+ "base_hash": text_file["hash"],
+ },
+ )
+
+ assert response.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_patch_content_unauthenticated(
+ self,
+ async_client: AsyncClient,
+ text_file: dict[str, str | int],
+ ) -> None:
+ """未认证返回 401"""
+ response = await async_client.patch(
+ f"/api/v1/file/content/{text_file['id']}",
+ json={
+ "patch": "--- a\n+++ b\n",
+ "base_hash": text_file["hash"],
+ },
+ )
+
+ assert response.status_code == 401
+
+ @pytest.mark.asyncio
+ async def test_patch_content_not_found(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ ) -> None:
+ """文件不存在返回 404"""
+ fake_id = uuid4()
+ response = await async_client.patch(
+ f"/api/v1/file/content/{fake_id}",
+ headers=auth_headers,
+ json={
+ "patch": "--- a\n+++ b\n",
+ "base_hash": "0" * 64,
+ },
+ )
+
+ assert response.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_patch_then_get_consistency(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ text_file: dict[str, str | int],
+ ) -> None:
+ """PATCH 后 GET 返回一致的内容和哈希"""
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,3 +1,3 @@\n"
+ " line1\n"
+ "-line2\n"
+ "+PATCHED\n"
+ " line3\n"
+ )
+
+ # PATCH
+ patch_resp = await async_client.patch(
+ f"/api/v1/file/content/{text_file['id']}",
+ headers=auth_headers,
+ json={
+ "patch": patch_text,
+ "base_hash": text_file["hash"],
+ },
+ )
+ assert patch_resp.status_code == 200
+ patch_data = patch_resp.json()
+
+ # GET
+ get_resp = await async_client.get(
+ f"/api/v1/file/content/{text_file['id']}",
+ headers=auth_headers,
+ )
+ assert get_resp.status_code == 200
+ get_data = get_resp.json()
+
+ # 一致性验证
+ assert get_data["hash"] == patch_data["new_hash"]
+ assert get_data["size"] == patch_data["new_size"]
+ assert "PATCHED" in get_data["content"]
diff --git a/tests/integration/api/test_file_viewers.py b/tests/integration/api/test_file_viewers.py
new file mode 100644
index 0000000..9734756
--- /dev/null
+++ b/tests/integration/api/test_file_viewers.py
@@ -0,0 +1,305 @@
+"""
+文件查看器集成测试
+
+测试查看器查询、用户默认设置、用户组过滤等端点。
+"""
+from uuid import UUID
+
+import pytest
+import pytest_asyncio
+from httpx import AsyncClient
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from sqlmodels.file_app import (
+ FileApp,
+ FileAppExtension,
+ FileAppGroupLink,
+ FileAppType,
+ UserFileAppDefault,
+)
+from sqlmodels.user import User
+
+
+# ==================== Fixtures ====================
+
+@pytest_asyncio.fixture
+async def setup_file_apps(
+ initialized_db: AsyncSession,
+) -> dict[str, UUID]:
+ """创建测试用文件查看器应用"""
+ # PDF 阅读器(不限制用户组)
+ pdf_app = FileApp(
+ name="PDF 阅读器",
+ app_key="pdfjs",
+ type=FileAppType.BUILTIN,
+ is_enabled=True,
+ is_restricted=False,
+ )
+ pdf_app = await pdf_app.save(initialized_db)
+
+ # Monaco 编辑器(不限制用户组)
+ monaco_app = FileApp(
+ name="代码编辑器",
+ app_key="monaco",
+ type=FileAppType.BUILTIN,
+ is_enabled=True,
+ is_restricted=False,
+ )
+ monaco_app = await monaco_app.save(initialized_db)
+
+ # Collabora(限制用户组)
+ collabora_app = FileApp(
+ name="Collabora",
+ app_key="collabora",
+ type=FileAppType.WOPI,
+ is_enabled=True,
+ is_restricted=True,
+ )
+ collabora_app = await collabora_app.save(initialized_db)
+
+ # 已禁用的应用
+ disabled_app = FileApp(
+ name="禁用的应用",
+ app_key="disabled_app",
+ type=FileAppType.BUILTIN,
+ is_enabled=False,
+ is_restricted=False,
+ )
+ disabled_app = await disabled_app.save(initialized_db)
+
+ # 创建扩展名
+ for ext in ["pdf"]:
+ await FileAppExtension(app_id=pdf_app.id, extension=ext, priority=0).save(initialized_db)
+
+ for ext in ["txt", "md", "json"]:
+ await FileAppExtension(app_id=monaco_app.id, extension=ext, priority=0).save(initialized_db)
+
+ for ext in ["docx", "xlsx", "pptx"]:
+ await FileAppExtension(app_id=collabora_app.id, extension=ext, priority=0).save(initialized_db)
+
+ for ext in ["pdf"]:
+ await FileAppExtension(app_id=disabled_app.id, extension=ext, priority=10).save(initialized_db)
+
+ return {
+ "pdf_app_id": pdf_app.id,
+ "monaco_app_id": monaco_app.id,
+ "collabora_app_id": collabora_app.id,
+ "disabled_app_id": disabled_app.id,
+ }
+
+
+# ==================== GET /file/viewers ====================
+
+class TestGetViewers:
+ """查询可用查看器测试"""
+
+ @pytest.mark.asyncio
+ async def test_get_viewers_for_pdf(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """查询 PDF 查看器:返回已启用的,排除已禁用的"""
+ response = await async_client.get(
+ "/api/v1/file/viewers?ext=pdf",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200
+
+ data = response.json()
+ assert "viewers" in data
+ viewer_keys = [v["app_key"] for v in data["viewers"]]
+
+ # pdfjs 应该在列表中
+ assert "pdfjs" in viewer_keys
+ # 禁用的应用不应出现
+ assert "disabled_app" not in viewer_keys
+ # 默认值应为 None
+ assert data["default_viewer_id"] is None
+
+ @pytest.mark.asyncio
+ async def test_get_viewers_normalizes_extension(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """扩展名规范化:.PDF → pdf"""
+ response = await async_client.get(
+ "/api/v1/file/viewers?ext=.PDF",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert len(data["viewers"]) >= 1
+
+ @pytest.mark.asyncio
+ async def test_get_viewers_empty_for_unknown_ext(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """未知扩展名返回空列表"""
+ response = await async_client.get(
+ "/api/v1/file/viewers?ext=xyz_unknown",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["viewers"] == []
+
+ @pytest.mark.asyncio
+ async def test_group_restriction_filters_app(
+ self,
+ async_client: AsyncClient,
+ initialized_db: AsyncSession,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """用户组限制:collabora 限制了用户组,用户不在白名单内则不可见"""
+ # collabora 是受限的,用户组不在白名单中
+ response = await async_client.get(
+ "/api/v1/file/viewers?ext=docx",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ viewer_keys = [v["app_key"] for v in data["viewers"]]
+ assert "collabora" not in viewer_keys
+
+ # 将用户组加入白名单
+ test_user = await User.get(initialized_db, User.email == "testuser@test.local")
+ link = FileAppGroupLink(
+ app_id=setup_file_apps["collabora_app_id"],
+ group_id=test_user.group_id,
+ )
+ initialized_db.add(link)
+ await initialized_db.commit()
+
+ # 再次查询
+ response = await async_client.get(
+ "/api/v1/file/viewers?ext=docx",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ viewer_keys = [v["app_key"] for v in data["viewers"]]
+ assert "collabora" in viewer_keys
+
+ @pytest.mark.asyncio
+ async def test_unauthorized_without_token(
+ self,
+ async_client: AsyncClient,
+ ) -> None:
+ """未认证请求返回 401"""
+ response = await async_client.get("/api/v1/file/viewers?ext=pdf")
+ assert response.status_code in (401, 403)
+
+
+# ==================== User File Viewer Defaults ====================
+
+class TestUserFileViewerDefaults:
+ """用户默认查看器设置测试"""
+
+ @pytest.mark.asyncio
+ async def test_set_default_viewer(
+ self,
+ async_client: AsyncClient,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """设置默认查看器"""
+ response = await async_client.put(
+ "/api/v1/user/settings/file-viewers/default",
+ headers=auth_headers,
+ json={
+ "extension": "pdf",
+ "app_id": str(setup_file_apps["pdf_app_id"]),
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["extension"] == "pdf"
+ assert data["app"]["app_key"] == "pdfjs"
+
+ @pytest.mark.asyncio
+ async def test_list_default_viewers(
+ self,
+ async_client: AsyncClient,
+ initialized_db: AsyncSession,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """列出默认查看器"""
+ # 先创建一个默认
+ test_user = await User.get(initialized_db, User.email == "testuser@test.local")
+ await UserFileAppDefault(
+ user_id=test_user.id,
+ extension="pdf",
+ app_id=setup_file_apps["pdf_app_id"],
+ ).save(initialized_db)
+
+ response = await async_client.get(
+ "/api/v1/user/settings/file-viewers/defaults",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert isinstance(data, list)
+ assert len(data) >= 1
+
+ @pytest.mark.asyncio
+ async def test_delete_default_viewer(
+ self,
+ async_client: AsyncClient,
+ initialized_db: AsyncSession,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """撤销默认查看器"""
+ # 创建一个默认
+ test_user = await User.get(initialized_db, User.email == "testuser@test.local")
+ default = await UserFileAppDefault(
+ user_id=test_user.id,
+ extension="txt",
+ app_id=setup_file_apps["monaco_app_id"],
+ ).save(initialized_db)
+
+ response = await async_client.delete(
+ f"/api/v1/user/settings/file-viewers/default/{default.id}",
+ headers=auth_headers,
+ )
+ assert response.status_code == 204
+
+ # 验证已删除
+ found = await UserFileAppDefault.get(
+ initialized_db, UserFileAppDefault.id == default.id
+ )
+ assert found is None
+
+ @pytest.mark.asyncio
+ async def test_get_viewers_includes_default(
+ self,
+ async_client: AsyncClient,
+ initialized_db: AsyncSession,
+ auth_headers: dict[str, str],
+ setup_file_apps: dict[str, UUID],
+ ) -> None:
+ """查看器查询应包含用户默认选择"""
+ # 设置默认
+ test_user = await User.get(initialized_db, User.email == "testuser@test.local")
+ await UserFileAppDefault(
+ user_id=test_user.id,
+ extension="pdf",
+ app_id=setup_file_apps["pdf_app_id"],
+ ).save(initialized_db)
+
+ response = await async_client.get(
+ "/api/v1/file/viewers?ext=pdf",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["default_viewer_id"] == str(setup_file_apps["pdf_app_id"])
diff --git a/tests/unit/models/test_file_app.py b/tests/unit/models/test_file_app.py
new file mode 100644
index 0000000..77c4f21
--- /dev/null
+++ b/tests/unit/models/test_file_app.py
@@ -0,0 +1,386 @@
+"""
+FileApp 模型单元测试
+
+测试 FileApp、FileAppExtension、UserFileAppDefault 的 CRUD 和约束。
+"""
+from uuid import UUID
+
+import pytest
+import pytest_asyncio
+from sqlalchemy import select
+from sqlalchemy.exc import IntegrityError
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from sqlmodels.file_app import (
+ FileApp,
+ FileAppExtension,
+ FileAppGroupLink,
+ FileAppType,
+ UserFileAppDefault,
+)
+from sqlmodels.group import Group
+from sqlmodels.user import User, UserStatus
+from sqlmodels.policy import Policy, PolicyType
+
+
+# ==================== Fixtures ====================
+
+@pytest_asyncio.fixture
+async def sample_group(db_session: AsyncSession) -> Group:
+ """创建测试用户组"""
+ group = Group(name="测试组", max_storage=0, admin=False)
+ return await group.save(db_session)
+
+
+@pytest_asyncio.fixture
+async def sample_user(db_session: AsyncSession, sample_group: Group) -> User:
+ """创建测试用户"""
+ user = User(
+ email="fileapp_test@test.local",
+ nickname="文件应用测试用户",
+ status=UserStatus.ACTIVE,
+ group_id=sample_group.id,
+ )
+ return await user.save(db_session)
+
+
+@pytest_asyncio.fixture
+async def sample_app(db_session: AsyncSession) -> FileApp:
+ """创建测试文件应用"""
+ app = FileApp(
+ name="测试PDF阅读器",
+ app_key="test_pdfjs",
+ type=FileAppType.BUILTIN,
+ icon="file-pdf",
+ description="测试用 PDF 阅读器",
+ is_enabled=True,
+ is_restricted=False,
+ )
+ return await app.save(db_session)
+
+
+@pytest_asyncio.fixture
+async def sample_app_with_extensions(db_session: AsyncSession, sample_app: FileApp) -> FileApp:
+ """创建带扩展名的文件应用"""
+ ext1 = FileAppExtension(app_id=sample_app.id, extension="pdf", priority=0)
+ ext2 = FileAppExtension(app_id=sample_app.id, extension="djvu", priority=1)
+ await ext1.save(db_session)
+ await ext2.save(db_session)
+ return sample_app
+
+
+# ==================== FileApp CRUD ====================
+
+class TestFileAppCRUD:
+ """FileApp 基础 CRUD 测试"""
+
+ async def test_create_file_app(self, db_session: AsyncSession) -> None:
+ """测试创建文件应用"""
+ app = FileApp(
+ name="Monaco 编辑器",
+ app_key="monaco",
+ type=FileAppType.BUILTIN,
+ description="代码编辑器",
+ is_enabled=True,
+ )
+ app = await app.save(db_session)
+
+ assert app.id is not None
+ assert app.name == "Monaco 编辑器"
+ assert app.app_key == "monaco"
+ assert app.type == FileAppType.BUILTIN
+ assert app.is_enabled is True
+ assert app.is_restricted is False
+
+ async def test_get_file_app_by_key(self, db_session: AsyncSession, sample_app: FileApp) -> None:
+ """测试按 app_key 查询"""
+ found = await FileApp.get(db_session, FileApp.app_key == "test_pdfjs")
+ assert found is not None
+ assert found.id == sample_app.id
+
+ async def test_unique_app_key(self, db_session: AsyncSession, sample_app: FileApp) -> None:
+ """测试 app_key 唯一约束"""
+ dup = FileApp(
+ name="重复应用",
+ app_key="test_pdfjs",
+ type=FileAppType.BUILTIN,
+ )
+ with pytest.raises(IntegrityError):
+ await dup.save(db_session)
+
+ async def test_update_file_app(self, db_session: AsyncSession, sample_app: FileApp) -> None:
+ """测试更新文件应用"""
+ sample_app.name = "更新后的名称"
+ sample_app.is_enabled = False
+ sample_app = await sample_app.save(db_session)
+
+ found = await FileApp.get(db_session, FileApp.id == sample_app.id)
+ assert found.name == "更新后的名称"
+ assert found.is_enabled is False
+
+ async def test_delete_file_app(self, db_session: AsyncSession) -> None:
+ """测试删除文件应用"""
+ app = FileApp(
+ name="待删除应用",
+ app_key="to_delete",
+ type=FileAppType.IFRAME,
+ )
+ app = await app.save(db_session)
+ app_id = app.id
+
+ await FileApp.delete(db_session, app)
+
+ found = await FileApp.get(db_session, FileApp.id == app_id)
+ assert found is None
+
+ async def test_create_wopi_app(self, db_session: AsyncSession) -> None:
+ """测试创建 WOPI 类型应用"""
+ app = FileApp(
+ name="Collabora",
+ app_key="collabora",
+ type=FileAppType.WOPI,
+ wopi_discovery_url="http://collabora:9980/hosting/discovery",
+ wopi_editor_url_template="http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc={wopi_src}&access_token={access_token}",
+ is_enabled=True,
+ )
+ app = await app.save(db_session)
+
+ assert app.type == FileAppType.WOPI
+ assert app.wopi_discovery_url is not None
+ assert app.wopi_editor_url_template is not None
+
+ async def test_create_iframe_app(self, db_session: AsyncSession) -> None:
+ """测试创建 iframe 类型应用"""
+ app = FileApp(
+ name="Office 在线预览",
+ app_key="office_viewer",
+ type=FileAppType.IFRAME,
+ iframe_url_template="https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
+ is_enabled=False,
+ )
+ app = await app.save(db_session)
+
+ assert app.type == FileAppType.IFRAME
+ assert "{file_url}" in app.iframe_url_template
+
+ async def test_to_summary(self, db_session: AsyncSession, sample_app: FileApp) -> None:
+ """测试转换为摘要 DTO"""
+ summary = sample_app.to_summary()
+ assert summary.id == sample_app.id
+ assert summary.name == sample_app.name
+ assert summary.app_key == sample_app.app_key
+ assert summary.type == sample_app.type
+
+
+# ==================== FileAppExtension ====================
+
+class TestFileAppExtension:
+ """FileAppExtension 测试"""
+
+ async def test_create_extension(self, db_session: AsyncSession, sample_app: FileApp) -> None:
+ """测试创建扩展名关联"""
+ ext = FileAppExtension(
+ app_id=sample_app.id,
+ extension="pdf",
+ priority=0,
+ )
+ ext = await ext.save(db_session)
+
+ assert ext.id is not None
+ assert ext.extension == "pdf"
+ assert ext.priority == 0
+
+ async def test_query_by_extension(
+ self, db_session: AsyncSession, sample_app_with_extensions: FileApp
+ ) -> None:
+ """测试按扩展名查询"""
+ results: list[FileAppExtension] = await FileAppExtension.get(
+ db_session,
+ FileAppExtension.extension == "pdf",
+ fetch_mode="all",
+ )
+ assert len(results) >= 1
+ assert any(r.app_id == sample_app_with_extensions.id for r in results)
+
+ async def test_unique_app_extension(self, db_session: AsyncSession, sample_app: FileApp) -> None:
+ """测试 (app_id, extension) 唯一约束"""
+ ext1 = FileAppExtension(app_id=sample_app.id, extension="txt", priority=0)
+ await ext1.save(db_session)
+
+ ext2 = FileAppExtension(app_id=sample_app.id, extension="txt", priority=1)
+ with pytest.raises(IntegrityError):
+ await ext2.save(db_session)
+
+ async def test_cascade_delete(
+ self, db_session: AsyncSession, sample_app_with_extensions: FileApp
+ ) -> None:
+ """测试级联删除:删除应用时扩展名也被删除"""
+ app_id = sample_app_with_extensions.id
+
+ # 确认扩展名存在
+ exts = await FileAppExtension.get(
+ db_session,
+ FileAppExtension.app_id == app_id,
+ fetch_mode="all",
+ )
+ assert len(exts) == 2
+
+ # 删除应用
+ await FileApp.delete(db_session, sample_app_with_extensions)
+
+ # 确认扩展名也被删除
+ exts = await FileAppExtension.get(
+ db_session,
+ FileAppExtension.app_id == app_id,
+ fetch_mode="all",
+ )
+ assert len(exts) == 0
+
+
+# ==================== FileAppGroupLink ====================
+
+class TestFileAppGroupLink:
+ """FileAppGroupLink 用户组访问控制测试"""
+
+ async def test_create_group_link(
+ self, db_session: AsyncSession, sample_app: FileApp, sample_group: Group
+ ) -> None:
+ """测试创建用户组关联"""
+ link = FileAppGroupLink(app_id=sample_app.id, group_id=sample_group.id)
+ db_session.add(link)
+ await db_session.commit()
+
+ result = await db_session.exec(
+ select(FileAppGroupLink).where(
+ FileAppGroupLink.app_id == sample_app.id,
+ FileAppGroupLink.group_id == sample_group.id,
+ )
+ )
+ found = result.first()
+ assert found is not None
+
+ async def test_multiple_groups(self, db_session: AsyncSession, sample_app: FileApp) -> None:
+ """测试一个应用关联多个用户组"""
+ group1 = Group(name="组A", admin=False)
+ group1 = await group1.save(db_session)
+ group2 = Group(name="组B", admin=False)
+ group2 = await group2.save(db_session)
+
+ db_session.add(FileAppGroupLink(app_id=sample_app.id, group_id=group1.id))
+ db_session.add(FileAppGroupLink(app_id=sample_app.id, group_id=group2.id))
+ await db_session.commit()
+
+ result = await db_session.exec(
+ select(FileAppGroupLink).where(FileAppGroupLink.app_id == sample_app.id)
+ )
+ links = result.all()
+ assert len(links) == 2
+
+
+# ==================== UserFileAppDefault ====================
+
+class TestUserFileAppDefault:
+ """UserFileAppDefault 用户偏好测试"""
+
+ async def test_create_default(
+ self, db_session: AsyncSession, sample_app: FileApp, sample_user: User
+ ) -> None:
+ """测试创建用户默认偏好"""
+ default = UserFileAppDefault(
+ user_id=sample_user.id,
+ extension="pdf",
+ app_id=sample_app.id,
+ )
+ default = await default.save(db_session)
+
+ assert default.id is not None
+ assert default.extension == "pdf"
+
+ async def test_unique_user_extension(
+ self, db_session: AsyncSession, sample_app: FileApp, sample_user: User
+ ) -> None:
+ """测试 (user_id, extension) 唯一约束"""
+ default1 = UserFileAppDefault(
+ user_id=sample_user.id, extension="pdf", app_id=sample_app.id
+ )
+ await default1.save(db_session)
+
+ # 创建另一个应用
+ app2 = FileApp(
+ name="另一个阅读器",
+ app_key="pdf_alt",
+ type=FileAppType.BUILTIN,
+ )
+ app2 = await app2.save(db_session)
+
+ default2 = UserFileAppDefault(
+ user_id=sample_user.id, extension="pdf", app_id=app2.id
+ )
+ with pytest.raises(IntegrityError):
+ await default2.save(db_session)
+
+ async def test_cascade_delete_on_app(
+ self, db_session: AsyncSession, sample_user: User
+ ) -> None:
+ """测试级联删除:删除应用时用户偏好也被删除"""
+ app = FileApp(
+ name="待删除应用2",
+ app_key="to_delete_2",
+ type=FileAppType.BUILTIN,
+ )
+ app = await app.save(db_session)
+ app_id = app.id
+
+ default = UserFileAppDefault(
+ user_id=sample_user.id, extension="xyz", app_id=app_id
+ )
+ await default.save(db_session)
+
+ # 确认存在
+ found = await UserFileAppDefault.get(
+ db_session, UserFileAppDefault.app_id == app_id
+ )
+ assert found is not None
+
+ # 删除应用
+ await FileApp.delete(db_session, app)
+
+ # 确认用户偏好也被删除
+ found = await UserFileAppDefault.get(
+ db_session, UserFileAppDefault.app_id == app_id
+ )
+ assert found is None
+
+
+# ==================== DTO ====================
+
+class TestFileAppDTO:
+ """DTO 模型测试"""
+
+ async def test_file_app_response_from_app(
+ self, db_session: AsyncSession, sample_app_with_extensions: FileApp, sample_group: Group
+ ) -> None:
+ """测试 FileAppResponse.from_app()"""
+ from sqlmodels.file_app import FileAppResponse
+
+ extensions = await FileAppExtension.get(
+ db_session,
+ FileAppExtension.app_id == sample_app_with_extensions.id,
+ fetch_mode="all",
+ )
+
+ # 直接构造 link 对象用于 DTO 测试,无需持久化
+ link = FileAppGroupLink(
+ app_id=sample_app_with_extensions.id,
+ group_id=sample_group.id,
+ )
+
+ response = FileAppResponse.from_app(
+ sample_app_with_extensions, extensions, [link]
+ )
+
+ assert response.id == sample_app_with_extensions.id
+ assert response.app_key == "test_pdfjs"
+ assert "pdf" in response.extensions
+ assert "djvu" in response.extensions
+ assert sample_group.id in response.allowed_group_ids
diff --git a/tests/unit/models/test_setting.py b/tests/unit/models/test_setting.py
index 2a3aa26..5b860d7 100644
--- a/tests/unit/models/test_setting.py
+++ b/tests/unit/models/test_setting.py
@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
setting = await setting.save(db_session)
# 更新值
- from sqlmodels.base import SQLModelBase
+ from sqlmodel_ext import SQLModelBase
class SettingUpdate(SQLModelBase):
value: str | None = None
diff --git a/tests/unit/utils/test_patch.py b/tests/unit/utils/test_patch.py
new file mode 100644
index 0000000..29b5f80
--- /dev/null
+++ b/tests/unit/utils/test_patch.py
@@ -0,0 +1,178 @@
+"""
+文本文件 patch 逻辑单元测试
+
+测试 whatthepatch 库的 patch 解析与应用,
+以及换行符规范化和 SHA-256 哈希计算。
+"""
+import hashlib
+
+import pytest
+import whatthepatch
+from whatthepatch.exceptions import HunkApplyException
+
+
+class TestPatchApply:
+ """测试 patch 解析与应用"""
+
+ def test_normal_patch(self) -> None:
+ """正常 patch 应用"""
+ original = "line1\nline2\nline3"
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,3 +1,3 @@\n"
+ " line1\n"
+ "-line2\n"
+ "+LINE2\n"
+ " line3\n"
+ )
+
+ diffs = list(whatthepatch.parse_patch(patch_text))
+ assert len(diffs) == 1
+
+ result = whatthepatch.apply_diff(diffs[0], original)
+ new_text = '\n'.join(result)
+
+ assert "LINE2" in new_text
+ assert "line2" not in new_text
+
+ def test_add_lines_patch(self) -> None:
+ """添加行的 patch"""
+ original = "line1\nline2"
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,2 +1,3 @@\n"
+ " line1\n"
+ " line2\n"
+ "+line3\n"
+ )
+
+ diffs = list(whatthepatch.parse_patch(patch_text))
+ result = whatthepatch.apply_diff(diffs[0], original)
+ new_text = '\n'.join(result)
+
+ assert "line3" in new_text
+
+ def test_delete_lines_patch(self) -> None:
+ """删除行的 patch"""
+ original = "line1\nline2\nline3"
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,3 +1,2 @@\n"
+ " line1\n"
+ "-line2\n"
+ " line3\n"
+ )
+
+ diffs = list(whatthepatch.parse_patch(patch_text))
+ result = whatthepatch.apply_diff(diffs[0], original)
+ new_text = '\n'.join(result)
+
+ assert "line2" not in new_text
+ assert "line1" in new_text
+ assert "line3" in new_text
+
+ def test_invalid_patch_format(self) -> None:
+ """无效的 patch 格式返回空列表"""
+ diffs = list(whatthepatch.parse_patch("this is not a patch"))
+ assert len(diffs) == 0
+
+ def test_patch_context_mismatch(self) -> None:
+ """patch 上下文不匹配时抛出异常"""
+ original = "line1\nline2\nline3\n"
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -1,3 +1,3 @@\n"
+ " line1\n"
+ "-WRONG\n"
+ "+REPLACED\n"
+ " line3\n"
+ )
+
+ diffs = list(whatthepatch.parse_patch(patch_text))
+ with pytest.raises(HunkApplyException):
+ whatthepatch.apply_diff(diffs[0], original)
+
+ def test_empty_file_patch(self) -> None:
+ """空文件应用 patch"""
+ original = ""
+ patch_text = (
+ "--- a\n"
+ "+++ b\n"
+ "@@ -0,0 +1,2 @@\n"
+ "+line1\n"
+ "+line2\n"
+ )
+
+ diffs = list(whatthepatch.parse_patch(patch_text))
+ result = whatthepatch.apply_diff(diffs[0], original)
+ new_text = '\n'.join(result)
+
+ assert "line1" in new_text
+ assert "line2" in new_text
+
+
+class TestHashComputation:
+ """测试 SHA-256 哈希计算"""
+
+ def test_hash_consistency(self) -> None:
+ """相同内容产生相同哈希"""
+ content = "hello world\n"
+ content_bytes = content.encode('utf-8')
+ hash1 = hashlib.sha256(content_bytes).hexdigest()
+ hash2 = hashlib.sha256(content_bytes).hexdigest()
+
+ assert hash1 == hash2
+ assert len(hash1) == 64
+
+ def test_hash_differs_for_different_content(self) -> None:
+ """不同内容产生不同哈希"""
+ hash1 = hashlib.sha256(b"content A").hexdigest()
+ hash2 = hashlib.sha256(b"content B").hexdigest()
+
+ assert hash1 != hash2
+
+ def test_hash_after_normalization(self) -> None:
+ """换行符规范化后的哈希一致性"""
+ content_crlf = "line1\r\nline2\r\n"
+ content_lf = "line1\nline2\n"
+
+ # 规范化后应相同
+ normalized = content_crlf.replace('\r\n', '\n').replace('\r', '\n')
+ assert normalized == content_lf
+
+ hash_normalized = hashlib.sha256(normalized.encode('utf-8')).hexdigest()
+ hash_lf = hashlib.sha256(content_lf.encode('utf-8')).hexdigest()
+
+ assert hash_normalized == hash_lf
+
+
+class TestLineEndingNormalization:
+ """测试换行符规范化"""
+
+ def test_crlf_to_lf(self) -> None:
+ """CRLF 转换为 LF"""
+ content = "line1\r\nline2\r\n"
+ normalized = content.replace('\r\n', '\n').replace('\r', '\n')
+ assert normalized == "line1\nline2\n"
+
+ def test_cr_to_lf(self) -> None:
+ """CR 转换为 LF"""
+ content = "line1\rline2\r"
+ normalized = content.replace('\r\n', '\n').replace('\r', '\n')
+ assert normalized == "line1\nline2\n"
+
+ def test_lf_unchanged(self) -> None:
+ """LF 保持不变"""
+ content = "line1\nline2\n"
+ normalized = content.replace('\r\n', '\n').replace('\r', '\n')
+ assert normalized == content
+
+ def test_mixed_line_endings(self) -> None:
+ """混合换行符统一为 LF"""
+ content = "line1\r\nline2\rline3\n"
+ normalized = content.replace('\r\n', '\n').replace('\r', '\n')
+ assert normalized == "line1\nline2\nline3\n"
diff --git a/tests/unit/utils/test_wopi_token.py b/tests/unit/utils/test_wopi_token.py
new file mode 100644
index 0000000..b370c42
--- /dev/null
+++ b/tests/unit/utils/test_wopi_token.py
@@ -0,0 +1,77 @@
+"""
+WOPI Token 单元测试
+
+测试 WOPI 访问令牌的生成和验证。
+"""
+from uuid import uuid4
+
+import pytest
+
+import utils.JWT as JWT
+from utils.JWT.wopi_token import create_wopi_token, verify_wopi_token
+
+# 确保测试 secret key
+JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
+
+
+class TestWopiToken:
+ """WOPI Token 测试"""
+
+ def test_create_and_verify_token(self) -> None:
+ """创建和验证令牌"""
+ file_id = uuid4()
+ user_id = uuid4()
+
+ token, ttl = create_wopi_token(file_id, user_id, can_write=True)
+
+ assert isinstance(token, str)
+ assert isinstance(ttl, int)
+ assert ttl > 0
+
+ payload = verify_wopi_token(token)
+ assert payload is not None
+ assert payload.file_id == file_id
+ assert payload.user_id == user_id
+ assert payload.can_write is True
+
+ def test_verify_read_only_token(self) -> None:
+ """验证只读令牌"""
+ file_id = uuid4()
+ user_id = uuid4()
+
+ token, ttl = create_wopi_token(file_id, user_id, can_write=False)
+
+ payload = verify_wopi_token(token)
+ assert payload is not None
+ assert payload.can_write is False
+
+ def test_verify_invalid_token(self) -> None:
+ """验证无效令牌返回 None"""
+ payload = verify_wopi_token("invalid_token_string")
+ assert payload is None
+
+ def test_verify_non_wopi_token(self) -> None:
+ """验证非 WOPI 类型令牌返回 None"""
+ import jwt as pyjwt
+ # 创建一个不含 type=wopi 的令牌
+ token = pyjwt.encode(
+ {"file_id": str(uuid4()), "user_id": str(uuid4()), "type": "download"},
+ JWT.SECRET_KEY,
+ algorithm="HS256",
+ )
+ payload = verify_wopi_token(token)
+ assert payload is None
+
+ def test_ttl_is_future_milliseconds(self) -> None:
+ """TTL 应为未来的毫秒时间戳"""
+ import time
+
+ file_id = uuid4()
+ user_id = uuid4()
+ token, ttl = create_wopi_token(file_id, user_id)
+
+ current_ms = int(time.time() * 1000)
+ # TTL 应大于当前时间
+ assert ttl > current_ms
+ # TTL 不应超过 11 小时后(10h + 余量)
+ assert ttl < current_ms + 11 * 3600 * 1000
diff --git a/utils/JWT/__init__.py b/utils/JWT/__init__.py
index 1f8b49c..a91377d 100644
--- a/utils/JWT/__init__.py
+++ b/utils/JWT/__init__.py
@@ -25,11 +25,11 @@ async def load_secret_key() -> None:
从数据库读取 JWT 的密钥。
"""
# 延迟导入以避免循环依赖
- from sqlmodels.database import get_session
+ from sqlmodels.database_connection import DatabaseManager
from sqlmodels.setting import Setting
global SECRET_KEY
- async for session in get_session():
+ async for session in DatabaseManager.get_session():
setting: Setting = await Setting.get(
session,
(Setting.type == "auth") & (Setting.name == "secret_key")
diff --git a/utils/JWT/wopi_token.py b/utils/JWT/wopi_token.py
new file mode 100644
index 0000000..846379e
--- /dev/null
+++ b/utils/JWT/wopi_token.py
@@ -0,0 +1,67 @@
+"""
+WOPI 访问令牌生成与验证。
+
+使用 JWT 签名,payload 包含 file_id, user_id, can_write, exp。
+TTL 默认 10 小时(WOPI 规范推荐长 TTL)。
+"""
+from datetime import datetime, timedelta, timezone
+from uuid import UUID, uuid4
+
+import jwt
+
+from sqlmodels.wopi import WopiAccessTokenPayload
+
+WOPI_TOKEN_TTL = timedelta(hours=10)
+"""WOPI 令牌有效期"""
+
+
+def create_wopi_token(
+ file_id: UUID,
+ user_id: UUID,
+ can_write: bool = False,
+) -> tuple[str, int]:
+ """
+ 创建 WOPI 访问令牌。
+
+ :param file_id: 文件UUID
+ :param user_id: 用户UUID
+ :param can_write: 是否可写
+ :return: (token_string, access_token_ttl_ms)
+ """
+ from utils.JWT import SECRET_KEY
+
+ expire = datetime.now(timezone.utc) + WOPI_TOKEN_TTL
+ payload = {
+ "jti": str(uuid4()),
+ "file_id": str(file_id),
+ "user_id": str(user_id),
+ "can_write": can_write,
+ "exp": expire,
+ "type": "wopi",
+ }
+ token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
+ # WOPI 规范要求 access_token_ttl 是毫秒级的 UNIX 时间戳
+ access_token_ttl = int(expire.timestamp() * 1000)
+ return token, access_token_ttl
+
+
+def verify_wopi_token(token: str) -> WopiAccessTokenPayload | None:
+ """
+ 验证 WOPI 访问令牌并返回 payload。
+
+ :param token: JWT 令牌字符串
+ :return: WopiAccessTokenPayload 或 None(验证失败)
+ """
+ from utils.JWT import SECRET_KEY
+
+ try:
+ payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
+ if payload.get("type") != "wopi":
+ return None
+ return WopiAccessTokenPayload(
+ file_id=UUID(payload["file_id"]),
+ user_id=UUID(payload["user_id"]),
+ can_write=payload.get("can_write", False),
+ )
+ except (jwt.InvalidTokenError, KeyError, ValueError):
+ return None
diff --git a/utils/http/http_exceptions.py b/utils/http/http_exceptions.py
index d7a594f..891656c 100644
--- a/utils/http/http_exceptions.py
+++ b/utils/http/http_exceptions.py
@@ -40,6 +40,10 @@ def raise_conflict(detail: str | None = None) -> NoReturn:
"""Raises an HTTP 409 Conflict exception."""
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=detail)
+def raise_unprocessable_entity(detail: str | None = None) -> NoReturn:
+ """Raises an HTTP 422 Unprocessable Content exception."""
+ raise HTTPException(status_code=422, detail=detail)
+
def raise_precondition_required(detail: str | None = None) -> NoReturn:
"""Raises an HTTP 428 Precondition required exception."""
raise HTTPException(status_code=status.HTTP_428_PRECONDITION_REQUIRED, detail=detail)
diff --git a/uv.lock b/uv.lock
index eab5695..d6d2623 100644
--- a/uv.lock
+++ b/uv.lock
@@ -474,6 +474,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e8/cb/2da4cc83f5edb9c3257d09e1e7ab7b23f049c7962cae8d842bbef0a9cec9/cryptography-46.0.3-cp38-abi3-win_arm64.whl", hash = "sha256:d89c3468de4cdc4f08a57e214384d0471911a3830fcdaf7a8cc587e42a866372", size = 2918740, upload-time = "2025-10-15T23:18:12.277Z" },
]
+[[package]]
+name = "cython"
+version = "3.2.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/91/85/7574c9cd44b69a27210444b6650f6477f56c75fee1b70d7672d3e4166167/cython-3.2.4.tar.gz", hash = "sha256:84226ecd313b233da27dc2eb3601b4f222b8209c3a7216d8733b031da1dc64e6", size = 3280291, upload-time = "2026-01-04T14:14:14.473Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/18/b5/1cfca43b7d20a0fdb1eac67313d6bb6b18d18897f82dd0f17436bdd2ba7f/cython-3.2.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:28e8075087a59756f2d059273184b8b639fe0f16cf17470bd91c39921bc154e0", size = 2960506, upload-time = "2026-01-04T14:15:16.733Z" },
+ { url = "https://files.pythonhosted.org/packages/71/bb/8f28c39c342621047fea349a82fac712a5e2b37546d2f737bbde48d5143d/cython-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03893c88299a2c868bb741ba6513357acd104e7c42265809fd58dce1456a36fc", size = 3213148, upload-time = "2026-01-04T14:15:18.804Z" },
+ { url = "https://files.pythonhosted.org/packages/7a/d2/16fa02f129ed2b627e88d9d9ebd5ade3eeb66392ae5ba85b259d2d52b047/cython-3.2.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f81eda419b5ada7b197bbc3c5f4494090e3884521ffd75a3876c93fbf66c9ca8", size = 3375764, upload-time = "2026-01-04T14:15:20.817Z" },
+ { url = "https://files.pythonhosted.org/packages/91/3f/deb8f023a5c10c0649eb81332a58c180fad27c7533bb4aae138b5bc34d92/cython-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:83266c356c13c68ffe658b4905279c993d8a5337bb0160fa90c8a3e297ea9a2e", size = 2754238, upload-time = "2026-01-04T14:15:23.001Z" },
+ { url = "https://files.pythonhosted.org/packages/ee/d7/3bda3efce0c5c6ce79cc21285dbe6f60369c20364e112f5a506ee8a1b067/cython-3.2.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d4b4fd5332ab093131fa6172e8362f16adef3eac3179fd24bbdc392531cb82fa", size = 2971496, upload-time = "2026-01-04T14:15:25.038Z" },
+ { url = "https://files.pythonhosted.org/packages/89/ed/1021ffc80b9c4720b7ba869aea8422c82c84245ef117ebe47a556bdc00c3/cython-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3b5ac54e95f034bc7fb07313996d27cbf71abc17b229b186c1540942d2dc28e", size = 3256146, upload-time = "2026-01-04T14:15:26.741Z" },
+ { url = "https://files.pythonhosted.org/packages/0c/51/ca221ec7e94b3c5dc4138dcdcbd41178df1729c1e88c5dfb25f9d30ba3da/cython-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:90f43be4eaa6afd58ce20d970bb1657a3627c44e1760630b82aa256ba74b4acb", size = 3383458, upload-time = "2026-01-04T14:15:28.425Z" },
+ { url = "https://files.pythonhosted.org/packages/79/2e/1388fc0243240cd54994bb74f26aaaf3b2e22f89d3a2cf8da06d75d46ca2/cython-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:983f9d2bb8a896e16fa68f2b37866ded35fa980195eefe62f764ddc5f9f5ef8e", size = 2791241, upload-time = "2026-01-04T14:15:30.448Z" },
+ { url = "https://files.pythonhosted.org/packages/0a/8b/fd393f0923c82be4ec0db712fffb2ff0a7a131707b842c99bf24b549274d/cython-3.2.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:36bf3f5eb56d5281aafabecbaa6ed288bc11db87547bba4e1e52943ae6961ccf", size = 2875622, upload-time = "2026-01-04T14:15:39.749Z" },
+ { url = "https://files.pythonhosted.org/packages/73/48/48530d9b9d64ec11dbe0dd3178a5fe1e0b27977c1054ecffb82be81e9b6a/cython-3.2.4-cp39-abi3-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6d5267f22b6451eb1e2e1b88f6f78a2c9c8733a6ddefd4520d3968d26b824581", size = 3210669, upload-time = "2026-01-04T14:15:41.911Z" },
+ { url = "https://files.pythonhosted.org/packages/5e/91/4865fbfef1f6bb4f21d79c46104a53d1a3fa4348286237e15eafb26e0828/cython-3.2.4-cp39-abi3-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3b6e58f73a69230218d5381817850ce6d0da5bb7e87eb7d528c7027cbba40b06", size = 2856835, upload-time = "2026-01-04T14:15:43.815Z" },
+ { url = "https://files.pythonhosted.org/packages/fa/39/60317957dbef179572398253f29d28f75f94ab82d6d39ea3237fb6c89268/cython-3.2.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e71efb20048358a6b8ec604a0532961c50c067b5e63e345e2e359fff72feaee8", size = 2994408, upload-time = "2026-01-04T14:15:45.422Z" },
+ { url = "https://files.pythonhosted.org/packages/8d/30/7c24d9292650db4abebce98abc9b49c820d40fa7c87921c0a84c32f4efe7/cython-3.2.4-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:28b1e363b024c4b8dcf52ff68125e635cb9cb4b0ba997d628f25e32543a71103", size = 2891478, upload-time = "2026-01-04T14:15:47.394Z" },
+ { url = "https://files.pythonhosted.org/packages/86/70/03dc3c962cde9da37a93cca8360e576f904d5f9beecfc9d70b1f820d2e5f/cython-3.2.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:31a90b4a2c47bb6d56baeb926948348ec968e932c1ae2c53239164e3e8880ccf", size = 3225663, upload-time = "2026-01-04T14:15:49.446Z" },
+ { url = "https://files.pythonhosted.org/packages/b1/97/10b50c38313c37b1300325e2e53f48ea9a2c078a85c0c9572057135e31d5/cython-3.2.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e65e4773021f8dc8532010b4fbebe782c77f9a0817e93886e518c93bd6a44e9d", size = 3115628, upload-time = "2026-01-04T14:15:51.323Z" },
+ { url = "https://files.pythonhosted.org/packages/8f/b1/d6a353c9b147848122a0db370863601fdf56de2d983b5c4a6a11e6ee3cd7/cython-3.2.4-cp39-abi3-win32.whl", hash = "sha256:2b1f12c0e4798293d2754e73cd6f35fa5bbdf072bdc14bc6fc442c059ef2d290", size = 2437463, upload-time = "2026-01-04T14:15:53.787Z" },
+ { url = "https://files.pythonhosted.org/packages/2d/d8/319a1263b9c33b71343adfd407e5daffd453daef47ebc7b642820a8b68ed/cython-3.2.4-cp39-abi3-win_arm64.whl", hash = "sha256:3b8e62049afef9da931d55de82d8f46c9a147313b69d5ff6af6e9121d545ce7a", size = 2442754, upload-time = "2026-01-04T14:15:55.382Z" },
+ { url = "https://files.pythonhosted.org/packages/ff/fa/d3c15189f7c52aaefbaea76fb012119b04b9013f4bf446cb4eb4c26c4e6b/cython-3.2.4-py3-none-any.whl", hash = "sha256:732fc93bc33ae4b14f6afaca663b916c2fdd5dcbfad7114e17fb2434eeaea45c", size = 1257078, upload-time = "2026-01-04T14:14:12.373Z" },
+]
+
[[package]]
name = "disknext-server"
version = "0.0.1"
@@ -486,6 +512,7 @@ dependencies = [
{ name = "asyncpg" },
{ name = "cachetools" },
{ name = "captcha" },
+ { name = "cryptography" },
{ name = "fastapi", extra = ["standard"] },
{ name = "httpx" },
{ name = "itsdangerous" },
@@ -502,8 +529,16 @@ dependencies = [
{ name = "redis", extra = ["hiredis"] },
{ name = "sqlalchemy" },
{ name = "sqlmodel" },
+ { name = "sqlmodel-ext", extra = ["pgvector"] },
{ name = "uvicorn" },
{ name = "webauthn" },
+ { name = "whatthepatch" },
+]
+
+[package.optional-dependencies]
+build = [
+ { name = "cython" },
+ { name = "setuptools" },
]
[package.metadata]
@@ -515,6 +550,8 @@ requires-dist = [
{ name = "asyncpg", specifier = ">=0.31.0" },
{ name = "cachetools", specifier = ">=6.2.4" },
{ name = "captcha", specifier = ">=0.7.1" },
+ { name = "cryptography", specifier = ">=46.0.3" },
+ { name = "cython", marker = "extra == 'build'", specifier = ">=3.0.11" },
{ name = "fastapi", extras = ["standard"], specifier = ">=0.122.0" },
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
@@ -529,11 +566,15 @@ requires-dist = [
{ name = "python-dotenv", specifier = ">=1.2.1" },
{ name = "python-multipart", specifier = ">=0.0.20" },
{ name = "redis", extras = ["hiredis"], specifier = ">=7.1.0" },
+ { name = "setuptools", marker = "extra == 'build'", specifier = ">=75.0.0" },
{ name = "sqlalchemy", specifier = ">=2.0.44" },
{ name = "sqlmodel", specifier = ">=0.0.27" },
+ { name = "sqlmodel-ext", extras = ["pgvector"], specifier = ">=0.1.1" },
{ name = "uvicorn", specifier = ">=0.38.0" },
{ name = "webauthn", specifier = ">=2.7.0" },
+ { name = "whatthepatch", specifier = ">=1.0.6" },
]
+provides-extras = ["build"]
[[package]]
name = "dnspython"
@@ -1101,6 +1142,56 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" },
]
+[[package]]
+name = "numpy"
+version = "2.4.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/57/fd/0005efbd0af48e55eb3c7208af93f2862d4b1a56cd78e84309a2d959208d/numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae", size = 20723651, upload-time = "2026-01-31T23:13:10.135Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a1/22/815b9fe25d1d7ae7d492152adbc7226d3eff731dffc38fe970589fcaaa38/numpy-2.4.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25f2059807faea4b077a2b6837391b5d830864b3543627f381821c646f31a63c", size = 16663696, upload-time = "2026-01-31T23:11:17.516Z" },
+ { url = "https://files.pythonhosted.org/packages/09/f0/817d03a03f93ba9c6c8993de509277d84e69f9453601915e4a69554102a1/numpy-2.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bd3a7a9f5847d2fb8c2c6d1c862fa109c31a9abeca1a3c2bd5a64572955b2979", size = 14688322, upload-time = "2026-01-31T23:11:19.883Z" },
+ { url = "https://files.pythonhosted.org/packages/da/b4/f805ab79293c728b9a99438775ce51885fd4f31b76178767cfc718701a39/numpy-2.4.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:8e4549f8a3c6d13d55041925e912bfd834285ef1dd64d6bc7d542583355e2e98", size = 5198157, upload-time = "2026-01-31T23:11:22.375Z" },
+ { url = "https://files.pythonhosted.org/packages/74/09/826e4289844eccdcd64aac27d13b0fd3f32039915dd5b9ba01baae1f436c/numpy-2.4.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:aea4f66ff44dfddf8c2cffd66ba6538c5ec67d389285292fe428cb2c738c8aef", size = 6546330, upload-time = "2026-01-31T23:11:23.958Z" },
+ { url = "https://files.pythonhosted.org/packages/19/fb/cbfdbfa3057a10aea5422c558ac57538e6acc87ec1669e666d32ac198da7/numpy-2.4.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3cd545784805de05aafe1dde61752ea49a359ccba9760c1e5d1c88a93bbf2b7", size = 15660968, upload-time = "2026-01-31T23:11:25.713Z" },
+ { url = "https://files.pythonhosted.org/packages/04/dc/46066ce18d01645541f0186877377b9371b8fa8017fa8262002b4ef22612/numpy-2.4.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d0d9b7c93578baafcbc5f0b83eaf17b79d345c6f36917ba0c67f45226911d499", size = 16607311, upload-time = "2026-01-31T23:11:28.117Z" },
+ { url = "https://files.pythonhosted.org/packages/14/d9/4b5adfc39a43fa6bf918c6d544bc60c05236cc2f6339847fc5b35e6cb5b0/numpy-2.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f74f0f7779cc7ae07d1810aab8ac6b1464c3eafb9e283a40da7309d5e6e48fbb", size = 17012850, upload-time = "2026-01-31T23:11:30.888Z" },
+ { url = "https://files.pythonhosted.org/packages/b7/20/adb6e6adde6d0130046e6fdfb7675cc62bc2f6b7b02239a09eb58435753d/numpy-2.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c7ac672d699bf36275c035e16b65539931347d68b70667d28984c9fb34e07fa7", size = 18334210, upload-time = "2026-01-31T23:11:33.214Z" },
+ { url = "https://files.pythonhosted.org/packages/78/0e/0a73b3dff26803a8c02baa76398015ea2a5434d9b8265a7898a6028c1591/numpy-2.4.2-cp313-cp313-win32.whl", hash = "sha256:8e9afaeb0beff068b4d9cd20d322ba0ee1cecfb0b08db145e4ab4dd44a6b5110", size = 5958199, upload-time = "2026-01-31T23:11:35.385Z" },
+ { url = "https://files.pythonhosted.org/packages/43/bc/6352f343522fcb2c04dbaf94cb30cca6fd32c1a750c06ad6231b4293708c/numpy-2.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:7df2de1e4fba69a51c06c28f5a3de36731eb9639feb8e1cf7e4a7b0daf4cf622", size = 12310848, upload-time = "2026-01-31T23:11:38.001Z" },
+ { url = "https://files.pythonhosted.org/packages/6e/8d/6da186483e308da5da1cc6918ce913dcfe14ffde98e710bfeff2a6158d4e/numpy-2.4.2-cp313-cp313-win_arm64.whl", hash = "sha256:0fece1d1f0a89c16b03442eae5c56dc0be0c7883b5d388e0c03f53019a4bfd71", size = 10221082, upload-time = "2026-01-31T23:11:40.392Z" },
+ { url = "https://files.pythonhosted.org/packages/25/a1/9510aa43555b44781968935c7548a8926274f815de42ad3997e9e83680dd/numpy-2.4.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5633c0da313330fd20c484c78cdd3f9b175b55e1a766c4a174230c6b70ad8262", size = 14815866, upload-time = "2026-01-31T23:11:42.495Z" },
+ { url = "https://files.pythonhosted.org/packages/36/30/6bbb5e76631a5ae46e7923dd16ca9d3f1c93cfa8d4ed79a129814a9d8db3/numpy-2.4.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d9f64d786b3b1dd742c946c42d15b07497ed14af1a1f3ce840cce27daa0ce913", size = 5325631, upload-time = "2026-01-31T23:11:44.7Z" },
+ { url = "https://files.pythonhosted.org/packages/46/00/3a490938800c1923b567b3a15cd17896e68052e2145d8662aaf3e1ffc58f/numpy-2.4.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:b21041e8cb6a1eb5312dd1d2f80a94d91efffb7a06b70597d44f1bd2dfc315ab", size = 6646254, upload-time = "2026-01-31T23:11:46.341Z" },
+ { url = "https://files.pythonhosted.org/packages/d3/e9/fac0890149898a9b609caa5af7455a948b544746e4b8fe7c212c8edd71f8/numpy-2.4.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:00ab83c56211a1d7c07c25e3217ea6695e50a3e2f255053686b081dc0b091a82", size = 15720138, upload-time = "2026-01-31T23:11:48.082Z" },
+ { url = "https://files.pythonhosted.org/packages/ea/5c/08887c54e68e1e28df53709f1893ce92932cc6f01f7c3d4dc952f61ffd4e/numpy-2.4.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2fb882da679409066b4603579619341c6d6898fc83a8995199d5249f986e8e8f", size = 16655398, upload-time = "2026-01-31T23:11:50.293Z" },
+ { url = "https://files.pythonhosted.org/packages/4d/89/253db0fa0e66e9129c745e4ef25631dc37d5f1314dad2b53e907b8538e6d/numpy-2.4.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:66cb9422236317f9d44b67b4d18f44efe6e9c7f8794ac0462978513359461554", size = 17079064, upload-time = "2026-01-31T23:11:52.927Z" },
+ { url = "https://files.pythonhosted.org/packages/2a/d5/cbade46ce97c59c6c3da525e8d95b7abe8a42974a1dc5c1d489c10433e88/numpy-2.4.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0f01dcf33e73d80bd8dc0f20a71303abbafa26a19e23f6b68d1aa9990af90257", size = 18379680, upload-time = "2026-01-31T23:11:55.22Z" },
+ { url = "https://files.pythonhosted.org/packages/40/62/48f99ae172a4b63d981babe683685030e8a3df4f246c893ea5c6ef99f018/numpy-2.4.2-cp313-cp313t-win32.whl", hash = "sha256:52b913ec40ff7ae845687b0b34d8d93b60cb66dcee06996dd5c99f2fc9328657", size = 6082433, upload-time = "2026-01-31T23:11:58.096Z" },
+ { url = "https://files.pythonhosted.org/packages/07/38/e054a61cfe48ad9f1ed0d188e78b7e26859d0b60ef21cd9de4897cdb5326/numpy-2.4.2-cp313-cp313t-win_amd64.whl", hash = "sha256:5eea80d908b2c1f91486eb95b3fb6fab187e569ec9752ab7d9333d2e66bf2d6b", size = 12451181, upload-time = "2026-01-31T23:11:59.782Z" },
+ { url = "https://files.pythonhosted.org/packages/6e/a4/a05c3a6418575e185dd84d0b9680b6bb2e2dc3e4202f036b7b4e22d6e9dc/numpy-2.4.2-cp313-cp313t-win_arm64.whl", hash = "sha256:fd49860271d52127d61197bb50b64f58454e9f578cb4b2c001a6de8b1f50b0b1", size = 10290756, upload-time = "2026-01-31T23:12:02.438Z" },
+ { url = "https://files.pythonhosted.org/packages/18/88/b7df6050bf18fdcfb7046286c6535cabbdd2064a3440fca3f069d319c16e/numpy-2.4.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:444be170853f1f9d528428eceb55f12918e4fda5d8805480f36a002f1415e09b", size = 16663092, upload-time = "2026-01-31T23:12:04.521Z" },
+ { url = "https://files.pythonhosted.org/packages/25/7a/1fee4329abc705a469a4afe6e69b1ef7e915117747886327104a8493a955/numpy-2.4.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d1240d50adff70c2a88217698ca844723068533f3f5c5fa6ee2e3220e3bdb000", size = 14698770, upload-time = "2026-01-31T23:12:06.96Z" },
+ { url = "https://files.pythonhosted.org/packages/fb/0b/f9e49ba6c923678ad5bc38181c08ac5e53b7a5754dbca8e581aa1a56b1ff/numpy-2.4.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:7cdde6de52fb6664b00b056341265441192d1291c130e99183ec0d4b110ff8b1", size = 5208562, upload-time = "2026-01-31T23:12:09.632Z" },
+ { url = "https://files.pythonhosted.org/packages/7d/12/d7de8f6f53f9bb76997e5e4c069eda2051e3fe134e9181671c4391677bb2/numpy-2.4.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:cda077c2e5b780200b6b3e09d0b42205a3d1c68f30c6dceb90401c13bff8fe74", size = 6543710, upload-time = "2026-01-31T23:12:11.969Z" },
+ { url = "https://files.pythonhosted.org/packages/09/63/c66418c2e0268a31a4cf8a8b512685748200f8e8e8ec6c507ce14e773529/numpy-2.4.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d30291931c915b2ab5717c2974bb95ee891a1cf22ebc16a8006bd59cd210d40a", size = 15677205, upload-time = "2026-01-31T23:12:14.33Z" },
+ { url = "https://files.pythonhosted.org/packages/5d/6c/7f237821c9642fb2a04d2f1e88b4295677144ca93285fd76eff3bcba858d/numpy-2.4.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bba37bc29d4d85761deed3954a1bc62be7cf462b9510b51d367b769a8c8df325", size = 16611738, upload-time = "2026-01-31T23:12:16.525Z" },
+ { url = "https://files.pythonhosted.org/packages/c2/a7/39c4cdda9f019b609b5c473899d87abff092fc908cfe4d1ecb2fcff453b0/numpy-2.4.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b2f0073ed0868db1dcd86e052d37279eef185b9c8db5bf61f30f46adac63c909", size = 17028888, upload-time = "2026-01-31T23:12:19.306Z" },
+ { url = "https://files.pythonhosted.org/packages/da/b3/e84bb64bdfea967cc10950d71090ec2d84b49bc691df0025dddb7c26e8e3/numpy-2.4.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7f54844851cdb630ceb623dcec4db3240d1ac13d4990532446761baede94996a", size = 18339556, upload-time = "2026-01-31T23:12:21.816Z" },
+ { url = "https://files.pythonhosted.org/packages/88/f5/954a291bc1192a27081706862ac62bb5920fbecfbaa302f64682aa90beed/numpy-2.4.2-cp314-cp314-win32.whl", hash = "sha256:12e26134a0331d8dbd9351620f037ec470b7c75929cb8a1537f6bfe411152a1a", size = 6006899, upload-time = "2026-01-31T23:12:24.14Z" },
+ { url = "https://files.pythonhosted.org/packages/05/cb/eff72a91b2efdd1bc98b3b8759f6a1654aa87612fc86e3d87d6fe4f948c4/numpy-2.4.2-cp314-cp314-win_amd64.whl", hash = "sha256:068cdb2d0d644cdb45670810894f6a0600797a69c05f1ac478e8d31670b8ee75", size = 12443072, upload-time = "2026-01-31T23:12:26.33Z" },
+ { url = "https://files.pythonhosted.org/packages/37/75/62726948db36a56428fce4ba80a115716dc4fad6a3a4352487f8bb950966/numpy-2.4.2-cp314-cp314-win_arm64.whl", hash = "sha256:6ed0be1ee58eef41231a5c943d7d1375f093142702d5723ca2eb07db9b934b05", size = 10494886, upload-time = "2026-01-31T23:12:28.488Z" },
+ { url = "https://files.pythonhosted.org/packages/36/2f/ee93744f1e0661dc267e4b21940870cabfae187c092e1433b77b09b50ac4/numpy-2.4.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:98f16a80e917003a12c0580f97b5f875853ebc33e2eaa4bccfc8201ac6869308", size = 14818567, upload-time = "2026-01-31T23:12:30.709Z" },
+ { url = "https://files.pythonhosted.org/packages/a7/24/6535212add7d76ff938d8bdc654f53f88d35cddedf807a599e180dcb8e66/numpy-2.4.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:20abd069b9cda45874498b245c8015b18ace6de8546bf50dfa8cea1696ed06ef", size = 5328372, upload-time = "2026-01-31T23:12:32.962Z" },
+ { url = "https://files.pythonhosted.org/packages/5e/9d/c48f0a035725f925634bf6b8994253b43f2047f6778a54147d7e213bc5a7/numpy-2.4.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:e98c97502435b53741540a5717a6749ac2ada901056c7db951d33e11c885cc7d", size = 6649306, upload-time = "2026-01-31T23:12:34.797Z" },
+ { url = "https://files.pythonhosted.org/packages/81/05/7c73a9574cd4a53a25907bad38b59ac83919c0ddc8234ec157f344d57d9a/numpy-2.4.2-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da6cad4e82cb893db4b69105c604d805e0c3ce11501a55b5e9f9083b47d2ffe8", size = 15722394, upload-time = "2026-01-31T23:12:36.565Z" },
+ { url = "https://files.pythonhosted.org/packages/35/fa/4de10089f21fc7d18442c4a767ab156b25c2a6eaf187c0db6d9ecdaeb43f/numpy-2.4.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e4424677ce4b47fe73c8b5556d876571f7c6945d264201180db2dc34f676ab5", size = 16653343, upload-time = "2026-01-31T23:12:39.188Z" },
+ { url = "https://files.pythonhosted.org/packages/b8/f9/d33e4ffc857f3763a57aa85650f2e82486832d7492280ac21ba9efda80da/numpy-2.4.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:2b8f157c8a6f20eb657e240f8985cc135598b2b46985c5bccbde7616dc9c6b1e", size = 17078045, upload-time = "2026-01-31T23:12:42.041Z" },
+ { url = "https://files.pythonhosted.org/packages/c8/b8/54bdb43b6225badbea6389fa038c4ef868c44f5890f95dd530a218706da3/numpy-2.4.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5daf6f3914a733336dab21a05cdec343144600e964d2fcdabaac0c0269874b2a", size = 18380024, upload-time = "2026-01-31T23:12:44.331Z" },
+ { url = "https://files.pythonhosted.org/packages/a5/55/6e1a61ded7af8df04016d81b5b02daa59f2ea9252ee0397cb9f631efe9e5/numpy-2.4.2-cp314-cp314t-win32.whl", hash = "sha256:8c50dd1fc8826f5b26a5ee4d77ca55d88a895f4e4819c7ecc2a9f5905047a443", size = 6153937, upload-time = "2026-01-31T23:12:47.229Z" },
+ { url = "https://files.pythonhosted.org/packages/45/aa/fa6118d1ed6d776b0983f3ceac9b1a5558e80df9365b1c3aa6d42bf9eee4/numpy-2.4.2-cp314-cp314t-win_amd64.whl", hash = "sha256:fcf92bee92742edd401ba41135185866f7026c502617f422eb432cfeca4fe236", size = 12631844, upload-time = "2026-01-31T23:12:48.997Z" },
+ { url = "https://files.pythonhosted.org/packages/32/0a/2ec5deea6dcd158f254a7b372fb09cfba5719419c8d66343bab35237b3fb/numpy-2.4.2-cp314-cp314t-win_arm64.whl", hash = "sha256:1f92f53998a17265194018d1cc321b2e96e900ca52d54c7c77837b71b9465181", size = 10565379, upload-time = "2026-01-31T23:12:51.345Z" },
+]
+
[[package]]
name = "orjson"
version = "3.11.7"
@@ -1148,6 +1239,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
+[[package]]
+name = "pgvector"
+version = "0.4.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/25/6c/6d8b4b03b958c02fa8687ec6063c49d952a189f8c91ebbe51e877dfab8f7/pgvector-0.4.2.tar.gz", hash = "sha256:322cac0c1dc5d41c9ecf782bd9991b7966685dee3a00bc873631391ed949513a", size = 31354, upload-time = "2025-12-05T01:07:17.87Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5a/26/6cee8a1ce8c43625ec561aff19df07f9776b7525d9002c86bceb3e0ac970/pgvector-0.4.2-py3-none-any.whl", hash = "sha256:549d45f7a18593783d5eec609ea1684a724ba8405c4cb182a0b2b08aeff04e08", size = 27441, upload-time = "2025-12-05T01:07:16.536Z" },
+]
+
[[package]]
name = "pillow"
version = "12.1.1"
@@ -1648,6 +1751,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/19/8d77f9992e5cbfcaa9133c3bf63b4fbbb051248802e1e803fed5c552fbb2/sentry_sdk-2.48.0-py2.py3-none-any.whl", hash = "sha256:6b12ac256769d41825d9b7518444e57fa35b5642df4c7c5e322af4d2c8721172", size = 414555, upload-time = "2025-12-16T14:55:40.152Z" },
]
+[[package]]
+name = "setuptools"
+version = "82.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/82/f3/748f4d6f65d1756b9ae577f329c951cda23fb900e4de9f70900ced962085/setuptools-82.0.0.tar.gz", hash = "sha256:22e0a2d69474c6ae4feb01951cb69d515ed23728cf96d05513d36e42b62b37cb", size = 1144893, upload-time = "2026-02-08T15:08:40.206Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e1/c6/76dc613121b793286a3f91621d7b75a2b493e0390ddca50f11993eadf192/setuptools-82.0.0-py3-none-any.whl", hash = "sha256:70b18734b607bd1da571d097d236cfcfacaf01de45717d59e6e04b96877532e0", size = 1003468, upload-time = "2026-02-08T15:08:38.723Z" },
+]
+
[[package]]
name = "shellingham"
version = "1.5.4"
@@ -1699,6 +1811,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/3f/f71798ed4cc6784122af8d404df187cfc683db795b3eee542e2df23f94a9/sqlmodel-0.0.29-py3-none-any.whl", hash = "sha256:56e4b3e9ab5825f6cbdd40bc46f90f01c0f73c73f44843a03d494c0845a948cd", size = 29355, upload-time = "2025-12-23T20:59:45.368Z" },
]
+[[package]]
+name = "sqlmodel-ext"
+version = "0.1.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pydantic" },
+ { name = "sqlalchemy" },
+ { name = "sqlmodel" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ff/01/c2b5c4939a0d466fde14fd82fd6bbefa39f69b7b9c63b01368862773ce04/sqlmodel_ext-0.1.1.tar.gz", hash = "sha256:2faa65cd8130f8a2384d2dadfa566de1d6f372ce9b39c46ebd8d2955529870fc", size = 70519, upload-time = "2026-02-07T13:08:14.249Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7d/93/f424feac8ad208bfa541a8e2fe0bb53bcd272735e40b18ee1ad7185cfd1f/sqlmodel_ext-0.1.1-py3-none-any.whl", hash = "sha256:f572abe3070c247ddc6dfbfe83b7b6a70567fbfca0225bef60df167632e6d836", size = 59033, upload-time = "2026-02-07T13:08:19.197Z" },
+]
+
+[package.optional-dependencies]
+pgvector = [
+ { name = "numpy" },
+ { name = "orjson" },
+ { name = "pgvector" },
+]
+
[[package]]
name = "starlette"
version = "0.50.0"
@@ -1898,6 +2031,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" },
]
+[[package]]
+name = "whatthepatch"
+version = "1.0.7"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/06/28/55bc3e107a56fdcf7d5022cb32b8c21d98a9cc2df5cd9f3b93e10419099e/whatthepatch-1.0.7.tar.gz", hash = "sha256:9eefb4ebea5200408e02d413d2b4bc28daea6b78bb4b4d53431af7245f7d7edf", size = 34612, upload-time = "2024-11-16T17:21:22.153Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8e/93/af1d6ccb69ab6b5a00e03fa0cefa563f9862412667776ea15dd4eece3a90/whatthepatch-1.0.7-py3-none-any.whl", hash = "sha256:1b6f655fd31091c001c209529dfaabbabdbad438f5de14e3951266ea0fc6e7ed", size = 11964, upload-time = "2024-11-16T17:21:20.761Z" },
+]
+
[[package]]
name = "win32-setctime"
version = "1.2.0"