feat: migrate ORM base to sqlmodel-ext, add file viewers and WOPI integration
- Migrate SQLModel base classes, mixins, and database management to external sqlmodel-ext package; remove sqlmodels/base/, sqlmodels/mixin/, and sqlmodels/database.py - Add file viewer/editor system with WOPI protocol support for collaborative editing (OnlyOffice, Collabora) - Add enterprise edition license verification module (ee/) - Add Dockerfile multi-stage build with Cython compilation support - Add new dependencies: sqlmodel-ext, cryptography, whatthepatch Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
37
.dockerignore
Normal file
37
.dockerignore
Normal file
@@ -0,0 +1,37 @@
|
||||
.git/
|
||||
.gitignore
|
||||
.github/
|
||||
.idea/
|
||||
.vscode/
|
||||
.venv/
|
||||
.env
|
||||
.env.*
|
||||
.run/
|
||||
.claude/
|
||||
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
tests/
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
coverage.xml
|
||||
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
*.log
|
||||
logs/
|
||||
data/
|
||||
|
||||
Dockerfile
|
||||
.dockerignore
|
||||
|
||||
# Cython 编译产物
|
||||
*.c
|
||||
build/
|
||||
|
||||
# 许可证私钥和工具脚本
|
||||
license_private.pem
|
||||
scripts/
|
||||
14
.github/copilot-instructions.md
vendored
14
.github/copilot-instructions.md
vendored
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
||||
```python
|
||||
import asyncio
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
class CharacterBase(SQLModelBase):
|
||||
name: str
|
||||
"""角色名称"""
|
||||
|
||||
class Character(CharacterBase, UUIDTableBase):
|
||||
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||
"""充血模型:包含数据和业务逻辑"""
|
||||
|
||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import (
|
||||
from sqlmodel_ext import (
|
||||
SQLModelBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
||||
# 2. 定义抽象父类(有表)
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='__polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
||||
# 3. 本地应用导入(从项目根目录的包开始)
|
||||
from dependencies import SessionDep
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.base import UUIDTableBase
|
||||
from sqlmodel_ext import UUIDTableBaseMixin
|
||||
|
||||
# 4. 相对导入(同包内的模块)
|
||||
from .base import BaseClass
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -69,3 +69,13 @@ data/
|
||||
# JB 的运行配置(换设备用不了)
|
||||
.run/
|
||||
.xml
|
||||
|
||||
# 前端构建产物(Docker 构建时复制)
|
||||
statics/
|
||||
|
||||
# Cython 编译产物
|
||||
*.c
|
||||
|
||||
# 许可证密钥(保密)
|
||||
license_private.pem
|
||||
license.key
|
||||
|
||||
14
AGENTS.md
14
AGENTS.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
||||
```python
|
||||
import asyncio
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
class CharacterBase(SQLModelBase):
|
||||
name: str
|
||||
"""角色名称"""
|
||||
|
||||
class Character(CharacterBase, UUIDTableBase):
|
||||
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||
"""充血模型:包含数据和业务逻辑"""
|
||||
|
||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import (
|
||||
from sqlmodel_ext import (
|
||||
SQLModelBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
||||
# 2. 定义抽象父类(有表)
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='__polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
||||
# 3. 本地应用导入(从项目根目录的包开始)
|
||||
from dependencies import SessionDep
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.base import UUIDTableBase
|
||||
from sqlmodel_ext import UUIDTableBaseMixin
|
||||
|
||||
# 4. 相对导入(同包内的模块)
|
||||
from .base import BaseClass
|
||||
|
||||
14
CLAUDE.md
14
CLAUDE.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
||||
```python
|
||||
import asyncio
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
class CharacterBase(SQLModelBase):
|
||||
name: str
|
||||
"""角色名称"""
|
||||
|
||||
class Character(CharacterBase, UUIDTableBase):
|
||||
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||
"""充血模型:包含数据和业务逻辑"""
|
||||
|
||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import (
|
||||
from sqlmodel_ext import (
|
||||
SQLModelBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
||||
# 2. 定义抽象父类(有表)
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='__polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
||||
# 3. 本地应用导入(从项目根目录的包开始)
|
||||
from dependencies import SessionDep
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.base import UUIDTableBase
|
||||
from sqlmodel_ext import UUIDTableBaseMixin
|
||||
|
||||
# 4. 相对导入(同包内的模块)
|
||||
from .base import BaseClass
|
||||
|
||||
47
Dockerfile
47
Dockerfile
@@ -1,13 +1,52 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
||||
# ============================================================
|
||||
# 基础层:安装运行时依赖
|
||||
# ============================================================
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 5213
|
||||
# ============================================================
|
||||
# Community 版本:删除 ee/ 目录
|
||||
# ============================================================
|
||||
FROM base AS community
|
||||
|
||||
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
||||
RUN rm -rf ee/
|
||||
COPY statics/ /app/statics/
|
||||
|
||||
EXPOSE 5213
|
||||
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
||||
|
||||
# ============================================================
|
||||
# Pro 编译层:Cython 编译 ee/ 模块
|
||||
# ============================================================
|
||||
FROM base AS pro-builder
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc libc6-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN uv sync --frozen --no-dev --extra build
|
||||
|
||||
RUN uv run python setup_cython.py build_ext --inplace && \
|
||||
uv run python setup_cython.py clean_artifacts
|
||||
|
||||
# ============================================================
|
||||
# Pro 版本:包含编译后的 ee/ 模块(仅 __init__.py + .so)
|
||||
# ============================================================
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS pro
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
COPY --from=pro-builder /app/ /app/
|
||||
COPY statics/ /app/statics/
|
||||
|
||||
EXPOSE 5213
|
||||
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
||||
|
||||
594
docs/file-viewer-api.md
Normal file
594
docs/file-viewer-api.md
Normal file
@@ -0,0 +1,594 @@
|
||||
# 文件预览应用选择器 — 前端适配文档
|
||||
|
||||
## 概述
|
||||
|
||||
文件预览系统类似 Android 的"使用什么应用打开"机制:用户点击文件时,前端根据扩展名查询可用查看器列表,展示选择弹窗,用户可选"仅此一次"或"始终使用"。
|
||||
|
||||
### 应用类型
|
||||
|
||||
| type | 说明 | 前端处理方式 |
|
||||
|------|------|-------------|
|
||||
| `builtin` | 前端内置组件 | 根据 `app_key` 路由到内置组件(如 `pdfjs`、`monaco`) |
|
||||
| `iframe` | iframe 内嵌 | 将 `iframe_url_template` 中的 `{file_url}` 替换为文件下载 URL,嵌入 iframe |
|
||||
| `wopi` | WOPI 协议 | 调用 `/file/{id}/wopi-session` 获取 `editor_url`,嵌入 iframe |
|
||||
|
||||
### 内置 app_key 映射
|
||||
|
||||
前端需要为以下 `app_key` 实现对应的内置预览组件:
|
||||
|
||||
| app_key | 组件 | 说明 |
|
||||
|---------|------|------|
|
||||
| `pdfjs` | PDF.js 阅读器 | pdf |
|
||||
| `monaco` | Monaco Editor | txt, md, json, py, js, ts, html, css, ... |
|
||||
| `markdown` | Markdown 渲染器 | md, markdown, mdx |
|
||||
| `image_viewer` | 图片查看器 | jpg, png, gif, webp, svg, ... |
|
||||
| `video_player` | HTML5 Video | mp4, webm, ogg, mov, mkv, m3u8 |
|
||||
| `audio_player` | HTML5 Audio | mp3, wav, flac, aac, m4a, opus |
|
||||
|
||||
> `office_viewer`(iframe)、`collabora`(wopi)、`onlyoffice`(wopi)默认禁用,需管理员在后台启用和配置。
|
||||
|
||||
---
|
||||
|
||||
## 文件下载 URL 与 iframe 预览
|
||||
|
||||
### 现有下载流程(两步式)
|
||||
|
||||
```
|
||||
步骤1: POST /api/v1/file/download/{file_id} → { access_token, expires_in }
|
||||
步骤2: GET /api/v1/file/download/{access_token} → 文件二进制流
|
||||
```
|
||||
|
||||
- 步骤 1 需要 JWT 认证,返回一个下载令牌(有效期 1 小时)
|
||||
- 步骤 2 **不需要认证**,用令牌直接下载,**令牌为一次性**,下载后失效
|
||||
|
||||
### 各类型查看器获取文件内容的方式
|
||||
|
||||
| type | 获取文件方式 | 说明 |
|
||||
|------|-------------|------|
|
||||
| `builtin` | 前端自行获取 | 前端用 JS 调用下载接口拿到 Blob/ArrayBuffer,传给内置组件渲染 |
|
||||
| `iframe` | 需要公开可访问的 URL | 第三方服务(如 Office Online)会**从服务端拉取文件** |
|
||||
| `wopi` | WOPI 协议自动处理 | 编辑器通过 `/wopi/files/{id}/contents` 获取,前端只需嵌入 `editor_url` |
|
||||
|
||||
### builtin 类型 — 前端自行获取
|
||||
|
||||
内置组件(pdfjs、monaco 等)运行在前端,直接用 JS 获取文件内容即可:
|
||||
|
||||
```typescript
|
||||
// 方式 A:用下载令牌拼 URL(适用于 PDF.js 等需要 URL 的组件)
|
||||
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||
const fileUrl = `${baseUrl}/api/v1/file/download/${access_token}`
|
||||
// 传给 PDF.js: pdfjsLib.getDocument(fileUrl)
|
||||
|
||||
// 方式 B:用 fetch + Authorization 头获取 Blob(适用于需要 ArrayBuffer 的组件)
|
||||
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||
const blob = await fetch(`${baseUrl}/api/v1/file/download/${access_token}`).then(r => r.blob())
|
||||
// 传给 Monaco: monaco.editor.create(el, { value: await blob.text() })
|
||||
```
|
||||
|
||||
### iframe 类型 — `{file_url}` 替换规则
|
||||
|
||||
`iframe_url_template` 中的 `{file_url}` 需要替换为一个**外部可访问的文件直链**。
|
||||
|
||||
**问题**:当前下载令牌是一次性的,而 Office Online 等服务可能多次请求该 URL。
|
||||
|
||||
**当前可行方案**:
|
||||
|
||||
```typescript
|
||||
// 1. 创建下载令牌
|
||||
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||
|
||||
// 2. 拼出完整的文件 URL(必须是公网可达的地址)
|
||||
const fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||
|
||||
// 3. 替换模板
|
||||
const iframeSrc = viewer.iframe_url_template.replace(
|
||||
'{file_url}',
|
||||
encodeURIComponent(fileUrl)
|
||||
)
|
||||
|
||||
// 4. 嵌入 iframe
|
||||
// <iframe src={iframeSrc} />
|
||||
```
|
||||
|
||||
> **已知限制**:下载令牌为一次性使用。如果第三方服务多次拉取文件(如 Office Online 可能重试),
|
||||
> 第二次请求会 404。后续版本将实现 `/file/get/{id}/{name}` 外链端点(多次可用),届时
|
||||
> iframe 应改用外链 URL。目前建议:
|
||||
>
|
||||
> 1. **优先使用 WOPI 类型**(Collabora/OnlyOffice),不存在此限制
|
||||
> 2. Office Online 预览在**文件较小**时通常只拉取一次,大多数场景可用
|
||||
> 3. 如需稳定方案,可等待外链端点实现后再启用 iframe 类型应用
|
||||
|
||||
### wopi 类型 — 无需关心文件 URL
|
||||
|
||||
WOPI 类型的查看器完全由后端处理文件传输,前端只需:
|
||||
|
||||
```typescript
|
||||
// 1. 创建 WOPI 会话
|
||||
const session = await api.post(`/file/${fileId}/wopi-session`)
|
||||
|
||||
// 2. 直接嵌入编辑器
|
||||
// <iframe src={session.editor_url} />
|
||||
```
|
||||
|
||||
编辑器(Collabora/OnlyOffice)会通过 WOPI 协议自动从 `/wopi/files/{id}/contents` 获取文件内容,使用 `access_token` 认证,前端无需干预。
|
||||
|
||||
---
|
||||
|
||||
## 用户端 API
|
||||
|
||||
### 1. 查询可用查看器
|
||||
|
||||
用户点击文件时调用,获取该扩展名的可用查看器列表。
|
||||
|
||||
```
|
||||
GET /api/v1/file/viewers?ext={extension}
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**Query 参数**
|
||||
|
||||
| 参数 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| ext | string | 是 | 文件扩展名,最长 20 字符。支持带点号(`.pdf`)、大写(`PDF`),后端会自动规范化 |
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"viewers": [
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": "基于 pdf.js 的 PDF 在线阅读器",
|
||||
"iframe_url_template": null,
|
||||
"wopi_editor_url_template": null
|
||||
}
|
||||
],
|
||||
"default_viewer_id": null
|
||||
}
|
||||
```
|
||||
|
||||
**字段说明**
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| viewers | FileAppSummary[] | 可用查看器列表,已按优先级排序 |
|
||||
| default_viewer_id | string \| null | 用户设置的"始终使用"查看器 UUID,未设置则为 null |
|
||||
|
||||
**FileAppSummary**
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | UUID | 应用 UUID |
|
||||
| name | string | 应用显示名称 |
|
||||
| app_key | string | 应用唯一标识,前端路由用 |
|
||||
| type | `"builtin"` \| `"iframe"` \| `"wopi"` | 应用类型 |
|
||||
| icon | string \| null | 图标名称(可映射到 icon library) |
|
||||
| description | string \| null | 应用描述 |
|
||||
| iframe_url_template | string \| null | iframe 类型专用,URL 模板含 `{file_url}` 占位符 |
|
||||
| wopi_editor_url_template | string \| null | wopi 类型专用,编辑器 URL 模板 |
|
||||
|
||||
---
|
||||
|
||||
### 2. 设置默认查看器("始终使用")
|
||||
|
||||
用户在选择弹窗中勾选"始终使用此应用"时调用。
|
||||
|
||||
```
|
||||
PUT /api/v1/user/settings/file-viewers/default
|
||||
Authorization: Bearer {token}
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
**请求体**
|
||||
|
||||
```json
|
||||
{
|
||||
"extension": "pdf",
|
||||
"app_id": "550e8400-e29b-41d4-a716-446655440000"
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| extension | string | 是 | 文件扩展名(小写,无点号) |
|
||||
| app_id | UUID | 是 | 选择的查看器应用 UUID |
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "660e8400-e29b-41d4-a716-446655440001",
|
||||
"extension": "pdf",
|
||||
"app": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": "基于 pdf.js 的 PDF 在线阅读器",
|
||||
"iframe_url_template": null,
|
||||
"wopi_editor_url_template": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**错误码**
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 400 | 该应用不支持此扩展名 |
|
||||
| 404 | 应用不存在 |
|
||||
|
||||
> 同一扩展名只允许一个默认值。重复 PUT 同一 extension 会更新(upsert),不会冲突。
|
||||
|
||||
---
|
||||
|
||||
### 3. 列出所有默认查看器设置
|
||||
|
||||
用于用户设置页展示"已设为始终使用"的列表。
|
||||
|
||||
```
|
||||
GET /api/v1/user/settings/file-viewers/defaults
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "660e8400-e29b-41d4-a716-446655440001",
|
||||
"extension": "pdf",
|
||||
"app": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": null,
|
||||
"iframe_url_template": null,
|
||||
"wopi_editor_url_template": null
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 撤销默认查看器设置
|
||||
|
||||
用户在设置页点击"取消始终使用"时调用。
|
||||
|
||||
```
|
||||
DELETE /api/v1/user/settings/file-viewers/default/{id}
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**响应** 204 No Content
|
||||
|
||||
**错误码**
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 404 | 记录不存在或不属于当前用户 |
|
||||
|
||||
---
|
||||
|
||||
### 5. 创建 WOPI 会话
|
||||
|
||||
打开 WOPI 类型应用(如 Collabora、OnlyOffice)时调用。
|
||||
|
||||
```
|
||||
POST /api/v1/file/{file_id}/wopi-session
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"wopi_src": "http://localhost:8000/wopi/files/770e8400-e29b-41d4-a716-446655440002",
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIs...",
|
||||
"access_token_ttl": 1739577600000,
|
||||
"editor_url": "http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc=...&access_token=...&access_token_ttl=..."
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| wopi_src | string | WOPI 源 URL(传给编辑器) |
|
||||
| access_token | string | WOPI 访问令牌 |
|
||||
| access_token_ttl | int | 令牌过期毫秒时间戳 |
|
||||
| editor_url | string | 完整的编辑器 URL,**直接嵌入 iframe 即可** |
|
||||
|
||||
**错误码**
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 400 | 文件无扩展名 / WOPI 应用未配置编辑器 URL |
|
||||
| 403 | 用户组无权限 |
|
||||
| 404 | 文件不存在 / 无可用 WOPI 查看器 |
|
||||
|
||||
---
|
||||
|
||||
## 前端交互流程
|
||||
|
||||
### 打开文件预览
|
||||
|
||||
```
|
||||
用户点击文件
|
||||
│
|
||||
▼
|
||||
GET /file/viewers?ext={扩展名}
|
||||
│
|
||||
├── viewers 为空 → 提示"暂无可用的预览方式"
|
||||
│
|
||||
├── default_viewer_id 不为空 → 直接用对应 viewer 打开(跳过选择弹窗)
|
||||
│
|
||||
└── viewers.length == 1 → 直接用唯一 viewer 打开(可选策略)
|
||||
│
|
||||
└── viewers.length > 1 → 展示选择弹窗
|
||||
│
|
||||
├── 用户选择 + 不勾选"始终使用" → 仅此一次打开
|
||||
│
|
||||
└── 用户选择 + 勾选"始终使用" → PUT /user/settings/file-viewers/default
|
||||
│
|
||||
└── 然后打开
|
||||
```
|
||||
|
||||
### 根据 type 打开查看器
|
||||
|
||||
```
|
||||
获取到 viewer 对象
|
||||
│
|
||||
├── type == "builtin"
|
||||
│ └── 根据 app_key 路由到内置组件
|
||||
│ switch(app_key):
|
||||
│ "pdfjs" → <PdfViewer />
|
||||
│ "monaco" → <CodeEditor />
|
||||
│ "markdown" → <MarkdownPreview />
|
||||
│ "image_viewer" → <ImageViewer />
|
||||
│ "video_player" → <VideoPlayer />
|
||||
│ "audio_player" → <AudioPlayer />
|
||||
│
|
||||
│ 获取文件内容:
|
||||
│ POST /file/download/{file_id} → { access_token }
|
||||
│ fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||
│ → 传 URL 或 fetch Blob 给内置组件
|
||||
│
|
||||
├── type == "iframe"
|
||||
│ └── 1. POST /file/download/{file_id} → { access_token }
|
||||
│ 2. fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||
│ 3. iframeSrc = viewer.iframe_url_template
|
||||
│ .replace("{file_url}", encodeURIComponent(fileUrl))
|
||||
│ 4. <iframe src={iframeSrc} />
|
||||
│
|
||||
└── type == "wopi"
|
||||
└── 1. POST /file/{file_id}/wopi-session → { editor_url }
|
||||
2. <iframe src={editor_url} />
|
||||
(编辑器自动通过 WOPI 协议获取文件,前端无需处理)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 管理员 API
|
||||
|
||||
所有管理端点需要管理员身份(JWT 中 group.admin == true)。
|
||||
|
||||
### 1. 列出所有文件应用
|
||||
|
||||
```
|
||||
GET /api/v1/admin/file-app/list?page=1&page_size=20
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"apps": [
|
||||
{
|
||||
"id": "...",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": "...",
|
||||
"is_enabled": true,
|
||||
"is_restricted": false,
|
||||
"iframe_url_template": null,
|
||||
"wopi_discovery_url": null,
|
||||
"wopi_editor_url_template": null,
|
||||
"extensions": ["pdf"],
|
||||
"allowed_group_ids": []
|
||||
}
|
||||
],
|
||||
"total": 9
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 创建文件应用
|
||||
|
||||
```
|
||||
POST /api/v1/admin/file-app/
|
||||
Authorization: Bearer {admin_token}
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "自定义查看器",
|
||||
"app_key": "my_viewer",
|
||||
"type": "iframe",
|
||||
"description": "自定义 iframe 查看器",
|
||||
"is_enabled": true,
|
||||
"is_restricted": false,
|
||||
"iframe_url_template": "https://example.com/view?url={file_url}",
|
||||
"extensions": ["pdf", "docx"],
|
||||
"allowed_group_ids": []
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 201 — 返回 FileAppResponse(同列表中的单项)
|
||||
|
||||
**错误码**: 409 — app_key 已存在
|
||||
|
||||
### 3. 获取应用详情
|
||||
|
||||
```
|
||||
GET /api/v1/admin/file-app/{id}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
### 4. 更新应用
|
||||
|
||||
```
|
||||
PATCH /api/v1/admin/file-app/{id}
|
||||
```
|
||||
|
||||
只传需要更新的字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "新名称",
|
||||
"is_enabled": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
### 5. 删除应用
|
||||
|
||||
```
|
||||
DELETE /api/v1/admin/file-app/{id}
|
||||
```
|
||||
|
||||
**响应** 204 No Content(级联删除扩展名关联、用户偏好、用户组关联)
|
||||
|
||||
### 6. 全量替换扩展名列表
|
||||
|
||||
```
|
||||
PUT /api/v1/admin/file-app/{id}/extensions
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"extensions": ["doc", "docx", "odt"]
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
### 7. 全量替换允许的用户组
|
||||
|
||||
```
|
||||
PUT /api/v1/admin/file-app/{id}/groups
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"group_ids": ["uuid-1", "uuid-2"]
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
> `is_restricted` 为 `true` 时,只有 `allowed_group_ids` 中的用户组成员能看到此应用。`is_restricted` 为 `false` 时所有用户可见,`allowed_group_ids` 不生效。
|
||||
|
||||
---
|
||||
|
||||
## TypeScript 类型参考
|
||||
|
||||
```typescript
|
||||
type FileAppType = 'builtin' | 'iframe' | 'wopi'
|
||||
|
||||
interface FileAppSummary {
|
||||
id: string
|
||||
name: string
|
||||
app_key: string
|
||||
type: FileAppType
|
||||
icon: string | null
|
||||
description: string | null
|
||||
iframe_url_template: string | null
|
||||
wopi_editor_url_template: string | null
|
||||
}
|
||||
|
||||
interface FileViewersResponse {
|
||||
viewers: FileAppSummary[]
|
||||
default_viewer_id: string | null
|
||||
}
|
||||
|
||||
interface SetDefaultViewerRequest {
|
||||
extension: string
|
||||
app_id: string
|
||||
}
|
||||
|
||||
interface UserFileAppDefaultResponse {
|
||||
id: string
|
||||
extension: string
|
||||
app: FileAppSummary
|
||||
}
|
||||
|
||||
interface WopiSessionResponse {
|
||||
wopi_src: string
|
||||
access_token: string
|
||||
access_token_ttl: number
|
||||
editor_url: string
|
||||
}
|
||||
|
||||
// ========== 管理员类型 ==========
|
||||
|
||||
interface FileAppResponse {
|
||||
id: string
|
||||
name: string
|
||||
app_key: string
|
||||
type: FileAppType
|
||||
icon: string | null
|
||||
description: string | null
|
||||
is_enabled: boolean
|
||||
is_restricted: boolean
|
||||
iframe_url_template: string | null
|
||||
wopi_discovery_url: string | null
|
||||
wopi_editor_url_template: string | null
|
||||
extensions: string[]
|
||||
allowed_group_ids: string[]
|
||||
}
|
||||
|
||||
interface FileAppListResponse {
|
||||
apps: FileAppResponse[]
|
||||
total: number
|
||||
}
|
||||
|
||||
interface FileAppCreateRequest {
|
||||
name: string
|
||||
app_key: string
|
||||
type: FileAppType
|
||||
icon?: string
|
||||
description?: string
|
||||
is_enabled?: boolean // default: true
|
||||
is_restricted?: boolean // default: false
|
||||
iframe_url_template?: string
|
||||
wopi_discovery_url?: string
|
||||
wopi_editor_url_template?: string
|
||||
extensions?: string[] // default: []
|
||||
allowed_group_ids?: string[] // default: []
|
||||
}
|
||||
|
||||
interface FileAppUpdateRequest {
|
||||
name?: string
|
||||
app_key?: string
|
||||
type?: FileAppType
|
||||
icon?: string
|
||||
description?: string
|
||||
is_enabled?: boolean
|
||||
is_restricted?: boolean
|
||||
iframe_url_template?: string
|
||||
wopi_discovery_url?: string
|
||||
wopi_editor_url_template?: string
|
||||
}
|
||||
```
|
||||
242
docs/text-editor-api.md
Normal file
242
docs/text-editor-api.md
Normal file
@@ -0,0 +1,242 @@
|
||||
# 文本文件在线编辑 — 前端适配文档
|
||||
|
||||
## 概述
|
||||
|
||||
Monaco Editor 打开文本文件时,通过 GET 获取内容和哈希作为编辑基线;保存时用 jsdiff 计算 unified diff,仅发送差异部分,后端验证无并发冲突后应用 patch。
|
||||
|
||||
```
|
||||
打开文件: GET /api/v1/file/content/{file_id} → { content, hash, size }
|
||||
保存文件: PATCH /api/v1/file/content/{file_id} ← { patch, base_hash }
|
||||
→ { new_hash, new_size }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 约定
|
||||
|
||||
| 项目 | 约定 |
|
||||
|------|------|
|
||||
| 编码 | 全程 UTF-8 |
|
||||
| 换行符 | 后端 GET 时统一规范化为 `\n`,前端无需处理 `\r\n` |
|
||||
| hash 算法 | SHA-256,hex 编码(64 字符),基于 UTF-8 bytes 计算 |
|
||||
| diff 格式 | jsdiff `createPatch()` 输出的标准 unified diff |
|
||||
| 空 diff | 前端自行判断,内容未变时不发请求 |
|
||||
|
||||
---
|
||||
|
||||
## GET /api/v1/file/content/{file_id}
|
||||
|
||||
获取文本文件内容。
|
||||
|
||||
### 请求
|
||||
|
||||
```
|
||||
GET /api/v1/file/content/{file_id}
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
### 响应 200
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "line1\nline2\nline3\n",
|
||||
"hash": "a1b2c3d4...(64字符 SHA-256 hex)",
|
||||
"size": 18
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `content` | string | 文件文本内容,换行符已规范化为 `\n` |
|
||||
| `hash` | string | 基于规范化内容 UTF-8 bytes 的 SHA-256 hex |
|
||||
| `size` | number | 规范化后的字节大小 |
|
||||
|
||||
### 错误
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 400 | 文件不是有效的 UTF-8 文本(二进制文件) |
|
||||
| 401 | 未认证 |
|
||||
| 404 | 文件不存在 |
|
||||
|
||||
---
|
||||
|
||||
## PATCH /api/v1/file/content/{file_id}
|
||||
|
||||
增量保存文本文件。
|
||||
|
||||
### 请求
|
||||
|
||||
```
|
||||
PATCH /api/v1/file/content/{file_id}
|
||||
Authorization: Bearer <token>
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"patch": "--- a\n+++ b\n@@ -1,3 +1,3 @@\n line1\n-line2\n+LINE2\n line3\n",
|
||||
"base_hash": "a1b2c3d4...(GET 返回的 hash)"
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `patch` | string | jsdiff `createPatch()` 生成的 unified diff |
|
||||
| `base_hash` | string | 编辑前 GET 返回的 `hash` 值 |
|
||||
|
||||
### 响应 200
|
||||
|
||||
```json
|
||||
{
|
||||
"new_hash": "e5f6a7b8...(64字符)",
|
||||
"new_size": 18
|
||||
}
|
||||
```
|
||||
|
||||
保存成功后,前端应将 `new_hash` 作为新的 `base_hash`,用于下次保存。
|
||||
|
||||
### 错误
|
||||
|
||||
| 状态码 | 说明 | 前端处理 |
|
||||
|--------|------|----------|
|
||||
| 401 | 未认证 | — |
|
||||
| 404 | 文件不存在 | — |
|
||||
| 409 | `base_hash` 不匹配(并发冲突) | 提示用户刷新,重新加载内容 |
|
||||
| 422 | patch 格式无效或应用失败 | 回退到全量保存或提示用户 |
|
||||
|
||||
---
|
||||
|
||||
## 前端实现参考
|
||||
|
||||
### 依赖
|
||||
|
||||
```bash
|
||||
npm install jsdiff
|
||||
```
|
||||
|
||||
### 计算 hash
|
||||
|
||||
```typescript
|
||||
async function sha256(text: string): Promise<string> {
|
||||
const bytes = new TextEncoder().encode(text);
|
||||
const hashBuffer = await crypto.subtle.digest("SHA-256", bytes);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
return hashArray.map(b => b.toString(16).padStart(2, "0")).join("");
|
||||
}
|
||||
```
|
||||
|
||||
### 打开文件
|
||||
|
||||
```typescript
|
||||
interface TextContent {
|
||||
content: string;
|
||||
hash: string;
|
||||
size: number;
|
||||
}
|
||||
|
||||
async function openFile(fileId: string): Promise<TextContent> {
|
||||
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
});
|
||||
|
||||
if (!resp.ok) {
|
||||
if (resp.status === 400) throw new Error("该文件不是文本文件");
|
||||
throw new Error("获取文件内容失败");
|
||||
}
|
||||
|
||||
return resp.json();
|
||||
}
|
||||
```
|
||||
|
||||
### 保存文件
|
||||
|
||||
```typescript
|
||||
import { createPatch } from "diff";
|
||||
|
||||
interface PatchResult {
|
||||
new_hash: string;
|
||||
new_size: number;
|
||||
}
|
||||
|
||||
async function saveFile(
|
||||
fileId: string,
|
||||
originalContent: string,
|
||||
currentContent: string,
|
||||
baseHash: string,
|
||||
): Promise<PatchResult | null> {
|
||||
// 内容未变,不发请求
|
||||
if (originalContent === currentContent) return null;
|
||||
|
||||
const patch = createPatch("file", originalContent, currentContent);
|
||||
|
||||
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ patch, base_hash: baseHash }),
|
||||
});
|
||||
|
||||
if (resp.status === 409) {
|
||||
// 并发冲突,需要用户决策
|
||||
throw new Error("CONFLICT");
|
||||
}
|
||||
|
||||
if (!resp.ok) throw new Error("保存失败");
|
||||
|
||||
return resp.json();
|
||||
}
|
||||
```
|
||||
|
||||
### 完整编辑流程
|
||||
|
||||
```typescript
|
||||
// 1. 打开
|
||||
const file = await openFile(fileId);
|
||||
let baseContent = file.content;
|
||||
let baseHash = file.hash;
|
||||
|
||||
// 2. 用户在 Monaco 中编辑...
|
||||
editor.setValue(baseContent);
|
||||
|
||||
// 3. 保存(Ctrl+S)
|
||||
const currentContent = editor.getValue();
|
||||
const result = await saveFile(fileId, baseContent, currentContent, baseHash);
|
||||
|
||||
if (result) {
|
||||
// 更新基线
|
||||
baseContent = currentContent;
|
||||
baseHash = result.new_hash;
|
||||
}
|
||||
```
|
||||
|
||||
### 冲突处理建议
|
||||
|
||||
当 PATCH 返回 409 时,说明文件已被其他会话修改:
|
||||
|
||||
```typescript
|
||||
try {
|
||||
await saveFile(fileId, baseContent, currentContent, baseHash);
|
||||
} catch (e) {
|
||||
if (e.message === "CONFLICT") {
|
||||
// 方案 A:提示用户,提供"覆盖"和"放弃"选项
|
||||
// 方案 B:重新 GET 最新内容,展示 diff 让用户合并
|
||||
const latest = await openFile(fileId);
|
||||
// 展示合并 UI...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## hash 一致性验证
|
||||
|
||||
前端可以在 GET 后本地验证 hash,确保传输无误:
|
||||
|
||||
```typescript
|
||||
const file = await openFile(fileId);
|
||||
const localHash = await sha256(file.content);
|
||||
console.assert(localHash === file.hash, "hash 不一致,内容可能损坏");
|
||||
```
|
||||
14
license_public.pem
Normal file
14
license_public.pem
Normal file
@@ -0,0 +1,14 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
|
||||
oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
|
||||
eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
|
||||
VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
|
||||
vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
|
||||
TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
|
||||
l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
|
||||
7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
|
||||
EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
|
||||
CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
|
||||
enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
|
||||
R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
|
||||
-----END PUBLIC KEY-----
|
||||
61
main.py
61
main.py
@@ -1,16 +1,40 @@
|
||||
from pathlib import Path
|
||||
from typing import NoReturn
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from loguru import logger as l
|
||||
|
||||
from utils.conf import appmeta
|
||||
from utils.http.http_exceptions import raise_internal_error
|
||||
from utils.lifespan import lifespan
|
||||
from routers import router
|
||||
from service.redis import RedisManager
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.migration import migration
|
||||
from utils import JWT
|
||||
from routers import router
|
||||
from service.redis import RedisManager
|
||||
from loguru import logger as l
|
||||
from utils.conf import appmeta
|
||||
from utils.http.http_exceptions import raise_internal_error
|
||||
from utils.lifespan import lifespan
|
||||
|
||||
# 尝试加载企业版功能
|
||||
try:
|
||||
from ee import init_ee
|
||||
from ee.license import LicenseError
|
||||
|
||||
async def _init_ee_and_routes() -> None:
|
||||
try:
|
||||
await init_ee()
|
||||
except LicenseError as exc:
|
||||
l.critical(f"许可证验证失败: {exc}")
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
from ee.routers import ee_router
|
||||
from routers.api.v1 import router as v1_router
|
||||
v1_router.include_router(ee_router)
|
||||
|
||||
lifespan.add_startup(_init_ee_and_routes)
|
||||
except ImportError:
|
||||
l.info("以 Community 版本运行")
|
||||
|
||||
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
|
||||
"""前端静态文件目录(由 Docker 构建时复制)"""
|
||||
|
||||
async def _init_db() -> None:
|
||||
"""初始化数据库连接引擎"""
|
||||
@@ -64,6 +88,31 @@ async def handle_unexpected_exceptions(
|
||||
# 挂载路由
|
||||
app.include_router(router)
|
||||
|
||||
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
|
||||
if STATICS_DIR.is_dir():
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
_assets_dir: Path = STATICS_DIR / "assets"
|
||||
if _assets_dir.is_dir():
|
||||
app.mount("/assets", StaticFiles(directory=_assets_dir), name="assets")
|
||||
|
||||
@app.get("/{path:path}")
|
||||
async def spa_fallback(path: str) -> FileResponse:
|
||||
"""
|
||||
SPA fallback 路由
|
||||
|
||||
优先级:API 路由 > /assets 静态挂载 > 此 catch-all 路由。
|
||||
若请求路径对应 statics/ 下的真实文件则直接返回,否则返回 index.html。
|
||||
"""
|
||||
file_path: Path = (STATICS_DIR / path).resolve()
|
||||
# 防止路径穿越
|
||||
if file_path.is_relative_to(STATICS_DIR) and path and file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
return FileResponse(STATICS_DIR / "index.html")
|
||||
|
||||
l.info(f"前端静态文件已挂载: {STATICS_DIR}")
|
||||
|
||||
# 防止直接运行 main.py
|
||||
if __name__ == "__main__":
|
||||
l.error("请用 fastapi ['dev', 'run'] 命令启动服务")
|
||||
|
||||
@@ -17,7 +17,7 @@ from fastapi import Depends, Form, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.mixin import TimeFilterRequest, TableViewRequest
|
||||
from sqlmodel_ext import TimeFilterRequest, TableViewRequest
|
||||
from sqlmodels.user import UserFilterParams, UserStatus
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ dependencies = [
|
||||
"asyncpg>=0.31.0",
|
||||
"cachetools>=6.2.4",
|
||||
"captcha>=0.7.1",
|
||||
"cryptography>=46.0.3",
|
||||
"fastapi[standard]>=0.122.0",
|
||||
"httpx>=0.27.0",
|
||||
"itsdangerous>=2.2.0",
|
||||
@@ -28,8 +29,16 @@ dependencies = [
|
||||
"redis[hiredis]>=7.1.0",
|
||||
"sqlalchemy>=2.0.44",
|
||||
"sqlmodel>=0.0.27",
|
||||
"sqlmodel-ext[pgvector]>=0.1.1",
|
||||
"uvicorn>=0.38.0",
|
||||
"webauthn>=2.7.0",
|
||||
"whatthepatch>=1.0.6",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
build = [
|
||||
"cython>=3.0.11",
|
||||
"setuptools>=75.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .api import router as api_router
|
||||
from .wopi import wopi_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(api_router)
|
||||
router.include_router(api_router)
|
||||
router.include_router(wopi_router)
|
||||
@@ -9,7 +9,7 @@ from sqlmodels import (
|
||||
User, ResponseBase,
|
||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||
)
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from sqlmodels.setting import (
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
@@ -17,6 +17,7 @@ from sqlmodels.setting import SettingsType
|
||||
from utils import http_exceptions
|
||||
from utils.conf import appmeta
|
||||
from .file import admin_file_router
|
||||
from .file_app import admin_file_app_router
|
||||
from .group import admin_group_router
|
||||
from .policy import admin_policy_router
|
||||
from .share import admin_share_router
|
||||
@@ -44,6 +45,7 @@ admin_router = APIRouter(
|
||||
admin_router.include_router(admin_group_router)
|
||||
admin_router.include_router(admin_user_router)
|
||||
admin_router.include_router(admin_file_router)
|
||||
admin_router.include_router(admin_file_app_router)
|
||||
admin_router.include_router(admin_policy_router)
|
||||
admin_router.include_router(admin_share_router)
|
||||
admin_router.include_router(admin_task_router)
|
||||
|
||||
348
routers/api/v1/admin/file_app/__init__.py
Normal file
348
routers/api/v1/admin/file_app/__init__.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
管理员文件应用管理端点
|
||||
|
||||
提供文件查看器应用的 CRUD、扩展名管理和用户组权限管理。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import select
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
FileApp,
|
||||
FileAppCreateRequest,
|
||||
FileAppExtension,
|
||||
FileAppGroupLink,
|
||||
FileAppListResponse,
|
||||
FileAppResponse,
|
||||
FileAppUpdateRequest,
|
||||
ExtensionUpdateRequest,
|
||||
GroupAccessUpdateRequest,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
admin_file_app_router = APIRouter(
|
||||
prefix="/file-app",
|
||||
tags=["admin", "file_app"],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
|
||||
|
||||
@admin_file_app_router.get(
|
||||
path='/list',
|
||||
summary='列出所有文件应用',
|
||||
)
|
||||
async def list_file_apps(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep,
|
||||
) -> FileAppListResponse:
|
||||
"""
|
||||
列出所有文件应用端点(分页)
|
||||
|
||||
认证:管理员权限
|
||||
"""
|
||||
result = await FileApp.get_with_count(
|
||||
session,
|
||||
table_view=table_view,
|
||||
)
|
||||
|
||||
apps: list[FileAppResponse] = []
|
||||
for app in result.items:
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
apps.append(FileAppResponse.from_app(app, extensions, group_links))
|
||||
|
||||
return FileAppListResponse(apps=apps, total=result.count)
|
||||
|
||||
|
||||
@admin_file_app_router.post(
|
||||
path='/',
|
||||
summary='创建文件应用',
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_file_app(
|
||||
session: SessionDep,
|
||||
request: FileAppCreateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
创建文件应用端点
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 409: app_key 已存在
|
||||
"""
|
||||
# 检查 app_key 唯一
|
||||
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
|
||||
|
||||
# 创建应用
|
||||
app = FileApp(
|
||||
name=request.name,
|
||||
app_key=request.app_key,
|
||||
type=request.type,
|
||||
icon=request.icon,
|
||||
description=request.description,
|
||||
is_enabled=request.is_enabled,
|
||||
is_restricted=request.is_restricted,
|
||||
iframe_url_template=request.iframe_url_template,
|
||||
wopi_discovery_url=request.wopi_discovery_url,
|
||||
wopi_editor_url_template=request.wopi_editor_url_template,
|
||||
)
|
||||
app = await app.save(session)
|
||||
app_id = app.id
|
||||
|
||||
# 创建扩展名关联
|
||||
extensions: list[FileAppExtension] = []
|
||||
for i, ext in enumerate(request.extensions):
|
||||
normalized = ext.lower().strip().lstrip('.')
|
||||
ext_record = FileAppExtension(
|
||||
app_id=app_id,
|
||||
extension=normalized,
|
||||
priority=i,
|
||||
)
|
||||
ext_record = await ext_record.save(session)
|
||||
extensions.append(ext_record)
|
||||
|
||||
# 创建用户组关联
|
||||
group_links: list[FileAppGroupLink] = []
|
||||
for group_id in request.allowed_group_ids:
|
||||
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
|
||||
session.add(link)
|
||||
group_links.append(link)
|
||||
if group_links:
|
||||
await session.commit()
|
||||
|
||||
l.info(f"创建文件应用: {app.name} ({app.app_key})")
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.get(
|
||||
path='/{app_id}',
|
||||
summary='获取文件应用详情',
|
||||
)
|
||||
async def get_file_app(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
获取文件应用详情端点
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.patch(
|
||||
path='/{app_id}',
|
||||
summary='更新文件应用',
|
||||
)
|
||||
async def update_file_app(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
request: FileAppUpdateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
更新文件应用端点
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
- 409: 新 app_key 已被其他应用使用
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
|
||||
# 检查 app_key 唯一性
|
||||
if request.app_key is not None and request.app_key != app.app_key:
|
||||
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
|
||||
|
||||
# 更新非 None 字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(app, key, value)
|
||||
|
||||
app = await app.save(session)
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
|
||||
l.info(f"更新文件应用: {app.name} ({app.app_key})")
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.delete(
|
||||
path='/{app_id}',
|
||||
summary='删除文件应用',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_file_app(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除文件应用端点(级联删除扩展名、用户偏好和用户组关联)
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
|
||||
app_name = app.app_key
|
||||
await FileApp.delete(session, app)
|
||||
l.info(f"删除文件应用: {app_name}")
|
||||
|
||||
|
||||
@admin_file_app_router.put(
|
||||
path='/{app_id}/extensions',
|
||||
summary='全量替换扩展名列表',
|
||||
)
|
||||
async def update_extensions(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
request: ExtensionUpdateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
全量替换扩展名列表端点
|
||||
|
||||
先删除旧的扩展名关联,再创建新的。
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
|
||||
# 删除旧的扩展名
|
||||
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
for old_ext in old_extensions:
|
||||
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||
|
||||
# 创建新的扩展名
|
||||
new_extensions: list[FileAppExtension] = []
|
||||
for i, ext in enumerate(request.extensions):
|
||||
normalized = ext.lower().strip().lstrip('.')
|
||||
ext_record = FileAppExtension(
|
||||
app_id=app_id,
|
||||
extension=normalized,
|
||||
priority=i,
|
||||
)
|
||||
session.add(ext_record)
|
||||
new_extensions.append(ext_record)
|
||||
|
||||
await session.commit()
|
||||
# refresh 新创建的记录
|
||||
for ext_record in new_extensions:
|
||||
await session.refresh(ext_record)
|
||||
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
|
||||
l.info(f"更新文件应用 {app.app_key} 的扩展名: {request.extensions}")
|
||||
|
||||
return FileAppResponse.from_app(app, new_extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.put(
|
||||
path='/{app_id}/groups',
|
||||
summary='全量替换允许的用户组',
|
||||
)
|
||||
async def update_group_access(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
request: GroupAccessUpdateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
全量替换允许的用户组端点
|
||||
|
||||
先删除旧的关联,再创建新的。
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
|
||||
# 删除旧的用户组关联
|
||||
old_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
|
||||
)
|
||||
old_links: list[FileAppGroupLink] = list(old_links_result.all())
|
||||
for old_link in old_links:
|
||||
await session.delete(old_link)
|
||||
|
||||
# 创建新的用户组关联
|
||||
new_links: list[FileAppGroupLink] = []
|
||||
for group_id in request.group_ids:
|
||||
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
|
||||
session.add(link)
|
||||
new_links.append(link)
|
||||
|
||||
await session.commit()
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, new_links)
|
||||
@@ -9,7 +9,7 @@ from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||
ListResponse, Object, )
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService
|
||||
|
||||
admin_policy_router = APIRouter(
|
||||
|
||||
@@ -8,47 +8,92 @@
|
||||
- /file/upload - 上传相关操作
|
||||
- /file/download - 下载相关操作
|
||||
"""
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
import whatthepatch
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger as l
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from whatthepatch.exceptions import HunkApplyException
|
||||
|
||||
from middleware.auth import auth_required, verify_download_token
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
CreateUploadSessionRequest,
|
||||
FileApp,
|
||||
FileAppExtension,
|
||||
FileAppGroupLink,
|
||||
FileAppType,
|
||||
Object,
|
||||
ObjectType,
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
ResponseBase,
|
||||
Setting,
|
||||
SettingsType,
|
||||
UploadChunkResponse,
|
||||
UploadSession,
|
||||
UploadSessionResponse,
|
||||
User,
|
||||
WopiSessionResponse,
|
||||
)
|
||||
from service.storage import LocalStorageService, adjust_user_storage
|
||||
from service.redis.token_store import TokenStore
|
||||
from utils.JWT import create_download_token, DOWNLOAD_TOKEN_TTL
|
||||
from utils.JWT.wopi_token import create_wopi_token
|
||||
from utils import http_exceptions
|
||||
from .viewers import viewers_router
|
||||
|
||||
|
||||
# DTO
|
||||
|
||||
class DownloadTokenModel(ResponseBase):
|
||||
"""下载Token响应模型"""
|
||||
|
||||
|
||||
access_token: str
|
||||
"""JWT 令牌"""
|
||||
|
||||
|
||||
expires_in: int
|
||||
"""过期时间(秒)"""
|
||||
|
||||
|
||||
class TextContentResponse(ResponseBase):
|
||||
"""文本文件内容响应"""
|
||||
|
||||
content: str
|
||||
"""文件文本内容(UTF-8)"""
|
||||
|
||||
hash: str
|
||||
"""SHA-256 hex"""
|
||||
|
||||
size: int
|
||||
"""文件字节大小"""
|
||||
|
||||
|
||||
class PatchContentRequest(SQLModelBase):
|
||||
"""增量保存请求"""
|
||||
|
||||
patch: str
|
||||
"""unified diff 文本"""
|
||||
|
||||
base_hash: str
|
||||
"""原始内容的 SHA-256 hex(64字符)"""
|
||||
|
||||
|
||||
class PatchContentResponse(ResponseBase):
|
||||
"""增量保存响应"""
|
||||
|
||||
new_hash: str
|
||||
"""新内容的 SHA-256 hex"""
|
||||
|
||||
new_size: int
|
||||
"""新文件字节大小"""
|
||||
|
||||
# ==================== 主路由 ====================
|
||||
|
||||
router = APIRouter(prefix="/file", tags=["file"])
|
||||
@@ -410,7 +455,7 @@ async def create_download_token_endpoint(
|
||||
@_download_router.get(
|
||||
path='/{token}',
|
||||
summary='下载文件',
|
||||
description='使用下载令牌下载文件(一次性令牌,仅可使用一次)。',
|
||||
description='使用下载令牌下载文件,令牌在有效期内可重复使用。',
|
||||
)
|
||||
async def download_file(
|
||||
session: SessionDep,
|
||||
@@ -420,19 +465,14 @@ async def download_file(
|
||||
下载文件端点
|
||||
|
||||
验证 JWT 令牌后返回文件内容。
|
||||
令牌为一次性使用,下载后即失效。
|
||||
令牌在有效期内可重复使用(支持浏览器 range 请求等场景)。
|
||||
"""
|
||||
# 验证令牌
|
||||
result = verify_download_token(token)
|
||||
if not result:
|
||||
raise HTTPException(status_code=401, detail="下载令牌无效或已过期")
|
||||
|
||||
jti, file_id, owner_id = result
|
||||
|
||||
# 检查并标记令牌已使用(原子操作)
|
||||
is_first_use = await TokenStore.mark_used(jti, DOWNLOAD_TOKEN_TTL)
|
||||
if not is_first_use:
|
||||
raise HTTPException(status_code=404)
|
||||
_, file_id, owner_id = result
|
||||
|
||||
# 获取文件对象(排除已删除的)
|
||||
file_obj = await Object.get(
|
||||
@@ -478,6 +518,7 @@ async def download_file(
|
||||
|
||||
router.include_router(_upload_router)
|
||||
router.include_router(_download_router)
|
||||
router.include_router(viewers_router)
|
||||
|
||||
|
||||
# ==================== 创建空白文件 ====================
|
||||
@@ -574,6 +615,113 @@ async def create_empty_file(
|
||||
})
|
||||
|
||||
|
||||
# ==================== WOPI 会话 ====================
|
||||
|
||||
@router.post(
|
||||
path='/{file_id}/wopi-session',
|
||||
summary='创建 WOPI 会话',
|
||||
description='为 WOPI 类型的查看器创建编辑会话,返回编辑器 URL 和访问令牌。',
|
||||
)
|
||||
async def create_wopi_session(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
file_id: UUID,
|
||||
) -> WopiSessionResponse:
|
||||
"""
|
||||
创建 WOPI 会话端点
|
||||
|
||||
流程:
|
||||
1. 验证文件存在且属于当前用户
|
||||
2. 查找文件扩展名对应的 WOPI 类型应用
|
||||
3. 检查用户组权限
|
||||
4. 生成 WOPI access token
|
||||
5. 构建 editor URL
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 文件不存在 / 无可用 WOPI 应用
|
||||
- 403: 用户组无权限
|
||||
"""
|
||||
# 验证文件
|
||||
file_obj: Object | None = await Object.get(
|
||||
session,
|
||||
Object.id == file_id,
|
||||
)
|
||||
if not file_obj or file_obj.owner_id != user.id:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
if not file_obj.is_file:
|
||||
http_exceptions.raise_bad_request("对象不是文件")
|
||||
|
||||
# 获取文件扩展名
|
||||
name_parts = file_obj.name.rsplit('.', 1)
|
||||
if len(name_parts) < 2:
|
||||
http_exceptions.raise_bad_request("文件无扩展名,无法使用 WOPI 查看器")
|
||||
ext = name_parts[1].lower()
|
||||
|
||||
# 查找 WOPI 类型的应用
|
||||
from sqlalchemy import and_, select
|
||||
ext_records: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.extension == ext,
|
||||
fetch_mode="all",
|
||||
load=FileAppExtension.app,
|
||||
)
|
||||
|
||||
wopi_app: FileApp | None = None
|
||||
for ext_record in ext_records:
|
||||
app = ext_record.app
|
||||
if app.type == FileAppType.WOPI and app.is_enabled:
|
||||
# 检查用户组权限(FileAppGroupLink 是纯关联表,使用 session 查询)
|
||||
if app.is_restricted:
|
||||
stmt = select(FileAppGroupLink).where(
|
||||
and_(
|
||||
FileAppGroupLink.app_id == app.id,
|
||||
FileAppGroupLink.group_id == user.group_id,
|
||||
)
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
if not result.first():
|
||||
continue
|
||||
wopi_app = app
|
||||
break
|
||||
|
||||
if not wopi_app:
|
||||
http_exceptions.raise_not_found("无可用的 WOPI 查看器")
|
||||
|
||||
if not wopi_app.wopi_editor_url_template:
|
||||
http_exceptions.raise_bad_request("WOPI 应用未配置编辑器 URL 模板")
|
||||
|
||||
# 获取站点 URL
|
||||
site_url_setting: Setting | None = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||
)
|
||||
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
||||
|
||||
# 生成 WOPI token
|
||||
can_write = file_obj.owner_id == user.id
|
||||
token, access_token_ttl = create_wopi_token(file_id, user.id, can_write)
|
||||
|
||||
# 构建 wopi_src
|
||||
wopi_src = f"{site_url}/wopi/files/{file_id}"
|
||||
|
||||
# 构建 editor URL
|
||||
editor_url = wopi_app.wopi_editor_url_template.format(
|
||||
wopi_src=wopi_src,
|
||||
access_token=token,
|
||||
access_token_ttl=access_token_ttl,
|
||||
)
|
||||
|
||||
return WopiSessionResponse(
|
||||
wopi_src=wopi_src,
|
||||
access_token=token,
|
||||
access_token_ttl=access_token_ttl,
|
||||
editor_url=editor_url,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 文件外链(保留原有端点结构) ====================
|
||||
|
||||
@router.get(
|
||||
@@ -612,36 +760,171 @@ async def file_update(id: str) -> ResponseBase:
|
||||
|
||||
|
||||
@router.get(
|
||||
path='/preview/{id}',
|
||||
summary='预览文件',
|
||||
description='获取文件预览。',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
async def file_preview(id: str) -> ResponseBase:
|
||||
"""预览文件"""
|
||||
raise HTTPException(status_code=501, detail="预览功能暂未实现")
|
||||
|
||||
|
||||
@router.get(
|
||||
path='/content/{id}',
|
||||
path='/content/{file_id}',
|
||||
summary='获取文本文件内容',
|
||||
description='获取文本文件内容。',
|
||||
dependencies=[Depends(auth_required)]
|
||||
description='获取文本文件的 UTF-8 内容和 SHA-256 哈希值。',
|
||||
)
|
||||
async def file_content(id: str) -> ResponseBase:
|
||||
"""获取文本文件内容"""
|
||||
raise HTTPException(status_code=501, detail="文本内容功能暂未实现")
|
||||
async def file_content(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
file_id: UUID,
|
||||
) -> TextContentResponse:
|
||||
"""
|
||||
获取文本文件内容端点
|
||||
|
||||
返回文件的 UTF-8 文本内容和基于规范化内容的 SHA-256 哈希值。
|
||||
换行符统一规范化为 ``\\n``。
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 400: 文件不是有效的 UTF-8 文本
|
||||
- 404: 文件不存在
|
||||
"""
|
||||
file_obj = await Object.get(
|
||||
session,
|
||||
(Object.id == file_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not file_obj or file_obj.owner_id != user.id:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
if not file_obj.is_file:
|
||||
http_exceptions.raise_bad_request("对象不是文件")
|
||||
|
||||
physical_file = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||
|
||||
policy = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(physical_file.storage_path)
|
||||
|
||||
try:
|
||||
content = raw_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
http_exceptions.raise_bad_request("文件不是有效的 UTF-8 文本")
|
||||
|
||||
# 换行符规范化
|
||||
content = content.replace('\r\n', '\n').replace('\r', '\n')
|
||||
normalized_bytes = content.encode('utf-8')
|
||||
hash_hex = hashlib.sha256(normalized_bytes).hexdigest()
|
||||
|
||||
return TextContentResponse(
|
||||
content=content,
|
||||
hash=hash_hex,
|
||||
size=len(normalized_bytes),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
path='/doc/{id}',
|
||||
summary='获取Office文档预览地址',
|
||||
description='获取Office文档在线预览地址。',
|
||||
dependencies=[Depends(auth_required)]
|
||||
@router.patch(
|
||||
path='/content/{file_id}',
|
||||
summary='增量保存文本文件',
|
||||
description='使用 unified diff 增量更新文本文件内容。',
|
||||
)
|
||||
async def file_doc(id: str) -> ResponseBase:
|
||||
"""获取Office文档预览地址"""
|
||||
raise HTTPException(status_code=501, detail="Office预览功能暂未实现")
|
||||
async def patch_file_content(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
file_id: UUID,
|
||||
request: PatchContentRequest,
|
||||
) -> PatchContentResponse:
|
||||
"""
|
||||
增量保存文本文件端点
|
||||
|
||||
接收 unified diff 和 base_hash,验证无并发冲突后应用 patch。
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 400: 文件不是有效的 UTF-8 文本
|
||||
- 404: 文件不存在
|
||||
- 409: base_hash 不匹配(并发冲突)
|
||||
- 422: 无效的 patch 格式或 patch 应用失败
|
||||
"""
|
||||
file_obj = await Object.get(
|
||||
session,
|
||||
(Object.id == file_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not file_obj or file_obj.owner_id != user.id:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
if not file_obj.is_file:
|
||||
http_exceptions.raise_bad_request("对象不是文件")
|
||||
|
||||
if file_obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
physical_file = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||
|
||||
storage_path = physical_file.storage_path
|
||||
|
||||
policy = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type != PolicyType.LOCAL:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
raw_bytes = await storage_service.read_file(storage_path)
|
||||
|
||||
# 解码 + 规范化
|
||||
original_text = raw_bytes.decode('utf-8')
|
||||
original_text = original_text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
normalized_bytes = original_text.encode('utf-8')
|
||||
|
||||
# 冲突检测(hash 基于规范化后的内容,与 GET 端点一致)
|
||||
current_hash = hashlib.sha256(normalized_bytes).hexdigest()
|
||||
if current_hash != request.base_hash:
|
||||
http_exceptions.raise_conflict("文件内容已被修改,请刷新后重试")
|
||||
|
||||
# 解析并应用 patch
|
||||
diffs = list(whatthepatch.parse_patch(request.patch))
|
||||
if not diffs:
|
||||
http_exceptions.raise_unprocessable_entity("无效的 patch 格式")
|
||||
|
||||
try:
|
||||
result = whatthepatch.apply_diff(diffs[0], original_text)
|
||||
except HunkApplyException:
|
||||
http_exceptions.raise_unprocessable_entity("Patch 应用失败,差异内容与当前文件不匹配")
|
||||
|
||||
new_text = '\n'.join(result)
|
||||
|
||||
# 保持尾部换行符一致
|
||||
if original_text.endswith('\n') and not new_text.endswith('\n'):
|
||||
new_text += '\n'
|
||||
|
||||
new_bytes = new_text.encode('utf-8')
|
||||
|
||||
# 写入文件
|
||||
await storage_service.write_file(storage_path, new_bytes)
|
||||
|
||||
# 更新数据库
|
||||
owner_id = file_obj.owner_id
|
||||
old_size = file_obj.size
|
||||
new_size = len(new_bytes)
|
||||
size_diff = new_size - old_size
|
||||
|
||||
file_obj.size = new_size
|
||||
file_obj = await file_obj.save(session, commit=False)
|
||||
physical_file.size = new_size
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
if size_diff != 0:
|
||||
await adjust_user_storage(session, owner_id, size_diff, commit=False)
|
||||
await session.commit()
|
||||
|
||||
new_hash = hashlib.sha256(new_bytes).hexdigest()
|
||||
|
||||
l.info(f"文本文件增量保存: file_id={file_id}, size={old_size}->{new_size}")
|
||||
|
||||
return PatchContentResponse(new_hash=new_hash, new_size=new_size)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
106
routers/api/v1/file/viewers/__init__.py
Normal file
106
routers/api/v1/file/viewers/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
文件查看器查询端点
|
||||
|
||||
提供按文件扩展名查询可用查看器的功能,包含用户组访问控制过滤。
|
||||
"""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import and_
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
FileApp,
|
||||
FileAppExtension,
|
||||
FileAppGroupLink,
|
||||
FileAppSummary,
|
||||
FileViewersResponse,
|
||||
User,
|
||||
UserFileAppDefault,
|
||||
)
|
||||
|
||||
viewers_router = APIRouter(prefix="/viewers", tags=["file", "viewers"])
|
||||
|
||||
|
||||
@viewers_router.get(
|
||||
path='',
|
||||
summary='查询可用文件查看器',
|
||||
description='根据文件扩展名查询可用的查看器应用列表。',
|
||||
)
|
||||
async def get_viewers(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
ext: Annotated[str, Query(max_length=20, description="文件扩展名")],
|
||||
) -> FileViewersResponse:
|
||||
"""
|
||||
查询可用文件查看器端点
|
||||
|
||||
流程:
|
||||
1. 规范化扩展名(小写,去点号)
|
||||
2. 查询匹配的已启用应用
|
||||
3. 按用户组权限过滤
|
||||
4. 按 priority 排序
|
||||
5. 查询用户默认偏好
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 401: 未授权
|
||||
"""
|
||||
# 规范化扩展名
|
||||
normalized_ext = ext.lower().strip().lstrip('.')
|
||||
|
||||
# 查询匹配扩展名的应用(已启用的)
|
||||
ext_records: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
and_(
|
||||
FileAppExtension.extension == normalized_ext,
|
||||
),
|
||||
fetch_mode="all",
|
||||
load=FileAppExtension.app,
|
||||
)
|
||||
|
||||
# 过滤和收集可用应用
|
||||
user_group_id = user.group_id
|
||||
viewers: list[tuple[FileAppSummary, int]] = []
|
||||
|
||||
for ext_record in ext_records:
|
||||
app: FileApp = ext_record.app
|
||||
if not app.is_enabled:
|
||||
continue
|
||||
|
||||
if app.is_restricted:
|
||||
# 检查用户组权限(FileAppGroupLink 是纯关联表,使用 session 查询)
|
||||
stmt = select(FileAppGroupLink).where(
|
||||
and_(
|
||||
FileAppGroupLink.app_id == app.id,
|
||||
FileAppGroupLink.group_id == user_group_id,
|
||||
)
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
group_link = result.first()
|
||||
if not group_link:
|
||||
continue
|
||||
|
||||
viewers.append((app.to_summary(), ext_record.priority))
|
||||
|
||||
# 按 priority 排序
|
||||
viewers.sort(key=lambda x: x[1])
|
||||
|
||||
# 查询用户默认偏好
|
||||
user_default: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||
session,
|
||||
and_(
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
UserFileAppDefault.extension == normalized_ext,
|
||||
),
|
||||
)
|
||||
|
||||
return FileViewersResponse(
|
||||
viewers=[v[0] for v in viewers],
|
||||
default_viewer_id=user_default.app_id if user_default else None,
|
||||
)
|
||||
@@ -14,7 +14,7 @@ from sqlmodels.share import (
|
||||
ShareDetailResponse, ShareOwnerInfo, ShareObjectItem,
|
||||
)
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.mixin import ListResponse, TableViewRequest
|
||||
from sqlmodel_ext import ListResponse, TableViewRequest
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
@@ -17,12 +17,14 @@ from sqlmodels.color import ThemeColorsBase
|
||||
from sqlmodels.user_authn import UserAuthn
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from utils.password.pwd import PasswordStatus, TwoFactorResponse, TwoFactorVerifyRequest
|
||||
from .file_viewers import file_viewers_router
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
user_settings_router.include_router(file_viewers_router)
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
|
||||
150
routers/api/v1/user/settings/file_viewers/__init__.py
Normal file
150
routers/api/v1/user/settings/file_viewers/__init__.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
用户文件查看器偏好设置端点
|
||||
|
||||
提供用户"始终使用"默认查看器的增删查功能。
|
||||
"""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy import and_
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
FileApp,
|
||||
FileAppExtension,
|
||||
SetDefaultViewerRequest,
|
||||
User,
|
||||
UserFileAppDefault,
|
||||
UserFileAppDefaultResponse,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
file_viewers_router = APIRouter(
|
||||
prefix='/file-viewers',
|
||||
tags=["user", "user_settings", "file_viewers"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
|
||||
|
||||
@file_viewers_router.put(
|
||||
path='/default',
|
||||
summary='设置默认查看器',
|
||||
description='为指定扩展名设置"始终使用"的查看器。',
|
||||
)
|
||||
async def set_default_viewer(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: SetDefaultViewerRequest,
|
||||
) -> UserFileAppDefaultResponse:
|
||||
"""
|
||||
设置默认查看器端点
|
||||
|
||||
如果用户已有该扩展名的默认设置,则更新;否则创建新记录。
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
- 400: 应用不支持该扩展名
|
||||
"""
|
||||
# 规范化扩展名
|
||||
normalized_ext = request.extension.lower().strip().lstrip('.')
|
||||
|
||||
# 验证应用存在
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == request.app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
|
||||
# 验证应用支持该扩展名
|
||||
ext_record: FileAppExtension | None = await FileAppExtension.get(
|
||||
session,
|
||||
and_(
|
||||
FileAppExtension.app_id == app.id,
|
||||
FileAppExtension.extension == normalized_ext,
|
||||
),
|
||||
)
|
||||
if not ext_record:
|
||||
http_exceptions.raise_bad_request("该应用不支持此扩展名")
|
||||
|
||||
# 查找已有记录
|
||||
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||
session,
|
||||
and_(
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
UserFileAppDefault.extension == normalized_ext,
|
||||
),
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.app_id = request.app_id
|
||||
existing = await existing.save(session)
|
||||
# 重新加载 app 关系
|
||||
await session.refresh(existing, attribute_names=["app"])
|
||||
return existing.to_response()
|
||||
else:
|
||||
new_default = UserFileAppDefault(
|
||||
user_id=user.id,
|
||||
extension=normalized_ext,
|
||||
app_id=request.app_id,
|
||||
)
|
||||
new_default = await new_default.save(session)
|
||||
# 重新加载 app 关系
|
||||
await session.refresh(new_default, attribute_names=["app"])
|
||||
return new_default.to_response()
|
||||
|
||||
|
||||
@file_viewers_router.get(
|
||||
path='/defaults',
|
||||
summary='列出所有默认查看器设置',
|
||||
description='获取当前用户所有"始终使用"的查看器偏好。',
|
||||
)
|
||||
async def list_default_viewers(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[UserFileAppDefaultResponse]:
|
||||
"""
|
||||
列出所有默认查看器设置端点
|
||||
|
||||
认证:JWT token 必填
|
||||
"""
|
||||
defaults: list[UserFileAppDefault] = await UserFileAppDefault.get(
|
||||
session,
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
load=UserFileAppDefault.app,
|
||||
)
|
||||
return [d.to_response() for d in defaults]
|
||||
|
||||
|
||||
@file_viewers_router.delete(
|
||||
path='/default/{default_id}',
|
||||
summary='撤销默认查看器设置',
|
||||
description='删除指定的"始终使用"偏好。',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_default_viewer(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
default_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
撤销默认查看器设置端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 记录不存在或不属于当前用户
|
||||
"""
|
||||
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||
session,
|
||||
and_(
|
||||
UserFileAppDefault.id == default_id,
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
),
|
||||
)
|
||||
if not existing:
|
||||
http_exceptions.raise_not_found("默认设置不存在")
|
||||
|
||||
await UserFileAppDefault.delete(session, existing)
|
||||
11
routers/wopi/__init__.py
Normal file
11
routers/wopi/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
WOPI(Web Application Open Platform Interface)路由
|
||||
|
||||
挂载在根级别 /wopi(非 /api/v1 下),因为 WOPI 客户端要求标准路径。
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .files import wopi_files_router
|
||||
|
||||
wopi_router = APIRouter(prefix="/wopi", tags=["wopi"])
|
||||
wopi_router.include_router(wopi_files_router)
|
||||
203
routers/wopi/files/__init__.py
Normal file
203
routers/wopi/files/__init__.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
WOPI 文件操作端点
|
||||
|
||||
实现 WOPI 协议的核心文件操作接口:
|
||||
- CheckFileInfo: 获取文件元数据
|
||||
- GetFile: 下载文件内容
|
||||
- PutFile: 上传/更新文件内容
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Query, Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import Object, PhysicalFile, Policy, PolicyType, User, WopiFileInfo
|
||||
from service.storage import LocalStorageService
|
||||
from utils import http_exceptions
|
||||
from utils.JWT.wopi_token import verify_wopi_token
|
||||
|
||||
wopi_files_router = APIRouter(prefix="/files", tags=["wopi"])
|
||||
|
||||
|
||||
@wopi_files_router.get(
|
||||
path='/{file_id}',
|
||||
summary='WOPI CheckFileInfo',
|
||||
description='返回文件的元数据信息。',
|
||||
)
|
||||
async def check_file_info(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
access_token: str = Query(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
WOPI CheckFileInfo 端点
|
||||
|
||||
认证:WOPI access_token(query 参数)
|
||||
|
||||
返回 WOPI 规范的 PascalCase JSON。
|
||||
"""
|
||||
# 验证令牌
|
||||
payload = verify_wopi_token(access_token)
|
||||
if not payload or payload.file_id != file_id:
|
||||
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||
|
||||
# 获取文件
|
||||
file_obj: Object | None = await Object.get(
|
||||
session,
|
||||
Object.id == file_id,
|
||||
)
|
||||
if not file_obj or not file_obj.is_file:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
# 获取用户信息
|
||||
user: User | None = await User.get(session, User.id == payload.user_id)
|
||||
user_name = user.nickname or user.email or str(payload.user_id) if user else str(payload.user_id)
|
||||
|
||||
# 构建响应
|
||||
info = WopiFileInfo(
|
||||
base_file_name=file_obj.name,
|
||||
size=file_obj.size or 0,
|
||||
owner_id=str(file_obj.owner_id),
|
||||
user_id=str(payload.user_id),
|
||||
user_friendly_name=user_name,
|
||||
version=file_obj.updated_at.isoformat() if file_obj.updated_at else "",
|
||||
user_can_write=payload.can_write,
|
||||
read_only=not payload.can_write,
|
||||
supports_update=payload.can_write,
|
||||
)
|
||||
|
||||
return JSONResponse(content=info.to_wopi_dict())
|
||||
|
||||
|
||||
@wopi_files_router.get(
|
||||
path='/{file_id}/contents',
|
||||
summary='WOPI GetFile',
|
||||
description='返回文件的二进制内容。',
|
||||
)
|
||||
async def get_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
access_token: str = Query(...),
|
||||
) -> Response:
|
||||
"""
|
||||
WOPI GetFile 端点
|
||||
|
||||
认证:WOPI access_token(query 参数)
|
||||
|
||||
返回文件的原始二进制内容。
|
||||
"""
|
||||
# 验证令牌
|
||||
payload = verify_wopi_token(access_token)
|
||||
if not payload or payload.file_id != file_id:
|
||||
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||
|
||||
# 获取文件
|
||||
file_obj: Object | None = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj or not file_obj.is_file:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
# 获取物理文件
|
||||
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||
|
||||
# 获取策略
|
||||
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
import aiofiles
|
||||
async with aiofiles.open(physical_file.storage_path, 'rb') as f:
|
||||
content = await f.read()
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"X-WOPI-ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
|
||||
)
|
||||
else:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
|
||||
@wopi_files_router.post(
|
||||
path='/{file_id}/contents',
|
||||
summary='WOPI PutFile',
|
||||
description='更新文件内容。',
|
||||
)
|
||||
async def put_file(
|
||||
session: SessionDep,
|
||||
request: Request,
|
||||
file_id: UUID,
|
||||
access_token: str = Query(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
WOPI PutFile 端点
|
||||
|
||||
认证:WOPI access_token(query 参数,需要写权限)
|
||||
|
||||
接收请求体中的文件二进制内容并覆盖存储。
|
||||
"""
|
||||
# 验证令牌
|
||||
payload = verify_wopi_token(access_token)
|
||||
if not payload or payload.file_id != file_id:
|
||||
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||
|
||||
if not payload.can_write:
|
||||
http_exceptions.raise_forbidden("没有写入权限")
|
||||
|
||||
# 获取文件
|
||||
file_obj: Object | None = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj or not file_obj.is_file:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
# 获取物理文件
|
||||
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||
|
||||
# 获取策略
|
||||
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
# 读取请求体
|
||||
content = await request.body()
|
||||
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
import aiofiles
|
||||
async with aiofiles.open(physical_file.storage_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
# 更新文件大小
|
||||
new_size = len(content)
|
||||
old_size = file_obj.size or 0
|
||||
file_obj.size = new_size
|
||||
file_obj = await file_obj.save(session, commit=False)
|
||||
|
||||
# 更新物理文件大小
|
||||
physical_file.size = new_size
|
||||
await physical_file.save(session, commit=False)
|
||||
|
||||
# 更新用户存储配额
|
||||
size_diff = new_size - old_size
|
||||
if size_diff != 0:
|
||||
from service.storage import adjust_user_storage
|
||||
await adjust_user_storage(session, file_obj.owner_id, size_diff, commit=False)
|
||||
|
||||
await session.commit()
|
||||
|
||||
l.info(f"WOPI PutFile: file_id={file_id}, new_size={new_size}")
|
||||
|
||||
return JSONResponse(
|
||||
content={"ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
|
||||
status_code=200,
|
||||
)
|
||||
else:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
@@ -23,7 +23,7 @@ import string
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class NamingContext(SQLModelBase):
|
||||
|
||||
91
setup_cython.py
Normal file
91
setup_cython.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Cython 编译脚本 — 将 ee/ 下的纯逻辑文件编译为 .so
|
||||
|
||||
用法:
|
||||
uv run --extra build python setup_cython.py build_ext --inplace
|
||||
|
||||
编译规则:
|
||||
- 跳过 __init__.py(Python 包发现需要)
|
||||
- 只编译 .py 文件(纯函数 / 服务逻辑)
|
||||
|
||||
编译后清理(Pro Docker 构建用):
|
||||
uv run --extra build python setup_cython.py clean_artifacts
|
||||
"""
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
EE_DIR = Path("ee")
|
||||
|
||||
# 跳过 __init__.py —— 包发现需要原始 .py
|
||||
SKIP_NAMES = {"__init__.py"}
|
||||
|
||||
|
||||
def _collect_modules() -> list[str]:
|
||||
"""收集 ee/ 下需要编译的 .py 文件路径(点分模块名)。"""
|
||||
modules: list[str] = []
|
||||
for py_file in EE_DIR.rglob("*.py"):
|
||||
if py_file.name in SKIP_NAMES:
|
||||
continue
|
||||
# ee/license.py → ee.license
|
||||
module = str(py_file.with_suffix("")).replace("\\", "/").replace("/", ".")
|
||||
modules.append(module)
|
||||
return modules
|
||||
|
||||
|
||||
def clean_artifacts() -> None:
|
||||
"""删除已编译的 .py 源码、.c 中间文件和 build/ 目录。"""
|
||||
for py_file in EE_DIR.rglob("*.py"):
|
||||
if py_file.name in SKIP_NAMES:
|
||||
continue
|
||||
# 只删除有对应 .so / .pyd 的源文件
|
||||
parent = py_file.parent
|
||||
stem = py_file.stem
|
||||
has_compiled = (
|
||||
any(parent.glob(f"{stem}*.so")) or
|
||||
any(parent.glob(f"{stem}*.pyd"))
|
||||
)
|
||||
if has_compiled:
|
||||
py_file.unlink()
|
||||
print(f"已删除源码: {py_file}")
|
||||
|
||||
# 删除 .c 中间文件
|
||||
for c_file in EE_DIR.rglob("*.c"):
|
||||
c_file.unlink()
|
||||
print(f"已删除中间文件: {c_file}")
|
||||
|
||||
# 删除 build/ 目录
|
||||
build_dir = Path("build")
|
||||
if build_dir.exists():
|
||||
shutil.rmtree(build_dir)
|
||||
print(f"已删除: {build_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "clean_artifacts":
|
||||
clean_artifacts()
|
||||
sys.exit(0)
|
||||
|
||||
# 动态导入(仅在编译时需要)
|
||||
from Cython.Build import cythonize
|
||||
from setuptools import Extension, setup
|
||||
|
||||
modules = _collect_modules()
|
||||
if not modules:
|
||||
print("未找到需要编译的模块")
|
||||
sys.exit(0)
|
||||
|
||||
print(f"即将编译以下模块: {modules}")
|
||||
|
||||
extensions = [
|
||||
Extension(mod, [mod.replace(".", "/") + ".py"])
|
||||
for mod in modules
|
||||
]
|
||||
|
||||
setup(
|
||||
name="disknext-ee",
|
||||
ext_modules=cythonize(
|
||||
extensions,
|
||||
compiler_directives={'language_level': "3"},
|
||||
),
|
||||
)
|
||||
@@ -115,6 +115,14 @@ from .storage_pack import StoragePack
|
||||
from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||
from .webdav import WebDAV
|
||||
from .file_app import (
|
||||
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
|
||||
# DTO
|
||||
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
|
||||
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
|
||||
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
|
||||
)
|
||||
from .wopi import WopiFileInfo, WopiAccessTokenPayload
|
||||
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
@@ -131,5 +139,5 @@ from .model_base import (
|
||||
AdminSummaryResponse,
|
||||
)
|
||||
|
||||
# mixin 中的通用分页模型
|
||||
from .mixin import ListResponse
|
||||
# 通用分页模型
|
||||
from sqlmodel_ext import ListResponse
|
||||
|
||||
@@ -10,8 +10,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -1,657 +0,0 @@
|
||||
# SQLModels Base Module
|
||||
|
||||
This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
|
||||
|
||||
**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Migration Notice](#migration-notice)
|
||||
- [Python 3.14 Compatibility](#python-314-compatibility)
|
||||
- [Core Component](#core-component)
|
||||
- [SQLModelBase](#sqlmodelbase)
|
||||
- [Metaclass Features](#metaclass-features)
|
||||
- [Automatic sa_type Injection](#automatic-sa_type-injection)
|
||||
- [Table Configuration](#table-configuration)
|
||||
- [Polymorphic Support](#polymorphic-support)
|
||||
- [Custom Types Integration](#custom-types-integration)
|
||||
- [Best Practices](#best-practices)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
|
||||
|
||||
- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
|
||||
- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
|
||||
- **Flexible configuration** through class parameters and automatic docstring support
|
||||
- **Type-safe annotations** with automatic validation
|
||||
|
||||
All models in this project should directly or indirectly inherit from `SQLModelBase`.
|
||||
|
||||
---
|
||||
|
||||
## Migration Notice
|
||||
|
||||
As of the recent refactoring, the following components have been moved:
|
||||
|
||||
| Component | Old Location | New Location |
|
||||
|-----------|-------------|--------------|
|
||||
| `TableBase` → `TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `UUIDTableBase` → `UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
|
||||
**Update your imports**:
|
||||
|
||||
```python
|
||||
# ❌ Old (deprecated)
|
||||
from sqlmodels.base import TableBase, UUIDTableBase
|
||||
|
||||
# ✅ New (correct)
|
||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
```
|
||||
|
||||
For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Python 3.14 Compatibility
|
||||
|
||||
### Overview
|
||||
|
||||
This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
|
||||
|
||||
**Key Changes in Python 3.14**:
|
||||
- Annotations are no longer evaluated at class definition time
|
||||
- Type hints are stored as deferred code objects
|
||||
- `__annotate__` function generates annotations on demand
|
||||
- Forward references become `ForwardRef` objects
|
||||
|
||||
### Implementation Strategy
|
||||
|
||||
We use **`typing.get_type_hints()`** as the universal annotations resolver:
|
||||
|
||||
```python
|
||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
|
||||
# Create temporary proxy class
|
||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
||||
|
||||
# Use get_type_hints with include_extras=True
|
||||
evaluated = get_type_hints(
|
||||
temp_cls,
|
||||
globalns=module_globals,
|
||||
localns=localns,
|
||||
include_extras=True # Preserve Annotated metadata
|
||||
)
|
||||
|
||||
return dict(evaluated), {}, module_globals, localns
|
||||
```
|
||||
|
||||
**Why `get_type_hints()`?**
|
||||
- ✅ Works across Python 3.10-3.14+
|
||||
- ✅ Handles PEP 649 automatically
|
||||
- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
|
||||
- ✅ Resolves forward references
|
||||
- ✅ Recommended by Python documentation
|
||||
|
||||
### SQLModel Compatibility Patch
|
||||
|
||||
**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
|
||||
|
||||
**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
|
||||
|
||||
```python
|
||||
if sys.version_info >= (3, 14):
|
||||
def _patched_get_sqlalchemy_type(field):
|
||||
annotation = getattr(field, 'annotation', None)
|
||||
if annotation is not None:
|
||||
# Priority 1: Check __sqlmodel_sa_type__ attribute
|
||||
# Handles NumpyVector[dims, dtype] and similar custom types
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# Priority 2: Check Annotated metadata
|
||||
if get_origin(annotation) is Annotated:
|
||||
for metadata in get_args(annotation)[1:]:
|
||||
if hasattr(metadata, '__sqlmodel_sa_type__'):
|
||||
return metadata.__sqlmodel_sa_type__
|
||||
|
||||
# ... handle ForwardRef, ClassVar, etc.
|
||||
|
||||
return _original_get_sqlalchemy_type(field)
|
||||
```
|
||||
|
||||
### Supported Patterns
|
||||
|
||||
#### Pattern 1: Direct Custom Type Usage
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Voice embedding - sa_type automatically extracted"""
|
||||
```
|
||||
|
||||
#### Pattern 2: Annotated Wrapper
|
||||
```python
|
||||
from typing import Annotated
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: EmbeddingVector
|
||||
```
|
||||
|
||||
#### Pattern 3: Array Type
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import Array
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
|
||||
class ServerConfig(TableBaseMixin, table=True):
|
||||
protocols: Array[ProtocolEnum]
|
||||
"""Allowed protocols - sa_type from Array handler"""
|
||||
```
|
||||
|
||||
### Migration from Python 3.13
|
||||
|
||||
**No code changes required!** The implementation is transparent:
|
||||
|
||||
- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
|
||||
- Custom types already use `__sqlmodel_sa_type__` attribute
|
||||
- Monkey-patch only activates for Python 3.14+
|
||||
|
||||
---
|
||||
|
||||
## Core Component
|
||||
|
||||
### SQLModelBase
|
||||
|
||||
`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
|
||||
|
||||
**Key Features**:
|
||||
- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
|
||||
- Automatic `validate_by_name` configuration
|
||||
- Custom metaclass for sa_type injection and polymorphic setup
|
||||
- Integration with Pydantic v2
|
||||
- Python 3.14 PEP 649 compatibility
|
||||
|
||||
**Usage**:
|
||||
|
||||
```python
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class UserBase(SQLModelBase):
|
||||
name: str
|
||||
"""User's display name"""
|
||||
|
||||
email: str
|
||||
"""User's email address"""
|
||||
```
|
||||
|
||||
**Important Notes**:
|
||||
- Use **docstrings** for field descriptions, not `Field(description=...)`
|
||||
- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
|
||||
- This class should be used for non-table models (DTOs, request/response models)
|
||||
|
||||
**For table models**, use mixins from `sqlmodels.mixin`:
|
||||
- `TableBaseMixin` - Integer primary key with timestamps
|
||||
- `UUIDTableBaseMixin` - UUID primary key with timestamps
|
||||
|
||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
|
||||
|
||||
---
|
||||
|
||||
## Metaclass Features
|
||||
|
||||
### Automatic sa_type Injection
|
||||
|
||||
The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
|
||||
|
||||
**Before** (verbose):
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: np.ndarray = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
||||
)
|
||||
```
|
||||
|
||||
**After** (clean):
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Speaker voice embedding"""
|
||||
```
|
||||
|
||||
**How It Works**:
|
||||
|
||||
The metaclass uses a three-tier detection strategy:
|
||||
|
||||
1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
|
||||
```python
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
```
|
||||
|
||||
2. **Annotated metadata** (Priority 2)
|
||||
```python
|
||||
# For Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
if get_origin(annotation) is typing.Annotated:
|
||||
for item in metadata_items:
|
||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
||||
return item.__sqlmodel_sa_type__
|
||||
```
|
||||
|
||||
3. **Pydantic Core Schema metadata** (Priority 3)
|
||||
```python
|
||||
schema = annotation.__get_pydantic_core_schema__(...)
|
||||
if schema['metadata'].get('sa_type'):
|
||||
return schema['metadata']['sa_type']
|
||||
```
|
||||
|
||||
After extracting `sa_type`, the metaclass:
|
||||
- Creates `Field(sa_type=sa_type)` if no Field is defined
|
||||
- Injects `sa_type` into existing Field if not already set
|
||||
- Respects explicit `Field(sa_type=...)` (no override)
|
||||
|
||||
**Supported Patterns**:
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# Pattern 1: Direct usage (recommended)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
|
||||
# Pattern 2: With Field constraints
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32] = Field(nullable=False)
|
||||
|
||||
# Pattern 3: Annotated wrapper
|
||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: EmbeddingVector
|
||||
|
||||
# Pattern 4: Explicit sa_type (override)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32] = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
|
||||
)
|
||||
```
|
||||
|
||||
### Table Configuration
|
||||
|
||||
The metaclass provides smart defaults and flexible configuration:
|
||||
|
||||
**Automatic `table=True`**:
|
||||
```python
|
||||
# Classes inheriting from TableBaseMixin automatically get table=True
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(UUIDTableBaseMixin): # table=True is automatic
|
||||
pass
|
||||
```
|
||||
|
||||
**Convenient mapper arguments**:
|
||||
```python
|
||||
# Instead of verbose __mapper_args__
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(
|
||||
UUIDTableBaseMixin,
|
||||
polymorphic_on='_polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
pass
|
||||
|
||||
# Equivalent to:
|
||||
class MyModel(UUIDTableBaseMixin):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': '_polymorphic_name',
|
||||
'polymorphic_abstract': True
|
||||
}
|
||||
```
|
||||
|
||||
**Smart merging**:
|
||||
```python
|
||||
# Dictionary and keyword arguments are merged
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(
|
||||
UUIDTableBaseMixin,
|
||||
mapper_args={'version_id_col': 'version'},
|
||||
polymorphic_on='type' # Merged into __mapper_args__
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### Polymorphic Support
|
||||
|
||||
The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
|
||||
|
||||
**Supported parameters**:
|
||||
- `polymorphic_on`: Discriminator column name
|
||||
- `polymorphic_identity`: Identity value for this class
|
||||
- `polymorphic_abstract`: Whether this is an abstract base
|
||||
- `table_args`: SQLAlchemy table arguments
|
||||
- `table_name`: Override table name (becomes `__tablename__`)
|
||||
|
||||
**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Custom Types Integration
|
||||
|
||||
### Using NumpyVector
|
||||
|
||||
The `NumpyVector` type demonstrates automatic sa_type injection:
|
||||
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
import numpy as np
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Speaker voice embedding - sa_type automatically injected"""
|
||||
```
|
||||
|
||||
**How NumpyVector works**:
|
||||
|
||||
```python
|
||||
# NumpyVector[dims, dtype] returns a class with:
|
||||
class _NumpyVectorType:
|
||||
__sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
return handler.generate_schema(np.ndarray)
|
||||
```
|
||||
|
||||
This dual approach ensures:
|
||||
1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
|
||||
2. Pydantic can validate as `np.ndarray`
|
||||
|
||||
### Creating Custom SQLAlchemy Types
|
||||
|
||||
To create types that work with automatic injection, provide one of:
|
||||
|
||||
**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
|
||||
|
||||
```python
|
||||
from sqlalchemy import TypeDecorator, String
|
||||
|
||||
class UpperCaseString(TypeDecorator):
|
||||
impl = String
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return value.upper() if value else value
|
||||
|
||||
class UpperCaseType:
|
||||
__sqlmodel_sa_type__ = UpperCaseString()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
return core_schema.str_schema()
|
||||
|
||||
# Usage
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(UUIDTableBaseMixin, table=True):
|
||||
code: UpperCaseType # Automatically uses UpperCaseString()
|
||||
```
|
||||
|
||||
**Option 2: Pydantic metadata with sa_type**:
|
||||
|
||||
```python
|
||||
def __get_pydantic_core_schema__(self, source_type, handler):
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.str_schema(),
|
||||
python_schema=core_schema.str_schema(),
|
||||
metadata={'sa_type': UpperCaseString()}
|
||||
)
|
||||
```
|
||||
|
||||
**Option 3: Using Annotated**:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
UpperCase = Annotated[str, UpperCaseType()]
|
||||
|
||||
class MyModel(UUIDTableBaseMixin, table=True):
|
||||
code: UpperCase
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Inherit from correct base classes
|
||||
|
||||
```python
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
# ✅ For non-table models (DTOs, requests, responses)
|
||||
class UserBase(SQLModelBase):
|
||||
name: str
|
||||
|
||||
# ✅ For table models with UUID primary key
|
||||
class User(UserBase, UUIDTableBaseMixin, table=True):
|
||||
email: str
|
||||
|
||||
# ✅ For table models with custom primary key
|
||||
class LegacyUser(TableBaseMixin, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
username: str
|
||||
```
|
||||
|
||||
### 2. Use docstrings for field descriptions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Recommended
|
||||
class User(UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
"""User's display name"""
|
||||
|
||||
# ❌ Avoid
|
||||
class User(UUIDTableBaseMixin, table=True):
|
||||
name: str = Field(description="User's display name")
|
||||
```
|
||||
|
||||
**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
|
||||
|
||||
### 3. Leverage automatic sa_type injection
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Clean and recommended
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Voice embedding"""
|
||||
|
||||
# ❌ Verbose and unnecessary
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: np.ndarray = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Follow polymorphic naming conventions
|
||||
|
||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
|
||||
|
||||
### 5. Separate Base, Parent, and Implementation classes
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
|
||||
|
||||
# ✅ Recommended structure
|
||||
class ASRBase(SQLModelBase):
|
||||
"""Pure data fields, no table"""
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract parent with table"""
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio: bytes) -> str:
|
||||
pass
|
||||
|
||||
class WhisperASR(ASR, table=True):
|
||||
"""Concrete implementation"""
|
||||
model_size: str
|
||||
|
||||
async def transcribe(self, audio: bytes) -> str:
|
||||
# Implementation
|
||||
pass
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Base class can be reused for DTOs
|
||||
- Parent class defines the polymorphic hierarchy
|
||||
- Implementation classes are clean and focused
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: ValueError: X has no matching SQLAlchemy type
|
||||
|
||||
**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
|
||||
|
||||
```python
|
||||
# ✅ Provide __sqlmodel_sa_type__
|
||||
class MyType:
|
||||
__sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
|
||||
```
|
||||
|
||||
### Issue: Can't generate DDL for NullType()
|
||||
|
||||
**Symptoms**: Error during table creation saying a column has `NullType`.
|
||||
|
||||
**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
|
||||
|
||||
**Solution**:
|
||||
1. Ensure your type has `__sqlmodel_sa_type__` class attribute
|
||||
2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
|
||||
3. Verify type annotation is correct (not a string forward reference)
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Correct
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
|
||||
|
||||
# ❌ Wrong (string annotation)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
data: 'NumpyVector[256, np.float32]' # sa_type lost
|
||||
```
|
||||
|
||||
### Issue: Polymorphic identity conflicts
|
||||
|
||||
**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
|
||||
|
||||
**Solution**:
|
||||
1. Check that each concrete class has a unique identity
|
||||
2. Use `AutoPolymorphicIdentityMixin` for automatic naming
|
||||
3. Manually specify identity if needed:
|
||||
```python
|
||||
class MyClass(Parent, polymorphic_identity='unique.name', table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### Issue: Python 3.14 annotation errors
|
||||
|
||||
**Symptoms**: Errors related to `__annotations__` or type resolution.
|
||||
|
||||
**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
|
||||
1. Check for manual `__annotations__` manipulation (avoid it)
|
||||
2. Ensure all types are properly imported
|
||||
3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
|
||||
|
||||
### Issue: Polymorphic and CRUD-related errors
|
||||
|
||||
For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Implementation Details
|
||||
|
||||
For developers modifying this module:
|
||||
|
||||
**Core files**:
|
||||
- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
|
||||
- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
|
||||
- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
|
||||
|
||||
**Key functions in this module**:
|
||||
|
||||
1. **`_resolve_annotations(attrs: dict[str, Any])`**
|
||||
- Uses `typing.get_type_hints()` for Python 3.14 compatibility
|
||||
- Returns tuple: `(annotations, annotation_strings, globalns, localns)`
|
||||
- Preserves `Annotated` metadata with `include_extras=True`
|
||||
|
||||
2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
|
||||
- Extracts SQLAlchemy type from type annotations
|
||||
- Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
|
||||
- Called by metaclass during class creation
|
||||
|
||||
3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
|
||||
- Global monkey-patch for SQLModel
|
||||
- Checks `__sqlmodel_sa_type__` before falling back to original logic
|
||||
- Handles custom types like `NumpyVector` and `Array`
|
||||
|
||||
4. **`__DeclarativeMeta.__new__()`**
|
||||
- Processes class definition parameters
|
||||
- Injects `sa_type` into field definitions
|
||||
- Sets up `__mapper_args__`, `__table_args__`, etc.
|
||||
- Handles Python 3.14 annotations via `get_type_hints()`
|
||||
|
||||
**Metaclass processing order**:
|
||||
1. Check if class should be a table (`_has_table_mixin`)
|
||||
2. Collect `__mapper_args__` from kwargs and explicit dict
|
||||
3. Process `table_args`, `table_name`, `abstract` parameters
|
||||
4. Resolve annotations using `get_type_hints()`
|
||||
5. For each field, try to extract `sa_type` and inject into Field
|
||||
6. Call parent metaclass with cleaned kwargs
|
||||
|
||||
For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## See Also
|
||||
|
||||
**Project Documentation**:
|
||||
- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
|
||||
- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
|
||||
- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
|
||||
|
||||
**External References**:
|
||||
- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
|
||||
- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
|
||||
- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
|
||||
- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
|
||||
- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
SQLModel 基础模块
|
||||
|
||||
包含:
|
||||
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
|
||||
|
||||
注意:
|
||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
|
||||
为了避免循环导入,此处不再重新导出它们
|
||||
请直接从 sqlmodels.mixin 导入这些类
|
||||
"""
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
@@ -1,846 +0,0 @@
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any, Mapping, get_args, get_origin, get_type_hints
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined as Undefined
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlmodel.main import SQLModelMetaclass
|
||||
|
||||
# Python 3.14+ PEP 649支持
|
||||
if sys.version_info >= (3, 14):
|
||||
import annotationlib
|
||||
|
||||
# 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
|
||||
import sqlmodel.main
|
||||
_original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
|
||||
|
||||
def _patched_get_sqlalchemy_type(field):
|
||||
"""
|
||||
修复SQLModel的get_sqlalchemy_type函数,处理Python 3.14的类型问题。
|
||||
|
||||
问题:
|
||||
1. ForwardRef对象(来自Relationship字段)会导致issubclass错误
|
||||
2. typing._GenericAlias对象(如ClassVar[T])也会导致同样问题
|
||||
3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
|
||||
4. Mapped类型在Python 3.14下可能出现在annotation中
|
||||
5. Annotated类型可能包含sa_type metadata(如Array[T])
|
||||
6. 自定义类型(如NumpyVector)有__sqlmodel_sa_type__属性
|
||||
7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
|
||||
|
||||
解决:
|
||||
- 优先检查field.metadata中的__get_pydantic_core_schema__(Pydantic已处理的情况)
|
||||
- 检测__sqlmodel_sa_type__属性(NumpyVector等)
|
||||
- 检测Relationship/ClassVar等返回None
|
||||
- 对于Annotated类型,尝试提取sa_type metadata
|
||||
- 其他情况调用原始函数
|
||||
"""
|
||||
# 优先检查 field.metadata(Pydantic已处理Annotated类型的情况)
|
||||
# 当使用 Array[T] 或 Annotated[T, metadata] 时,Pydantic会将metadata存储在这里
|
||||
metadata = getattr(field, 'metadata', None)
|
||||
if metadata:
|
||||
# metadata是一个列表,包含所有Annotated的元数据项
|
||||
for metadata_item in metadata:
|
||||
# 检查metadata_item是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(metadata_item, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema
|
||||
schema = metadata_item.__get_pydantic_core_schema__(None, None)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Pydantic schema获取可能失败(类型不匹配、缺少属性等)
|
||||
# 这是正常情况,继续检查下一个metadata项
|
||||
pass
|
||||
|
||||
annotation = getattr(field, 'annotation', None)
|
||||
if annotation is not None:
|
||||
# 优先检查 __sqlmodel_sa_type__ 属性
|
||||
# 这处理 NumpyVector[dims, dtype] 等自定义类型
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# 检查自定义类型(如JSON100K)的 __get_pydantic_core_schema__ 方法
|
||||
# 这些类型在schema的metadata中定义sa_type
|
||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema(传None作为handler,因为我们只需要metadata)
|
||||
schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Schema获取失败,继续其他检查
|
||||
pass
|
||||
|
||||
anno_type_name = type(annotation).__name__
|
||||
|
||||
# ForwardRef: Relationship字段的annotation
|
||||
if anno_type_name == 'ForwardRef':
|
||||
return None
|
||||
|
||||
# AnnotatedAlias: 检查是否有sa_type metadata(如Array[T])
|
||||
if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
|
||||
from typing import get_origin, get_args
|
||||
import typing
|
||||
|
||||
# 尝试提取Annotated的metadata
|
||||
if hasattr(typing, 'get_args'):
|
||||
args = get_args(annotation)
|
||||
# args[0]是实际类型,args[1:]是metadata
|
||||
for metadata in args[1:]:
|
||||
# 检查metadata是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(metadata, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema
|
||||
schema = metadata.__get_pydantic_core_schema__(None, None)
|
||||
# 检查schema中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Annotated metadata的schema获取可能失败
|
||||
# 这是正常的类型检查过程,继续检查下一个metadata
|
||||
pass
|
||||
|
||||
# _GenericAlias或GenericAlias: typing泛型类型
|
||||
if anno_type_name in ('_GenericAlias', 'GenericAlias'):
|
||||
from typing import get_origin
|
||||
import typing
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# ClassVar必须跳过
|
||||
if origin is typing.ClassVar:
|
||||
return None
|
||||
|
||||
# list/dict/tuple/set等内置泛型,如果字段没有明确的Field或Relationship,也跳过
|
||||
# 这通常意味着它是Relationship字段或类变量
|
||||
if origin in (list, dict, tuple, set):
|
||||
# 检查field_info是否存在且有意义
|
||||
# Relationship字段会有特殊的field_info
|
||||
field_info = getattr(field, 'field_info', None)
|
||||
if field_info is None:
|
||||
return None
|
||||
|
||||
# Mapped: SQLAlchemy 2.0的Mapped类型,SQLModel不应该处理
|
||||
# 这可能是从父类继承的字段或Python 3.14注解处理的副作用
|
||||
# 检查类型名称和annotation的字符串表示
|
||||
if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
|
||||
return None
|
||||
|
||||
# 检查annotation是否是Mapped类或其实例
|
||||
try:
|
||||
from sqlalchemy.orm import Mapped as SAMapped
|
||||
# 检查origin(对于Mapped[T]这种泛型)
|
||||
from typing import get_origin
|
||||
if get_origin(annotation) is SAMapped:
|
||||
return None
|
||||
# 检查类型本身
|
||||
if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
|
||||
return None
|
||||
except (ImportError, TypeError):
|
||||
# 如果SQLAlchemy没有Mapped或检查失败,继续
|
||||
pass
|
||||
|
||||
# 其他情况正常处理
|
||||
return _original_get_sqlalchemy_type(field)
|
||||
|
||||
sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
|
||||
|
||||
# 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
|
||||
# 在Python 3.14 + SQLModel组合下,当子类(如SMSBaoProvider)继承父类(如VerificationCodeProvider)时,
|
||||
# 父类的关系字段(如server_config)会在子类的model_fields中出现,
|
||||
# 但其default值错误地设置为InstrumentedAttribute对象,而不是None
|
||||
# 这导致实例化时尝试设置InstrumentedAttribute为字段值,触发SQLAlchemy内部错误
|
||||
import sqlmodel._compat as _compat
|
||||
from sqlalchemy.orm import attributes as _sa_attributes
|
||||
|
||||
_original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
|
||||
|
||||
def _patched_sqlmodel_table_construct(self_instance, values):
|
||||
"""
|
||||
修复sqlmodel_table_construct,跳过InstrumentedAttribute默认值
|
||||
|
||||
问题:
|
||||
- 继承自polymorphic基类的表类(如FishAudioTTS, SMSBaoProvider)
|
||||
- 其model_fields中的继承字段default值为InstrumentedAttribute
|
||||
- 原函数尝试将InstrumentedAttribute设置为字段值
|
||||
- SQLAlchemy无法处理,抛出 '_sa_instance_state' 错误
|
||||
|
||||
解决:
|
||||
- 只设置用户提供的值和非InstrumentedAttribute默认值
|
||||
- InstrumentedAttribute默认值跳过(让SQLAlchemy自己处理)
|
||||
"""
|
||||
cls = type(self_instance)
|
||||
|
||||
# 收集要设置的字段值
|
||||
fields_to_set = {}
|
||||
|
||||
for name, field in cls.model_fields.items():
|
||||
# 如果用户提供了值,直接使用
|
||||
if name in values:
|
||||
fields_to_set[name] = values[name]
|
||||
continue
|
||||
|
||||
# 否则检查默认值
|
||||
# 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
|
||||
if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
|
||||
continue
|
||||
|
||||
# 使用正常的默认值
|
||||
if field.default is not Undefined:
|
||||
fields_to_set[name] = field.default
|
||||
elif field.default_factory is not None:
|
||||
fields_to_set[name] = field.get_default(call_default_factory=True)
|
||||
|
||||
# 设置属性 - 只设置非InstrumentedAttribute值
|
||||
for key, value in fields_to_set.items():
|
||||
if not isinstance(value, _sa_attributes.InstrumentedAttribute):
|
||||
setattr(self_instance, key, value)
|
||||
|
||||
# 设置Pydantic内部属性
|
||||
object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
|
||||
if not cls.__pydantic_root_model__:
|
||||
_extra = None
|
||||
if cls.model_config.get('extra') == 'allow':
|
||||
_extra = {}
|
||||
for k, v in values.items():
|
||||
if k not in cls.model_fields:
|
||||
_extra[k] = v
|
||||
object.__setattr__(self_instance, '__pydantic_extra__', _extra)
|
||||
|
||||
if cls.__pydantic_post_init__:
|
||||
self_instance.model_post_init(None)
|
||||
elif not cls.__pydantic_root_model__:
|
||||
object.__setattr__(self_instance, '__pydantic_private__', None)
|
||||
|
||||
# 设置关系
|
||||
for key in self_instance.__sqlmodel_relationships__:
|
||||
value = values.get(key, Undefined)
|
||||
if value is not Undefined:
|
||||
setattr(self_instance, key, value)
|
||||
|
||||
return self_instance
|
||||
|
||||
_compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
|
||||
else:
|
||||
annotationlib = None
|
||||
|
||||
|
||||
def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
|
||||
"""
|
||||
从类型注解中提取SQLAlchemy类型。
|
||||
|
||||
支持以下形式:
|
||||
1. NumpyVector[256, np.float32] - 直接使用类型(有__sqlmodel_sa_type__属性)
|
||||
2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
|
||||
3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
|
||||
|
||||
Args:
|
||||
annotation: 字段的类型注解
|
||||
|
||||
Returns:
|
||||
提取到的SQLAlchemy类型,如果没有则返回None
|
||||
"""
|
||||
# 方法1:直接检查类型本身是否有__sqlmodel_sa_type__属性
|
||||
# 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# 方法2:检查是否为Annotated类型
|
||||
if get_origin(annotation) is typing.Annotated:
|
||||
# 获取元数据项(跳过第一个实际类型参数)
|
||||
args = get_args(annotation)
|
||||
if len(args) >= 2:
|
||||
metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
|
||||
|
||||
# 遍历元数据,查找包含sa_type的项
|
||||
for item in metadata_items:
|
||||
# 检查元数据项是否有__sqlmodel_sa_type__属性
|
||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
||||
return item.__sqlmodel_sa_type__
|
||||
|
||||
# 检查是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(item, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用该方法获取core schema
|
||||
schema = item.__get_pydantic_core_schema__(
|
||||
annotation,
|
||||
lambda x: None # 虚拟handler
|
||||
)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError, ValueError):
|
||||
# Pydantic core schema获取可能失败:
|
||||
# - TypeError: 参数不匹配
|
||||
# - AttributeError: metadata不存在
|
||||
# - KeyError: schema结构不符合预期
|
||||
# - ValueError: 无效的类型定义
|
||||
# 这是正常的类型探测过程,继续检查下一个metadata项
|
||||
pass
|
||||
|
||||
# 方法3:检查类型本身是否有__get_pydantic_core_schema__
|
||||
# (虽然NumpyVector已经在方法1处理,但这是通用的fallback)
|
||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
schema = annotation.__get_pydantic_core_schema__(
|
||||
annotation,
|
||||
lambda x: None # 虚拟handler
|
||||
)
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError, ValueError):
|
||||
# 类型本身的schema获取失败
|
||||
# 这是正常的fallback机制,annotation可能不支持此协议
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
|
||||
dict[str, Any],
|
||||
dict[str, str],
|
||||
Mapping[str, Any],
|
||||
Mapping[str, Any],
|
||||
]:
|
||||
"""
|
||||
Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
|
||||
|
||||
This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
|
||||
metadata and custom types remain accessible. Forward references that cannot be
|
||||
evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
|
||||
whole resolution process.
|
||||
"""
|
||||
raw_annotations = attrs.get('__annotations__') or {}
|
||||
try:
|
||||
base_annotations = dict(raw_annotations)
|
||||
except TypeError:
|
||||
base_annotations = {}
|
||||
|
||||
module_name = attrs.get('__module__')
|
||||
module_globals: dict[str, Any]
|
||||
if module_name and module_name in sys.modules:
|
||||
module_globals = dict(sys.modules[module_name].__dict__)
|
||||
else:
|
||||
module_globals = {}
|
||||
|
||||
module_globals.setdefault('__builtins__', __builtins__)
|
||||
localns: dict[str, Any] = dict(attrs)
|
||||
|
||||
try:
|
||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
||||
temp_cls.__module__ = module_name
|
||||
extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
|
||||
evaluated = get_type_hints(
|
||||
temp_cls,
|
||||
globalns=module_globals,
|
||||
localns=localns,
|
||||
**extras_kw,
|
||||
)
|
||||
except (NameError, AttributeError, TypeError, RecursionError):
|
||||
# get_type_hints可能失败的原因:
|
||||
# - NameError: 前向引用无法解析(类型尚未定义)
|
||||
# - AttributeError: 模块或类型不存在
|
||||
# - TypeError: 无效的类型注解
|
||||
# - RecursionError: 循环依赖的类型定义
|
||||
# 这是正常情况,回退到原始注解字符串
|
||||
evaluated = base_annotations
|
||||
|
||||
return dict(evaluated), {}, module_globals, localns
|
||||
|
||||
|
||||
def _evaluate_annotation_from_string(
|
||||
field_name: str,
|
||||
annotation_strings: dict[str, str],
|
||||
current_type: Any,
|
||||
globalns: Mapping[str, Any],
|
||||
localns: Mapping[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Attempt to re-evaluate the original annotation string for a field.
|
||||
|
||||
This is used as a fallback when the resolved annotation lost its metadata
|
||||
(e.g., Annotated wrappers) and we need to recover custom sa_type data.
|
||||
"""
|
||||
if not annotation_strings:
|
||||
return current_type
|
||||
|
||||
expr = annotation_strings.get(field_name)
|
||||
if not expr or not isinstance(expr, str):
|
||||
return current_type
|
||||
|
||||
try:
|
||||
return eval(expr, globalns, localns)
|
||||
except (NameError, SyntaxError, AttributeError, TypeError):
|
||||
# eval可能失败的原因:
|
||||
# - NameError: 类型名称在namespace中不存在
|
||||
# - SyntaxError: 注解字符串有语法错误
|
||||
# - AttributeError: 访问不存在的模块属性
|
||||
# - TypeError: 无效的类型表达式
|
||||
# 这是正常的fallback机制,返回当前已解析的类型
|
||||
return current_type
|
||||
|
||||
|
||||
class __DeclarativeMeta(SQLModelMetaclass):
|
||||
"""
|
||||
一个智能的混合模式元类,它提供了灵活性和清晰度:
|
||||
|
||||
1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
|
||||
2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
|
||||
3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
|
||||
4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
|
||||
"""
|
||||
|
||||
_KNOWN_MAPPER_KEYS = {
|
||||
"polymorphic_on",
|
||||
"polymorphic_identity",
|
||||
"polymorphic_abstract",
|
||||
"version_id_col",
|
||||
"concrete",
|
||||
}
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
# 1. 约定优于配置:自动设置 table=True
|
||||
is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
|
||||
if is_intended_as_table and 'table' not in kwargs:
|
||||
kwargs['table'] = True
|
||||
|
||||
# 2. 智能合并 __mapper_args__
|
||||
collected_mapper_args = {}
|
||||
|
||||
# 首先,处理明确的 mapper_args 字典 (优先级较低)
|
||||
if 'mapper_args' in kwargs:
|
||||
collected_mapper_args.update(kwargs.pop('mapper_args'))
|
||||
|
||||
# 其次,处理便捷的关键字参数 (优先级更高)
|
||||
for key in cls._KNOWN_MAPPER_KEYS:
|
||||
if key in kwargs:
|
||||
# .pop() 获取值并移除,避免传递给父类
|
||||
collected_mapper_args[key] = kwargs.pop(key)
|
||||
|
||||
# 如果收集到了任何 mapper 参数,则更新到类的属性中
|
||||
if collected_mapper_args:
|
||||
existing = attrs.get('__mapper_args__', {}).copy()
|
||||
existing.update(collected_mapper_args)
|
||||
attrs['__mapper_args__'] = existing
|
||||
|
||||
# 3. 处理其他明确的参数
|
||||
if 'table_args' in kwargs:
|
||||
attrs['__table_args__'] = kwargs.pop('table_args')
|
||||
if 'table_name' in kwargs:
|
||||
attrs['__tablename__'] = kwargs.pop('table_name')
|
||||
if 'abstract' in kwargs:
|
||||
attrs['__abstract__'] = kwargs.pop('abstract')
|
||||
|
||||
# 4. 从Annotated元数据中提取sa_type并注入到Field
|
||||
# 重要:必须在调用父类__new__之前处理,因为SQLModel会消费annotations
|
||||
#
|
||||
# Python 3.14兼容性问题:
|
||||
# - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃(issubclass错误)
|
||||
# - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
|
||||
# - 虽然PEP 749建议不修改__annotations__,但这是修复SQLModel bug的必要措施
|
||||
#
|
||||
# 获取annotations的策略:
|
||||
# - Python 3.14+: 优先从__annotate__获取(如果存在)
|
||||
# - fallback: 从__annotations__读取(如果存在)
|
||||
# - 最终fallback: 空字典
|
||||
annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
|
||||
|
||||
if annotations:
|
||||
attrs['__annotations__'] = annotations
|
||||
if annotationlib is not None:
|
||||
# 在Python 3.14中禁用descriptor,转为普通dict
|
||||
attrs['__annotate__'] = None
|
||||
|
||||
for field_name, field_type in annotations.items():
|
||||
field_type = _evaluate_annotation_from_string(
|
||||
field_name,
|
||||
annotation_strings,
|
||||
field_type,
|
||||
eval_globals,
|
||||
eval_locals,
|
||||
)
|
||||
|
||||
# 跳过字符串或ForwardRef类型注解,让SQLModel自己处理
|
||||
if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
|
||||
continue
|
||||
|
||||
# 跳过特殊类型的字段
|
||||
origin = get_origin(field_type)
|
||||
|
||||
# 跳过 ClassVar 字段 - 它们不是数据库字段
|
||||
if origin is typing.ClassVar:
|
||||
continue
|
||||
|
||||
# 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
|
||||
if origin is Mapped:
|
||||
continue
|
||||
|
||||
# 尝试从注解中提取sa_type
|
||||
sa_type = _extract_sa_type_from_annotation(field_type)
|
||||
|
||||
if sa_type is not None:
|
||||
# 检查字段是否已有Field定义
|
||||
field_value = attrs.get(field_name, Undefined)
|
||||
|
||||
if field_value is Undefined:
|
||||
# 没有Field定义,创建一个新的Field并注入sa_type
|
||||
attrs[field_name] = Field(sa_type=sa_type)
|
||||
elif isinstance(field_value, FieldInfo):
|
||||
# 已有Field定义,检查是否已设置sa_type
|
||||
# 注意:只有在未设置时才注入,尊重显式配置
|
||||
# SQLModel使用Undefined作为"未设置"的标记
|
||||
if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
|
||||
field_value.sa_type = sa_type
|
||||
# 如果field_value是其他类型(如默认值),不处理
|
||||
# SQLModel会在后续处理中将其转换为Field
|
||||
|
||||
# 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
|
||||
result = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
|
||||
# 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
|
||||
# SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
|
||||
# 这导致子类丢失父类的关系定义,触发错误的 Column 创建
|
||||
# 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
|
||||
if kwargs.get('table', False):
|
||||
for base in bases:
|
||||
if hasattr(base, '__sqlmodel_relationships__'):
|
||||
for rel_name, rel_info in base.__sqlmodel_relationships__.items():
|
||||
# 只继承子类没有重新定义的关系
|
||||
if rel_name not in result.__sqlmodel_relationships__:
|
||||
result.__sqlmodel_relationships__[rel_name] = rel_info
|
||||
# 同时修复被错误创建的 Column - 恢复为父类的 relationship
|
||||
if hasattr(base, rel_name):
|
||||
base_attr = getattr(base, rel_name)
|
||||
setattr(result, rel_name, base_attr)
|
||||
|
||||
# 7. 检测:禁止子类重定义父类的 Relationship 字段
|
||||
# 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
|
||||
# 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
|
||||
for base in bases:
|
||||
parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
|
||||
for rel_name in parent_relationships:
|
||||
# 检查当前类是否在 attrs 中重新定义了这个关系字段
|
||||
if rel_name in attrs:
|
||||
raise TypeError(
|
||||
f"类 {name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'。"
|
||||
f"如需修改关系配置,请在父类中修改。"
|
||||
)
|
||||
|
||||
# 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
|
||||
# SQLModel 0.0.27 bug:子类会错误地继承父类的 Relationship 字段到 model_fields
|
||||
# 这导致 Pydantic 尝试为 Relationship 字段生成 schema,因为类型是
|
||||
# Mapped[list['Character']] 这种前向引用,Pydantic 无法解析,
|
||||
# 导致 __pydantic_complete__ = False
|
||||
#
|
||||
# 修复策略:
|
||||
# - 检查类的 __sqlmodel_relationships__ 属性
|
||||
# - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
|
||||
# - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
|
||||
relationships = getattr(result, '__sqlmodel_relationships__', {})
|
||||
if relationships:
|
||||
model_fields = getattr(result, 'model_fields', {})
|
||||
pydantic_fields = getattr(result, '__pydantic_fields__', {})
|
||||
|
||||
fields_removed = False
|
||||
for rel_name in relationships:
|
||||
if rel_name in model_fields:
|
||||
del model_fields[rel_name]
|
||||
fields_removed = True
|
||||
if rel_name in pydantic_fields:
|
||||
del pydantic_fields[rel_name]
|
||||
fields_removed = True
|
||||
|
||||
# 如果移除了字段,重新构建 Pydantic 模式
|
||||
# 注意:只在有字段被移除时才 rebuild,避免不必要的开销
|
||||
if fields_removed and hasattr(result, 'model_rebuild'):
|
||||
result.model_rebuild(force=True)
|
||||
|
||||
return result
|
||||
|
||||
def __init__(
|
||||
cls,
|
||||
classname: str,
|
||||
bases: tuple[type, ...],
|
||||
dict_: dict[str, typing.Any],
|
||||
**kw: typing.Any,
|
||||
) -> None:
|
||||
"""
|
||||
重写 SQLModel 的 __init__ 以支持联表继承(Joined Table Inheritance)
|
||||
|
||||
SQLModel 原始行为:
|
||||
- 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
|
||||
- 这阻止了子类创建自己的表
|
||||
|
||||
修复逻辑:
|
||||
- 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
|
||||
- 强制调用 DeclarativeMeta.__init__ 来创建子表
|
||||
"""
|
||||
from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
|
||||
|
||||
# 检查是否是表模型
|
||||
if not is_table_model_class(cls):
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
return
|
||||
|
||||
# 检查是否有基类是表模型
|
||||
base_is_table = any(is_table_model_class(base) for base in bases)
|
||||
|
||||
if not base_is_table:
|
||||
# 没有基类是表模型,走正常的 SQLModel 流程
|
||||
# 处理关系字段
|
||||
cls._setup_relationships()
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
return
|
||||
|
||||
# 关键:检测联表继承场景
|
||||
# 条件:
|
||||
# 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
|
||||
# 2. 当前类有字段带有 foreign_key 指向父表
|
||||
current_tablename = getattr(cls, '__tablename__', None)
|
||||
|
||||
# 查找父表信息
|
||||
parent_table = None
|
||||
parent_tablename = None
|
||||
for base in bases:
|
||||
if is_table_model_class(base) and hasattr(base, '__tablename__'):
|
||||
parent_tablename = base.__tablename__
|
||||
break
|
||||
|
||||
# 检查是否有不同的 tablename
|
||||
has_different_tablename = (
|
||||
current_tablename is not None
|
||||
and parent_tablename is not None
|
||||
and current_tablename != parent_tablename
|
||||
)
|
||||
|
||||
# 检查是否有外键字段指向父表的主键
|
||||
# 注意:由于字段合并,我们需要检查直接基类的 model_fields
|
||||
# 而不是当前类的合并后的 model_fields
|
||||
has_fk_to_parent = False
|
||||
|
||||
def _normalize_tablename(name: str) -> str:
|
||||
"""标准化表名以进行比较(移除下划线,转小写)"""
|
||||
return name.replace('_', '').lower()
|
||||
|
||||
def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
|
||||
"""检查 FK 字符串是否指向父表"""
|
||||
if not fk_str or not parent_table:
|
||||
return False
|
||||
# FK 格式: "tablename.column" 或 "schema.tablename.column"
|
||||
parts = fk_str.split('.')
|
||||
if len(parts) >= 2:
|
||||
fk_table = parts[-2] # 取倒数第二个作为表名
|
||||
# 标准化比较(处理下划线差异)
|
||||
return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
|
||||
return False
|
||||
|
||||
if has_different_tablename and parent_tablename:
|
||||
# 首先检查当前类的 model_fields
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
has_fk_to_parent = True
|
||||
break
|
||||
|
||||
# 如果没找到,检查直接基类的 model_fields(解决 mixin 字段被覆盖的问题)
|
||||
if not has_fk_to_parent:
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields'):
|
||||
for field_name, field_info in base.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
has_fk_to_parent = True
|
||||
break
|
||||
if has_fk_to_parent:
|
||||
break
|
||||
|
||||
is_joined_inheritance = has_different_tablename and has_fk_to_parent
|
||||
|
||||
if is_joined_inheritance:
|
||||
# 联表继承:需要创建子表
|
||||
|
||||
# 修复外键字段:由于字段合并,外键信息可能丢失
|
||||
# 需要从基类的 mixin 中找回外键信息,并重建列
|
||||
from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
|
||||
from sqlalchemy.dialects.postgresql import UUID as SA_UUID
|
||||
from sqlalchemy.exc import NoInspectionAvailable
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
# 联表继承:子表只应该有 id(FK 到父表)+ 子类特有的字段
|
||||
# 所有继承自祖先表的列都不应该在子表中重复创建
|
||||
|
||||
# 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
|
||||
# 需要遍历整个 MRO,因为可能是多级继承(如 Tool -> Function -> GetWeatherFunction)
|
||||
ancestor_column_names: set[str] = set()
|
||||
for ancestor in cls.__mro__:
|
||||
if ancestor is cls:
|
||||
continue # 跳过当前类
|
||||
if is_table_model_class(ancestor):
|
||||
try:
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
|
||||
mapper = sa_inspect(ancestor)
|
||||
for col in mapper.local_table.columns:
|
||||
# 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
|
||||
if col.name.startswith('_polymorphic'):
|
||||
continue
|
||||
ancestor_column_names.add(col.name)
|
||||
except NoInspectionAvailable:
|
||||
continue
|
||||
|
||||
# 找到子类自己定义的字段(不在父类中的)
|
||||
child_own_fields: set[str] = set()
|
||||
for field_name in cls.model_fields:
|
||||
# 检查这个字段是否是在当前类直接定义的(不是继承的)
|
||||
# 通过检查父类是否有这个字段来判断
|
||||
is_inherited = False
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
is_inherited = True
|
||||
break
|
||||
if not is_inherited:
|
||||
child_own_fields.add(field_name)
|
||||
|
||||
# 从子类类属性中移除父表已有的列定义
|
||||
# 这样 SQLAlchemy 就不会在子表中创建这些列
|
||||
fk_field_name = None
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields'):
|
||||
for field_name, field_info in base.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
pk = getattr(field_info, 'primary_key', False)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
fk_field_name = field_name
|
||||
# 找到了外键字段,重建它
|
||||
# 创建一个新的 Column 对象包含外键约束
|
||||
new_col = Column(
|
||||
field_name,
|
||||
SA_UUID(as_uuid=True),
|
||||
ForeignKey(fk),
|
||||
primary_key=pk if pk else False
|
||||
)
|
||||
setattr(cls, field_name, new_col)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
# 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
|
||||
# 这防止 SQLAlchemy 在子表中创建重复列
|
||||
# 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
|
||||
for col_name in ancestor_column_names:
|
||||
if col_name == fk_field_name:
|
||||
continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
|
||||
if col_name == 'id':
|
||||
continue # id 会被 FK 字段覆盖
|
||||
if col_name in child_own_fields:
|
||||
continue # 保留子类自己定义的字段
|
||||
|
||||
# 检查类属性是否是 Column 或 InstrumentedAttribute
|
||||
if col_name in cls.__dict__:
|
||||
attr = cls.__dict__[col_name]
|
||||
# Column 对象或 InstrumentedAttribute 都需要删除
|
||||
if isinstance(attr, (Column, InstrumentedAttribute)):
|
||||
try:
|
||||
delattr(cls, col_name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# 找到子类自己定义的关系(不在父类中的)
|
||||
# 继承的关系会从父类自动获取,只需要设置子类新增的关系
|
||||
child_own_relationships: set[str] = set()
|
||||
for rel_name in cls.__sqlmodel_relationships__:
|
||||
is_inherited = False
|
||||
for base in bases:
|
||||
if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
|
||||
is_inherited = True
|
||||
break
|
||||
if not is_inherited:
|
||||
child_own_relationships.add(rel_name)
|
||||
|
||||
# 只为子类自己定义的新关系调用关系设置
|
||||
if child_own_relationships:
|
||||
cls._setup_relationships(only_these=child_own_relationships)
|
||||
|
||||
# 强制调用 DeclarativeMeta.__init__
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
else:
|
||||
# 非联表继承:单表继承或正常 Pydantic 模型
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
|
||||
"""
|
||||
设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
|
||||
|
||||
Args:
|
||||
only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
|
||||
如果为 None,设置所有关系(默认行为)
|
||||
"""
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy import inspect
|
||||
from sqlmodel.main import get_relationship_to
|
||||
from typing import get_origin
|
||||
|
||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
||||
# 如果指定了 only_these,只设置这些关系
|
||||
if only_these is not None and rel_name not in only_these:
|
||||
continue
|
||||
if rel_info.sa_relationship:
|
||||
setattr(cls, rel_name, rel_info.sa_relationship)
|
||||
continue
|
||||
|
||||
raw_ann = cls.__annotations__[rel_name]
|
||||
origin: typing.Any = get_origin(raw_ann)
|
||||
if origin is Mapped:
|
||||
ann = raw_ann.__args__[0]
|
||||
else:
|
||||
ann = raw_ann
|
||||
cls.__annotations__[rel_name] = Mapped[ann]
|
||||
|
||||
relationship_to = get_relationship_to(
|
||||
name=rel_name, rel_info=rel_info, annotation=ann
|
||||
)
|
||||
rel_kwargs: dict[str, typing.Any] = {}
|
||||
if rel_info.back_populates:
|
||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
||||
if rel_info.cascade_delete:
|
||||
rel_kwargs["cascade"] = "all, delete-orphan"
|
||||
if rel_info.passive_deletes:
|
||||
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
|
||||
if rel_info.link_model:
|
||||
ins = inspect(rel_info.link_model)
|
||||
local_table = getattr(ins, "local_table")
|
||||
if local_table is None:
|
||||
raise RuntimeError(
|
||||
f"Couldn't find secondary table for {rel_info.link_model}"
|
||||
)
|
||||
rel_kwargs["secondary"] = local_table
|
||||
|
||||
rel_args: list[typing.Any] = []
|
||||
if rel_info.sa_relationship_args:
|
||||
rel_args.extend(rel_info.sa_relationship_args)
|
||||
if rel_info.sa_relationship_kwargs:
|
||||
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
||||
|
||||
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
|
||||
setattr(cls, rel_name, rel_value)
|
||||
|
||||
|
||||
class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
|
||||
"""此类必须和TableBase系列类搭配使用"""
|
||||
|
||||
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)
|
||||
@@ -1,6 +1,6 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from .base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class ChromaticColor(StrEnum):
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from sqlmodel import SQLModel
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from utils.conf import appmeta
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import AsyncGenerator
|
||||
|
||||
ASYNC_DATABASE_URL = appmeta.database_url
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
echo=appmeta.debug,
|
||||
connect_args={
|
||||
"check_same_thread": False
|
||||
} if ASYNC_DATABASE_URL.startswith("sqlite") else {},
|
||||
future=True,
|
||||
# pool_size=POOL_SIZE,
|
||||
# max_overflow=64,
|
||||
)
|
||||
|
||||
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with _async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def init_db(
|
||||
url: str = ASYNC_DATABASE_URL
|
||||
):
|
||||
"""创建数据库结构"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
409
sqlmodels/file_app.py
Normal file
409
sqlmodels/file_app.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
文件查看器应用模块
|
||||
|
||||
提供文件预览应用选择器系统的数据模型和 DTO。
|
||||
类似 Android 的"使用什么应用打开"机制:
|
||||
- 管理员注册应用(内置/iframe/WOPI)
|
||||
- 用户按扩展名查询可用查看器
|
||||
- 用户可设置"始终使用"偏好
|
||||
- 支持用户组级别的访问控制
|
||||
|
||||
架构:
|
||||
FileApp (应用注册表)
|
||||
├── FileAppExtension (扩展名关联)
|
||||
├── FileAppGroupLink (用户组访问控制)
|
||||
└── UserFileAppDefault (用户默认偏好)
|
||||
"""
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .group import Group
|
||||
|
||||
|
||||
# ==================== 枚举 ====================
|
||||
|
||||
class FileAppType(StrEnum):
|
||||
"""文件应用类型"""
|
||||
|
||||
BUILTIN = "builtin"
|
||||
"""前端内置查看器(如 pdf.js, Monaco)"""
|
||||
|
||||
IFRAME = "iframe"
|
||||
"""iframe 内嵌第三方服务"""
|
||||
|
||||
WOPI = "wopi"
|
||||
"""WOPI 协议(OnlyOffice / Collabora)"""
|
||||
|
||||
|
||||
# ==================== Link 表 ====================
|
||||
|
||||
class FileAppGroupLink(SQLModelBase, table=True):
|
||||
"""应用-用户组访问控制关联表"""
|
||||
|
||||
app_id: UUID = Field(foreign_key="fileapp.id", primary_key=True, ondelete="CASCADE")
|
||||
"""关联的应用UUID"""
|
||||
|
||||
group_id: UUID = Field(foreign_key="group.id", primary_key=True, ondelete="CASCADE")
|
||||
"""关联的用户组UUID"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class FileAppSummary(SQLModelBase):
|
||||
"""查看器列表项 DTO,用于选择器弹窗"""
|
||||
|
||||
id: UUID
|
||||
"""应用UUID"""
|
||||
|
||||
name: str
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = None
|
||||
"""应用描述"""
|
||||
|
||||
iframe_url_template: str | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_editor_url_template: str | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
|
||||
class FileViewersResponse(SQLModelBase):
|
||||
"""查看器查询响应 DTO"""
|
||||
|
||||
viewers: list[FileAppSummary] = []
|
||||
"""可用查看器列表(已按 priority 排序)"""
|
||||
|
||||
default_viewer_id: UUID | None = None
|
||||
"""用户默认查看器UUID(如果已设置"始终使用")"""
|
||||
|
||||
|
||||
class SetDefaultViewerRequest(SQLModelBase):
|
||||
"""设置默认查看器请求 DTO"""
|
||||
|
||||
extension: str = Field(max_length=20)
|
||||
"""文件扩展名(小写,无点号)"""
|
||||
|
||||
app_id: UUID
|
||||
"""应用UUID"""
|
||||
|
||||
|
||||
class UserFileAppDefaultResponse(SQLModelBase):
|
||||
"""用户默认查看器响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""记录UUID"""
|
||||
|
||||
extension: str
|
||||
"""扩展名"""
|
||||
|
||||
app: FileAppSummary
|
||||
"""关联的应用摘要"""
|
||||
|
||||
|
||||
class FileAppCreateRequest(SQLModelBase):
|
||||
"""管理员创建应用请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=100)
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str = Field(max_length=50)
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool = True
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
extensions: list[str] = []
|
||||
"""关联的扩展名列表"""
|
||||
|
||||
allowed_group_ids: list[UUID] = []
|
||||
"""允许访问的用户组UUID列表"""
|
||||
|
||||
|
||||
class FileAppUpdateRequest(SQLModelBase):
|
||||
"""管理员更新应用请求 DTO(所有字段可选)"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str | None = Field(default=None, max_length=50)
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType | None = None
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool | None = None
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool | None = None
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
|
||||
class FileAppResponse(SQLModelBase):
|
||||
"""管理员应用详情响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""应用UUID"""
|
||||
|
||||
name: str
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = None
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool = True
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = None
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
extensions: list[str] = []
|
||||
"""关联的扩展名列表"""
|
||||
|
||||
allowed_group_ids: list[UUID] = []
|
||||
"""允许访问的用户组UUID列表"""
|
||||
|
||||
@classmethod
|
||||
def from_app(
|
||||
cls,
|
||||
app: "FileApp",
|
||||
extensions: list["FileAppExtension"],
|
||||
group_links: list[FileAppGroupLink],
|
||||
) -> "FileAppResponse":
|
||||
"""从 ORM 对象构建 DTO"""
|
||||
return cls(
|
||||
id=app.id,
|
||||
name=app.name,
|
||||
app_key=app.app_key,
|
||||
type=app.type,
|
||||
icon=app.icon,
|
||||
description=app.description,
|
||||
is_enabled=app.is_enabled,
|
||||
is_restricted=app.is_restricted,
|
||||
iframe_url_template=app.iframe_url_template,
|
||||
wopi_discovery_url=app.wopi_discovery_url,
|
||||
wopi_editor_url_template=app.wopi_editor_url_template,
|
||||
extensions=[ext.extension for ext in extensions],
|
||||
allowed_group_ids=[link.group_id for link in group_links],
|
||||
)
|
||||
|
||||
|
||||
class FileAppListResponse(SQLModelBase):
|
||||
"""管理员应用列表响应 DTO"""
|
||||
|
||||
apps: list[FileAppResponse] = []
|
||||
"""应用列表"""
|
||||
|
||||
total: int = 0
|
||||
"""总数"""
|
||||
|
||||
|
||||
class ExtensionUpdateRequest(SQLModelBase):
|
||||
"""扩展名全量替换请求 DTO"""
|
||||
|
||||
extensions: list[str]
|
||||
"""扩展名列表(小写,无点号)"""
|
||||
|
||||
|
||||
class GroupAccessUpdateRequest(SQLModelBase):
|
||||
"""用户组权限全量替换请求 DTO"""
|
||||
|
||||
group_ids: list[UUID]
|
||||
"""允许访问的用户组UUID列表"""
|
||||
|
||||
|
||||
class WopiSessionResponse(SQLModelBase):
|
||||
"""WOPI 会话响应 DTO"""
|
||||
|
||||
wopi_src: str
|
||||
"""WOPI 源 URL"""
|
||||
|
||||
access_token: str
|
||||
"""WOPI 访问令牌"""
|
||||
|
||||
access_token_ttl: int
|
||||
"""令牌过期时间戳(毫秒,WOPI 规范要求)"""
|
||||
|
||||
editor_url: str
|
||||
"""完整的编辑器 URL"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
||||
"""文件查看器应用注册表"""
|
||||
|
||||
name: str = Field(max_length=100)
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str = Field(max_length=50, unique=True, index=True)
|
||||
"""应用唯一标识,前端路由用"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool = True
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = Field(default=None, max_length=1024)
|
||||
"""iframe URL 模板,支持 {file_url} 占位符"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 客户端发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = Field(default=None, max_length=1024)
|
||||
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||
|
||||
# 关系
|
||||
extensions: list["FileAppExtension"] = Relationship(
|
||||
back_populates="app",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
user_defaults: list["UserFileAppDefault"] = Relationship(
|
||||
back_populates="app",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
allowed_groups: list["Group"] = Relationship(
|
||||
link_model=FileAppGroupLink,
|
||||
)
|
||||
|
||||
def to_summary(self) -> FileAppSummary:
|
||||
"""转换为摘要 DTO"""
|
||||
return FileAppSummary(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
app_key=self.app_key,
|
||||
type=self.type,
|
||||
icon=self.icon,
|
||||
description=self.description,
|
||||
iframe_url_template=self.iframe_url_template,
|
||||
wopi_editor_url_template=self.wopi_editor_url_template,
|
||||
)
|
||||
|
||||
|
||||
class FileAppExtension(SQLModelBase, TableBaseMixin):
|
||||
"""扩展名关联表"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("app_id", "extension", name="uq_fileappextension_app_extension"),
|
||||
)
|
||||
|
||||
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
|
||||
"""关联的应用UUID"""
|
||||
|
||||
extension: str = Field(max_length=20, index=True)
|
||||
"""扩展名(小写,无点号)"""
|
||||
|
||||
priority: int = Field(default=0, ge=0)
|
||||
"""排序优先级(越小越优先)"""
|
||||
|
||||
# 关系
|
||||
app: FileApp = Relationship(back_populates="extensions")
|
||||
|
||||
|
||||
class UserFileAppDefault(SQLModelBase, UUIDTableBaseMixin):
|
||||
"""用户"始终使用"偏好"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "extension", name="uq_userfileappdefault_user_extension"),
|
||||
)
|
||||
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
|
||||
"""用户UUID"""
|
||||
|
||||
extension: str = Field(max_length=20)
|
||||
"""扩展名(小写,无点号)"""
|
||||
|
||||
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
|
||||
"""关联的应用UUID"""
|
||||
|
||||
# 关系
|
||||
app: FileApp = Relationship(back_populates="user_defaults")
|
||||
|
||||
def to_response(self) -> UserFileAppDefaultResponse:
|
||||
"""转换为响应 DTO(需预加载 app 关系)"""
|
||||
return UserFileAppDefaultResponse(
|
||||
id=self.id,
|
||||
extension=self.extension,
|
||||
app=self.app.to_summary(),
|
||||
)
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -17,6 +17,7 @@ async def migration() -> None:
|
||||
await init_default_group()
|
||||
await init_default_user()
|
||||
await init_default_theme_presets()
|
||||
await init_default_file_apps()
|
||||
|
||||
log.info('数据库初始化结束')
|
||||
|
||||
@@ -372,3 +373,146 @@ async def init_default_theme_presets() -> None:
|
||||
)
|
||||
await default_preset.save(session)
|
||||
log.info('已创建默认主题预设')
|
||||
|
||||
|
||||
# ==================== 默认文件查看器应用种子数据 ====================
|
||||
|
||||
_DEFAULT_FILE_APPS: list[dict] = [
|
||||
# 内置应用(type=builtin,默认启用)
|
||||
{
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": "基于 pdf.js 的 PDF 在线阅读器",
|
||||
"is_enabled": True,
|
||||
"extensions": ["pdf"],
|
||||
},
|
||||
{
|
||||
"name": "代码编辑器",
|
||||
"app_key": "monaco",
|
||||
"type": "builtin",
|
||||
"icon": "code",
|
||||
"description": "基于 Monaco Editor 的代码编辑器",
|
||||
"is_enabled": True,
|
||||
"extensions": [
|
||||
"txt", "md", "json", "xml", "yaml", "yml",
|
||||
"py", "js", "ts", "jsx", "tsx",
|
||||
"html", "css", "scss", "less",
|
||||
"sh", "bash", "zsh",
|
||||
"c", "cpp", "h", "hpp",
|
||||
"java", "kt", "go", "rs", "rb",
|
||||
"sql", "graphql",
|
||||
"toml", "ini", "cfg", "conf",
|
||||
"env", "gitignore", "dockerfile",
|
||||
"vue", "svelte",
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Markdown 预览",
|
||||
"app_key": "markdown",
|
||||
"type": "builtin",
|
||||
"icon": "markdown",
|
||||
"description": "Markdown 实时预览",
|
||||
"is_enabled": True,
|
||||
"extensions": ["md", "markdown", "mdx"],
|
||||
},
|
||||
{
|
||||
"name": "图片查看器",
|
||||
"app_key": "image_viewer",
|
||||
"type": "builtin",
|
||||
"icon": "image",
|
||||
"description": "图片在线查看器",
|
||||
"is_enabled": True,
|
||||
"extensions": ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg", "ico", "avif"],
|
||||
},
|
||||
{
|
||||
"name": "视频播放器",
|
||||
"app_key": "video_player",
|
||||
"type": "builtin",
|
||||
"icon": "video",
|
||||
"description": "HTML5 视频播放器",
|
||||
"is_enabled": True,
|
||||
"extensions": ["mp4", "webm", "ogg", "mov", "mkv", "m3u8"],
|
||||
},
|
||||
{
|
||||
"name": "音频播放器",
|
||||
"app_key": "audio_player",
|
||||
"type": "builtin",
|
||||
"icon": "audio",
|
||||
"description": "HTML5 音频播放器",
|
||||
"is_enabled": True,
|
||||
"extensions": ["mp3", "wav", "ogg", "flac", "aac", "m4a", "opus"],
|
||||
},
|
||||
# iframe 应用(默认禁用)
|
||||
{
|
||||
"name": "Office 在线预览",
|
||||
"app_key": "office_viewer",
|
||||
"type": "iframe",
|
||||
"icon": "file-word",
|
||||
"description": "使用 Microsoft Office Online 预览文档",
|
||||
"is_enabled": False,
|
||||
"iframe_url_template": "https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
|
||||
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
|
||||
},
|
||||
# WOPI 应用(默认禁用)
|
||||
{
|
||||
"name": "Collabora Online",
|
||||
"app_key": "collabora",
|
||||
"type": "wopi",
|
||||
"icon": "file-text",
|
||||
"description": "Collabora Online 文档编辑器(需自行部署)",
|
||||
"is_enabled": False,
|
||||
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx", "odt", "ods", "odp"],
|
||||
},
|
||||
{
|
||||
"name": "OnlyOffice",
|
||||
"app_key": "onlyoffice",
|
||||
"type": "wopi",
|
||||
"icon": "file-text",
|
||||
"description": "OnlyOffice 文档编辑器(需自行部署)",
|
||||
"is_enabled": False,
|
||||
"extensions": ["doc", "docx", "xls", "xlsx", "ppt", "pptx"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def init_default_file_apps() -> None:
|
||||
"""初始化默认文件查看器应用"""
|
||||
from .file_app import FileApp, FileAppExtension, FileAppType
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化文件查看器应用...')
|
||||
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 已存在应用则跳过
|
||||
existing_count = await FileApp.count(session)
|
||||
if existing_count > 0:
|
||||
return
|
||||
|
||||
for app_data in _DEFAULT_FILE_APPS:
|
||||
extensions = app_data.pop("extensions")
|
||||
|
||||
app = FileApp(
|
||||
name=app_data["name"],
|
||||
app_key=app_data["app_key"],
|
||||
type=FileAppType(app_data["type"]),
|
||||
icon=app_data.get("icon"),
|
||||
description=app_data.get("description"),
|
||||
is_enabled=app_data.get("is_enabled", True),
|
||||
iframe_url_template=app_data.get("iframe_url_template"),
|
||||
wopi_discovery_url=app_data.get("wopi_discovery_url"),
|
||||
wopi_editor_url_template=app_data.get("wopi_editor_url_template"),
|
||||
)
|
||||
app = await app.save(session)
|
||||
app_id = app.id
|
||||
|
||||
for i, ext in enumerate(extensions):
|
||||
ext_record = FileAppExtension(
|
||||
app_id=app_id,
|
||||
extension=ext.lower(),
|
||||
priority=i,
|
||||
)
|
||||
await ext_record.save(session)
|
||||
|
||||
log.info(f'已创建 {len(_DEFAULT_FILE_APPS)} 个默认文件查看器应用')
|
||||
|
||||
@@ -1,543 +0,0 @@
|
||||
# SQLModel Mixin Module
|
||||
|
||||
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
|
||||
|
||||
## Module Overview
|
||||
|
||||
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
|
||||
|
||||
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
|
||||
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
|
||||
- **JWT Authentication**: Token generation and validation
|
||||
- **Pagination & Sorting**: Standardized table view parameters
|
||||
- **Response DTOs**: Consistent id/timestamp fields for API responses
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
sqlmodels/mixin/
|
||||
├── __init__.py # Module exports
|
||||
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
|
||||
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
|
||||
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
|
||||
└── jwt/ # JWT authentication
|
||||
├── __init__.py
|
||||
├── key.py # JWTKey database model
|
||||
├── payload.py # JWTPayloadBase
|
||||
├── manager.py # JWTManager singleton
|
||||
├── auth.py # JWTAuthMixin
|
||||
├── exceptions.py # JWT-related exceptions
|
||||
└── responses.py # TokenResponse DTO
|
||||
```
|
||||
|
||||
## Dependency Hierarchy
|
||||
|
||||
The module has a strict import order to avoid circular dependencies:
|
||||
|
||||
1. **polymorphic.py** - Only depends on `SQLModelBase`
|
||||
2. **table.py** - Depends on `polymorphic.py`
|
||||
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
|
||||
4. **info_response.py** - Only depends on `SQLModelBase`
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. TableBaseMixin
|
||||
|
||||
Base mixin for database table models with integer primary keys.
|
||||
|
||||
**Features:**
|
||||
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
|
||||
- Automatic timestamp management (`created_at`, `updated_at`)
|
||||
- Async relationship loading support (via `AsyncAttrs`)
|
||||
- Pagination and sorting via `TableViewRequest`
|
||||
- Polymorphic subclass loading support
|
||||
|
||||
**Fields:**
|
||||
- `id: int | None` - Integer primary key (auto-increment)
|
||||
- `created_at: datetime` - Record creation timestamp
|
||||
- `updated_at: datetime` - Record update timestamp (auto-updated)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class User(SQLModelBase, TableBaseMixin, table=True):
|
||||
name: str
|
||||
email: str
|
||||
"""User email"""
|
||||
|
||||
# CRUD operations
|
||||
async def example(session: AsyncSession):
|
||||
# Add
|
||||
user = User(name="Alice", email="alice@example.com")
|
||||
user = await user.save(session)
|
||||
|
||||
# Get
|
||||
user = await User.get(session, User.id == 1)
|
||||
|
||||
# Update
|
||||
update_data = UserUpdateRequest(name="Alice Smith")
|
||||
user = await user.update(session, update_data)
|
||||
|
||||
# Delete
|
||||
await User.delete(session, user)
|
||||
|
||||
# Count
|
||||
count = await User.count(session, User.is_active == True)
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- `save()` and `update()` return refreshed instances - **always use the return value**:
|
||||
```python
|
||||
# ✅ Correct
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device
|
||||
```
|
||||
|
||||
### 2. UUIDTableBaseMixin
|
||||
|
||||
Extends `TableBaseMixin` with UUID primary keys instead of integers.
|
||||
|
||||
**Differences from TableBaseMixin:**
|
||||
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
|
||||
- `get_exist_one()` accepts `UUID` instead of `int`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
description: str | None = None
|
||||
"""Character description"""
|
||||
```
|
||||
|
||||
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
|
||||
|
||||
### 3. TableViewRequest
|
||||
|
||||
Standardized pagination and sorting parameters for LIST endpoints.
|
||||
|
||||
**Fields:**
|
||||
- `offset: int | None` - Skip first N records (default: 0)
|
||||
- `limit: int | None` - Return max N records (default: 50, max: 100)
|
||||
- `desc: bool | None` - Sort descending (default: True)
|
||||
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
|
||||
|
||||
**Usage with TableBaseMixin.get():**
|
||||
```python
|
||||
from dependencies import TableViewRequestDep
|
||||
|
||||
@router.get("/list")
|
||||
async def list_characters(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
) -> list[Character]:
|
||||
"""List characters with pagination and sorting"""
|
||||
return await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view # Automatically handles pagination and sorting
|
||||
)
|
||||
```
|
||||
|
||||
**Manual usage:**
|
||||
```python
|
||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
||||
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
|
||||
|
||||
### 4. PolymorphicBaseMixin
|
||||
|
||||
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
|
||||
|
||||
**Automatic Configuration:**
|
||||
- Defines `_polymorphic_name: str` field (indexed)
|
||||
- Sets `polymorphic_on='_polymorphic_name'`
|
||||
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
|
||||
|
||||
**Methods:**
|
||||
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
|
||||
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
|
||||
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
|
||||
"""Abstract base class for all tools"""
|
||||
name: str
|
||||
description: str
|
||||
"""Tool description"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, params: dict) -> dict:
|
||||
"""Execute the tool"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Why Single Underscore Prefix?**
|
||||
- SQLAlchemy maps single-underscore fields to database columns
|
||||
- Pydantic treats them as private (excluded from serialization)
|
||||
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
|
||||
|
||||
### 5. create_subclass_id_mixin()
|
||||
|
||||
Factory function to create ID mixins for subclasses in joined table inheritance.
|
||||
|
||||
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
|
||||
"""
|
||||
Args:
|
||||
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
A mixin class containing id field (foreign key + primary key)
|
||||
"""
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import create_subclass_id_mixin
|
||||
|
||||
# Create mixin for ASR subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
|
||||
"""FunASR implementation"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
|
||||
|
||||
### 6. AutoPolymorphicIdentityMixin
|
||||
|
||||
Automatically generates `polymorphic_identity` based on class name.
|
||||
|
||||
**Naming Convention:**
|
||||
- Format: `{parent_identity}.{classname_lowercase}`
|
||||
- If no parent identity exists, uses `{classname_lowercase}`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
|
||||
|
||||
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
"""Base class for function-type tools"""
|
||||
pass
|
||||
# polymorphic_identity = 'function'
|
||||
|
||||
class GetWeatherFunction(Function, table=True):
|
||||
"""Weather query function"""
|
||||
pass
|
||||
# polymorphic_identity = 'function.getweatherfunction'
|
||||
```
|
||||
|
||||
**Manual Override:**
|
||||
```python
|
||||
class CustomTool(
|
||||
Tool,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_identity='custom_name', # Override auto-generated name
|
||||
table=True
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### 7. JWTAuthMixin
|
||||
|
||||
Provides JWT token generation and validation for entity classes (User, Client).
|
||||
|
||||
**Methods:**
|
||||
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
|
||||
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
|
||||
|
||||
**Requirements:**
|
||||
Subclasses must define:
|
||||
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
|
||||
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
|
||||
|
||||
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
|
||||
JWTPayload = UserJWTPayload # Define payload model
|
||||
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
|
||||
|
||||
email: str
|
||||
is_admin: bool = False
|
||||
is_active: bool = True
|
||||
"""User active status"""
|
||||
|
||||
# Generate token
|
||||
async def login(session: AsyncSession, user: User) -> str:
|
||||
token = await user.issue_jwt(session)
|
||||
return token
|
||||
|
||||
# Validate token
|
||||
async def verify(session: AsyncSession, token: str) -> User:
|
||||
user = await User.from_jwt(session, token)
|
||||
return user
|
||||
```
|
||||
|
||||
### 8. Response DTO Mixins
|
||||
|
||||
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
|
||||
|
||||
**Available Mixins:**
|
||||
- `IntIdInfoMixin` - Integer ID field
|
||||
- `UUIDIdInfoMixin` - UUID ID field
|
||||
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
|
||||
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
|
||||
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
|
||||
|
||||
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
|
||||
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
"""Character response DTO with id and timestamps"""
|
||||
pass # Inherits id, created_at, updated_at from mixin
|
||||
```
|
||||
|
||||
## Complete Joined Table Inheritance Example
|
||||
|
||||
Here's a complete example demonstrating polymorphic inheritance:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import (
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. Define Base class (fields only, no table)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
"""Configuration name"""
|
||||
|
||||
base_url: str
|
||||
"""Service URL"""
|
||||
|
||||
# 2. Define abstract parent class (with table)
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract base class for ASR configurations"""
|
||||
# PolymorphicBaseMixin automatically provides:
|
||||
# - _polymorphic_name field
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True (when ABC with abstract methods)
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
"""Transcribe audio to text"""
|
||||
pass
|
||||
|
||||
# 3. Create ID Mixin for second-level subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. Create second-level abstract class (if needed)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
"""FunASR abstract base (may have multiple implementations)"""
|
||||
pass
|
||||
# polymorphic_identity = 'funasr'
|
||||
|
||||
# 5. Create concrete implementation classes
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
"""FunASR local deployment"""
|
||||
# polymorphic_identity = 'funasr.funasrlocal'
|
||||
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
# Implementation...
|
||||
return "transcribed text"
|
||||
|
||||
# 6. Get all concrete subclasses (for selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# Returns: [FunASRLocal, ...]
|
||||
```
|
||||
|
||||
## Import Guidelines
|
||||
|
||||
**Standard Import:**
|
||||
```python
|
||||
from sqlmodels.mixin import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
TableViewRequest,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
JWTAuthMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
Some exports are also available from `sqlmodels.base` for backward compatibility:
|
||||
```python
|
||||
# Legacy import path (still works)
|
||||
from sqlmodels.base import UUIDTableBase, TableViewRequest
|
||||
|
||||
# Recommended new import path
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Mixin Order Matters
|
||||
|
||||
**Correct Order:**
|
||||
```python
|
||||
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
|
||||
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
**Wrong Order:**
|
||||
```python
|
||||
# ❌ ID Mixin not first - won't override parent's id field
|
||||
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. Always Use Return Values from save() and update()
|
||||
|
||||
```python
|
||||
# ✅ Correct - use returned instance
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device # AttributeError when accessing fields
|
||||
```
|
||||
|
||||
### 3. Prefer table_view Over Manual Pagination
|
||||
|
||||
```python
|
||||
# ✅ Recommended - consistent across all endpoints
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view
|
||||
)
|
||||
|
||||
# ⚠️ Works but not recommended - manual parameter management
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
offset=0,
|
||||
limit=20,
|
||||
order_by=[desc(Character.created_at)]
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Polymorphic Loading for Many Subclasses
|
||||
|
||||
```python
|
||||
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
|
||||
)
|
||||
|
||||
# For fewer subclasses, specify the list explicitly
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Response DTOs Should Inherit Base Classes
|
||||
|
||||
```python
|
||||
# ✅ Correct - inherits from CharacterBase
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
pass
|
||||
|
||||
# ❌ Wrong - doesn't inherit from CharacterBase
|
||||
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
|
||||
name: str # Duplicated field definition
|
||||
description: str | None = None
|
||||
```
|
||||
|
||||
**Reason:** Inheriting from Base classes ensures:
|
||||
- Type checking via `isinstance(obj, XxxBase)`
|
||||
- Consistency across related DTOs
|
||||
- Future field additions automatically propagate
|
||||
|
||||
### 6. Use Specific Types, Not Containers
|
||||
|
||||
```python
|
||||
# ✅ Correct - specific DTO for config updates
|
||||
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
|
||||
weather_api_key: str | None = None
|
||||
default_location: str | None = None
|
||||
"""Default location"""
|
||||
|
||||
# ❌ Wrong - lose type safety
|
||||
class ToolUpdateRequest(SQLModelBase):
|
||||
config: dict[str, Any] # No field validation
|
||||
```
|
||||
|
||||
## Type Variables
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import T, M
|
||||
|
||||
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
|
||||
M = TypeVar("M", bound="SQLModel") # For update() method
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import now, now_date
|
||||
|
||||
# Lambda functions for default factories
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
```
|
||||
|
||||
## Related Modules
|
||||
|
||||
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
|
||||
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
|
||||
- **sqlmodels.user** - User model with JWT authentication
|
||||
- **sqlmodels.client** - Client model with JWT authentication
|
||||
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
|
||||
- `CLAUDE.md` - Project coding standards and design philosophy
|
||||
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
@@ -1,62 +0,0 @@
|
||||
"""
|
||||
SQLModel Mixin模块
|
||||
|
||||
提供各种Mixin类供SQLModel实体使用。
|
||||
|
||||
包含:
|
||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
||||
- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError)
|
||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
||||
- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations)
|
||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
||||
|
||||
导入顺序很重要,避免循环导入:
|
||||
1. polymorphic(只依赖 SQLModelBase)
|
||||
2. optimistic_lock(只依赖 SQLAlchemy)
|
||||
3. table(依赖 polymorphic 和 optimistic_lock)
|
||||
4. relation_preload(只依赖 SQLModelBase)
|
||||
|
||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
||||
"""
|
||||
# polymorphic 必须先导入
|
||||
from .polymorphic import (
|
||||
AutoPolymorphicIdentityMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
register_sti_column_properties_for_all_subclasses,
|
||||
register_sti_columns_for_all_subclasses,
|
||||
)
|
||||
# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前
|
||||
from .optimistic_lock import (
|
||||
OptimisticLockError,
|
||||
OptimisticLockMixin,
|
||||
)
|
||||
# table 依赖 polymorphic 和 optimistic_lock
|
||||
from .table import (
|
||||
ListResponse,
|
||||
PaginationRequest,
|
||||
T,
|
||||
TableBaseMixin,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
UUIDTableBaseMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
# relation_preload 只依赖 SQLModelBase
|
||||
from .relation_preload import (
|
||||
RelationPreloadMixin,
|
||||
requires_relations,
|
||||
)
|
||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
||||
from .info_response import (
|
||||
DatetimeInfoMixin,
|
||||
IntIdDatetimeInfoMixin,
|
||||
IntIdInfoMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
InfoResponse DTO Mixin模块
|
||||
|
||||
提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。
|
||||
|
||||
设计说明:
|
||||
- 这些Mixin用于**响应DTO**,不是数据库表
|
||||
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
|
||||
- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成)
|
||||
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
|
||||
"""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
|
||||
class IntIdInfoMixin(SQLModelBase):
|
||||
"""整数ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: int
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class UUIDIdInfoMixin(SQLModelBase):
|
||||
"""UUID ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: UUID
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class DatetimeInfoMixin(SQLModelBase):
|
||||
"""时间戳响应mixin - 用于InfoResponse DTO"""
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""更新时间"""
|
||||
|
||||
|
||||
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
|
||||
"""整数ID + 时间戳响应mixin"""
|
||||
pass
|
||||
|
||||
|
||||
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
|
||||
"""UUID ID + 时间戳响应mixin"""
|
||||
pass
|
||||
@@ -1,90 +0,0 @@
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
|
||||
|
||||
乐观锁适用场景:
|
||||
- 涉及"状态转换"的表(如:待支付 -> 已支付)
|
||||
- 涉及"数值变动"的表(如:余额、库存)
|
||||
|
||||
不适用场景:
|
||||
- 日志表、纯插入表、低价值统计表
|
||||
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
|
||||
|
||||
使用示例:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
status: OrderStatusEnum
|
||||
amount: Decimal
|
||||
|
||||
# save/update 时自动检查版本号
|
||||
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
|
||||
try:
|
||||
order = await order.save(session)
|
||||
except OptimisticLockError as e:
|
||||
# 处理冲突:重新查询并重试,或报错给用户
|
||||
l.warning(f"乐观锁冲突: {e}")
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
|
||||
|
||||
class OptimisticLockError(Exception):
|
||||
"""
|
||||
乐观锁冲突异常
|
||||
|
||||
当 save/update 操作检测到版本号不匹配时抛出。
|
||||
这意味着在读取和写入之间,其他事务已经修改了该记录。
|
||||
|
||||
Attributes:
|
||||
model_class: 发生冲突的模型类名
|
||||
record_id: 记录 ID(如果可用)
|
||||
expected_version: 期望的版本号(如果可用)
|
||||
original_error: 原始的 StaleDataError
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
model_class: str | None = None,
|
||||
record_id: str | None = None,
|
||||
expected_version: int | None = None,
|
||||
original_error: StaleDataError | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.model_class = model_class
|
||||
self.record_id = record_id
|
||||
self.expected_version = expected_version
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class OptimisticLockMixin:
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
|
||||
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
|
||||
session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。
|
||||
|
||||
原理:
|
||||
1. 每条记录有一个 version 字段,初始值为 0
|
||||
2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似:
|
||||
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
|
||||
3. 如果 WHERE 条件不匹配(version 已被其他事务修改),
|
||||
UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError
|
||||
|
||||
继承顺序:
|
||||
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
...
|
||||
|
||||
配套重试:
|
||||
如果加了乐观锁,业务层需要处理 OptimisticLockError:
|
||||
- 报错给用户:"数据已被修改,请刷新后重试"
|
||||
- 自动重试:重新查询最新数据并再次尝试
|
||||
"""
|
||||
_has_optimistic_lock: ClassVar[bool] = True
|
||||
"""标记此类启用了乐观锁"""
|
||||
|
||||
version: int = 0
|
||||
"""乐观锁版本号,每次更新自动递增"""
|
||||
@@ -1,710 +0,0 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import Column, String, inspect
|
||||
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
from sqlmodel.main import get_column_from_field
|
||||
|
||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# 用于延迟注册 STI 子类列的队列
|
||||
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
|
||||
_sti_subclasses_to_register: list[type] = []
|
||||
|
||||
|
||||
def register_sti_columns_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
|
||||
|
||||
此函数应在 configure_mappers() 之前调用。
|
||||
将 STI 子类的字段添加到父表的 metadata 中。
|
||||
同时修复被 Column 对象污染的 model_fields。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_columns()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
|
||||
|
||||
# 修复被 Column 对象污染的 model_fields
|
||||
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
|
||||
try:
|
||||
_fix_polluted_model_fields(cls)
|
||||
except Exception as e:
|
||||
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
|
||||
|
||||
|
||||
def register_sti_column_properties_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
|
||||
|
||||
此函数应在 configure_mappers() 之后调用。
|
||||
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_column_properties()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
|
||||
|
||||
# 清空队列
|
||||
_sti_subclasses_to_register.clear()
|
||||
|
||||
|
||||
def _fix_polluted_model_fields(cls: type) -> None:
|
||||
"""
|
||||
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
|
||||
|
||||
当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
|
||||
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
|
||||
时错误地使用这些 SQLAlchemy 对象作为默认值。
|
||||
|
||||
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
|
||||
|
||||
:param cls: 要修复的类
|
||||
"""
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 或 Column 污染的
|
||||
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
|
||||
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
|
||||
# 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
|
||||
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制 SQLModel 特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的 default_factory 丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin,并支持STI子类列注册
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
**重要:数据库迁移注意事项**
|
||||
|
||||
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
|
||||
|
||||
例如,对于以下继承链::
|
||||
|
||||
LLM (polymorphic_on='_polymorphic_name')
|
||||
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
|
||||
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
|
||||
|
||||
迁移脚本中设置 _polymorphic_name 时::
|
||||
|
||||
# ❌ 错误:缺少父类前缀
|
||||
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
|
||||
|
||||
# ✅ 正确:包含完整的继承链前缀
|
||||
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
|
||||
|
||||
**STI(单表继承)支持**:
|
||||
当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
|
||||
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
|
||||
注册到父表的问题。
|
||||
|
||||
Example (JTI):
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Example (STI):
|
||||
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
|
||||
... user_id: UUID
|
||||
...
|
||||
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
|
||||
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
|
||||
... # polymorphic_identity 会自动设置为 'pendingfile'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
- STI模式下,新字段会在类定义时自动添加到父表的metadata中
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity并处理STI列注册
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
# 注册 STI 子类列的延迟执行
|
||||
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
|
||||
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
|
||||
_sti_subclasses_to_register.append(cls)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复 STI 继承中子类字段被 Column 对象污染的问题。
|
||||
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
|
||||
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
@classmethod
|
||||
def _register_sti_columns(cls) -> None:
|
||||
"""
|
||||
将STI子类的新字段注册到父表的列定义中
|
||||
|
||||
检测当前类是否是STI子类(与父类共用同一张表),
|
||||
如果是,则将子类定义的新字段添加到父表的metadata中。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_fields: set[str] = set()
|
||||
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
# 收集父类的所有字段名
|
||||
if hasattr(base, 'model_fields'):
|
||||
parent_fields.update(base.model_fields.keys())
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列注册到父表
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列注册
|
||||
|
||||
# 获取当前类的新字段(不在父类中的字段)
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
existing_columns = {col.name for col in parent_table.columns}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段和ClassVar
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 跳过已存在的列
|
||||
if field_name in existing_columns:
|
||||
continue
|
||||
|
||||
# 使用 SQLModel 的内置 API 创建列
|
||||
try:
|
||||
column = get_column_from_field(field_info)
|
||||
column.name = field_name
|
||||
column.key = field_name
|
||||
# STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
|
||||
# Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
|
||||
column.nullable = True
|
||||
|
||||
# 将列添加到父表
|
||||
parent_table.append_column(column)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def _register_sti_column_properties(cls) -> None:
|
||||
"""
|
||||
将 STI 子类的列作为 ColumnProperty 添加到 mapper
|
||||
|
||||
此方法在 configure_mappers() 之后调用,将已添加到表中的列
|
||||
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
|
||||
|
||||
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
|
||||
这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
|
||||
避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_class = None
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
parent_class = base
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列属性注册
|
||||
|
||||
# 获取子类和父类的 mapper
|
||||
child_mapper = inspect(cls).mapper
|
||||
parent_mapper = inspect(parent_class).mapper
|
||||
local_table = child_mapper.local_table
|
||||
|
||||
# 查找父类的所有字段名
|
||||
parent_fields: set[str] = set()
|
||||
if hasattr(parent_class, 'model_fields'):
|
||||
parent_fields.update(parent_class.model_fields.keys())
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
# 获取两个 mapper 已有的列属性
|
||||
child_existing_props = {p.key for p in child_mapper.column_attrs}
|
||||
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
|
||||
|
||||
for field_name in cls.model_fields:
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 检查表中是否有这个列
|
||||
if field_name not in local_table.columns:
|
||||
continue
|
||||
|
||||
column = local_table.columns[field_name]
|
||||
|
||||
# 添加到子类的 mapper(如果尚不存在)
|
||||
if field_name not in child_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
child_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
|
||||
|
||||
# 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
|
||||
if field_name not in parent_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
parent_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def _is_joined_table_inheritance(cls) -> bool:
|
||||
"""
|
||||
检测当前类是否使用联表继承(Joined Table Inheritance)
|
||||
|
||||
通过检查子类是否有独立的表来判断:
|
||||
- JTI: 子类有独立的 local_table(与父类不同)
|
||||
- STI: 子类与父类共用同一个 local_table
|
||||
|
||||
:return: True 表示 JTI,False 表示 STI 或无子类
|
||||
"""
|
||||
mapper = inspect(cls)
|
||||
base_table_name = mapper.local_table.name
|
||||
|
||||
# 检查所有直接子类
|
||||
for subclass in cls.__subclasses__():
|
||||
sub_mapper = inspect(subclass)
|
||||
# 如果任何子类有不同的表名,说明是 JTI
|
||||
if sub_mapper.local_table.name != base_table_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
@@ -1,470 +0,0 @@
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
||||
|
||||
设计原则:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
||||
- 零侵入:调用方无需任何改动
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
||||
|
||||
使用方式:
|
||||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
||||
|
||||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
||||
kling_video_generator: KlingO1Generator = Relationship(...)
|
||||
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# 自动加载,可以安全访问
|
||||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
||||
...
|
||||
|
||||
# 调用方 - 无需任何改动
|
||||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
||||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
||||
|
||||
支持 AsyncGenerator:
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
||||
"""
|
||||
import inspect as python_inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
|
||||
|
||||
def _extract_session(
|
||||
func: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> AsyncSession | None:
|
||||
"""
|
||||
从方法参数中提取 AsyncSession
|
||||
|
||||
按以下顺序查找:
|
||||
1. kwargs 中名为 'session' 的参数
|
||||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
||||
3. kwargs 中类型为 AsyncSession 的参数
|
||||
"""
|
||||
# 1. 优先从 kwargs 查找
|
||||
if 'session' in kwargs:
|
||||
return kwargs['session']
|
||||
|
||||
# 2. 从函数签名定位位置参数
|
||||
try:
|
||||
sig = python_inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
if 'session' in param_names:
|
||||
# 计算位置(减去 self)
|
||||
idx = param_names.index('session') - 1
|
||||
if 0 <= idx < len(args):
|
||||
return args[idx]
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, AsyncSession):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
||||
"""
|
||||
检查对象的关系是否已加载(独立函数版本)
|
||||
|
||||
Args:
|
||||
obj: 要检查的对象
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(obj)
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
||||
"""
|
||||
在类中查找指向目标类的关系属性名
|
||||
|
||||
Args:
|
||||
from_class: 源类
|
||||
to_class: 目标类
|
||||
|
||||
Returns:
|
||||
关系属性名,如果找不到则返回 None
|
||||
|
||||
Example:
|
||||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
||||
# 返回 'kling_video_generator'
|
||||
"""
|
||||
for attr_name in dir(from_class):
|
||||
try:
|
||||
attr = getattr(from_class, attr_name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
||||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
||||
target_class = attr.property.mapper.class_
|
||||
if target_class == to_class:
|
||||
return attr_name
|
||||
except AttributeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""
|
||||
装饰器:声明方法需要的关系,自动按需增量加载
|
||||
|
||||
参数格式:
|
||||
- 字符串:本类属性名,如 'kling_video_generator'
|
||||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
||||
|
||||
行为:
|
||||
- 方法调用时自动检查关系是否已加载
|
||||
- 未加载的关系会被增量加载(单次查询)
|
||||
- 已加载的关系直接跳过
|
||||
|
||||
支持:
|
||||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
||||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
||||
|
||||
Example:
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# self.kling_video_generator.kling_o1 已自动加载
|
||||
...
|
||||
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
||||
|
||||
验证:
|
||||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
||||
- 拼写错误会在导入时抛出 AttributeError
|
||||
"""
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
# 检测是否是 async generator 函数
|
||||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
||||
|
||||
if is_async_gen:
|
||||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
# 委托给原始 async generator,逐个 yield 值
|
||||
async for item in func(self, *args, **kwargs):
|
||||
yield item # type: ignore
|
||||
else:
|
||||
# 普通 async 函数:await 并返回结果
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
# 保存关系声明供验证和内省使用
|
||||
wrapper._required_relations = relations # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RelationPreloadMixin:
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
||||
|
||||
特性:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 原地更新:直接修改 self,无需替换实例
|
||||
- 导入时验证:字符串关系名在类创建时验证
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""类创建时验证所有 @requires_relations 声明"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 收集类及其父类的所有注解(包含普通字段)
|
||||
all_annotations: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__annotations__'):
|
||||
all_annotations.update(klass.__annotations__.keys())
|
||||
|
||||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
||||
sqlmodel_relationships: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__sqlmodel_relationships__'):
|
||||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
||||
|
||||
# 合并所有可用的属性名
|
||||
all_available_names = all_annotations | sqlmodel_relationships
|
||||
|
||||
for method_name in dir(cls):
|
||||
if method_name.startswith('__'):
|
||||
continue
|
||||
|
||||
try:
|
||||
method = getattr(cls, method_name, None)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
continue
|
||||
|
||||
# 验证字符串格式的关系名
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
# 检查注解、Relationship 或已有属性
|
||||
if spec not in all_available_names and not hasattr(cls, spec):
|
||||
raise AttributeError(
|
||||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
||||
f"但 {cls.__name__} 没有此属性"
|
||||
)
|
||||
|
||||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
||||
"""
|
||||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
||||
|
||||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
||||
自动处理 commit 导致的 expire 问题。
|
||||
|
||||
Args:
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(self)
|
||||
# unloaded 包含未加载的关系属性名
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
# 对象可能未被 SQLAlchemy 管理
|
||||
return False
|
||||
|
||||
async def _ensure_relations_loaded(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
relations: tuple[str | RelationshipInfo, ...],
|
||||
) -> None:
|
||||
"""
|
||||
确保指定关系已加载,只加载未加载的部分
|
||||
|
||||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
||||
- 首次访问的关系
|
||||
- commit 后 expire 的关系
|
||||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
relations: 需要的关系规格
|
||||
"""
|
||||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
||||
to_load: list[str | RelationshipInfo] = []
|
||||
# 区分直接关系和嵌套关系的 key
|
||||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
||||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
||||
|
||||
for rel in relations:
|
||||
if isinstance(rel, str):
|
||||
# 直接关系:检查本类的关系是否已加载
|
||||
if not self._is_relation_loaded(rel):
|
||||
to_load.append(rel)
|
||||
direct_keys.add(rel)
|
||||
else:
|
||||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
||||
# 1. 查找指向父类的关系属性
|
||||
parent_class = rel.parent.class_
|
||||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
||||
|
||||
if parent_attr is None:
|
||||
# 找不到路径,可能是配置错误,但仍尝试加载
|
||||
l.warning(
|
||||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
||||
f"无法检查 {rel.key} 是否已加载"
|
||||
)
|
||||
to_load.append(rel)
|
||||
continue
|
||||
|
||||
# 2. 检查父对象是否已加载
|
||||
if not self._is_relation_loaded(parent_attr):
|
||||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
else:
|
||||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
||||
parent_obj = getattr(self, parent_attr)
|
||||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
||||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
||||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
|
||||
if not to_load:
|
||||
return # 全部已加载,跳过
|
||||
|
||||
# 构建 load 参数
|
||||
load_options = self._specs_to_load_options(to_load)
|
||||
if not load_options:
|
||||
return
|
||||
|
||||
# 安全地获取主键值(避免触发懒加载)
|
||||
state = sa_inspect(self)
|
||||
pk_tuple = state.key[1] if state.key else None
|
||||
if pk_tuple is None:
|
||||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
||||
return
|
||||
# 主键是元组,取第一个值(假设单列主键)
|
||||
pk_value = pk_tuple[0]
|
||||
|
||||
# 单次查询加载缺失的关系
|
||||
fresh = await self.__class__.get(
|
||||
session,
|
||||
self.__class__.id == pk_value,
|
||||
load=load_options,
|
||||
)
|
||||
|
||||
if fresh is None:
|
||||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
||||
return
|
||||
|
||||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
||||
all_direct_keys = direct_keys | nested_parent_keys
|
||||
for key in all_direct_keys:
|
||||
value = getattr(fresh, key, None)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _specs_to_load_options(
|
||||
self,
|
||||
specs: list[str | RelationshipInfo],
|
||||
) -> list[RelationshipInfo]:
|
||||
"""
|
||||
将关系规格转换为 load 参数
|
||||
|
||||
- 字符串 → cls.{name}
|
||||
- RelationshipInfo → 直接使用
|
||||
"""
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for spec in specs:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(self.__class__, spec, None)
|
||||
if rel is not None:
|
||||
result.append(rel)
|
||||
else:
|
||||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 可选的手动预加载 API ====================
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取指定方法声明的关系(用于外部预加载场景)
|
||||
|
||||
Args:
|
||||
method_name: 方法名
|
||||
|
||||
Returns:
|
||||
RelationshipInfo 列表
|
||||
"""
|
||||
method = getattr(cls, method_name, None)
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
return []
|
||||
|
||||
result: list[RelationshipInfo] = []
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(cls, spec, None)
|
||||
if rel:
|
||||
result.append(rel)
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取多个方法的关系并去重(用于批量预加载场景)
|
||||
|
||||
Args:
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
去重后的 RelationshipInfo 列表
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
for rel in cls.get_relations_for_method(method_name):
|
||||
key = rel.key
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
result.append(rel)
|
||||
|
||||
return result
|
||||
|
||||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
||||
"""
|
||||
手动预加载指定方法的关系(可选优化 API)
|
||||
|
||||
当需要确保在调用方法前完成所有加载时使用。
|
||||
通常情况下不需要调用此方法,装饰器会自动处理。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Example:
|
||||
# 可选:显式预加载(通常不需要)
|
||||
tool = await tool.preload_for(session, 'cost', '_call')
|
||||
"""
|
||||
all_relations: list[str | RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
method = getattr(self.__class__, method_name, None)
|
||||
if method and hasattr(method, '_required_relations'):
|
||||
all_relations.extend(method._required_relations)
|
||||
|
||||
if all_relations:
|
||||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
||||
|
||||
return self
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ from enum import StrEnum
|
||||
|
||||
from sqlmodel import Field
|
||||
|
||||
from .base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class ResponseBase(SQLModelBase):
|
||||
|
||||
@@ -3,8 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship, text, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .download import Download
|
||||
|
||||
@@ -7,8 +7,7 @@ from enum import StrEnum
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, CheckConstraint, Index, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -15,8 +15,7 @@ from uuid import UUID
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
from enum import StrEnum
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
|
||||
@@ -2,8 +2,7 @@ from enum import StrEnum
|
||||
|
||||
from sqlmodel import Field, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
|
||||
class RedeemType(StrEnum):
|
||||
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .share import Share
|
||||
|
||||
@@ -2,9 +2,9 @@ from enum import StrEnum
|
||||
|
||||
from sqlmodel import UniqueConstraint
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
from .auth_identity import AuthProviderType
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from .user import UserResponse
|
||||
|
||||
class CaptchaType(StrEnum):
|
||||
|
||||
@@ -5,9 +5,9 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
from .model_base import ResponseBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
from .object import ObjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
|
||||
@@ -5,8 +5,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -5,8 +5,7 @@ from datetime import datetime
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -5,8 +5,7 @@ from datetime import datetime
|
||||
|
||||
from sqlmodel import Field, Relationship, CheckConstraint, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .download import Download
|
||||
|
||||
@@ -3,9 +3,9 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field
|
||||
|
||||
from .base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
|
||||
|
||||
class ThemePresetBase(SQLModelBase):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from enum import StrEnum
|
||||
from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote
|
||||
|
||||
from .base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class FileSystemNamespace(StrEnum):
|
||||
|
||||
@@ -9,11 +9,11 @@ from sqlmodel import Field, Relationship
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableViewRequest, ListResponse
|
||||
|
||||
from .auth_identity import AuthProviderType
|
||||
from .base import SQLModelBase
|
||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase
|
||||
from .model_base import ResponseBase
|
||||
from .mixin import UUIDTableBaseMixin, TableViewRequest, ListResponse
|
||||
|
||||
T = TypeVar("T", bound="User")
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from uuid import UUID
|
||||
from sqlalchemy import Column, Text
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
83
sqlmodels/wopi.py
Normal file
83
sqlmodels/wopi.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
WOPI(Web Application Open Platform Interface)协议模型
|
||||
|
||||
提供 WOPI CheckFileInfo 响应模型和 WOPI 访问令牌 Payload 定义。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class WopiFileInfo(SQLModelBase):
|
||||
"""
|
||||
WOPI CheckFileInfo 响应模型。
|
||||
|
||||
字段命名遵循 WOPI 规范(PascalCase),通过 alias 映射。
|
||||
参考: https://learn.microsoft.com/en-us/microsoft-365/cloud-storage-partner-program/rest/files/checkfileinfo
|
||||
"""
|
||||
|
||||
base_file_name: str
|
||||
"""文件名(含扩展名)"""
|
||||
|
||||
size: int
|
||||
"""文件大小(字节)"""
|
||||
|
||||
owner_id: str
|
||||
"""文件所有者标识"""
|
||||
|
||||
user_id: str
|
||||
"""当前用户标识"""
|
||||
|
||||
user_friendly_name: str
|
||||
"""用户显示名"""
|
||||
|
||||
version: str
|
||||
"""文件版本标识(使用 updated_at)"""
|
||||
|
||||
sha256: str = ""
|
||||
"""文件 SHA256 哈希(如果可用)"""
|
||||
|
||||
user_can_write: bool = False
|
||||
"""用户是否可写"""
|
||||
|
||||
user_can_not_write_relative: bool = True
|
||||
"""是否禁止创建关联文件"""
|
||||
|
||||
read_only: bool = True
|
||||
"""文件是否只读"""
|
||||
|
||||
supports_locks: bool = False
|
||||
"""是否支持锁(v1 不实现)"""
|
||||
|
||||
supports_update: bool = True
|
||||
"""是否支持更新"""
|
||||
|
||||
def to_wopi_dict(self) -> dict[str, str | int | bool]:
|
||||
"""转换为 WOPI 规范的 PascalCase 字典"""
|
||||
return {
|
||||
"BaseFileName": self.base_file_name,
|
||||
"Size": self.size,
|
||||
"OwnerId": self.owner_id,
|
||||
"UserId": self.user_id,
|
||||
"UserFriendlyName": self.user_friendly_name,
|
||||
"Version": self.version,
|
||||
"SHA256": self.sha256,
|
||||
"UserCanWrite": self.user_can_write,
|
||||
"UserCanNotWriteRelative": self.user_can_not_write_relative,
|
||||
"ReadOnly": self.read_only,
|
||||
"SupportsLocks": self.supports_locks,
|
||||
"SupportsUpdate": self.supports_update,
|
||||
}
|
||||
|
||||
|
||||
class WopiAccessTokenPayload(SQLModelBase):
|
||||
"""WOPI 访问令牌内部 Payload"""
|
||||
|
||||
file_id: UUID
|
||||
"""文件UUID"""
|
||||
|
||||
user_id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
can_write: bool = False
|
||||
"""是否可写"""
|
||||
253
tests/integration/api/test_admin_file_app.py
Normal file
253
tests/integration/api/test_admin_file_app.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
管理员文件应用管理集成测试
|
||||
|
||||
测试管理员 CRUD、扩展名更新、用户组权限更新和权限校验。
|
||||
"""
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.file_app import FileApp, FileAppExtension, FileAppType
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.user import User
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_admin_app(
|
||||
initialized_db: AsyncSession,
|
||||
) -> dict[str, UUID]:
|
||||
"""创建测试用管理员文件应用"""
|
||||
app = FileApp(
|
||||
name="管理员测试应用",
|
||||
app_key="admin_test_app",
|
||||
type=FileAppType.BUILTIN,
|
||||
is_enabled=True,
|
||||
)
|
||||
app = await app.save(initialized_db)
|
||||
|
||||
ext = FileAppExtension(app_id=app.id, extension="test", priority=0)
|
||||
await ext.save(initialized_db)
|
||||
|
||||
return {"app_id": app.id}
|
||||
|
||||
|
||||
# ==================== Admin CRUD ====================
|
||||
|
||||
class TestAdminFileAppCRUD:
|
||||
"""管理员文件应用 CRUD 测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_file_app(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
admin_headers: dict[str, str],
|
||||
) -> None:
|
||||
"""管理员创建文件应用"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/admin/file-app/",
|
||||
headers=admin_headers,
|
||||
json={
|
||||
"name": "新建应用",
|
||||
"app_key": "new_app",
|
||||
"type": "builtin",
|
||||
"description": "测试新建",
|
||||
"extensions": ["pdf", "txt"],
|
||||
"allowed_group_ids": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "新建应用"
|
||||
assert data["app_key"] == "new_app"
|
||||
assert "pdf" in data["extensions"]
|
||||
assert "txt" in data["extensions"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_app_key(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
admin_headers: dict[str, str],
|
||||
setup_admin_app: dict[str, UUID],
|
||||
) -> None:
|
||||
"""创建重复 app_key 返回 409"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/admin/file-app/",
|
||||
headers=admin_headers,
|
||||
json={
|
||||
"name": "重复应用",
|
||||
"app_key": "admin_test_app",
|
||||
"type": "builtin",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_file_apps(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
admin_headers: dict[str, str],
|
||||
setup_admin_app: dict[str, UUID],
|
||||
) -> None:
|
||||
"""管理员列出文件应用"""
|
||||
response = await async_client.get(
|
||||
"/api/v1/admin/file-app/list",
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "apps" in data
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_app_detail(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
admin_headers: dict[str, str],
|
||||
setup_admin_app: dict[str, UUID],
|
||||
) -> None:
|
||||
"""管理员获取应用详情"""
|
||||
app_id = setup_admin_app["app_id"]
|
||||
response = await async_client.get(
|
||||
f"/api/v1/admin/file-app/{app_id}",
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["app_key"] == "admin_test_app"
|
||||
assert "test" in data["extensions"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_app(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
admin_headers: dict[str, str],
|
||||
) -> None:
|
||||
"""获取不存在的应用返回 404"""
|
||||
response = await async_client.get(
|
||||
f"/api/v1/admin/file-app/{uuid4()}",
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_file_app(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
admin_headers: dict[str, str],
|
||||
setup_admin_app: dict[str, UUID],
|
||||
) -> None:
|
||||
"""管理员更新应用"""
|
||||
app_id = setup_admin_app["app_id"]
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/admin/file-app/{app_id}",
|
||||
headers=admin_headers,
|
||||
json={
|
||||
"name": "更新后的名称",
|
||||
"is_enabled": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "更新后的名称"
|
||||
assert data["is_enabled"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_file_app(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
initialized_db: AsyncSession,
|
||||
admin_headers: dict[str, str],
|
||||
) -> None:
|
||||
"""管理员删除应用"""
|
||||
# 先创建一个应用
|
||||
app = FileApp(
|
||||
name="待删除应用", app_key="to_delete_admin", type=FileAppType.BUILTIN
|
||||
)
|
||||
app = await app.save(initialized_db)
|
||||
app_id = app.id
|
||||
|
||||
response = await async_client.delete(
|
||||
f"/api/v1/admin/file-app/{app_id}",
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 确认已删除
|
||||
found = await FileApp.get(initialized_db, FileApp.id == app_id)
|
||||
assert found is None
|
||||
|
||||
|
||||
# ==================== Extensions Management ====================
|
||||
|
||||
class TestAdminExtensionManagement:
|
||||
"""管理员扩展名管理测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_extensions(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
admin_headers: dict[str, str],
|
||||
setup_admin_app: dict[str, UUID],
|
||||
) -> None:
|
||||
"""全量替换扩展名列表"""
|
||||
app_id = setup_admin_app["app_id"]
|
||||
response = await async_client.put(
|
||||
f"/api/v1/admin/file-app/{app_id}/extensions",
|
||||
headers=admin_headers,
|
||||
json={"extensions": ["doc", "docx", "odt"]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert sorted(data["extensions"]) == ["doc", "docx", "odt"]
|
||||
|
||||
|
||||
# ==================== Group Access Management ====================
|
||||
|
||||
class TestAdminGroupAccessManagement:
|
||||
"""管理员用户组权限管理测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_group_access(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
initialized_db: AsyncSession,
|
||||
admin_headers: dict[str, str],
|
||||
setup_admin_app: dict[str, UUID],
|
||||
) -> None:
|
||||
"""全量替换用户组权限"""
|
||||
app_id = setup_admin_app["app_id"]
|
||||
admin_user = await User.get(initialized_db, User.email == "admin@disknext.local")
|
||||
group_id = admin_user.group_id
|
||||
|
||||
response = await async_client.put(
|
||||
f"/api/v1/admin/file-app/{app_id}/groups",
|
||||
headers=admin_headers,
|
||||
json={"group_ids": [str(group_id)]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert str(group_id) in data["allowed_group_ids"]
|
||||
|
||||
|
||||
# ==================== Permission Tests ====================
|
||||
|
||||
class TestAdminPermission:
|
||||
"""权限校验测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_forbidden(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
) -> None:
|
||||
"""普通用户访问管理端点返回 403"""
|
||||
response = await async_client.get(
|
||||
"/api/v1/admin/file-app/list",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
466
tests/integration/api/test_file_content.py
Normal file
466
tests/integration/api/test_file_content.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
文本文件内容 GET/PATCH 集成测试
|
||||
|
||||
测试 GET /file/content/{file_id} 和 PATCH /file/content/{file_id} 端点。
|
||||
"""
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels import Object, ObjectType, PhysicalFile, Policy, User
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _register_sqlite_greatest():
|
||||
"""注册 SQLite 的 greatest 函数以兼容 PostgreSQL 语法"""
|
||||
|
||||
def _on_connect(dbapi_connection, connection_record):
|
||||
if hasattr(dbapi_connection, 'create_function'):
|
||||
dbapi_connection.create_function("greatest", 2, max)
|
||||
|
||||
event.listen(Engine, "connect", _on_connect)
|
||||
yield
|
||||
event.remove(Engine, "connect", _on_connect)
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def local_policy(
|
||||
initialized_db: AsyncSession,
|
||||
tmp_path: Path,
|
||||
) -> Policy:
|
||||
"""创建指向临时目录的本地存储策略"""
|
||||
from sqlmodels import PolicyType
|
||||
|
||||
policy = Policy(
|
||||
id=uuid4(),
|
||||
name="测试本地存储",
|
||||
type=PolicyType.LOCAL,
|
||||
server=str(tmp_path),
|
||||
)
|
||||
initialized_db.add(policy)
|
||||
await initialized_db.commit()
|
||||
await initialized_db.refresh(policy)
|
||||
return policy
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def text_file(
|
||||
initialized_db: AsyncSession,
|
||||
tmp_path: Path,
|
||||
local_policy: Policy,
|
||||
) -> dict[str, str | int]:
|
||||
"""创建包含 UTF-8 文本内容的测试文件"""
|
||||
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
root = await Object.get_root(initialized_db, user.id)
|
||||
|
||||
content = "line1\nline2\nline3\n"
|
||||
content_bytes = content.encode('utf-8')
|
||||
content_hash = hashlib.sha256(content_bytes).hexdigest()
|
||||
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_bytes(content_bytes)
|
||||
|
||||
physical_file = PhysicalFile(
|
||||
id=uuid4(),
|
||||
storage_path=str(file_path),
|
||||
size=len(content_bytes),
|
||||
policy_id=local_policy.id,
|
||||
reference_count=1,
|
||||
)
|
||||
initialized_db.add(physical_file)
|
||||
|
||||
file_obj = Object(
|
||||
id=uuid4(),
|
||||
name="test.txt",
|
||||
type=ObjectType.FILE,
|
||||
size=len(content_bytes),
|
||||
physical_file_id=physical_file.id,
|
||||
parent_id=root.id,
|
||||
owner_id=user.id,
|
||||
policy_id=local_policy.id,
|
||||
)
|
||||
initialized_db.add(file_obj)
|
||||
await initialized_db.commit()
|
||||
|
||||
return {
|
||||
"id": str(file_obj.id),
|
||||
"content": content,
|
||||
"hash": content_hash,
|
||||
"size": len(content_bytes),
|
||||
"path": str(file_path),
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def binary_file(
|
||||
initialized_db: AsyncSession,
|
||||
tmp_path: Path,
|
||||
local_policy: Policy,
|
||||
) -> dict[str, str | int]:
|
||||
"""创建非 UTF-8 的二进制测试文件"""
|
||||
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
root = await Object.get_root(initialized_db, user.id)
|
||||
|
||||
# 包含无效 UTF-8 字节序列
|
||||
content_bytes = b'\x80\x81\x82\xff\xfe\xfd'
|
||||
|
||||
file_path = tmp_path / "binary.dat"
|
||||
file_path.write_bytes(content_bytes)
|
||||
|
||||
physical_file = PhysicalFile(
|
||||
id=uuid4(),
|
||||
storage_path=str(file_path),
|
||||
size=len(content_bytes),
|
||||
policy_id=local_policy.id,
|
||||
reference_count=1,
|
||||
)
|
||||
initialized_db.add(physical_file)
|
||||
|
||||
file_obj = Object(
|
||||
id=uuid4(),
|
||||
name="binary.dat",
|
||||
type=ObjectType.FILE,
|
||||
size=len(content_bytes),
|
||||
physical_file_id=physical_file.id,
|
||||
parent_id=root.id,
|
||||
owner_id=user.id,
|
||||
policy_id=local_policy.id,
|
||||
)
|
||||
initialized_db.add(file_obj)
|
||||
await initialized_db.commit()
|
||||
|
||||
return {
|
||||
"id": str(file_obj.id),
|
||||
"path": str(file_path),
|
||||
}
|
||||
|
||||
|
||||
# ==================== GET /file/content/{file_id} ====================
|
||||
|
||||
class TestGetFileContent:
|
||||
"""GET /file/content/{file_id} 端点测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_success(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""成功获取文本文件内容和哈希"""
|
||||
response = await async_client.get(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["content"] == text_file["content"]
|
||||
assert data["hash"] == text_file["hash"]
|
||||
assert data["size"] == text_file["size"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_non_utf8_returns_400(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
binary_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""非 UTF-8 文件返回 400"""
|
||||
response = await async_client.get(
|
||||
f"/api/v1/file/content/{binary_file['id']}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "UTF-8" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_not_found(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
) -> None:
|
||||
"""文件不存在返回 404"""
|
||||
fake_id = uuid4()
|
||||
response = await async_client.get(
|
||||
f"/api/v1/file/content/{fake_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_unauthenticated(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""未认证返回 401"""
|
||||
response = await async_client.get(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_normalizes_crlf(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
initialized_db: AsyncSession,
|
||||
tmp_path: Path,
|
||||
local_policy: Policy,
|
||||
) -> None:
|
||||
"""CRLF 换行符被规范化为 LF"""
|
||||
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
root = await Object.get_root(initialized_db, user.id)
|
||||
|
||||
crlf_content = b"line1\r\nline2\r\n"
|
||||
file_path = tmp_path / "crlf.txt"
|
||||
file_path.write_bytes(crlf_content)
|
||||
|
||||
physical_file = PhysicalFile(
|
||||
id=uuid4(),
|
||||
storage_path=str(file_path),
|
||||
size=len(crlf_content),
|
||||
policy_id=local_policy.id,
|
||||
reference_count=1,
|
||||
)
|
||||
initialized_db.add(physical_file)
|
||||
|
||||
file_obj = Object(
|
||||
id=uuid4(),
|
||||
name="crlf.txt",
|
||||
type=ObjectType.FILE,
|
||||
size=len(crlf_content),
|
||||
physical_file_id=physical_file.id,
|
||||
parent_id=root.id,
|
||||
owner_id=user.id,
|
||||
policy_id=local_policy.id,
|
||||
)
|
||||
initialized_db.add(file_obj)
|
||||
await initialized_db.commit()
|
||||
|
||||
response = await async_client.get(
|
||||
f"/api/v1/file/content/{file_obj.id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# 内容应该被规范化为 LF
|
||||
assert data["content"] == "line1\nline2\n"
|
||||
# 哈希基于规范化后的内容
|
||||
expected_hash = hashlib.sha256("line1\nline2\n".encode('utf-8')).hexdigest()
|
||||
assert data["hash"] == expected_hash
|
||||
|
||||
|
||||
# ==================== PATCH /file/content/{file_id} ====================
|
||||
|
||||
class TestPatchFileContent:
|
||||
"""PATCH /file/content/{file_id} 端点测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_content_success(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""正常增量保存"""
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,3 +1,3 @@\n"
|
||||
" line1\n"
|
||||
"-line2\n"
|
||||
"+LINE2_MODIFIED\n"
|
||||
" line3\n"
|
||||
)
|
||||
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"patch": patch_text,
|
||||
"base_hash": text_file["hash"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "new_hash" in data
|
||||
assert "new_size" in data
|
||||
assert data["new_hash"] != text_file["hash"]
|
||||
|
||||
# 验证文件实际被修改
|
||||
file_path = Path(text_file["path"])
|
||||
new_content = file_path.read_text(encoding='utf-8')
|
||||
assert "LINE2_MODIFIED" in new_content
|
||||
assert "line2" not in new_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_content_hash_mismatch_returns_409(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""base_hash 不匹配返回 409"""
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,3 +1,3 @@\n"
|
||||
" line1\n"
|
||||
"-line2\n"
|
||||
"+changed\n"
|
||||
" line3\n"
|
||||
)
|
||||
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"patch": patch_text,
|
||||
"base_hash": "0" * 64, # 错误的哈希
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_content_invalid_patch_returns_422(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""无效的 patch 格式返回 422"""
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"patch": "this is not a valid patch",
|
||||
"base_hash": text_file["hash"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_content_context_mismatch_returns_422(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""patch 上下文行不匹配返回 422"""
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,3 +1,3 @@\n"
|
||||
" WRONG_CONTEXT_LINE\n"
|
||||
"-line2\n"
|
||||
"+replaced\n"
|
||||
" line3\n"
|
||||
)
|
||||
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"patch": patch_text,
|
||||
"base_hash": text_file["hash"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_content_unauthenticated(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""未认证返回 401"""
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
json={
|
||||
"patch": "--- a\n+++ b\n",
|
||||
"base_hash": text_file["hash"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_content_not_found(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
) -> None:
|
||||
"""文件不存在返回 404"""
|
||||
fake_id = uuid4()
|
||||
response = await async_client.patch(
|
||||
f"/api/v1/file/content/{fake_id}",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"patch": "--- a\n+++ b\n",
|
||||
"base_hash": "0" * 64,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_then_get_consistency(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
text_file: dict[str, str | int],
|
||||
) -> None:
|
||||
"""PATCH 后 GET 返回一致的内容和哈希"""
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,3 +1,3 @@\n"
|
||||
" line1\n"
|
||||
"-line2\n"
|
||||
"+PATCHED\n"
|
||||
" line3\n"
|
||||
)
|
||||
|
||||
# PATCH
|
||||
patch_resp = await async_client.patch(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"patch": patch_text,
|
||||
"base_hash": text_file["hash"],
|
||||
},
|
||||
)
|
||||
assert patch_resp.status_code == 200
|
||||
patch_data = patch_resp.json()
|
||||
|
||||
# GET
|
||||
get_resp = await async_client.get(
|
||||
f"/api/v1/file/content/{text_file['id']}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert get_resp.status_code == 200
|
||||
get_data = get_resp.json()
|
||||
|
||||
# 一致性验证
|
||||
assert get_data["hash"] == patch_data["new_hash"]
|
||||
assert get_data["size"] == patch_data["new_size"]
|
||||
assert "PATCHED" in get_data["content"]
|
||||
305
tests/integration/api/test_file_viewers.py
Normal file
305
tests/integration/api/test_file_viewers.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
文件查看器集成测试
|
||||
|
||||
测试查看器查询、用户默认设置、用户组过滤等端点。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.file_app import (
|
||||
FileApp,
|
||||
FileAppExtension,
|
||||
FileAppGroupLink,
|
||||
FileAppType,
|
||||
UserFileAppDefault,
|
||||
)
|
||||
from sqlmodels.user import User
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_file_apps(
|
||||
initialized_db: AsyncSession,
|
||||
) -> dict[str, UUID]:
|
||||
"""创建测试用文件查看器应用"""
|
||||
# PDF 阅读器(不限制用户组)
|
||||
pdf_app = FileApp(
|
||||
name="PDF 阅读器",
|
||||
app_key="pdfjs",
|
||||
type=FileAppType.BUILTIN,
|
||||
is_enabled=True,
|
||||
is_restricted=False,
|
||||
)
|
||||
pdf_app = await pdf_app.save(initialized_db)
|
||||
|
||||
# Monaco 编辑器(不限制用户组)
|
||||
monaco_app = FileApp(
|
||||
name="代码编辑器",
|
||||
app_key="monaco",
|
||||
type=FileAppType.BUILTIN,
|
||||
is_enabled=True,
|
||||
is_restricted=False,
|
||||
)
|
||||
monaco_app = await monaco_app.save(initialized_db)
|
||||
|
||||
# Collabora(限制用户组)
|
||||
collabora_app = FileApp(
|
||||
name="Collabora",
|
||||
app_key="collabora",
|
||||
type=FileAppType.WOPI,
|
||||
is_enabled=True,
|
||||
is_restricted=True,
|
||||
)
|
||||
collabora_app = await collabora_app.save(initialized_db)
|
||||
|
||||
# 已禁用的应用
|
||||
disabled_app = FileApp(
|
||||
name="禁用的应用",
|
||||
app_key="disabled_app",
|
||||
type=FileAppType.BUILTIN,
|
||||
is_enabled=False,
|
||||
is_restricted=False,
|
||||
)
|
||||
disabled_app = await disabled_app.save(initialized_db)
|
||||
|
||||
# 创建扩展名
|
||||
for ext in ["pdf"]:
|
||||
await FileAppExtension(app_id=pdf_app.id, extension=ext, priority=0).save(initialized_db)
|
||||
|
||||
for ext in ["txt", "md", "json"]:
|
||||
await FileAppExtension(app_id=monaco_app.id, extension=ext, priority=0).save(initialized_db)
|
||||
|
||||
for ext in ["docx", "xlsx", "pptx"]:
|
||||
await FileAppExtension(app_id=collabora_app.id, extension=ext, priority=0).save(initialized_db)
|
||||
|
||||
for ext in ["pdf"]:
|
||||
await FileAppExtension(app_id=disabled_app.id, extension=ext, priority=10).save(initialized_db)
|
||||
|
||||
return {
|
||||
"pdf_app_id": pdf_app.id,
|
||||
"monaco_app_id": monaco_app.id,
|
||||
"collabora_app_id": collabora_app.id,
|
||||
"disabled_app_id": disabled_app.id,
|
||||
}
|
||||
|
||||
|
||||
# ==================== GET /file/viewers ====================
|
||||
|
||||
class TestGetViewers:
|
||||
"""查询可用查看器测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_viewers_for_pdf(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""查询 PDF 查看器:返回已启用的,排除已禁用的"""
|
||||
response = await async_client.get(
|
||||
"/api/v1/file/viewers?ext=pdf",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "viewers" in data
|
||||
viewer_keys = [v["app_key"] for v in data["viewers"]]
|
||||
|
||||
# pdfjs 应该在列表中
|
||||
assert "pdfjs" in viewer_keys
|
||||
# 禁用的应用不应出现
|
||||
assert "disabled_app" not in viewer_keys
|
||||
# 默认值应为 None
|
||||
assert data["default_viewer_id"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_viewers_normalizes_extension(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""扩展名规范化:.PDF → pdf"""
|
||||
response = await async_client.get(
|
||||
"/api/v1/file/viewers?ext=.PDF",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["viewers"]) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_viewers_empty_for_unknown_ext(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""未知扩展名返回空列表"""
|
||||
response = await async_client.get(
|
||||
"/api/v1/file/viewers?ext=xyz_unknown",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["viewers"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_restriction_filters_app(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
initialized_db: AsyncSession,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""用户组限制:collabora 限制了用户组,用户不在白名单内则不可见"""
|
||||
# collabora 是受限的,用户组不在白名单中
|
||||
response = await async_client.get(
|
||||
"/api/v1/file/viewers?ext=docx",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
viewer_keys = [v["app_key"] for v in data["viewers"]]
|
||||
assert "collabora" not in viewer_keys
|
||||
|
||||
# 将用户组加入白名单
|
||||
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
link = FileAppGroupLink(
|
||||
app_id=setup_file_apps["collabora_app_id"],
|
||||
group_id=test_user.group_id,
|
||||
)
|
||||
initialized_db.add(link)
|
||||
await initialized_db.commit()
|
||||
|
||||
# 再次查询
|
||||
response = await async_client.get(
|
||||
"/api/v1/file/viewers?ext=docx",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
viewer_keys = [v["app_key"] for v in data["viewers"]]
|
||||
assert "collabora" in viewer_keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_without_token(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
) -> None:
|
||||
"""未认证请求返回 401"""
|
||||
response = await async_client.get("/api/v1/file/viewers?ext=pdf")
|
||||
assert response.status_code in (401, 403)
|
||||
|
||||
|
||||
# ==================== User File Viewer Defaults ====================
|
||||
|
||||
class TestUserFileViewerDefaults:
|
||||
"""用户默认查看器设置测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_default_viewer(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""设置默认查看器"""
|
||||
response = await async_client.put(
|
||||
"/api/v1/user/settings/file-viewers/default",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"extension": "pdf",
|
||||
"app_id": str(setup_file_apps["pdf_app_id"]),
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["extension"] == "pdf"
|
||||
assert data["app"]["app_key"] == "pdfjs"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_default_viewers(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
initialized_db: AsyncSession,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""列出默认查看器"""
|
||||
# 先创建一个默认
|
||||
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
await UserFileAppDefault(
|
||||
user_id=test_user.id,
|
||||
extension="pdf",
|
||||
app_id=setup_file_apps["pdf_app_id"],
|
||||
).save(initialized_db)
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/v1/user/settings/file-viewers/defaults",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_default_viewer(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
initialized_db: AsyncSession,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""撤销默认查看器"""
|
||||
# 创建一个默认
|
||||
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
default = await UserFileAppDefault(
|
||||
user_id=test_user.id,
|
||||
extension="txt",
|
||||
app_id=setup_file_apps["monaco_app_id"],
|
||||
).save(initialized_db)
|
||||
|
||||
response = await async_client.delete(
|
||||
f"/api/v1/user/settings/file-viewers/default/{default.id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# 验证已删除
|
||||
found = await UserFileAppDefault.get(
|
||||
initialized_db, UserFileAppDefault.id == default.id
|
||||
)
|
||||
assert found is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_viewers_includes_default(
|
||||
self,
|
||||
async_client: AsyncClient,
|
||||
initialized_db: AsyncSession,
|
||||
auth_headers: dict[str, str],
|
||||
setup_file_apps: dict[str, UUID],
|
||||
) -> None:
|
||||
"""查看器查询应包含用户默认选择"""
|
||||
# 设置默认
|
||||
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
await UserFileAppDefault(
|
||||
user_id=test_user.id,
|
||||
extension="pdf",
|
||||
app_id=setup_file_apps["pdf_app_id"],
|
||||
).save(initialized_db)
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/v1/file/viewers?ext=pdf",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["default_viewer_id"] == str(setup_file_apps["pdf_app_id"])
|
||||
386
tests/unit/models/test_file_app.py
Normal file
386
tests/unit/models/test_file_app.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
FileApp 模型单元测试
|
||||
|
||||
测试 FileApp、FileAppExtension、UserFileAppDefault 的 CRUD 和约束。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.file_app import (
|
||||
FileApp,
|
||||
FileAppExtension,
|
||||
FileAppGroupLink,
|
||||
FileAppType,
|
||||
UserFileAppDefault,
|
||||
)
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_group(db_session: AsyncSession) -> Group:
|
||||
"""创建测试用户组"""
|
||||
group = Group(name="测试组", max_storage=0, admin=False)
|
||||
return await group.save(db_session)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_user(db_session: AsyncSession, sample_group: Group) -> User:
|
||||
"""创建测试用户"""
|
||||
user = User(
|
||||
email="fileapp_test@test.local",
|
||||
nickname="文件应用测试用户",
|
||||
status=UserStatus.ACTIVE,
|
||||
group_id=sample_group.id,
|
||||
)
|
||||
return await user.save(db_session)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_app(db_session: AsyncSession) -> FileApp:
|
||||
"""创建测试文件应用"""
|
||||
app = FileApp(
|
||||
name="测试PDF阅读器",
|
||||
app_key="test_pdfjs",
|
||||
type=FileAppType.BUILTIN,
|
||||
icon="file-pdf",
|
||||
description="测试用 PDF 阅读器",
|
||||
is_enabled=True,
|
||||
is_restricted=False,
|
||||
)
|
||||
return await app.save(db_session)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_app_with_extensions(db_session: AsyncSession, sample_app: FileApp) -> FileApp:
|
||||
"""创建带扩展名的文件应用"""
|
||||
ext1 = FileAppExtension(app_id=sample_app.id, extension="pdf", priority=0)
|
||||
ext2 = FileAppExtension(app_id=sample_app.id, extension="djvu", priority=1)
|
||||
await ext1.save(db_session)
|
||||
await ext2.save(db_session)
|
||||
return sample_app
|
||||
|
||||
|
||||
# ==================== FileApp CRUD ====================
|
||||
|
||||
class TestFileAppCRUD:
|
||||
"""FileApp 基础 CRUD 测试"""
|
||||
|
||||
async def test_create_file_app(self, db_session: AsyncSession) -> None:
|
||||
"""测试创建文件应用"""
|
||||
app = FileApp(
|
||||
name="Monaco 编辑器",
|
||||
app_key="monaco",
|
||||
type=FileAppType.BUILTIN,
|
||||
description="代码编辑器",
|
||||
is_enabled=True,
|
||||
)
|
||||
app = await app.save(db_session)
|
||||
|
||||
assert app.id is not None
|
||||
assert app.name == "Monaco 编辑器"
|
||||
assert app.app_key == "monaco"
|
||||
assert app.type == FileAppType.BUILTIN
|
||||
assert app.is_enabled is True
|
||||
assert app.is_restricted is False
|
||||
|
||||
async def test_get_file_app_by_key(self, db_session: AsyncSession, sample_app: FileApp) -> None:
|
||||
"""测试按 app_key 查询"""
|
||||
found = await FileApp.get(db_session, FileApp.app_key == "test_pdfjs")
|
||||
assert found is not None
|
||||
assert found.id == sample_app.id
|
||||
|
||||
async def test_unique_app_key(self, db_session: AsyncSession, sample_app: FileApp) -> None:
|
||||
"""测试 app_key 唯一约束"""
|
||||
dup = FileApp(
|
||||
name="重复应用",
|
||||
app_key="test_pdfjs",
|
||||
type=FileAppType.BUILTIN,
|
||||
)
|
||||
with pytest.raises(IntegrityError):
|
||||
await dup.save(db_session)
|
||||
|
||||
async def test_update_file_app(self, db_session: AsyncSession, sample_app: FileApp) -> None:
|
||||
"""测试更新文件应用"""
|
||||
sample_app.name = "更新后的名称"
|
||||
sample_app.is_enabled = False
|
||||
sample_app = await sample_app.save(db_session)
|
||||
|
||||
found = await FileApp.get(db_session, FileApp.id == sample_app.id)
|
||||
assert found.name == "更新后的名称"
|
||||
assert found.is_enabled is False
|
||||
|
||||
async def test_delete_file_app(self, db_session: AsyncSession) -> None:
|
||||
"""测试删除文件应用"""
|
||||
app = FileApp(
|
||||
name="待删除应用",
|
||||
app_key="to_delete",
|
||||
type=FileAppType.IFRAME,
|
||||
)
|
||||
app = await app.save(db_session)
|
||||
app_id = app.id
|
||||
|
||||
await FileApp.delete(db_session, app)
|
||||
|
||||
found = await FileApp.get(db_session, FileApp.id == app_id)
|
||||
assert found is None
|
||||
|
||||
async def test_create_wopi_app(self, db_session: AsyncSession) -> None:
|
||||
"""测试创建 WOPI 类型应用"""
|
||||
app = FileApp(
|
||||
name="Collabora",
|
||||
app_key="collabora",
|
||||
type=FileAppType.WOPI,
|
||||
wopi_discovery_url="http://collabora:9980/hosting/discovery",
|
||||
wopi_editor_url_template="http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc={wopi_src}&access_token={access_token}",
|
||||
is_enabled=True,
|
||||
)
|
||||
app = await app.save(db_session)
|
||||
|
||||
assert app.type == FileAppType.WOPI
|
||||
assert app.wopi_discovery_url is not None
|
||||
assert app.wopi_editor_url_template is not None
|
||||
|
||||
async def test_create_iframe_app(self, db_session: AsyncSession) -> None:
|
||||
"""测试创建 iframe 类型应用"""
|
||||
app = FileApp(
|
||||
name="Office 在线预览",
|
||||
app_key="office_viewer",
|
||||
type=FileAppType.IFRAME,
|
||||
iframe_url_template="https://view.officeapps.live.com/op/embed.aspx?src={file_url}",
|
||||
is_enabled=False,
|
||||
)
|
||||
app = await app.save(db_session)
|
||||
|
||||
assert app.type == FileAppType.IFRAME
|
||||
assert "{file_url}" in app.iframe_url_template
|
||||
|
||||
async def test_to_summary(self, db_session: AsyncSession, sample_app: FileApp) -> None:
|
||||
"""测试转换为摘要 DTO"""
|
||||
summary = sample_app.to_summary()
|
||||
assert summary.id == sample_app.id
|
||||
assert summary.name == sample_app.name
|
||||
assert summary.app_key == sample_app.app_key
|
||||
assert summary.type == sample_app.type
|
||||
|
||||
|
||||
# ==================== FileAppExtension ====================
|
||||
|
||||
class TestFileAppExtension:
|
||||
"""FileAppExtension 测试"""
|
||||
|
||||
async def test_create_extension(self, db_session: AsyncSession, sample_app: FileApp) -> None:
|
||||
"""测试创建扩展名关联"""
|
||||
ext = FileAppExtension(
|
||||
app_id=sample_app.id,
|
||||
extension="pdf",
|
||||
priority=0,
|
||||
)
|
||||
ext = await ext.save(db_session)
|
||||
|
||||
assert ext.id is not None
|
||||
assert ext.extension == "pdf"
|
||||
assert ext.priority == 0
|
||||
|
||||
async def test_query_by_extension(
|
||||
self, db_session: AsyncSession, sample_app_with_extensions: FileApp
|
||||
) -> None:
|
||||
"""测试按扩展名查询"""
|
||||
results: list[FileAppExtension] = await FileAppExtension.get(
|
||||
db_session,
|
||||
FileAppExtension.extension == "pdf",
|
||||
fetch_mode="all",
|
||||
)
|
||||
assert len(results) >= 1
|
||||
assert any(r.app_id == sample_app_with_extensions.id for r in results)
|
||||
|
||||
async def test_unique_app_extension(self, db_session: AsyncSession, sample_app: FileApp) -> None:
|
||||
"""测试 (app_id, extension) 唯一约束"""
|
||||
ext1 = FileAppExtension(app_id=sample_app.id, extension="txt", priority=0)
|
||||
await ext1.save(db_session)
|
||||
|
||||
ext2 = FileAppExtension(app_id=sample_app.id, extension="txt", priority=1)
|
||||
with pytest.raises(IntegrityError):
|
||||
await ext2.save(db_session)
|
||||
|
||||
async def test_cascade_delete(
|
||||
self, db_session: AsyncSession, sample_app_with_extensions: FileApp
|
||||
) -> None:
|
||||
"""测试级联删除:删除应用时扩展名也被删除"""
|
||||
app_id = sample_app_with_extensions.id
|
||||
|
||||
# 确认扩展名存在
|
||||
exts = await FileAppExtension.get(
|
||||
db_session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
assert len(exts) == 2
|
||||
|
||||
# 删除应用
|
||||
await FileApp.delete(db_session, sample_app_with_extensions)
|
||||
|
||||
# 确认扩展名也被删除
|
||||
exts = await FileAppExtension.get(
|
||||
db_session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
assert len(exts) == 0
|
||||
|
||||
|
||||
# ==================== FileAppGroupLink ====================
|
||||
|
||||
class TestFileAppGroupLink:
|
||||
"""FileAppGroupLink 用户组访问控制测试"""
|
||||
|
||||
async def test_create_group_link(
|
||||
self, db_session: AsyncSession, sample_app: FileApp, sample_group: Group
|
||||
) -> None:
|
||||
"""测试创建用户组关联"""
|
||||
link = FileAppGroupLink(app_id=sample_app.id, group_id=sample_group.id)
|
||||
db_session.add(link)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.exec(
|
||||
select(FileAppGroupLink).where(
|
||||
FileAppGroupLink.app_id == sample_app.id,
|
||||
FileAppGroupLink.group_id == sample_group.id,
|
||||
)
|
||||
)
|
||||
found = result.first()
|
||||
assert found is not None
|
||||
|
||||
async def test_multiple_groups(self, db_session: AsyncSession, sample_app: FileApp) -> None:
|
||||
"""测试一个应用关联多个用户组"""
|
||||
group1 = Group(name="组A", admin=False)
|
||||
group1 = await group1.save(db_session)
|
||||
group2 = Group(name="组B", admin=False)
|
||||
group2 = await group2.save(db_session)
|
||||
|
||||
db_session.add(FileAppGroupLink(app_id=sample_app.id, group_id=group1.id))
|
||||
db_session.add(FileAppGroupLink(app_id=sample_app.id, group_id=group2.id))
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == sample_app.id)
|
||||
)
|
||||
links = result.all()
|
||||
assert len(links) == 2
|
||||
|
||||
|
||||
# ==================== UserFileAppDefault ====================
|
||||
|
||||
class TestUserFileAppDefault:
|
||||
"""UserFileAppDefault 用户偏好测试"""
|
||||
|
||||
async def test_create_default(
|
||||
self, db_session: AsyncSession, sample_app: FileApp, sample_user: User
|
||||
) -> None:
|
||||
"""测试创建用户默认偏好"""
|
||||
default = UserFileAppDefault(
|
||||
user_id=sample_user.id,
|
||||
extension="pdf",
|
||||
app_id=sample_app.id,
|
||||
)
|
||||
default = await default.save(db_session)
|
||||
|
||||
assert default.id is not None
|
||||
assert default.extension == "pdf"
|
||||
|
||||
async def test_unique_user_extension(
|
||||
self, db_session: AsyncSession, sample_app: FileApp, sample_user: User
|
||||
) -> None:
|
||||
"""测试 (user_id, extension) 唯一约束"""
|
||||
default1 = UserFileAppDefault(
|
||||
user_id=sample_user.id, extension="pdf", app_id=sample_app.id
|
||||
)
|
||||
await default1.save(db_session)
|
||||
|
||||
# 创建另一个应用
|
||||
app2 = FileApp(
|
||||
name="另一个阅读器",
|
||||
app_key="pdf_alt",
|
||||
type=FileAppType.BUILTIN,
|
||||
)
|
||||
app2 = await app2.save(db_session)
|
||||
|
||||
default2 = UserFileAppDefault(
|
||||
user_id=sample_user.id, extension="pdf", app_id=app2.id
|
||||
)
|
||||
with pytest.raises(IntegrityError):
|
||||
await default2.save(db_session)
|
||||
|
||||
async def test_cascade_delete_on_app(
|
||||
self, db_session: AsyncSession, sample_user: User
|
||||
) -> None:
|
||||
"""测试级联删除:删除应用时用户偏好也被删除"""
|
||||
app = FileApp(
|
||||
name="待删除应用2",
|
||||
app_key="to_delete_2",
|
||||
type=FileAppType.BUILTIN,
|
||||
)
|
||||
app = await app.save(db_session)
|
||||
app_id = app.id
|
||||
|
||||
default = UserFileAppDefault(
|
||||
user_id=sample_user.id, extension="xyz", app_id=app_id
|
||||
)
|
||||
await default.save(db_session)
|
||||
|
||||
# 确认存在
|
||||
found = await UserFileAppDefault.get(
|
||||
db_session, UserFileAppDefault.app_id == app_id
|
||||
)
|
||||
assert found is not None
|
||||
|
||||
# 删除应用
|
||||
await FileApp.delete(db_session, app)
|
||||
|
||||
# 确认用户偏好也被删除
|
||||
found = await UserFileAppDefault.get(
|
||||
db_session, UserFileAppDefault.app_id == app_id
|
||||
)
|
||||
assert found is None
|
||||
|
||||
|
||||
# ==================== DTO ====================
|
||||
|
||||
class TestFileAppDTO:
|
||||
"""DTO 模型测试"""
|
||||
|
||||
async def test_file_app_response_from_app(
|
||||
self, db_session: AsyncSession, sample_app_with_extensions: FileApp, sample_group: Group
|
||||
) -> None:
|
||||
"""测试 FileAppResponse.from_app()"""
|
||||
from sqlmodels.file_app import FileAppResponse
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
db_session,
|
||||
FileAppExtension.app_id == sample_app_with_extensions.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
# 直接构造 link 对象用于 DTO 测试,无需持久化
|
||||
link = FileAppGroupLink(
|
||||
app_id=sample_app_with_extensions.id,
|
||||
group_id=sample_group.id,
|
||||
)
|
||||
|
||||
response = FileAppResponse.from_app(
|
||||
sample_app_with_extensions, extensions, [link]
|
||||
)
|
||||
|
||||
assert response.id == sample_app_with_extensions.id
|
||||
assert response.app_key == "test_pdfjs"
|
||||
assert "pdf" in response.extensions
|
||||
assert "djvu" in response.extensions
|
||||
assert sample_group.id in response.allowed_group_ids
|
||||
@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
|
||||
setting = await setting.save(db_session)
|
||||
|
||||
# 更新值
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
class SettingUpdate(SQLModelBase):
|
||||
value: str | None = None
|
||||
|
||||
178
tests/unit/utils/test_patch.py
Normal file
178
tests/unit/utils/test_patch.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
文本文件 patch 逻辑单元测试
|
||||
|
||||
测试 whatthepatch 库的 patch 解析与应用,
|
||||
以及换行符规范化和 SHA-256 哈希计算。
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
import whatthepatch
|
||||
from whatthepatch.exceptions import HunkApplyException
|
||||
|
||||
|
||||
class TestPatchApply:
|
||||
"""测试 patch 解析与应用"""
|
||||
|
||||
def test_normal_patch(self) -> None:
|
||||
"""正常 patch 应用"""
|
||||
original = "line1\nline2\nline3"
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,3 +1,3 @@\n"
|
||||
" line1\n"
|
||||
"-line2\n"
|
||||
"+LINE2\n"
|
||||
" line3\n"
|
||||
)
|
||||
|
||||
diffs = list(whatthepatch.parse_patch(patch_text))
|
||||
assert len(diffs) == 1
|
||||
|
||||
result = whatthepatch.apply_diff(diffs[0], original)
|
||||
new_text = '\n'.join(result)
|
||||
|
||||
assert "LINE2" in new_text
|
||||
assert "line2" not in new_text
|
||||
|
||||
def test_add_lines_patch(self) -> None:
|
||||
"""添加行的 patch"""
|
||||
original = "line1\nline2"
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,2 +1,3 @@\n"
|
||||
" line1\n"
|
||||
" line2\n"
|
||||
"+line3\n"
|
||||
)
|
||||
|
||||
diffs = list(whatthepatch.parse_patch(patch_text))
|
||||
result = whatthepatch.apply_diff(diffs[0], original)
|
||||
new_text = '\n'.join(result)
|
||||
|
||||
assert "line3" in new_text
|
||||
|
||||
def test_delete_lines_patch(self) -> None:
|
||||
"""删除行的 patch"""
|
||||
original = "line1\nline2\nline3"
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,3 +1,2 @@\n"
|
||||
" line1\n"
|
||||
"-line2\n"
|
||||
" line3\n"
|
||||
)
|
||||
|
||||
diffs = list(whatthepatch.parse_patch(patch_text))
|
||||
result = whatthepatch.apply_diff(diffs[0], original)
|
||||
new_text = '\n'.join(result)
|
||||
|
||||
assert "line2" not in new_text
|
||||
assert "line1" in new_text
|
||||
assert "line3" in new_text
|
||||
|
||||
def test_invalid_patch_format(self) -> None:
|
||||
"""无效的 patch 格式返回空列表"""
|
||||
diffs = list(whatthepatch.parse_patch("this is not a patch"))
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_patch_context_mismatch(self) -> None:
|
||||
"""patch 上下文不匹配时抛出异常"""
|
||||
original = "line1\nline2\nline3\n"
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -1,3 +1,3 @@\n"
|
||||
" line1\n"
|
||||
"-WRONG\n"
|
||||
"+REPLACED\n"
|
||||
" line3\n"
|
||||
)
|
||||
|
||||
diffs = list(whatthepatch.parse_patch(patch_text))
|
||||
with pytest.raises(HunkApplyException):
|
||||
whatthepatch.apply_diff(diffs[0], original)
|
||||
|
||||
def test_empty_file_patch(self) -> None:
|
||||
"""空文件应用 patch"""
|
||||
original = ""
|
||||
patch_text = (
|
||||
"--- a\n"
|
||||
"+++ b\n"
|
||||
"@@ -0,0 +1,2 @@\n"
|
||||
"+line1\n"
|
||||
"+line2\n"
|
||||
)
|
||||
|
||||
diffs = list(whatthepatch.parse_patch(patch_text))
|
||||
result = whatthepatch.apply_diff(diffs[0], original)
|
||||
new_text = '\n'.join(result)
|
||||
|
||||
assert "line1" in new_text
|
||||
assert "line2" in new_text
|
||||
|
||||
|
||||
class TestHashComputation:
|
||||
"""测试 SHA-256 哈希计算"""
|
||||
|
||||
def test_hash_consistency(self) -> None:
|
||||
"""相同内容产生相同哈希"""
|
||||
content = "hello world\n"
|
||||
content_bytes = content.encode('utf-8')
|
||||
hash1 = hashlib.sha256(content_bytes).hexdigest()
|
||||
hash2 = hashlib.sha256(content_bytes).hexdigest()
|
||||
|
||||
assert hash1 == hash2
|
||||
assert len(hash1) == 64
|
||||
|
||||
def test_hash_differs_for_different_content(self) -> None:
|
||||
"""不同内容产生不同哈希"""
|
||||
hash1 = hashlib.sha256(b"content A").hexdigest()
|
||||
hash2 = hashlib.sha256(b"content B").hexdigest()
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_hash_after_normalization(self) -> None:
|
||||
"""换行符规范化后的哈希一致性"""
|
||||
content_crlf = "line1\r\nline2\r\n"
|
||||
content_lf = "line1\nline2\n"
|
||||
|
||||
# 规范化后应相同
|
||||
normalized = content_crlf.replace('\r\n', '\n').replace('\r', '\n')
|
||||
assert normalized == content_lf
|
||||
|
||||
hash_normalized = hashlib.sha256(normalized.encode('utf-8')).hexdigest()
|
||||
hash_lf = hashlib.sha256(content_lf.encode('utf-8')).hexdigest()
|
||||
|
||||
assert hash_normalized == hash_lf
|
||||
|
||||
|
||||
class TestLineEndingNormalization:
|
||||
"""测试换行符规范化"""
|
||||
|
||||
def test_crlf_to_lf(self) -> None:
|
||||
"""CRLF 转换为 LF"""
|
||||
content = "line1\r\nline2\r\n"
|
||||
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
|
||||
assert normalized == "line1\nline2\n"
|
||||
|
||||
def test_cr_to_lf(self) -> None:
|
||||
"""CR 转换为 LF"""
|
||||
content = "line1\rline2\r"
|
||||
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
|
||||
assert normalized == "line1\nline2\n"
|
||||
|
||||
def test_lf_unchanged(self) -> None:
|
||||
"""LF 保持不变"""
|
||||
content = "line1\nline2\n"
|
||||
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
|
||||
assert normalized == content
|
||||
|
||||
def test_mixed_line_endings(self) -> None:
|
||||
"""混合换行符统一为 LF"""
|
||||
content = "line1\r\nline2\rline3\n"
|
||||
normalized = content.replace('\r\n', '\n').replace('\r', '\n')
|
||||
assert normalized == "line1\nline2\nline3\n"
|
||||
77
tests/unit/utils/test_wopi_token.py
Normal file
77
tests/unit/utils/test_wopi_token.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
WOPI Token 单元测试
|
||||
|
||||
测试 WOPI 访问令牌的生成和验证。
|
||||
"""
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import utils.JWT as JWT
|
||||
from utils.JWT.wopi_token import create_wopi_token, verify_wopi_token
|
||||
|
||||
# 确保测试 secret key
|
||||
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
|
||||
|
||||
|
||||
class TestWopiToken:
|
||||
"""WOPI Token 测试"""
|
||||
|
||||
def test_create_and_verify_token(self) -> None:
|
||||
"""创建和验证令牌"""
|
||||
file_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
token, ttl = create_wopi_token(file_id, user_id, can_write=True)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(ttl, int)
|
||||
assert ttl > 0
|
||||
|
||||
payload = verify_wopi_token(token)
|
||||
assert payload is not None
|
||||
assert payload.file_id == file_id
|
||||
assert payload.user_id == user_id
|
||||
assert payload.can_write is True
|
||||
|
||||
def test_verify_read_only_token(self) -> None:
|
||||
"""验证只读令牌"""
|
||||
file_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
token, ttl = create_wopi_token(file_id, user_id, can_write=False)
|
||||
|
||||
payload = verify_wopi_token(token)
|
||||
assert payload is not None
|
||||
assert payload.can_write is False
|
||||
|
||||
def test_verify_invalid_token(self) -> None:
|
||||
"""验证无效令牌返回 None"""
|
||||
payload = verify_wopi_token("invalid_token_string")
|
||||
assert payload is None
|
||||
|
||||
def test_verify_non_wopi_token(self) -> None:
|
||||
"""验证非 WOPI 类型令牌返回 None"""
|
||||
import jwt as pyjwt
|
||||
# 创建一个不含 type=wopi 的令牌
|
||||
token = pyjwt.encode(
|
||||
{"file_id": str(uuid4()), "user_id": str(uuid4()), "type": "download"},
|
||||
JWT.SECRET_KEY,
|
||||
algorithm="HS256",
|
||||
)
|
||||
payload = verify_wopi_token(token)
|
||||
assert payload is None
|
||||
|
||||
def test_ttl_is_future_milliseconds(self) -> None:
|
||||
"""TTL 应为未来的毫秒时间戳"""
|
||||
import time
|
||||
|
||||
file_id = uuid4()
|
||||
user_id = uuid4()
|
||||
token, ttl = create_wopi_token(file_id, user_id)
|
||||
|
||||
current_ms = int(time.time() * 1000)
|
||||
# TTL 应大于当前时间
|
||||
assert ttl > current_ms
|
||||
# TTL 不应超过 11 小时后(10h + 余量)
|
||||
assert ttl < current_ms + 11 * 3600 * 1000
|
||||
@@ -25,11 +25,11 @@ async def load_secret_key() -> None:
|
||||
从数据库读取 JWT 的密钥。
|
||||
"""
|
||||
# 延迟导入以避免循环依赖
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.setting import Setting
|
||||
|
||||
global SECRET_KEY
|
||||
async for session in get_session():
|
||||
async for session in DatabaseManager.get_session():
|
||||
setting: Setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == "auth") & (Setting.name == "secret_key")
|
||||
|
||||
67
utils/JWT/wopi_token.py
Normal file
67
utils/JWT/wopi_token.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
WOPI 访问令牌生成与验证。
|
||||
|
||||
使用 JWT 签名,payload 包含 file_id, user_id, can_write, exp。
|
||||
TTL 默认 10 小时(WOPI 规范推荐长 TTL)。
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
|
||||
from sqlmodels.wopi import WopiAccessTokenPayload
|
||||
|
||||
WOPI_TOKEN_TTL = timedelta(hours=10)
|
||||
"""WOPI 令牌有效期"""
|
||||
|
||||
|
||||
def create_wopi_token(
|
||||
file_id: UUID,
|
||||
user_id: UUID,
|
||||
can_write: bool = False,
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
创建 WOPI 访问令牌。
|
||||
|
||||
:param file_id: 文件UUID
|
||||
:param user_id: 用户UUID
|
||||
:param can_write: 是否可写
|
||||
:return: (token_string, access_token_ttl_ms)
|
||||
"""
|
||||
from utils.JWT import SECRET_KEY
|
||||
|
||||
expire = datetime.now(timezone.utc) + WOPI_TOKEN_TTL
|
||||
payload = {
|
||||
"jti": str(uuid4()),
|
||||
"file_id": str(file_id),
|
||||
"user_id": str(user_id),
|
||||
"can_write": can_write,
|
||||
"exp": expire,
|
||||
"type": "wopi",
|
||||
}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
|
||||
# WOPI 规范要求 access_token_ttl 是毫秒级的 UNIX 时间戳
|
||||
access_token_ttl = int(expire.timestamp() * 1000)
|
||||
return token, access_token_ttl
|
||||
|
||||
|
||||
def verify_wopi_token(token: str) -> WopiAccessTokenPayload | None:
|
||||
"""
|
||||
验证 WOPI 访问令牌并返回 payload。
|
||||
|
||||
:param token: JWT 令牌字符串
|
||||
:return: WopiAccessTokenPayload 或 None(验证失败)
|
||||
"""
|
||||
from utils.JWT import SECRET_KEY
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
|
||||
if payload.get("type") != "wopi":
|
||||
return None
|
||||
return WopiAccessTokenPayload(
|
||||
file_id=UUID(payload["file_id"]),
|
||||
user_id=UUID(payload["user_id"]),
|
||||
can_write=payload.get("can_write", False),
|
||||
)
|
||||
except (jwt.InvalidTokenError, KeyError, ValueError):
|
||||
return None
|
||||
@@ -40,6 +40,10 @@ def raise_conflict(detail: str | None = None) -> NoReturn:
|
||||
"""Raises an HTTP 409 Conflict exception."""
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=detail)
|
||||
|
||||
def raise_unprocessable_entity(detail: str | None = None) -> NoReturn:
|
||||
"""Raises an HTTP 422 Unprocessable Content exception."""
|
||||
raise HTTPException(status_code=422, detail=detail)
|
||||
|
||||
def raise_precondition_required(detail: str | None = None) -> NoReturn:
|
||||
"""Raises an HTTP 428 Precondition required exception."""
|
||||
raise HTTPException(status_code=status.HTTP_428_PRECONDITION_REQUIRED, detail=detail)
|
||||
|
||||
142
uv.lock
generated
142
uv.lock
generated
@@ -474,6 +474,32 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/cb/2da4cc83f5edb9c3257d09e1e7ab7b23f049c7962cae8d842bbef0a9cec9/cryptography-46.0.3-cp38-abi3-win_arm64.whl", hash = "sha256:d89c3468de4cdc4f08a57e214384d0471911a3830fcdaf7a8cc587e42a866372", size = 2918740, upload-time = "2025-10-15T23:18:12.277Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cython"
|
||||
version = "3.2.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/85/7574c9cd44b69a27210444b6650f6477f56c75fee1b70d7672d3e4166167/cython-3.2.4.tar.gz", hash = "sha256:84226ecd313b233da27dc2eb3601b4f222b8209c3a7216d8733b031da1dc64e6", size = 3280291, upload-time = "2026-01-04T14:14:14.473Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/18/b5/1cfca43b7d20a0fdb1eac67313d6bb6b18d18897f82dd0f17436bdd2ba7f/cython-3.2.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:28e8075087a59756f2d059273184b8b639fe0f16cf17470bd91c39921bc154e0", size = 2960506, upload-time = "2026-01-04T14:15:16.733Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/71/bb/8f28c39c342621047fea349a82fac712a5e2b37546d2f737bbde48d5143d/cython-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03893c88299a2c868bb741ba6513357acd104e7c42265809fd58dce1456a36fc", size = 3213148, upload-time = "2026-01-04T14:15:18.804Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/d2/16fa02f129ed2b627e88d9d9ebd5ade3eeb66392ae5ba85b259d2d52b047/cython-3.2.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f81eda419b5ada7b197bbc3c5f4494090e3884521ffd75a3876c93fbf66c9ca8", size = 3375764, upload-time = "2026-01-04T14:15:20.817Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/3f/deb8f023a5c10c0649eb81332a58c180fad27c7533bb4aae138b5bc34d92/cython-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:83266c356c13c68ffe658b4905279c993d8a5337bb0160fa90c8a3e297ea9a2e", size = 2754238, upload-time = "2026-01-04T14:15:23.001Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/d7/3bda3efce0c5c6ce79cc21285dbe6f60369c20364e112f5a506ee8a1b067/cython-3.2.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d4b4fd5332ab093131fa6172e8362f16adef3eac3179fd24bbdc392531cb82fa", size = 2971496, upload-time = "2026-01-04T14:15:25.038Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/ed/1021ffc80b9c4720b7ba869aea8422c82c84245ef117ebe47a556bdc00c3/cython-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3b5ac54e95f034bc7fb07313996d27cbf71abc17b229b186c1540942d2dc28e", size = 3256146, upload-time = "2026-01-04T14:15:26.741Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/51/ca221ec7e94b3c5dc4138dcdcbd41178df1729c1e88c5dfb25f9d30ba3da/cython-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:90f43be4eaa6afd58ce20d970bb1657a3627c44e1760630b82aa256ba74b4acb", size = 3383458, upload-time = "2026-01-04T14:15:28.425Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/2e/1388fc0243240cd54994bb74f26aaaf3b2e22f89d3a2cf8da06d75d46ca2/cython-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:983f9d2bb8a896e16fa68f2b37866ded35fa980195eefe62f764ddc5f9f5ef8e", size = 2791241, upload-time = "2026-01-04T14:15:30.448Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/8b/fd393f0923c82be4ec0db712fffb2ff0a7a131707b842c99bf24b549274d/cython-3.2.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:36bf3f5eb56d5281aafabecbaa6ed288bc11db87547bba4e1e52943ae6961ccf", size = 2875622, upload-time = "2026-01-04T14:15:39.749Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/48/48530d9b9d64ec11dbe0dd3178a5fe1e0b27977c1054ecffb82be81e9b6a/cython-3.2.4-cp39-abi3-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6d5267f22b6451eb1e2e1b88f6f78a2c9c8733a6ddefd4520d3968d26b824581", size = 3210669, upload-time = "2026-01-04T14:15:41.911Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/91/4865fbfef1f6bb4f21d79c46104a53d1a3fa4348286237e15eafb26e0828/cython-3.2.4-cp39-abi3-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3b6e58f73a69230218d5381817850ce6d0da5bb7e87eb7d528c7027cbba40b06", size = 2856835, upload-time = "2026-01-04T14:15:43.815Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/39/60317957dbef179572398253f29d28f75f94ab82d6d39ea3237fb6c89268/cython-3.2.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e71efb20048358a6b8ec604a0532961c50c067b5e63e345e2e359fff72feaee8", size = 2994408, upload-time = "2026-01-04T14:15:45.422Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/30/7c24d9292650db4abebce98abc9b49c820d40fa7c87921c0a84c32f4efe7/cython-3.2.4-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:28b1e363b024c4b8dcf52ff68125e635cb9cb4b0ba997d628f25e32543a71103", size = 2891478, upload-time = "2026-01-04T14:15:47.394Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/70/03dc3c962cde9da37a93cca8360e576f904d5f9beecfc9d70b1f820d2e5f/cython-3.2.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:31a90b4a2c47bb6d56baeb926948348ec968e932c1ae2c53239164e3e8880ccf", size = 3225663, upload-time = "2026-01-04T14:15:49.446Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/97/10b50c38313c37b1300325e2e53f48ea9a2c078a85c0c9572057135e31d5/cython-3.2.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e65e4773021f8dc8532010b4fbebe782c77f9a0817e93886e518c93bd6a44e9d", size = 3115628, upload-time = "2026-01-04T14:15:51.323Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/b1/d6a353c9b147848122a0db370863601fdf56de2d983b5c4a6a11e6ee3cd7/cython-3.2.4-cp39-abi3-win32.whl", hash = "sha256:2b1f12c0e4798293d2754e73cd6f35fa5bbdf072bdc14bc6fc442c059ef2d290", size = 2437463, upload-time = "2026-01-04T14:15:53.787Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/d8/319a1263b9c33b71343adfd407e5daffd453daef47ebc7b642820a8b68ed/cython-3.2.4-cp39-abi3-win_arm64.whl", hash = "sha256:3b8e62049afef9da931d55de82d8f46c9a147313b69d5ff6af6e9121d545ce7a", size = 2442754, upload-time = "2026-01-04T14:15:55.382Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/fa/d3c15189f7c52aaefbaea76fb012119b04b9013f4bf446cb4eb4c26c4e6b/cython-3.2.4-py3-none-any.whl", hash = "sha256:732fc93bc33ae4b14f6afaca663b916c2fdd5dcbfad7114e17fb2434eeaea45c", size = 1257078, upload-time = "2026-01-04T14:14:12.373Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "disknext-server"
|
||||
version = "0.0.1"
|
||||
@@ -486,6 +512,7 @@ dependencies = [
|
||||
{ name = "asyncpg" },
|
||||
{ name = "cachetools" },
|
||||
{ name = "captcha" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "fastapi", extra = ["standard"] },
|
||||
{ name = "httpx" },
|
||||
{ name = "itsdangerous" },
|
||||
@@ -502,8 +529,16 @@ dependencies = [
|
||||
{ name = "redis", extra = ["hiredis"] },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "sqlmodel" },
|
||||
{ name = "sqlmodel-ext", extra = ["pgvector"] },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "webauthn" },
|
||||
{ name = "whatthepatch" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
build = [
|
||||
{ name = "cython" },
|
||||
{ name = "setuptools" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -515,6 +550,8 @@ requires-dist = [
|
||||
{ name = "asyncpg", specifier = ">=0.31.0" },
|
||||
{ name = "cachetools", specifier = ">=6.2.4" },
|
||||
{ name = "captcha", specifier = ">=0.7.1" },
|
||||
{ name = "cryptography", specifier = ">=46.0.3" },
|
||||
{ name = "cython", marker = "extra == 'build'", specifier = ">=3.0.11" },
|
||||
{ name = "fastapi", extras = ["standard"], specifier = ">=0.122.0" },
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||
@@ -529,11 +566,15 @@ requires-dist = [
|
||||
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||
{ name = "redis", extras = ["hiredis"], specifier = ">=7.1.0" },
|
||||
{ name = "setuptools", marker = "extra == 'build'", specifier = ">=75.0.0" },
|
||||
{ name = "sqlalchemy", specifier = ">=2.0.44" },
|
||||
{ name = "sqlmodel", specifier = ">=0.0.27" },
|
||||
{ name = "sqlmodel-ext", extras = ["pgvector"], specifier = ">=0.1.1" },
|
||||
{ name = "uvicorn", specifier = ">=0.38.0" },
|
||||
{ name = "webauthn", specifier = ">=2.7.0" },
|
||||
{ name = "whatthepatch", specifier = ">=1.0.6" },
|
||||
]
|
||||
provides-extras = ["build"]
|
||||
|
||||
[[package]]
|
||||
name = "dnspython"
|
||||
@@ -1101,6 +1142,56 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "2.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/57/fd/0005efbd0af48e55eb3c7208af93f2862d4b1a56cd78e84309a2d959208d/numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae", size = 20723651, upload-time = "2026-01-31T23:13:10.135Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/22/815b9fe25d1d7ae7d492152adbc7226d3eff731dffc38fe970589fcaaa38/numpy-2.4.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25f2059807faea4b077a2b6837391b5d830864b3543627f381821c646f31a63c", size = 16663696, upload-time = "2026-01-31T23:11:17.516Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/f0/817d03a03f93ba9c6c8993de509277d84e69f9453601915e4a69554102a1/numpy-2.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bd3a7a9f5847d2fb8c2c6d1c862fa109c31a9abeca1a3c2bd5a64572955b2979", size = 14688322, upload-time = "2026-01-31T23:11:19.883Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/b4/f805ab79293c728b9a99438775ce51885fd4f31b76178767cfc718701a39/numpy-2.4.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:8e4549f8a3c6d13d55041925e912bfd834285ef1dd64d6bc7d542583355e2e98", size = 5198157, upload-time = "2026-01-31T23:11:22.375Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/74/09/826e4289844eccdcd64aac27d13b0fd3f32039915dd5b9ba01baae1f436c/numpy-2.4.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:aea4f66ff44dfddf8c2cffd66ba6538c5ec67d389285292fe428cb2c738c8aef", size = 6546330, upload-time = "2026-01-31T23:11:23.958Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/fb/cbfdbfa3057a10aea5422c558ac57538e6acc87ec1669e666d32ac198da7/numpy-2.4.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3cd545784805de05aafe1dde61752ea49a359ccba9760c1e5d1c88a93bbf2b7", size = 15660968, upload-time = "2026-01-31T23:11:25.713Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/dc/46066ce18d01645541f0186877377b9371b8fa8017fa8262002b4ef22612/numpy-2.4.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d0d9b7c93578baafcbc5f0b83eaf17b79d345c6f36917ba0c67f45226911d499", size = 16607311, upload-time = "2026-01-31T23:11:28.117Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/d9/4b5adfc39a43fa6bf918c6d544bc60c05236cc2f6339847fc5b35e6cb5b0/numpy-2.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f74f0f7779cc7ae07d1810aab8ac6b1464c3eafb9e283a40da7309d5e6e48fbb", size = 17012850, upload-time = "2026-01-31T23:11:30.888Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/20/adb6e6adde6d0130046e6fdfb7675cc62bc2f6b7b02239a09eb58435753d/numpy-2.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c7ac672d699bf36275c035e16b65539931347d68b70667d28984c9fb34e07fa7", size = 18334210, upload-time = "2026-01-31T23:11:33.214Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/0e/0a73b3dff26803a8c02baa76398015ea2a5434d9b8265a7898a6028c1591/numpy-2.4.2-cp313-cp313-win32.whl", hash = "sha256:8e9afaeb0beff068b4d9cd20d322ba0ee1cecfb0b08db145e4ab4dd44a6b5110", size = 5958199, upload-time = "2026-01-31T23:11:35.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/bc/6352f343522fcb2c04dbaf94cb30cca6fd32c1a750c06ad6231b4293708c/numpy-2.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:7df2de1e4fba69a51c06c28f5a3de36731eb9639feb8e1cf7e4a7b0daf4cf622", size = 12310848, upload-time = "2026-01-31T23:11:38.001Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/8d/6da186483e308da5da1cc6918ce913dcfe14ffde98e710bfeff2a6158d4e/numpy-2.4.2-cp313-cp313-win_arm64.whl", hash = "sha256:0fece1d1f0a89c16b03442eae5c56dc0be0c7883b5d388e0c03f53019a4bfd71", size = 10221082, upload-time = "2026-01-31T23:11:40.392Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/a1/9510aa43555b44781968935c7548a8926274f815de42ad3997e9e83680dd/numpy-2.4.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5633c0da313330fd20c484c78cdd3f9b175b55e1a766c4a174230c6b70ad8262", size = 14815866, upload-time = "2026-01-31T23:11:42.495Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/30/6bbb5e76631a5ae46e7923dd16ca9d3f1c93cfa8d4ed79a129814a9d8db3/numpy-2.4.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d9f64d786b3b1dd742c946c42d15b07497ed14af1a1f3ce840cce27daa0ce913", size = 5325631, upload-time = "2026-01-31T23:11:44.7Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/00/3a490938800c1923b567b3a15cd17896e68052e2145d8662aaf3e1ffc58f/numpy-2.4.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:b21041e8cb6a1eb5312dd1d2f80a94d91efffb7a06b70597d44f1bd2dfc315ab", size = 6646254, upload-time = "2026-01-31T23:11:46.341Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/e9/fac0890149898a9b609caa5af7455a948b544746e4b8fe7c212c8edd71f8/numpy-2.4.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:00ab83c56211a1d7c07c25e3217ea6695e50a3e2f255053686b081dc0b091a82", size = 15720138, upload-time = "2026-01-31T23:11:48.082Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/5c/08887c54e68e1e28df53709f1893ce92932cc6f01f7c3d4dc952f61ffd4e/numpy-2.4.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2fb882da679409066b4603579619341c6d6898fc83a8995199d5249f986e8e8f", size = 16655398, upload-time = "2026-01-31T23:11:50.293Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/89/253db0fa0e66e9129c745e4ef25631dc37d5f1314dad2b53e907b8538e6d/numpy-2.4.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:66cb9422236317f9d44b67b4d18f44efe6e9c7f8794ac0462978513359461554", size = 17079064, upload-time = "2026-01-31T23:11:52.927Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/d5/cbade46ce97c59c6c3da525e8d95b7abe8a42974a1dc5c1d489c10433e88/numpy-2.4.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0f01dcf33e73d80bd8dc0f20a71303abbafa26a19e23f6b68d1aa9990af90257", size = 18379680, upload-time = "2026-01-31T23:11:55.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/62/48f99ae172a4b63d981babe683685030e8a3df4f246c893ea5c6ef99f018/numpy-2.4.2-cp313-cp313t-win32.whl", hash = "sha256:52b913ec40ff7ae845687b0b34d8d93b60cb66dcee06996dd5c99f2fc9328657", size = 6082433, upload-time = "2026-01-31T23:11:58.096Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/38/e054a61cfe48ad9f1ed0d188e78b7e26859d0b60ef21cd9de4897cdb5326/numpy-2.4.2-cp313-cp313t-win_amd64.whl", hash = "sha256:5eea80d908b2c1f91486eb95b3fb6fab187e569ec9752ab7d9333d2e66bf2d6b", size = 12451181, upload-time = "2026-01-31T23:11:59.782Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/a4/a05c3a6418575e185dd84d0b9680b6bb2e2dc3e4202f036b7b4e22d6e9dc/numpy-2.4.2-cp313-cp313t-win_arm64.whl", hash = "sha256:fd49860271d52127d61197bb50b64f58454e9f578cb4b2c001a6de8b1f50b0b1", size = 10290756, upload-time = "2026-01-31T23:12:02.438Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/88/b7df6050bf18fdcfb7046286c6535cabbdd2064a3440fca3f069d319c16e/numpy-2.4.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:444be170853f1f9d528428eceb55f12918e4fda5d8805480f36a002f1415e09b", size = 16663092, upload-time = "2026-01-31T23:12:04.521Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/7a/1fee4329abc705a469a4afe6e69b1ef7e915117747886327104a8493a955/numpy-2.4.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d1240d50adff70c2a88217698ca844723068533f3f5c5fa6ee2e3220e3bdb000", size = 14698770, upload-time = "2026-01-31T23:12:06.96Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/0b/f9e49ba6c923678ad5bc38181c08ac5e53b7a5754dbca8e581aa1a56b1ff/numpy-2.4.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:7cdde6de52fb6664b00b056341265441192d1291c130e99183ec0d4b110ff8b1", size = 5208562, upload-time = "2026-01-31T23:12:09.632Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/12/d7de8f6f53f9bb76997e5e4c069eda2051e3fe134e9181671c4391677bb2/numpy-2.4.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:cda077c2e5b780200b6b3e09d0b42205a3d1c68f30c6dceb90401c13bff8fe74", size = 6543710, upload-time = "2026-01-31T23:12:11.969Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/63/c66418c2e0268a31a4cf8a8b512685748200f8e8e8ec6c507ce14e773529/numpy-2.4.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d30291931c915b2ab5717c2974bb95ee891a1cf22ebc16a8006bd59cd210d40a", size = 15677205, upload-time = "2026-01-31T23:12:14.33Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/6c/7f237821c9642fb2a04d2f1e88b4295677144ca93285fd76eff3bcba858d/numpy-2.4.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bba37bc29d4d85761deed3954a1bc62be7cf462b9510b51d367b769a8c8df325", size = 16611738, upload-time = "2026-01-31T23:12:16.525Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/a7/39c4cdda9f019b609b5c473899d87abff092fc908cfe4d1ecb2fcff453b0/numpy-2.4.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b2f0073ed0868db1dcd86e052d37279eef185b9c8db5bf61f30f46adac63c909", size = 17028888, upload-time = "2026-01-31T23:12:19.306Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/b3/e84bb64bdfea967cc10950d71090ec2d84b49bc691df0025dddb7c26e8e3/numpy-2.4.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7f54844851cdb630ceb623dcec4db3240d1ac13d4990532446761baede94996a", size = 18339556, upload-time = "2026-01-31T23:12:21.816Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/f5/954a291bc1192a27081706862ac62bb5920fbecfbaa302f64682aa90beed/numpy-2.4.2-cp314-cp314-win32.whl", hash = "sha256:12e26134a0331d8dbd9351620f037ec470b7c75929cb8a1537f6bfe411152a1a", size = 6006899, upload-time = "2026-01-31T23:12:24.14Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/cb/eff72a91b2efdd1bc98b3b8759f6a1654aa87612fc86e3d87d6fe4f948c4/numpy-2.4.2-cp314-cp314-win_amd64.whl", hash = "sha256:068cdb2d0d644cdb45670810894f6a0600797a69c05f1ac478e8d31670b8ee75", size = 12443072, upload-time = "2026-01-31T23:12:26.33Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/75/62726948db36a56428fce4ba80a115716dc4fad6a3a4352487f8bb950966/numpy-2.4.2-cp314-cp314-win_arm64.whl", hash = "sha256:6ed0be1ee58eef41231a5c943d7d1375f093142702d5723ca2eb07db9b934b05", size = 10494886, upload-time = "2026-01-31T23:12:28.488Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/2f/ee93744f1e0661dc267e4b21940870cabfae187c092e1433b77b09b50ac4/numpy-2.4.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:98f16a80e917003a12c0580f97b5f875853ebc33e2eaa4bccfc8201ac6869308", size = 14818567, upload-time = "2026-01-31T23:12:30.709Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/24/6535212add7d76ff938d8bdc654f53f88d35cddedf807a599e180dcb8e66/numpy-2.4.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:20abd069b9cda45874498b245c8015b18ace6de8546bf50dfa8cea1696ed06ef", size = 5328372, upload-time = "2026-01-31T23:12:32.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/9d/c48f0a035725f925634bf6b8994253b43f2047f6778a54147d7e213bc5a7/numpy-2.4.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:e98c97502435b53741540a5717a6749ac2ada901056c7db951d33e11c885cc7d", size = 6649306, upload-time = "2026-01-31T23:12:34.797Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/05/7c73a9574cd4a53a25907bad38b59ac83919c0ddc8234ec157f344d57d9a/numpy-2.4.2-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da6cad4e82cb893db4b69105c604d805e0c3ce11501a55b5e9f9083b47d2ffe8", size = 15722394, upload-time = "2026-01-31T23:12:36.565Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/fa/4de10089f21fc7d18442c4a767ab156b25c2a6eaf187c0db6d9ecdaeb43f/numpy-2.4.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e4424677ce4b47fe73c8b5556d876571f7c6945d264201180db2dc34f676ab5", size = 16653343, upload-time = "2026-01-31T23:12:39.188Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/f9/d33e4ffc857f3763a57aa85650f2e82486832d7492280ac21ba9efda80da/numpy-2.4.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:2b8f157c8a6f20eb657e240f8985cc135598b2b46985c5bccbde7616dc9c6b1e", size = 17078045, upload-time = "2026-01-31T23:12:42.041Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/b8/54bdb43b6225badbea6389fa038c4ef868c44f5890f95dd530a218706da3/numpy-2.4.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5daf6f3914a733336dab21a05cdec343144600e964d2fcdabaac0c0269874b2a", size = 18380024, upload-time = "2026-01-31T23:12:44.331Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/55/6e1a61ded7af8df04016d81b5b02daa59f2ea9252ee0397cb9f631efe9e5/numpy-2.4.2-cp314-cp314t-win32.whl", hash = "sha256:8c50dd1fc8826f5b26a5ee4d77ca55d88a895f4e4819c7ecc2a9f5905047a443", size = 6153937, upload-time = "2026-01-31T23:12:47.229Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/aa/fa6118d1ed6d776b0983f3ceac9b1a5558e80df9365b1c3aa6d42bf9eee4/numpy-2.4.2-cp314-cp314t-win_amd64.whl", hash = "sha256:fcf92bee92742edd401ba41135185866f7026c502617f422eb432cfeca4fe236", size = 12631844, upload-time = "2026-01-31T23:12:48.997Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/0a/2ec5deea6dcd158f254a7b372fb09cfba5719419c8d66343bab35237b3fb/numpy-2.4.2-cp314-cp314t-win_arm64.whl", hash = "sha256:1f92f53998a17265194018d1cc321b2e96e900ca52d54c7c77837b71b9465181", size = 10565379, upload-time = "2026-01-31T23:12:51.345Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.11.7"
|
||||
@@ -1148,6 +1239,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pgvector"
|
||||
version = "0.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/25/6c/6d8b4b03b958c02fa8687ec6063c49d952a189f8c91ebbe51e877dfab8f7/pgvector-0.4.2.tar.gz", hash = "sha256:322cac0c1dc5d41c9ecf782bd9991b7966685dee3a00bc873631391ed949513a", size = 31354, upload-time = "2025-12-05T01:07:17.87Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/26/6cee8a1ce8c43625ec561aff19df07f9776b7525d9002c86bceb3e0ac970/pgvector-0.4.2-py3-none-any.whl", hash = "sha256:549d45f7a18593783d5eec609ea1684a724ba8405c4cb182a0b2b08aeff04e08", size = 27441, upload-time = "2025-12-05T01:07:16.536Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "12.1.1"
|
||||
@@ -1648,6 +1751,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/19/8d77f9992e5cbfcaa9133c3bf63b4fbbb051248802e1e803fed5c552fbb2/sentry_sdk-2.48.0-py2.py3-none-any.whl", hash = "sha256:6b12ac256769d41825d9b7518444e57fa35b5642df4c7c5e322af4d2c8721172", size = 414555, upload-time = "2025-12-16T14:55:40.152Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "82.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/82/f3/748f4d6f65d1756b9ae577f329c951cda23fb900e4de9f70900ced962085/setuptools-82.0.0.tar.gz", hash = "sha256:22e0a2d69474c6ae4feb01951cb69d515ed23728cf96d05513d36e42b62b37cb", size = 1144893, upload-time = "2026-02-08T15:08:40.206Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/c6/76dc613121b793286a3f91621d7b75a2b493e0390ddca50f11993eadf192/setuptools-82.0.0-py3-none-any.whl", hash = "sha256:70b18734b607bd1da571d097d236cfcfacaf01de45717d59e6e04b96877532e0", size = 1003468, upload-time = "2026-02-08T15:08:38.723Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shellingham"
|
||||
version = "1.5.4"
|
||||
@@ -1699,6 +1811,27 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/3f/f71798ed4cc6784122af8d404df187cfc683db795b3eee542e2df23f94a9/sqlmodel-0.0.29-py3-none-any.whl", hash = "sha256:56e4b3e9ab5825f6cbdd40bc46f90f01c0f73c73f44843a03d494c0845a948cd", size = 29355, upload-time = "2025-12-23T20:59:45.368Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlmodel-ext"
|
||||
version = "0.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pydantic" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "sqlmodel" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ff/01/c2b5c4939a0d466fde14fd82fd6bbefa39f69b7b9c63b01368862773ce04/sqlmodel_ext-0.1.1.tar.gz", hash = "sha256:2faa65cd8130f8a2384d2dadfa566de1d6f372ce9b39c46ebd8d2955529870fc", size = 70519, upload-time = "2026-02-07T13:08:14.249Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/93/f424feac8ad208bfa541a8e2fe0bb53bcd272735e40b18ee1ad7185cfd1f/sqlmodel_ext-0.1.1-py3-none-any.whl", hash = "sha256:f572abe3070c247ddc6dfbfe83b7b6a70567fbfca0225bef60df167632e6d836", size = 59033, upload-time = "2026-02-07T13:08:19.197Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
pgvector = [
|
||||
{ name = "numpy" },
|
||||
{ name = "orjson" },
|
||||
{ name = "pgvector" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "starlette"
|
||||
version = "0.50.0"
|
||||
@@ -1898,6 +2031,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whatthepatch"
|
||||
version = "1.0.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/06/28/55bc3e107a56fdcf7d5022cb32b8c21d98a9cc2df5cd9f3b93e10419099e/whatthepatch-1.0.7.tar.gz", hash = "sha256:9eefb4ebea5200408e02d413d2b4bc28daea6b78bb4b4d53431af7245f7d7edf", size = 34612, upload-time = "2024-11-16T17:21:22.153Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/93/af1d6ccb69ab6b5a00e03fa0cefa563f9862412667776ea15dd4eece3a90/whatthepatch-1.0.7-py3-none-any.whl", hash = "sha256:1b6f655fd31091c001c209529dfaabbabdbad438f5de14e3951266ea0fc6e7ed", size = 11964, upload-time = "2024-11-16T17:21:20.761Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "win32-setctime"
|
||||
version = "1.2.0"
|
||||
|
||||
Reference in New Issue
Block a user