Compare commits
22 Commits
a99091ea7a
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 15b2efe52a | |||
| 6c96c43bea | |||
| 9185f26b83 | |||
| f4052d229a | |||
| bc2182720d | |||
| eddf38d316 | |||
| 03e768d232 | |||
| bcb0a9b322 | |||
| 743a2c9d65 | |||
| 3639a31163 | |||
| 7200df6d87 | |||
| 40b6a31c98 | |||
| 19837b4817 | |||
| b5d09009e3 | |||
| 0b521ae8ab | |||
| eac0766e79 | |||
| 53b757de7a | |||
| 69f852a4ce | |||
| 800c85bf8d | |||
| 729773cae3 | |||
| d831c9c0d6 | |||
| 4c1b7a8aad |
@@ -5,7 +5,8 @@
|
|||||||
"Bash(findstr:*)",
|
"Bash(findstr:*)",
|
||||||
"Bash(find:*)",
|
"Bash(find:*)",
|
||||||
"Bash(yarn tsc:*)",
|
"Bash(yarn tsc:*)",
|
||||||
"Bash(dir:*)"
|
"Bash(dir:*)",
|
||||||
|
"mcp__server-notify__notify"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
37
.dockerignore
Normal file
37
.dockerignore
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
.git/
|
||||||
|
.gitignore
|
||||||
|
.github/
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
.venv/
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
.run/
|
||||||
|
.claude/
|
||||||
|
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
|
||||||
|
tests/
|
||||||
|
htmlcov/
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
coverage.xml
|
||||||
|
|
||||||
|
*.db
|
||||||
|
*.sqlite
|
||||||
|
*.sqlite3
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
data/
|
||||||
|
|
||||||
|
Dockerfile
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
# Cython 编译产物
|
||||||
|
*.c
|
||||||
|
build/
|
||||||
|
|
||||||
|
# 许可证私钥和工具脚本
|
||||||
|
license_private.pem
|
||||||
|
scripts/
|
||||||
31
.gitea/workflows/test.yml
Normal file
31
.gitea/workflows/test.yml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
name: Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main, develop]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: ghcr.io/catthehacker/ubuntu:act-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.13"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: uv run pytest tests/ -v --tb=short
|
||||||
14
.github/copilot-instructions.md
vendored
14
.github/copilot-instructions.md
vendored
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
|||||||
```python
|
```python
|
||||||
import asyncio
|
import asyncio
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
class CharacterBase(SQLModelBase):
|
class CharacterBase(SQLModelBase):
|
||||||
name: str
|
name: str
|
||||||
"""角色名称"""
|
"""角色名称"""
|
||||||
|
|
||||||
class Character(CharacterBase, UUIDTableBase):
|
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||||
"""充血模型:包含数据和业务逻辑"""
|
"""充血模型:包含数据和业务逻辑"""
|
||||||
|
|
||||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
from sqlmodels.base import (
|
from sqlmodel_ext import (
|
||||||
SQLModelBase,
|
SQLModelBase,
|
||||||
UUIDTableBase,
|
UUIDTableBaseMixin,
|
||||||
create_subclass_id_mixin,
|
create_subclass_id_mixin,
|
||||||
AutoPolymorphicIdentityMixin
|
AutoPolymorphicIdentityMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 定义Base类(只有字段,无表)
|
# 1. 定义Base类(只有字段,无表)
|
||||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
|||||||
# 2. 定义抽象父类(有表)
|
# 2. 定义抽象父类(有表)
|
||||||
class ASR(
|
class ASR(
|
||||||
ASRBase,
|
ASRBase,
|
||||||
UUIDTableBase,
|
UUIDTableBaseMixin,
|
||||||
ABC,
|
ABC,
|
||||||
polymorphic_on='__polymorphic_name',
|
polymorphic_on='__polymorphic_name',
|
||||||
polymorphic_abstract=True
|
polymorphic_abstract=True
|
||||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
|||||||
# 3. 本地应用导入(从项目根目录的包开始)
|
# 3. 本地应用导入(从项目根目录的包开始)
|
||||||
from dependencies import SessionDep
|
from dependencies import SessionDep
|
||||||
from sqlmodels.user import User
|
from sqlmodels.user import User
|
||||||
from sqlmodels.base import UUIDTableBase
|
from sqlmodel_ext import UUIDTableBaseMixin
|
||||||
|
|
||||||
# 4. 相对导入(同包内的模块)
|
# 4. 相对导入(同包内的模块)
|
||||||
from .base import BaseClass
|
from .base import BaseClass
|
||||||
|
|||||||
29
.github/workflows/test.yml
vendored
Normal file
29
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main, develop]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.13"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: uv run pytest tests/ -v --tb=short
|
||||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -1,8 +1,6 @@
|
|||||||
# Python
|
# Python
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
*.so
|
*.so
|
||||||
*.egg
|
*.egg
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
@@ -69,3 +67,16 @@ data/
|
|||||||
# JB 的运行配置(换设备用不了)
|
# JB 的运行配置(换设备用不了)
|
||||||
.run/
|
.run/
|
||||||
.xml
|
.xml
|
||||||
|
|
||||||
|
# 前端构建产物(Docker 构建时复制)
|
||||||
|
statics/
|
||||||
|
|
||||||
|
# Cython 编译产物
|
||||||
|
*.c
|
||||||
|
|
||||||
|
# 许可证密钥(保密)
|
||||||
|
license_private.pem
|
||||||
|
license.key
|
||||||
|
|
||||||
|
avatar/
|
||||||
|
.dev/
|
||||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "ee"]
|
||||||
|
path = ee
|
||||||
|
url = https://git.yxqi.cn/Yuerchu/disknext-ee.git
|
||||||
14
AGENTS.md
14
AGENTS.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
|||||||
```python
|
```python
|
||||||
import asyncio
|
import asyncio
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
class CharacterBase(SQLModelBase):
|
class CharacterBase(SQLModelBase):
|
||||||
name: str
|
name: str
|
||||||
"""角色名称"""
|
"""角色名称"""
|
||||||
|
|
||||||
class Character(CharacterBase, UUIDTableBase):
|
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||||
"""充血模型:包含数据和业务逻辑"""
|
"""充血模型:包含数据和业务逻辑"""
|
||||||
|
|
||||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
from sqlmodels.base import (
|
from sqlmodel_ext import (
|
||||||
SQLModelBase,
|
SQLModelBase,
|
||||||
UUIDTableBase,
|
UUIDTableBaseMixin,
|
||||||
create_subclass_id_mixin,
|
create_subclass_id_mixin,
|
||||||
AutoPolymorphicIdentityMixin
|
AutoPolymorphicIdentityMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 定义Base类(只有字段,无表)
|
# 1. 定义Base类(只有字段,无表)
|
||||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
|||||||
# 2. 定义抽象父类(有表)
|
# 2. 定义抽象父类(有表)
|
||||||
class ASR(
|
class ASR(
|
||||||
ASRBase,
|
ASRBase,
|
||||||
UUIDTableBase,
|
UUIDTableBaseMixin,
|
||||||
ABC,
|
ABC,
|
||||||
polymorphic_on='__polymorphic_name',
|
polymorphic_on='__polymorphic_name',
|
||||||
polymorphic_abstract=True
|
polymorphic_abstract=True
|
||||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
|||||||
# 3. 本地应用导入(从项目根目录的包开始)
|
# 3. 本地应用导入(从项目根目录的包开始)
|
||||||
from dependencies import SessionDep
|
from dependencies import SessionDep
|
||||||
from sqlmodels.user import User
|
from sqlmodels.user import User
|
||||||
from sqlmodels.base import UUIDTableBase
|
from sqlmodel_ext import UUIDTableBaseMixin
|
||||||
|
|
||||||
# 4. 相对导入(同包内的模块)
|
# 4. 相对导入(同包内的模块)
|
||||||
from .base import BaseClass
|
from .base import BaseClass
|
||||||
|
|||||||
14
CLAUDE.md
14
CLAUDE.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
|||||||
```python
|
```python
|
||||||
import asyncio
|
import asyncio
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||||
|
|
||||||
class CharacterBase(SQLModelBase):
|
class CharacterBase(SQLModelBase):
|
||||||
name: str
|
name: str
|
||||||
"""角色名称"""
|
"""角色名称"""
|
||||||
|
|
||||||
class Character(CharacterBase, UUIDTableBase):
|
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||||
"""充血模型:包含数据和业务逻辑"""
|
"""充血模型:包含数据和业务逻辑"""
|
||||||
|
|
||||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
from sqlmodels.base import (
|
from sqlmodel_ext import (
|
||||||
SQLModelBase,
|
SQLModelBase,
|
||||||
UUIDTableBase,
|
UUIDTableBaseMixin,
|
||||||
create_subclass_id_mixin,
|
create_subclass_id_mixin,
|
||||||
AutoPolymorphicIdentityMixin
|
AutoPolymorphicIdentityMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 定义Base类(只有字段,无表)
|
# 1. 定义Base类(只有字段,无表)
|
||||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
|||||||
# 2. 定义抽象父类(有表)
|
# 2. 定义抽象父类(有表)
|
||||||
class ASR(
|
class ASR(
|
||||||
ASRBase,
|
ASRBase,
|
||||||
UUIDTableBase,
|
UUIDTableBaseMixin,
|
||||||
ABC,
|
ABC,
|
||||||
polymorphic_on='__polymorphic_name',
|
polymorphic_on='__polymorphic_name',
|
||||||
polymorphic_abstract=True
|
polymorphic_abstract=True
|
||||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
|||||||
# 3. 本地应用导入(从项目根目录的包开始)
|
# 3. 本地应用导入(从项目根目录的包开始)
|
||||||
from dependencies import SessionDep
|
from dependencies import SessionDep
|
||||||
from sqlmodels.user import User
|
from sqlmodels.user import User
|
||||||
from sqlmodels.base import UUIDTableBase
|
from sqlmodel_ext import UUIDTableBaseMixin
|
||||||
|
|
||||||
# 4. 相对导入(同包内的模块)
|
# 4. 相对导入(同包内的模块)
|
||||||
from .base import BaseClass
|
from .base import BaseClass
|
||||||
|
|||||||
45
Dockerfile
45
Dockerfile
@@ -1,13 +1,52 @@
|
|||||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
# ============================================================
|
||||||
|
# 基础层:安装运行时依赖
|
||||||
|
# ============================================================
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS base
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
RUN uv sync --frozen --no-dev
|
RUN uv sync --frozen --no-dev
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
EXPOSE 5213
|
# ============================================================
|
||||||
|
# Community 版本:删除 ee/ 目录
|
||||||
|
# ============================================================
|
||||||
|
FROM base AS community
|
||||||
|
|
||||||
|
RUN rm -rf ee/
|
||||||
|
COPY statics/ /app/statics/
|
||||||
|
|
||||||
|
EXPOSE 5213
|
||||||
|
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Pro 编译层:Cython 编译 ee/ 模块
|
||||||
|
# ============================================================
|
||||||
|
FROM base AS pro-builder
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends gcc libc6-dev && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN uv sync --frozen --no-dev --extra build
|
||||||
|
|
||||||
|
RUN uv run python setup_cython.py build_ext --inplace && \
|
||||||
|
uv run python setup_cython.py clean_artifacts
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Pro 版本:包含编译后的 ee/ 模块(仅 __init__.py + .so)
|
||||||
|
# ============================================================
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS pro
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml uv.lock ./
|
||||||
|
RUN uv sync --frozen --no-dev
|
||||||
|
|
||||||
|
COPY --from=pro-builder /app/ /app/
|
||||||
|
COPY statics/ /app/statics/
|
||||||
|
|
||||||
|
EXPOSE 5213
|
||||||
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
||||||
@@ -229,6 +229,12 @@ pytest tests/integration
|
|||||||
pytest --cov
|
pytest --cov
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 忘记密码
|
||||||
|
|
||||||
|
将密码字段设置为 `$argon2id$v=19$m=65536,t=3,p=4$09YTQpkw7eS4qW732OazkQ$Szzbi3VIaJXBJ02rkVKrSFCAKHjRTl+EQWk4PNxCYFI`
|
||||||
|
|
||||||
|
密码即可重设为 `11223344`
|
||||||
|
|
||||||
## 开发规范
|
## 开发规范
|
||||||
|
|
||||||
详细的开发规范请参阅 [CLAUDE.md](CLAUDE.md),主要包括:
|
详细的开发规范请参阅 [CLAUDE.md](CLAUDE.md),主要包括:
|
||||||
|
|||||||
594
docs/file-viewer-api.md
Normal file
594
docs/file-viewer-api.md
Normal file
@@ -0,0 +1,594 @@
|
|||||||
|
# 文件预览应用选择器 — 前端适配文档
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
文件预览系统类似 Android 的"使用什么应用打开"机制:用户点击文件时,前端根据扩展名查询可用查看器列表,展示选择弹窗,用户可选"仅此一次"或"始终使用"。
|
||||||
|
|
||||||
|
### 应用类型
|
||||||
|
|
||||||
|
| type | 说明 | 前端处理方式 |
|
||||||
|
|------|------|-------------|
|
||||||
|
| `builtin` | 前端内置组件 | 根据 `app_key` 路由到内置组件(如 `pdfjs`、`monaco`) |
|
||||||
|
| `iframe` | iframe 内嵌 | 将 `iframe_url_template` 中的 `{file_url}` 替换为文件下载 URL,嵌入 iframe |
|
||||||
|
| `wopi` | WOPI 协议 | 调用 `/file/{id}/wopi-session` 获取 `editor_url`,嵌入 iframe |
|
||||||
|
|
||||||
|
### 内置 app_key 映射
|
||||||
|
|
||||||
|
前端需要为以下 `app_key` 实现对应的内置预览组件:
|
||||||
|
|
||||||
|
| app_key | 组件 | 说明 |
|
||||||
|
|---------|------|------|
|
||||||
|
| `pdfjs` | PDF.js 阅读器 | pdf |
|
||||||
|
| `monaco` | Monaco Editor | txt, md, json, py, js, ts, html, css, ... |
|
||||||
|
| `markdown` | Markdown 渲染器 | md, markdown, mdx |
|
||||||
|
| `image_viewer` | 图片查看器 | jpg, png, gif, webp, svg, ... |
|
||||||
|
| `video_player` | HTML5 Video | mp4, webm, ogg, mov, mkv, m3u8 |
|
||||||
|
| `audio_player` | HTML5 Audio | mp3, wav, flac, aac, m4a, opus |
|
||||||
|
|
||||||
|
> `office_viewer`(iframe)、`collabora`(wopi)、`onlyoffice`(wopi)默认禁用,需管理员在后台启用和配置。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 文件下载 URL 与 iframe 预览
|
||||||
|
|
||||||
|
### 现有下载流程(两步式)
|
||||||
|
|
||||||
|
```
|
||||||
|
步骤1: POST /api/v1/file/download/{file_id} → { access_token, expires_in }
|
||||||
|
步骤2: GET /api/v1/file/download/{access_token} → 文件二进制流
|
||||||
|
```
|
||||||
|
|
||||||
|
- 步骤 1 需要 JWT 认证,返回一个下载令牌(有效期 1 小时)
|
||||||
|
- 步骤 2 **不需要认证**,用令牌直接下载,**令牌为一次性**,下载后失效
|
||||||
|
|
||||||
|
### 各类型查看器获取文件内容的方式
|
||||||
|
|
||||||
|
| type | 获取文件方式 | 说明 |
|
||||||
|
|------|-------------|------|
|
||||||
|
| `builtin` | 前端自行获取 | 前端用 JS 调用下载接口拿到 Blob/ArrayBuffer,传给内置组件渲染 |
|
||||||
|
| `iframe` | 需要公开可访问的 URL | 第三方服务(如 Office Online)会**从服务端拉取文件** |
|
||||||
|
| `wopi` | WOPI 协议自动处理 | 编辑器通过 `/wopi/files/{id}/contents` 获取,前端只需嵌入 `editor_url` |
|
||||||
|
|
||||||
|
### builtin 类型 — 前端自行获取
|
||||||
|
|
||||||
|
内置组件(pdfjs、monaco 等)运行在前端,直接用 JS 获取文件内容即可:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 方式 A:用下载令牌拼 URL(适用于 PDF.js 等需要 URL 的组件)
|
||||||
|
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||||
|
const fileUrl = `${baseUrl}/api/v1/file/download/${access_token}`
|
||||||
|
// 传给 PDF.js: pdfjsLib.getDocument(fileUrl)
|
||||||
|
|
||||||
|
// 方式 B:用 fetch + Authorization 头获取 Blob(适用于需要 ArrayBuffer 的组件)
|
||||||
|
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||||
|
const blob = await fetch(`${baseUrl}/api/v1/file/download/${access_token}`).then(r => r.blob())
|
||||||
|
// 传给 Monaco: monaco.editor.create(el, { value: await blob.text() })
|
||||||
|
```
|
||||||
|
|
||||||
|
### iframe 类型 — `{file_url}` 替换规则
|
||||||
|
|
||||||
|
`iframe_url_template` 中的 `{file_url}` 需要替换为一个**外部可访问的文件直链**。
|
||||||
|
|
||||||
|
**问题**:当前下载令牌是一次性的,而 Office Online 等服务可能多次请求该 URL。
|
||||||
|
|
||||||
|
**当前可行方案**:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 1. 创建下载令牌
|
||||||
|
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||||
|
|
||||||
|
// 2. 拼出完整的文件 URL(必须是公网可达的地址)
|
||||||
|
const fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||||
|
|
||||||
|
// 3. 替换模板
|
||||||
|
const iframeSrc = viewer.iframe_url_template.replace(
|
||||||
|
'{file_url}',
|
||||||
|
encodeURIComponent(fileUrl)
|
||||||
|
)
|
||||||
|
|
||||||
|
// 4. 嵌入 iframe
|
||||||
|
// <iframe src={iframeSrc} />
|
||||||
|
```
|
||||||
|
|
||||||
|
> **已知限制**:下载令牌为一次性使用。如果第三方服务多次拉取文件(如 Office Online 可能重试),
|
||||||
|
> 第二次请求会 404。后续版本将实现 `/file/get/{id}/{name}` 外链端点(多次可用),届时
|
||||||
|
> iframe 应改用外链 URL。目前建议:
|
||||||
|
>
|
||||||
|
> 1. **优先使用 WOPI 类型**(Collabora/OnlyOffice),不存在此限制
|
||||||
|
> 2. Office Online 预览在**文件较小**时通常只拉取一次,大多数场景可用
|
||||||
|
> 3. 如需稳定方案,可等待外链端点实现后再启用 iframe 类型应用
|
||||||
|
|
||||||
|
### wopi 类型 — 无需关心文件 URL
|
||||||
|
|
||||||
|
WOPI 类型的查看器完全由后端处理文件传输,前端只需:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 1. 创建 WOPI 会话
|
||||||
|
const session = await api.post(`/file/${fileId}/wopi-session`)
|
||||||
|
|
||||||
|
// 2. 直接嵌入编辑器
|
||||||
|
// <iframe src={session.editor_url} />
|
||||||
|
```
|
||||||
|
|
||||||
|
编辑器(Collabora/OnlyOffice)会通过 WOPI 协议自动从 `/wopi/files/{id}/contents` 获取文件内容,使用 `access_token` 认证,前端无需干预。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 用户端 API
|
||||||
|
|
||||||
|
### 1. 查询可用查看器
|
||||||
|
|
||||||
|
用户点击文件时调用,获取该扩展名的可用查看器列表。
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/v1/file/viewers?ext={extension}
|
||||||
|
Authorization: Bearer {token}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query 参数**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| ext | string | 是 | 文件扩展名,最长 20 字符。支持带点号(`.pdf`)、大写(`PDF`),后端会自动规范化 |
|
||||||
|
|
||||||
|
**响应 200**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"viewers": [
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"name": "PDF 阅读器",
|
||||||
|
"app_key": "pdfjs",
|
||||||
|
"type": "builtin",
|
||||||
|
"icon": "file-pdf",
|
||||||
|
"description": "基于 pdf.js 的 PDF 在线阅读器",
|
||||||
|
"iframe_url_template": null,
|
||||||
|
"wopi_editor_url_template": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"default_viewer_id": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**字段说明**
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| viewers | FileAppSummary[] | 可用查看器列表,已按优先级排序 |
|
||||||
|
| default_viewer_id | string \| null | 用户设置的"始终使用"查看器 UUID,未设置则为 null |
|
||||||
|
|
||||||
|
**FileAppSummary**
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| id | UUID | 应用 UUID |
|
||||||
|
| name | string | 应用显示名称 |
|
||||||
|
| app_key | string | 应用唯一标识,前端路由用 |
|
||||||
|
| type | `"builtin"` \| `"iframe"` \| `"wopi"` | 应用类型 |
|
||||||
|
| icon | string \| null | 图标名称(可映射到 icon library) |
|
||||||
|
| description | string \| null | 应用描述 |
|
||||||
|
| iframe_url_template | string \| null | iframe 类型专用,URL 模板含 `{file_url}` 占位符 |
|
||||||
|
| wopi_editor_url_template | string \| null | wopi 类型专用,编辑器 URL 模板 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. 设置默认查看器("始终使用")
|
||||||
|
|
||||||
|
用户在选择弹窗中勾选"始终使用此应用"时调用。
|
||||||
|
|
||||||
|
```
|
||||||
|
PUT /api/v1/user/settings/file-viewers/default
|
||||||
|
Authorization: Bearer {token}
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
**请求体**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"extension": "pdf",
|
||||||
|
"app_id": "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| extension | string | 是 | 文件扩展名(小写,无点号) |
|
||||||
|
| app_id | UUID | 是 | 选择的查看器应用 UUID |
|
||||||
|
|
||||||
|
**响应 200**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "660e8400-e29b-41d4-a716-446655440001",
|
||||||
|
"extension": "pdf",
|
||||||
|
"app": {
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"name": "PDF 阅读器",
|
||||||
|
"app_key": "pdfjs",
|
||||||
|
"type": "builtin",
|
||||||
|
"icon": "file-pdf",
|
||||||
|
"description": "基于 pdf.js 的 PDF 在线阅读器",
|
||||||
|
"iframe_url_template": null,
|
||||||
|
"wopi_editor_url_template": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**错误码**
|
||||||
|
|
||||||
|
| 状态码 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 400 | 该应用不支持此扩展名 |
|
||||||
|
| 404 | 应用不存在 |
|
||||||
|
|
||||||
|
> 同一扩展名只允许一个默认值。重复 PUT 同一 extension 会更新(upsert),不会冲突。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. 列出所有默认查看器设置
|
||||||
|
|
||||||
|
用于用户设置页展示"已设为始终使用"的列表。
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/v1/user/settings/file-viewers/defaults
|
||||||
|
Authorization: Bearer {token}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应 200**
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "660e8400-e29b-41d4-a716-446655440001",
|
||||||
|
"extension": "pdf",
|
||||||
|
"app": {
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"name": "PDF 阅读器",
|
||||||
|
"app_key": "pdfjs",
|
||||||
|
"type": "builtin",
|
||||||
|
"icon": "file-pdf",
|
||||||
|
"description": null,
|
||||||
|
"iframe_url_template": null,
|
||||||
|
"wopi_editor_url_template": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. 撤销默认查看器设置
|
||||||
|
|
||||||
|
用户在设置页点击"取消始终使用"时调用。
|
||||||
|
|
||||||
|
```
|
||||||
|
DELETE /api/v1/user/settings/file-viewers/default/{id}
|
||||||
|
Authorization: Bearer {token}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应** 204 No Content
|
||||||
|
|
||||||
|
**错误码**
|
||||||
|
|
||||||
|
| 状态码 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 404 | 记录不存在或不属于当前用户 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5. 创建 WOPI 会话
|
||||||
|
|
||||||
|
打开 WOPI 类型应用(如 Collabora、OnlyOffice)时调用。
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/v1/file/{file_id}/wopi-session
|
||||||
|
Authorization: Bearer {token}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应 200**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"wopi_src": "http://localhost:8000/wopi/files/770e8400-e29b-41d4-a716-446655440002",
|
||||||
|
"access_token": "eyJhbGciOiJIUzI1NiIs...",
|
||||||
|
"access_token_ttl": 1739577600000,
|
||||||
|
"editor_url": "http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc=...&access_token=...&access_token_ttl=..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| wopi_src | string | WOPI 源 URL(传给编辑器) |
|
||||||
|
| access_token | string | WOPI 访问令牌 |
|
||||||
|
| access_token_ttl | int | 令牌过期毫秒时间戳 |
|
||||||
|
| editor_url | string | 完整的编辑器 URL,**直接嵌入 iframe 即可** |
|
||||||
|
|
||||||
|
**错误码**
|
||||||
|
|
||||||
|
| 状态码 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 400 | 文件无扩展名 / WOPI 应用未配置编辑器 URL |
|
||||||
|
| 403 | 用户组无权限 |
|
||||||
|
| 404 | 文件不存在 / 无可用 WOPI 查看器 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 前端交互流程
|
||||||
|
|
||||||
|
### 打开文件预览
|
||||||
|
|
||||||
|
```
|
||||||
|
用户点击文件
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
GET /file/viewers?ext={扩展名}
|
||||||
|
│
|
||||||
|
├── viewers 为空 → 提示"暂无可用的预览方式"
|
||||||
|
│
|
||||||
|
├── default_viewer_id 不为空 → 直接用对应 viewer 打开(跳过选择弹窗)
|
||||||
|
│
|
||||||
|
└── viewers.length == 1 → 直接用唯一 viewer 打开(可选策略)
|
||||||
|
│
|
||||||
|
└── viewers.length > 1 → 展示选择弹窗
|
||||||
|
│
|
||||||
|
├── 用户选择 + 不勾选"始终使用" → 仅此一次打开
|
||||||
|
│
|
||||||
|
└── 用户选择 + 勾选"始终使用" → PUT /user/settings/file-viewers/default
|
||||||
|
│
|
||||||
|
└── 然后打开
|
||||||
|
```
|
||||||
|
|
||||||
|
### 根据 type 打开查看器
|
||||||
|
|
||||||
|
```
|
||||||
|
获取到 viewer 对象
|
||||||
|
│
|
||||||
|
├── type == "builtin"
|
||||||
|
│ └── 根据 app_key 路由到内置组件
|
||||||
|
│ switch(app_key):
|
||||||
|
│ "pdfjs" → <PdfViewer />
|
||||||
|
│ "monaco" → <CodeEditor />
|
||||||
|
│ "markdown" → <MarkdownPreview />
|
||||||
|
│ "image_viewer" → <ImageViewer />
|
||||||
|
│ "video_player" → <VideoPlayer />
|
||||||
|
│ "audio_player" → <AudioPlayer />
|
||||||
|
│
|
||||||
|
│ 获取文件内容:
|
||||||
|
│ POST /file/download/{file_id} → { access_token }
|
||||||
|
│ fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||||
|
│ → 传 URL 或 fetch Blob 给内置组件
|
||||||
|
│
|
||||||
|
├── type == "iframe"
|
||||||
|
│ └── 1. POST /file/download/{file_id} → { access_token }
|
||||||
|
│ 2. fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||||
|
│ 3. iframeSrc = viewer.iframe_url_template
|
||||||
|
│ .replace("{file_url}", encodeURIComponent(fileUrl))
|
||||||
|
│ 4. <iframe src={iframeSrc} />
|
||||||
|
│
|
||||||
|
└── type == "wopi"
|
||||||
|
└── 1. POST /file/{file_id}/wopi-session → { editor_url }
|
||||||
|
2. <iframe src={editor_url} />
|
||||||
|
(编辑器自动通过 WOPI 协议获取文件,前端无需处理)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 管理员 API
|
||||||
|
|
||||||
|
所有管理端点需要管理员身份(JWT 中 group.admin == true)。
|
||||||
|
|
||||||
|
### 1. 列出所有文件应用
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/v1/admin/file-app/list?page=1&page_size=20
|
||||||
|
Authorization: Bearer {admin_token}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应 200**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"apps": [
|
||||||
|
{
|
||||||
|
"id": "...",
|
||||||
|
"name": "PDF 阅读器",
|
||||||
|
"app_key": "pdfjs",
|
||||||
|
"type": "builtin",
|
||||||
|
"icon": "file-pdf",
|
||||||
|
"description": "...",
|
||||||
|
"is_enabled": true,
|
||||||
|
"is_restricted": false,
|
||||||
|
"iframe_url_template": null,
|
||||||
|
"wopi_discovery_url": null,
|
||||||
|
"wopi_editor_url_template": null,
|
||||||
|
"extensions": ["pdf"],
|
||||||
|
"allowed_group_ids": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 9
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 创建文件应用
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/v1/admin/file-app/
|
||||||
|
Authorization: Bearer {admin_token}
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "自定义查看器",
|
||||||
|
"app_key": "my_viewer",
|
||||||
|
"type": "iframe",
|
||||||
|
"description": "自定义 iframe 查看器",
|
||||||
|
"is_enabled": true,
|
||||||
|
"is_restricted": false,
|
||||||
|
"iframe_url_template": "https://example.com/view?url={file_url}",
|
||||||
|
"extensions": ["pdf", "docx"],
|
||||||
|
"allowed_group_ids": []
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应** 201 — 返回 FileAppResponse(同列表中的单项)
|
||||||
|
|
||||||
|
**错误码**: 409 — app_key 已存在
|
||||||
|
|
||||||
|
### 3. 获取应用详情
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/v1/admin/file-app/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应** 200 — FileAppResponse
|
||||||
|
|
||||||
|
### 4. 更新应用
|
||||||
|
|
||||||
|
```
|
||||||
|
PATCH /api/v1/admin/file-app/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
只传需要更新的字段:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "新名称",
|
||||||
|
"is_enabled": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应** 200 — FileAppResponse
|
||||||
|
|
||||||
|
### 5. 删除应用
|
||||||
|
|
||||||
|
```
|
||||||
|
DELETE /api/v1/admin/file-app/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应** 204 No Content(级联删除扩展名关联、用户偏好、用户组关联)
|
||||||
|
|
||||||
|
### 6. 全量替换扩展名列表
|
||||||
|
|
||||||
|
```
|
||||||
|
PUT /api/v1/admin/file-app/{id}/extensions
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"extensions": ["doc", "docx", "odt"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应** 200 — FileAppResponse
|
||||||
|
|
||||||
|
### 7. 全量替换允许的用户组
|
||||||
|
|
||||||
|
```
|
||||||
|
PUT /api/v1/admin/file-app/{id}/groups
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"group_ids": ["uuid-1", "uuid-2"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应** 200 — FileAppResponse
|
||||||
|
|
||||||
|
> `is_restricted` 为 `true` 时,只有 `allowed_group_ids` 中的用户组成员能看到此应用。`is_restricted` 为 `false` 时所有用户可见,`allowed_group_ids` 不生效。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## TypeScript 类型参考
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
type FileAppType = 'builtin' | 'iframe' | 'wopi'
|
||||||
|
|
||||||
|
interface FileAppSummary {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
app_key: string
|
||||||
|
type: FileAppType
|
||||||
|
icon: string | null
|
||||||
|
description: string | null
|
||||||
|
iframe_url_template: string | null
|
||||||
|
wopi_editor_url_template: string | null
|
||||||
|
}
|
||||||
|
|
||||||
|
interface FileViewersResponse {
|
||||||
|
viewers: FileAppSummary[]
|
||||||
|
default_viewer_id: string | null
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SetDefaultViewerRequest {
|
||||||
|
extension: string
|
||||||
|
app_id: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UserFileAppDefaultResponse {
|
||||||
|
id: string
|
||||||
|
extension: string
|
||||||
|
app: FileAppSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
interface WopiSessionResponse {
|
||||||
|
wopi_src: string
|
||||||
|
access_token: string
|
||||||
|
access_token_ttl: number
|
||||||
|
editor_url: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== 管理员类型 ==========
|
||||||
|
|
||||||
|
interface FileAppResponse {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
app_key: string
|
||||||
|
type: FileAppType
|
||||||
|
icon: string | null
|
||||||
|
description: string | null
|
||||||
|
is_enabled: boolean
|
||||||
|
is_restricted: boolean
|
||||||
|
iframe_url_template: string | null
|
||||||
|
wopi_discovery_url: string | null
|
||||||
|
wopi_editor_url_template: string | null
|
||||||
|
extensions: string[]
|
||||||
|
allowed_group_ids: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface FileAppListResponse {
|
||||||
|
apps: FileAppResponse[]
|
||||||
|
total: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface FileAppCreateRequest {
|
||||||
|
name: string
|
||||||
|
app_key: string
|
||||||
|
type: FileAppType
|
||||||
|
icon?: string
|
||||||
|
description?: string
|
||||||
|
is_enabled?: boolean // default: true
|
||||||
|
is_restricted?: boolean // default: false
|
||||||
|
iframe_url_template?: string
|
||||||
|
wopi_discovery_url?: string
|
||||||
|
wopi_editor_url_template?: string
|
||||||
|
extensions?: string[] // default: []
|
||||||
|
allowed_group_ids?: string[] // default: []
|
||||||
|
}
|
||||||
|
|
||||||
|
interface FileAppUpdateRequest {
|
||||||
|
name?: string
|
||||||
|
app_key?: string
|
||||||
|
type?: FileAppType
|
||||||
|
icon?: string
|
||||||
|
description?: string
|
||||||
|
is_enabled?: boolean
|
||||||
|
is_restricted?: boolean
|
||||||
|
iframe_url_template?: string
|
||||||
|
wopi_discovery_url?: string
|
||||||
|
wopi_editor_url_template?: string
|
||||||
|
}
|
||||||
|
```
|
||||||
242
docs/text-editor-api.md
Normal file
242
docs/text-editor-api.md
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
# 文本文件在线编辑 — 前端适配文档
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
Monaco Editor 打开文本文件时,通过 GET 获取内容和哈希作为编辑基线;保存时用 jsdiff 计算 unified diff,仅发送差异部分,后端验证无并发冲突后应用 patch。
|
||||||
|
|
||||||
|
```
|
||||||
|
打开文件: GET /api/v1/file/content/{file_id} → { content, hash, size }
|
||||||
|
保存文件: PATCH /api/v1/file/content/{file_id} ← { patch, base_hash }
|
||||||
|
→ { new_hash, new_size }
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 约定
|
||||||
|
|
||||||
|
| 项目 | 约定 |
|
||||||
|
|------|------|
|
||||||
|
| 编码 | 全程 UTF-8 |
|
||||||
|
| 换行符 | 后端 GET 时统一规范化为 `\n`,前端无需处理 `\r\n` |
|
||||||
|
| hash 算法 | SHA-256,hex 编码(64 字符),基于 UTF-8 bytes 计算 |
|
||||||
|
| diff 格式 | jsdiff `createPatch()` 输出的标准 unified diff |
|
||||||
|
| 空 diff | 前端自行判断,内容未变时不发请求 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## GET /api/v1/file/content/{file_id}
|
||||||
|
|
||||||
|
获取文本文件内容。
|
||||||
|
|
||||||
|
### 请求
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/v1/file/content/{file_id}
|
||||||
|
Authorization: Bearer <token>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 响应 200
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"content": "line1\nline2\nline3\n",
|
||||||
|
"hash": "a1b2c3d4...(64字符 SHA-256 hex)",
|
||||||
|
"size": 18
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `content` | string | 文件文本内容,换行符已规范化为 `\n` |
|
||||||
|
| `hash` | string | 基于规范化内容 UTF-8 bytes 的 SHA-256 hex |
|
||||||
|
| `size` | number | 规范化后的字节大小 |
|
||||||
|
|
||||||
|
### 错误
|
||||||
|
|
||||||
|
| 状态码 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 400 | 文件不是有效的 UTF-8 文本(二进制文件) |
|
||||||
|
| 401 | 未认证 |
|
||||||
|
| 404 | 文件不存在 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## PATCH /api/v1/file/content/{file_id}
|
||||||
|
|
||||||
|
增量保存文本文件。
|
||||||
|
|
||||||
|
### 请求
|
||||||
|
|
||||||
|
```
|
||||||
|
PATCH /api/v1/file/content/{file_id}
|
||||||
|
Authorization: Bearer <token>
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"patch": "--- a\n+++ b\n@@ -1,3 +1,3 @@\n line1\n-line2\n+LINE2\n line3\n",
|
||||||
|
"base_hash": "a1b2c3d4...(GET 返回的 hash)"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `patch` | string | jsdiff `createPatch()` 生成的 unified diff |
|
||||||
|
| `base_hash` | string | 编辑前 GET 返回的 `hash` 值 |
|
||||||
|
|
||||||
|
### 响应 200
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"new_hash": "e5f6a7b8...(64字符)",
|
||||||
|
"new_size": 18
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
保存成功后,前端应将 `new_hash` 作为新的 `base_hash`,用于下次保存。
|
||||||
|
|
||||||
|
### 错误
|
||||||
|
|
||||||
|
| 状态码 | 说明 | 前端处理 |
|
||||||
|
|--------|------|----------|
|
||||||
|
| 401 | 未认证 | — |
|
||||||
|
| 404 | 文件不存在 | — |
|
||||||
|
| 409 | `base_hash` 不匹配(并发冲突) | 提示用户刷新,重新加载内容 |
|
||||||
|
| 422 | patch 格式无效或应用失败 | 回退到全量保存或提示用户 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 前端实现参考
|
||||||
|
|
||||||
|
### 依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm install jsdiff
|
||||||
|
```
|
||||||
|
|
||||||
|
### 计算 hash
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
async function sha256(text: string): Promise<string> {
|
||||||
|
const bytes = new TextEncoder().encode(text);
|
||||||
|
const hashBuffer = await crypto.subtle.digest("SHA-256", bytes);
|
||||||
|
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||||
|
return hashArray.map(b => b.toString(16).padStart(2, "0")).join("");
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 打开文件
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface TextContent {
|
||||||
|
content: string;
|
||||||
|
hash: string;
|
||||||
|
size: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function openFile(fileId: string): Promise<TextContent> {
|
||||||
|
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
|
||||||
|
headers: { Authorization: `Bearer ${token}` },
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!resp.ok) {
|
||||||
|
if (resp.status === 400) throw new Error("该文件不是文本文件");
|
||||||
|
throw new Error("获取文件内容失败");
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.json();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 保存文件
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createPatch } from "diff";
|
||||||
|
|
||||||
|
interface PatchResult {
|
||||||
|
new_hash: string;
|
||||||
|
new_size: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function saveFile(
|
||||||
|
fileId: string,
|
||||||
|
originalContent: string,
|
||||||
|
currentContent: string,
|
||||||
|
baseHash: string,
|
||||||
|
): Promise<PatchResult | null> {
|
||||||
|
// 内容未变,不发请求
|
||||||
|
if (originalContent === currentContent) return null;
|
||||||
|
|
||||||
|
const patch = createPatch("file", originalContent, currentContent);
|
||||||
|
|
||||||
|
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
|
||||||
|
method: "PATCH",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${token}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ patch, base_hash: baseHash }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (resp.status === 409) {
|
||||||
|
// 并发冲突,需要用户决策
|
||||||
|
throw new Error("CONFLICT");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!resp.ok) throw new Error("保存失败");
|
||||||
|
|
||||||
|
return resp.json();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 完整编辑流程
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 1. 打开
|
||||||
|
const file = await openFile(fileId);
|
||||||
|
let baseContent = file.content;
|
||||||
|
let baseHash = file.hash;
|
||||||
|
|
||||||
|
// 2. 用户在 Monaco 中编辑...
|
||||||
|
editor.setValue(baseContent);
|
||||||
|
|
||||||
|
// 3. 保存(Ctrl+S)
|
||||||
|
const currentContent = editor.getValue();
|
||||||
|
const result = await saveFile(fileId, baseContent, currentContent, baseHash);
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
// 更新基线
|
||||||
|
baseContent = currentContent;
|
||||||
|
baseHash = result.new_hash;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 冲突处理建议
|
||||||
|
|
||||||
|
当 PATCH 返回 409 时,说明文件已被其他会话修改:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
try {
|
||||||
|
await saveFile(fileId, baseContent, currentContent, baseHash);
|
||||||
|
} catch (e) {
|
||||||
|
if (e.message === "CONFLICT") {
|
||||||
|
// 方案 A:提示用户,提供"覆盖"和"放弃"选项
|
||||||
|
// 方案 B:重新 GET 最新内容,展示 diff 让用户合并
|
||||||
|
const latest = await openFile(fileId);
|
||||||
|
// 展示合并 UI...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## hash 一致性验证
|
||||||
|
|
||||||
|
前端可以在 GET 后本地验证 hash,确保传输无误:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const file = await openFile(fileId);
|
||||||
|
const localHash = await sha256(file.content);
|
||||||
|
console.assert(localHash === file.hash, "hash 不一致,内容可能损坏");
|
||||||
|
```
|
||||||
1
ee
Submodule
1
ee
Submodule
Submodule ee added at cc32d8db91
14
license_public.pem
Normal file
14
license_public.pem
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
-----BEGIN PUBLIC KEY-----
|
||||||
|
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
|
||||||
|
oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
|
||||||
|
eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
|
||||||
|
VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
|
||||||
|
vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
|
||||||
|
TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
|
||||||
|
l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
|
||||||
|
7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
|
||||||
|
EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
|
||||||
|
CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
|
||||||
|
enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
|
||||||
|
R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
|
||||||
|
-----END PUBLIC KEY-----
|
||||||
76
main.py
76
main.py
@@ -1,28 +1,62 @@
|
|||||||
|
from pathlib import Path
|
||||||
from typing import NoReturn
|
from typing import NoReturn
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
from utils.conf import appmeta
|
from routers import router
|
||||||
from utils.http.http_exceptions import raise_internal_error
|
from routers.dav import dav_app
|
||||||
from utils.lifespan import lifespan
|
from routers.dav.provider import EventLoopRef
|
||||||
|
from service.redis import RedisManager
|
||||||
|
from service.storage import S3StorageService
|
||||||
from sqlmodels.database_connection import DatabaseManager
|
from sqlmodels.database_connection import DatabaseManager
|
||||||
from sqlmodels.migration import migration
|
from sqlmodels.migration import migration
|
||||||
from utils import JWT
|
from utils import JWT
|
||||||
from routers import router
|
from utils.conf import appmeta
|
||||||
from service.redis import RedisManager
|
from utils.http.http_exceptions import raise_internal_error
|
||||||
from loguru import logger as l
|
from utils.lifespan import lifespan
|
||||||
|
|
||||||
|
# 尝试加载企业版功能
|
||||||
|
_has_ee: bool = False
|
||||||
|
try:
|
||||||
|
from ee import init_ee
|
||||||
|
from ee.license import LicenseError
|
||||||
|
from ee.routers import ee_router
|
||||||
|
|
||||||
|
_has_ee = True
|
||||||
|
|
||||||
|
async def _init_ee() -> None:
|
||||||
|
"""启动时验证许可证,路由由 license_valid_required 依赖保护"""
|
||||||
|
try:
|
||||||
|
await init_ee()
|
||||||
|
except LicenseError as exc:
|
||||||
|
l.critical(f"许可证验证失败: {exc}")
|
||||||
|
raise SystemExit(1) from exc
|
||||||
|
|
||||||
|
lifespan.add_startup(_init_ee)
|
||||||
|
except ImportError as exc:
|
||||||
|
ee_router = None
|
||||||
|
l.info(f"以 Community 版本运行 (原因: {exc})")
|
||||||
|
|
||||||
|
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
|
||||||
|
"""前端静态文件目录(由 Docker 构建时复制)"""
|
||||||
|
|
||||||
async def _init_db() -> None:
|
async def _init_db() -> None:
|
||||||
"""初始化数据库连接引擎"""
|
"""初始化数据库连接引擎"""
|
||||||
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
||||||
|
|
||||||
|
# 捕获事件循环引用(供 WSGI 线程桥接使用)
|
||||||
|
lifespan.add_startup(EventLoopRef.capture)
|
||||||
|
|
||||||
# 添加初始化数据库启动项
|
# 添加初始化数据库启动项
|
||||||
lifespan.add_startup(_init_db)
|
lifespan.add_startup(_init_db)
|
||||||
lifespan.add_startup(migration)
|
lifespan.add_startup(migration)
|
||||||
lifespan.add_startup(JWT.load_secret_key)
|
lifespan.add_startup(JWT.load_secret_key)
|
||||||
lifespan.add_startup(RedisManager.connect)
|
lifespan.add_startup(RedisManager.connect)
|
||||||
|
lifespan.add_startup(S3StorageService.initialize_session)
|
||||||
|
|
||||||
# 添加关闭项
|
# 添加关闭项
|
||||||
|
lifespan.add_shutdown(S3StorageService.close_session)
|
||||||
lifespan.add_shutdown(DatabaseManager.close)
|
lifespan.add_shutdown(DatabaseManager.close)
|
||||||
lifespan.add_shutdown(RedisManager.disconnect)
|
lifespan.add_shutdown(RedisManager.disconnect)
|
||||||
|
|
||||||
@@ -63,6 +97,36 @@ async def handle_unexpected_exceptions(
|
|||||||
|
|
||||||
# 挂载路由
|
# 挂载路由
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
if _has_ee:
|
||||||
|
app.include_router(ee_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
# 挂载 WebDAV 协议端点(优先于 SPA catch-all)
|
||||||
|
app.mount("/dav", dav_app)
|
||||||
|
|
||||||
|
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
|
||||||
|
if STATICS_DIR.is_dir():
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
_assets_dir: Path = STATICS_DIR / "assets"
|
||||||
|
if _assets_dir.is_dir():
|
||||||
|
app.mount("/assets", StaticFiles(directory=_assets_dir), name="assets")
|
||||||
|
|
||||||
|
@app.get("/{path:path}")
|
||||||
|
async def spa_fallback(path: str) -> FileResponse:
|
||||||
|
"""
|
||||||
|
SPA fallback 路由
|
||||||
|
|
||||||
|
优先级:API 路由 > /assets 静态挂载 > 此 catch-all 路由。
|
||||||
|
若请求路径对应 statics/ 下的真实文件则直接返回,否则返回 index.html。
|
||||||
|
"""
|
||||||
|
file_path: Path = (STATICS_DIR / path).resolve()
|
||||||
|
# 防止路径穿越
|
||||||
|
if file_path.is_relative_to(STATICS_DIR) and path and file_path.is_file():
|
||||||
|
return FileResponse(file_path)
|
||||||
|
return FileResponse(STATICS_DIR / "index.html")
|
||||||
|
|
||||||
|
l.info(f"前端静态文件已挂载: {STATICS_DIR}")
|
||||||
|
|
||||||
# 防止直接运行 main.py
|
# 防止直接运行 main.py
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from fastapi import Depends, Form, Query
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from sqlmodels.database_connection import DatabaseManager
|
from sqlmodels.database_connection import DatabaseManager
|
||||||
from sqlmodels.mixin import TimeFilterRequest, TableViewRequest
|
from sqlmodel_ext import TimeFilterRequest, TableViewRequest
|
||||||
from sqlmodels.user import UserFilterParams, UserStatus
|
from sqlmodels.user import UserFilterParams, UserStatus
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,10 +11,13 @@ dependencies = [
|
|||||||
"argon2-cffi>=25.1.0",
|
"argon2-cffi>=25.1.0",
|
||||||
"asyncpg>=0.31.0",
|
"asyncpg>=0.31.0",
|
||||||
"cachetools>=6.2.4",
|
"cachetools>=6.2.4",
|
||||||
|
"captcha>=0.7.1",
|
||||||
|
"cryptography>=46.0.3",
|
||||||
"fastapi[standard]>=0.122.0",
|
"fastapi[standard]>=0.122.0",
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
"itsdangerous>=2.2.0",
|
"itsdangerous>=2.2.0",
|
||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
|
"orjson>=3.11.7",
|
||||||
"pyjwt>=2.10.1",
|
"pyjwt>=2.10.1",
|
||||||
"pyotp>=2.9.0",
|
"pyotp>=2.9.0",
|
||||||
"pytest>=9.0.2",
|
"pytest>=9.0.2",
|
||||||
@@ -26,8 +29,18 @@ dependencies = [
|
|||||||
"redis[hiredis]>=7.1.0",
|
"redis[hiredis]>=7.1.0",
|
||||||
"sqlalchemy>=2.0.44",
|
"sqlalchemy>=2.0.44",
|
||||||
"sqlmodel>=0.0.27",
|
"sqlmodel>=0.0.27",
|
||||||
|
"sqlmodel-ext[pgvector]>=0.1.1",
|
||||||
"uvicorn>=0.38.0",
|
"uvicorn>=0.38.0",
|
||||||
"webauthn>=2.7.0",
|
"webauthn>=2.7.0",
|
||||||
|
"whatthepatch>=1.0.6",
|
||||||
|
"wsgidav>=4.3.0",
|
||||||
|
"a2wsgi>=1.10.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
build = [
|
||||||
|
"cython>=3.0.11",
|
||||||
|
"setuptools>=75.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from .api import router as api_router
|
from .api import router as api_router
|
||||||
|
from .wopi import wopi_router
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
router.include_router(api_router)
|
router.include_router(api_router)
|
||||||
|
router.include_router(wopi_router)
|
||||||
@@ -5,15 +5,16 @@ from utils.conf import appmeta
|
|||||||
from .admin import admin_router
|
from .admin import admin_router
|
||||||
|
|
||||||
from .callback import callback_router
|
from .callback import callback_router
|
||||||
|
from .category import category_router
|
||||||
from .directory import directory_router
|
from .directory import directory_router
|
||||||
from .download import download_router
|
from .download import download_router
|
||||||
from .file import router as file_router
|
from .file import router as file_router
|
||||||
from .object import object_router
|
from .object import object_router
|
||||||
from .share import share_router
|
from .share import share_router
|
||||||
|
from .trash import trash_router
|
||||||
from .site import site_router
|
from .site import site_router
|
||||||
from .slave import slave_router
|
from .slave import slave_router
|
||||||
from .user import user_router
|
from .user import user_router
|
||||||
from .vas import vas_router
|
|
||||||
from .webdav import webdav_router
|
from .webdav import webdav_router
|
||||||
|
|
||||||
router = APIRouter(prefix="/v1")
|
router = APIRouter(prefix="/v1")
|
||||||
@@ -23,14 +24,15 @@ router = APIRouter(prefix="/v1")
|
|||||||
if appmeta.mode == "master":
|
if appmeta.mode == "master":
|
||||||
router.include_router(admin_router)
|
router.include_router(admin_router)
|
||||||
router.include_router(callback_router)
|
router.include_router(callback_router)
|
||||||
|
router.include_router(category_router)
|
||||||
router.include_router(directory_router)
|
router.include_router(directory_router)
|
||||||
router.include_router(download_router)
|
router.include_router(download_router)
|
||||||
router.include_router(file_router)
|
router.include_router(file_router)
|
||||||
router.include_router(object_router)
|
router.include_router(object_router)
|
||||||
router.include_router(share_router)
|
router.include_router(share_router)
|
||||||
router.include_router(site_router)
|
router.include_router(site_router)
|
||||||
|
router.include_router(trash_router)
|
||||||
router.include_router(user_router)
|
router.include_router(user_router)
|
||||||
router.include_router(vas_router)
|
|
||||||
router.include_router(webdav_router)
|
router.include_router(webdav_router)
|
||||||
elif appmeta.mode == "slave":
|
elif appmeta.mode == "slave":
|
||||||
router.include_router(slave_router)
|
router.include_router(slave_router)
|
||||||
|
|||||||
@@ -9,20 +9,27 @@ from sqlmodels import (
|
|||||||
User, ResponseBase,
|
User, ResponseBase,
|
||||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||||
)
|
)
|
||||||
from sqlmodels.base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
from sqlmodels.setting import (
|
from sqlmodels.setting import (
|
||||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||||
)
|
)
|
||||||
from sqlmodels.setting import SettingsType
|
from sqlmodels.setting import SettingsType
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
from utils.conf import appmeta
|
from utils.conf import appmeta
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ee.service import get_cached_license
|
||||||
|
except ImportError:
|
||||||
|
get_cached_license = None
|
||||||
|
|
||||||
from .file import admin_file_router
|
from .file import admin_file_router
|
||||||
|
from .file_app import admin_file_app_router
|
||||||
from .group import admin_group_router
|
from .group import admin_group_router
|
||||||
from .policy import admin_policy_router
|
from .policy import admin_policy_router
|
||||||
from .share import admin_share_router
|
from .share import admin_share_router
|
||||||
from .task import admin_task_router
|
from .task import admin_task_router
|
||||||
from .user import admin_user_router
|
from .user import admin_user_router
|
||||||
from .vas import admin_vas_router
|
from .theme import admin_theme_router
|
||||||
|
|
||||||
|
|
||||||
class Aria2TestRequest(SQLModelBase):
|
class Aria2TestRequest(SQLModelBase):
|
||||||
@@ -43,10 +50,11 @@ admin_router = APIRouter(
|
|||||||
admin_router.include_router(admin_group_router)
|
admin_router.include_router(admin_group_router)
|
||||||
admin_router.include_router(admin_user_router)
|
admin_router.include_router(admin_user_router)
|
||||||
admin_router.include_router(admin_file_router)
|
admin_router.include_router(admin_file_router)
|
||||||
|
admin_router.include_router(admin_file_app_router)
|
||||||
admin_router.include_router(admin_policy_router)
|
admin_router.include_router(admin_policy_router)
|
||||||
admin_router.include_router(admin_share_router)
|
admin_router.include_router(admin_share_router)
|
||||||
admin_router.include_router(admin_task_router)
|
admin_router.include_router(admin_task_router)
|
||||||
admin_router.include_router(admin_vas_router)
|
admin_router.include_router(admin_theme_router)
|
||||||
|
|
||||||
# 离线下载 /api/admin/aria2
|
# 离线下载 /api/admin/aria2
|
||||||
admin_aria2_router = APIRouter(
|
admin_aria2_router = APIRouter(
|
||||||
@@ -155,9 +163,19 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
|||||||
if site_url_setting and site_url_setting.value:
|
if site_url_setting and site_url_setting.value:
|
||||||
site_urls.append(site_url_setting.value)
|
site_urls.append(site_url_setting.value)
|
||||||
|
|
||||||
# 许可证信息(从设置读取或使用默认值)
|
# 许可证信息(Pro 版本从缓存读取,CE 版本永不过期)
|
||||||
|
if appmeta.IsPro and get_cached_license:
|
||||||
|
payload = get_cached_license()
|
||||||
license_info = LicenseInfo(
|
license_info = LicenseInfo(
|
||||||
expired_at=now + timedelta(days=365),
|
expired_at=payload.expires_at,
|
||||||
|
signed_at=payload.issued_at,
|
||||||
|
root_domains=[],
|
||||||
|
domains=[payload.domain],
|
||||||
|
vol_domains=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
license_info = LicenseInfo(
|
||||||
|
expired_at=datetime.max,
|
||||||
signed_at=now,
|
signed_at=now,
|
||||||
root_domains=[],
|
root_domains=[],
|
||||||
domains=[],
|
domains=[],
|
||||||
@@ -221,11 +239,11 @@ async def router_admin_update_settings(
|
|||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
existing.value = item.value
|
existing.value = item.value
|
||||||
await existing.save(session)
|
existing = await existing.save(session)
|
||||||
updated_count += 1
|
updated_count += 1
|
||||||
else:
|
else:
|
||||||
new_setting = Setting(type=item.type, name=item.name, value=item.value)
|
new_setting = Setting(type=item.type, name=item.name, value=item.value)
|
||||||
await new_setting.save(session)
|
new_setting = await new_setting.save(session)
|
||||||
created_count += 1
|
created_count += 1
|
||||||
|
|
||||||
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")
|
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")
|
||||||
@@ -279,16 +297,17 @@ async def router_admin_get_settings(
|
|||||||
path='/test',
|
path='/test',
|
||||||
summary='测试 Aria2 连接',
|
summary='测试 Aria2 连接',
|
||||||
description='Test Aria2 RPC connection',
|
description='Test Aria2 RPC connection',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_admin_aira2_test(
|
async def router_admin_aira2_test(
|
||||||
request: Aria2TestRequest,
|
request: Aria2TestRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
测试 Aria2 RPC 连接。
|
测试 Aria2 RPC 连接。
|
||||||
|
|
||||||
:param request: 测试请求
|
:param request: 测试请求
|
||||||
:return: 测试结果
|
:raises HTTPException: 连接失败时抛出 400
|
||||||
"""
|
"""
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@@ -303,22 +322,18 @@ async def router_admin_aira2_test(
|
|||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
return ResponseBase(
|
raise HTTPException(
|
||||||
code=400,
|
status_code=400,
|
||||||
msg=f"连接失败,HTTP {resp.status}"
|
detail=f"连接失败,HTTP {resp.status}",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await resp.json()
|
result = await resp.json()
|
||||||
if "error" in result:
|
if "error" in result:
|
||||||
return ResponseBase(
|
raise HTTPException(
|
||||||
code=400,
|
status_code=400,
|
||||||
msg=f"Aria2 错误: {result['error']['message']}"
|
detail=f"Aria2 错误: {result['error']['message']}",
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
version = result.get("result", {}).get("version", "unknown")
|
raise
|
||||||
return ResponseBase(data={
|
|
||||||
"connected": True,
|
|
||||||
"version": version,
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ResponseBase(code=400, msg=f"连接失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
|
||||||
@@ -54,7 +54,7 @@ async def _set_ban_recursive(
|
|||||||
obj.banned_by = None
|
obj.banned_by = None
|
||||||
obj.ban_reason = None
|
obj.ban_reason = None
|
||||||
|
|
||||||
await obj.save(session)
|
obj = await obj.save(session)
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@@ -131,9 +131,7 @@ async def router_admin_preview_file(
|
|||||||
:param file_id: 文件UUID
|
:param file_id: 文件UUID
|
||||||
:return: 文件内容
|
:return: 文件内容
|
||||||
"""
|
"""
|
||||||
file_obj = await Object.get(session, Object.id == file_id)
|
file_obj = await Object.get_exist_one(session, file_id)
|
||||||
if not file_obj:
|
|
||||||
raise HTTPException(status_code=404, detail="文件不存在")
|
|
||||||
|
|
||||||
if not file_obj.is_file:
|
if not file_obj.is_file:
|
||||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||||
@@ -182,9 +180,7 @@ async def router_admin_ban_file(
|
|||||||
:param claims: 当前管理员 JWT claims
|
:param claims: 当前管理员 JWT claims
|
||||||
:return: 封禁结果
|
:return: 封禁结果
|
||||||
"""
|
"""
|
||||||
file_obj = await Object.get(session, Object.id == file_id)
|
file_obj = await Object.get_exist_one(session, file_id)
|
||||||
if not file_obj:
|
|
||||||
raise HTTPException(status_code=404, detail="文件不存在")
|
|
||||||
|
|
||||||
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
||||||
|
|
||||||
@@ -212,9 +208,7 @@ async def router_admin_delete_file(
|
|||||||
:param delete_physical: 是否同时删除物理文件
|
:param delete_physical: 是否同时删除物理文件
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
file_obj = await Object.get(session, Object.id == file_id)
|
file_obj = await Object.get_exist_one(session, file_id)
|
||||||
if not file_obj:
|
|
||||||
raise HTTPException(status_code=404, detail="文件不存在")
|
|
||||||
|
|
||||||
if not file_obj.is_file:
|
if not file_obj.is_file:
|
||||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||||
|
|||||||
450
routers/api/v1/admin/file_app/__init__.py
Normal file
450
routers/api/v1/admin/file_app/__init__.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
"""
|
||||||
|
管理员文件应用管理端点
|
||||||
|
|
||||||
|
提供文件查看器应用的 CRUD、扩展名管理、用户组权限管理和 WOPI Discovery。
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from loguru import logger as l
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from middleware.auth import admin_required
|
||||||
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
|
from service.wopi import parse_wopi_discovery_xml
|
||||||
|
from sqlmodels import (
|
||||||
|
FileApp,
|
||||||
|
FileAppCreateRequest,
|
||||||
|
FileAppExtension,
|
||||||
|
FileAppGroupLink,
|
||||||
|
FileAppListResponse,
|
||||||
|
FileAppResponse,
|
||||||
|
FileAppUpdateRequest,
|
||||||
|
ExtensionUpdateRequest,
|
||||||
|
GroupAccessUpdateRequest,
|
||||||
|
WopiDiscoveredExtension,
|
||||||
|
WopiDiscoveryResponse,
|
||||||
|
)
|
||||||
|
from sqlmodels.file_app import FileAppType
|
||||||
|
from utils import http_exceptions
|
||||||
|
|
||||||
|
admin_file_app_router = APIRouter(
|
||||||
|
prefix="/file-app",
|
||||||
|
tags=["admin", "file_app"],
|
||||||
|
dependencies=[Depends(admin_required)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.get(
|
||||||
|
path='/list',
|
||||||
|
summary='列出所有文件应用',
|
||||||
|
)
|
||||||
|
async def list_file_apps(
|
||||||
|
session: SessionDep,
|
||||||
|
table_view: TableViewRequestDep,
|
||||||
|
) -> FileAppListResponse:
|
||||||
|
"""
|
||||||
|
列出所有文件应用端点(分页)
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
"""
|
||||||
|
result = await FileApp.get_with_count(
|
||||||
|
session,
|
||||||
|
table_view=table_view,
|
||||||
|
)
|
||||||
|
|
||||||
|
apps: list[FileAppResponse] = []
|
||||||
|
for app in result.items:
|
||||||
|
extensions = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
FileAppExtension.app_id == app.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
group_links_result = await session.exec(
|
||||||
|
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||||
|
)
|
||||||
|
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||||
|
apps.append(FileAppResponse.from_app(app, extensions, group_links))
|
||||||
|
|
||||||
|
return FileAppListResponse(apps=apps, total=result.count)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.post(
|
||||||
|
path='/',
|
||||||
|
summary='创建文件应用',
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
)
|
||||||
|
async def create_file_app(
|
||||||
|
session: SessionDep,
|
||||||
|
request: FileAppCreateRequest,
|
||||||
|
) -> FileAppResponse:
|
||||||
|
"""
|
||||||
|
创建文件应用端点
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 409: app_key 已存在
|
||||||
|
"""
|
||||||
|
# 检查 app_key 唯一
|
||||||
|
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
|
||||||
|
if existing:
|
||||||
|
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
|
||||||
|
|
||||||
|
# 创建应用
|
||||||
|
app = FileApp(
|
||||||
|
name=request.name,
|
||||||
|
app_key=request.app_key,
|
||||||
|
type=request.type,
|
||||||
|
icon=request.icon,
|
||||||
|
description=request.description,
|
||||||
|
is_enabled=request.is_enabled,
|
||||||
|
is_restricted=request.is_restricted,
|
||||||
|
iframe_url_template=request.iframe_url_template,
|
||||||
|
wopi_discovery_url=request.wopi_discovery_url,
|
||||||
|
wopi_editor_url_template=request.wopi_editor_url_template,
|
||||||
|
)
|
||||||
|
app = await app.save(session)
|
||||||
|
app_id = app.id
|
||||||
|
|
||||||
|
# 创建扩展名关联
|
||||||
|
extensions: list[FileAppExtension] = []
|
||||||
|
for i, ext in enumerate(request.extensions):
|
||||||
|
normalized = ext.lower().strip().lstrip('.')
|
||||||
|
ext_record = FileAppExtension(
|
||||||
|
app_id=app_id,
|
||||||
|
extension=normalized,
|
||||||
|
priority=i,
|
||||||
|
)
|
||||||
|
ext_record = await ext_record.save(session)
|
||||||
|
extensions.append(ext_record)
|
||||||
|
|
||||||
|
# 创建用户组关联
|
||||||
|
group_links: list[FileAppGroupLink] = []
|
||||||
|
for group_id in request.allowed_group_ids:
|
||||||
|
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
|
||||||
|
session.add(link)
|
||||||
|
group_links.append(link)
|
||||||
|
if group_links:
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(app)
|
||||||
|
|
||||||
|
l.info(f"创建文件应用: {app.name} ({app.app_key})")
|
||||||
|
|
||||||
|
return FileAppResponse.from_app(app, extensions, group_links)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.get(
|
||||||
|
path='/{app_id}',
|
||||||
|
summary='获取文件应用详情',
|
||||||
|
)
|
||||||
|
async def get_file_app(
|
||||||
|
session: SessionDep,
|
||||||
|
app_id: UUID,
|
||||||
|
) -> FileAppResponse:
|
||||||
|
"""
|
||||||
|
获取文件应用详情端点
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 应用不存在
|
||||||
|
"""
|
||||||
|
app = await FileApp.get_exist_one(session, app_id)
|
||||||
|
|
||||||
|
extensions = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
FileAppExtension.app_id == app.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
group_links_result = await session.exec(
|
||||||
|
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||||
|
)
|
||||||
|
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||||
|
|
||||||
|
return FileAppResponse.from_app(app, extensions, group_links)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.patch(
|
||||||
|
path='/{app_id}',
|
||||||
|
summary='更新文件应用',
|
||||||
|
)
|
||||||
|
async def update_file_app(
|
||||||
|
session: SessionDep,
|
||||||
|
app_id: UUID,
|
||||||
|
request: FileAppUpdateRequest,
|
||||||
|
) -> FileAppResponse:
|
||||||
|
"""
|
||||||
|
更新文件应用端点
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 应用不存在
|
||||||
|
- 409: 新 app_key 已被其他应用使用
|
||||||
|
"""
|
||||||
|
app = await FileApp.get_exist_one(session, app_id)
|
||||||
|
|
||||||
|
# 检查 app_key 唯一性
|
||||||
|
if request.app_key is not None and request.app_key != app.app_key:
|
||||||
|
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
|
||||||
|
if existing:
|
||||||
|
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
|
||||||
|
|
||||||
|
# 更新非 None 字段
|
||||||
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(app, key, value)
|
||||||
|
|
||||||
|
app = await app.save(session)
|
||||||
|
|
||||||
|
extensions = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
FileAppExtension.app_id == app.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
group_links_result = await session.exec(
|
||||||
|
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||||
|
)
|
||||||
|
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||||
|
|
||||||
|
l.info(f"更新文件应用: {app.name} ({app.app_key})")
|
||||||
|
|
||||||
|
return FileAppResponse.from_app(app, extensions, group_links)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.delete(
|
||||||
|
path='/{app_id}',
|
||||||
|
summary='删除文件应用',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def delete_file_app(
|
||||||
|
session: SessionDep,
|
||||||
|
app_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
删除文件应用端点(级联删除扩展名、用户偏好和用户组关联)
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 应用不存在
|
||||||
|
"""
|
||||||
|
app = await FileApp.get_exist_one(session, app_id)
|
||||||
|
|
||||||
|
app_name = app.app_key
|
||||||
|
await FileApp.delete(session, app)
|
||||||
|
l.info(f"删除文件应用: {app_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.put(
|
||||||
|
path='/{app_id}/extensions',
|
||||||
|
summary='全量替换扩展名列表',
|
||||||
|
)
|
||||||
|
async def update_extensions(
|
||||||
|
session: SessionDep,
|
||||||
|
app_id: UUID,
|
||||||
|
request: ExtensionUpdateRequest,
|
||||||
|
) -> FileAppResponse:
|
||||||
|
"""
|
||||||
|
全量替换扩展名列表端点
|
||||||
|
|
||||||
|
先删除旧的扩展名关联,再创建新的。
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 应用不存在
|
||||||
|
"""
|
||||||
|
app = await FileApp.get_exist_one(session, app_id)
|
||||||
|
|
||||||
|
# 保留旧扩展名的 wopi_action_url(Discovery 填充的值)
|
||||||
|
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
FileAppExtension.app_id == app_id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
old_url_map: dict[str, str] = {
|
||||||
|
ext.extension: ext.wopi_action_url
|
||||||
|
for ext in old_extensions
|
||||||
|
if ext.wopi_action_url
|
||||||
|
}
|
||||||
|
for old_ext in old_extensions:
|
||||||
|
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# 创建新的扩展名(保留已有的 wopi_action_url)
|
||||||
|
new_extensions: list[FileAppExtension] = []
|
||||||
|
for i, ext in enumerate(request.extensions):
|
||||||
|
normalized = ext.lower().strip().lstrip('.')
|
||||||
|
ext_record = FileAppExtension(
|
||||||
|
app_id=app_id,
|
||||||
|
extension=normalized,
|
||||||
|
priority=i,
|
||||||
|
wopi_action_url=old_url_map.get(normalized),
|
||||||
|
)
|
||||||
|
session.add(ext_record)
|
||||||
|
new_extensions.append(ext_record)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
# refresh commit 后过期的对象
|
||||||
|
await session.refresh(app)
|
||||||
|
for ext_record in new_extensions:
|
||||||
|
await session.refresh(ext_record)
|
||||||
|
|
||||||
|
group_links_result = await session.exec(
|
||||||
|
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
|
||||||
|
)
|
||||||
|
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||||
|
|
||||||
|
l.info(f"更新文件应用 {app.app_key} 的扩展名: {request.extensions}")
|
||||||
|
|
||||||
|
return FileAppResponse.from_app(app, new_extensions, group_links)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.put(
|
||||||
|
path='/{app_id}/groups',
|
||||||
|
summary='全量替换允许的用户组',
|
||||||
|
)
|
||||||
|
async def update_group_access(
|
||||||
|
session: SessionDep,
|
||||||
|
app_id: UUID,
|
||||||
|
request: GroupAccessUpdateRequest,
|
||||||
|
) -> FileAppResponse:
|
||||||
|
"""
|
||||||
|
全量替换允许的用户组端点
|
||||||
|
|
||||||
|
先删除旧的关联,再创建新的。
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 应用不存在
|
||||||
|
"""
|
||||||
|
app = await FileApp.get_exist_one(session, app_id)
|
||||||
|
|
||||||
|
# 删除旧的用户组关联
|
||||||
|
old_links_result = await session.exec(
|
||||||
|
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
|
||||||
|
)
|
||||||
|
old_links: list[FileAppGroupLink] = list(old_links_result.all())
|
||||||
|
for old_link in old_links:
|
||||||
|
await session.delete(old_link)
|
||||||
|
|
||||||
|
# 创建新的用户组关联
|
||||||
|
new_links: list[FileAppGroupLink] = []
|
||||||
|
for group_id in request.group_ids:
|
||||||
|
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
|
||||||
|
session.add(link)
|
||||||
|
new_links.append(link)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(app)
|
||||||
|
|
||||||
|
extensions = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
FileAppExtension.app_id == app_id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
|
||||||
|
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
|
||||||
|
|
||||||
|
return FileAppResponse.from_app(app, extensions, new_links)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_file_app_router.post(
|
||||||
|
path='/{app_id}/discover',
|
||||||
|
summary='执行 WOPI Discovery',
|
||||||
|
)
|
||||||
|
async def discover_wopi(
|
||||||
|
session: SessionDep,
|
||||||
|
app_id: UUID,
|
||||||
|
) -> WopiDiscoveryResponse:
|
||||||
|
"""
|
||||||
|
从 WOPI 服务端获取 Discovery XML 并自动配置扩展名和 URL 模板。
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 验证 FileApp 存在且为 WOPI 类型
|
||||||
|
2. 使用 FileApp.wopi_discovery_url 获取 Discovery XML
|
||||||
|
3. 解析 XML,提取扩展名和动作 URL
|
||||||
|
4. 全量替换 FileAppExtension 记录(带 wopi_action_url)
|
||||||
|
|
||||||
|
认证:管理员权限
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 应用不存在
|
||||||
|
- 400: 非 WOPI 类型 / discovery URL 未配置 / XML 解析失败
|
||||||
|
- 502: WOPI 服务端不可达或返回无效响应
|
||||||
|
"""
|
||||||
|
app = await FileApp.get_exist_one(session, app_id)
|
||||||
|
|
||||||
|
if app.type != FileAppType.WOPI:
|
||||||
|
http_exceptions.raise_bad_request("仅 WOPI 类型应用支持自动发现")
|
||||||
|
|
||||||
|
if not app.wopi_discovery_url:
|
||||||
|
http_exceptions.raise_bad_request("未配置 WOPI Discovery URL")
|
||||||
|
|
||||||
|
# commit 后对象会过期,先保存需要的值
|
||||||
|
discovery_url = app.wopi_discovery_url
|
||||||
|
app_key = app.app_key
|
||||||
|
|
||||||
|
# 获取 Discovery XML
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as client:
|
||||||
|
async with client.get(
|
||||||
|
discovery_url,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
http_exceptions.raise_bad_gateway(
|
||||||
|
f"WOPI 服务端返回 HTTP {resp.status}"
|
||||||
|
)
|
||||||
|
xml_content = await resp.text()
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
http_exceptions.raise_bad_gateway(f"无法连接 WOPI 服务端: {e}")
|
||||||
|
|
||||||
|
# 解析 XML
|
||||||
|
try:
|
||||||
|
action_urls, app_names = parse_wopi_discovery_xml(xml_content)
|
||||||
|
except ValueError as e:
|
||||||
|
http_exceptions.raise_bad_request(str(e))
|
||||||
|
|
||||||
|
if not action_urls:
|
||||||
|
return WopiDiscoveryResponse(app_names=app_names)
|
||||||
|
|
||||||
|
# 全量替换扩展名
|
||||||
|
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
FileAppExtension.app_id == app_id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
for old_ext in old_extensions:
|
||||||
|
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
new_extensions: list[FileAppExtension] = []
|
||||||
|
discovered: list[WopiDiscoveredExtension] = []
|
||||||
|
for i, (ext, action_url) in enumerate(sorted(action_urls.items())):
|
||||||
|
ext_record = FileAppExtension(
|
||||||
|
app_id=app_id,
|
||||||
|
extension=ext,
|
||||||
|
priority=i,
|
||||||
|
wopi_action_url=action_url,
|
||||||
|
)
|
||||||
|
session.add(ext_record)
|
||||||
|
new_extensions.append(ext_record)
|
||||||
|
discovered.append(WopiDiscoveredExtension(extension=ext, action_url=action_url))
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
l.info(
|
||||||
|
f"WOPI Discovery 完成: 应用 {app_key}, "
|
||||||
|
f"发现 {len(discovered)} 个扩展名"
|
||||||
|
)
|
||||||
|
|
||||||
|
return WopiDiscoveryResponse(
|
||||||
|
discovered_extensions=discovered,
|
||||||
|
app_names=app_names,
|
||||||
|
applied_count=len(discovered),
|
||||||
|
)
|
||||||
@@ -55,7 +55,7 @@ async def router_admin_get_groups(
|
|||||||
async def router_admin_get_group(
|
async def router_admin_get_group(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
group_id: UUID,
|
group_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> GroupDetailResponse:
|
||||||
"""
|
"""
|
||||||
根据用户组ID获取用户组详细信息。
|
根据用户组ID获取用户组详细信息。
|
||||||
|
|
||||||
@@ -63,17 +63,12 @@ async def router_admin_get_group(
|
|||||||
:param group_id: 用户组UUID
|
:param group_id: 用户组UUID
|
||||||
:return: 用户组详情
|
:return: 用户组详情
|
||||||
"""
|
"""
|
||||||
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
|
group = await Group.get_exist_one(session, group_id, load=[Group.options, Group.policies])
|
||||||
|
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
|
||||||
|
|
||||||
# 直接访问已加载的关系,无需额外查询
|
# 直接访问已加载的关系,无需额外查询
|
||||||
policies = group.policies
|
policies = group.policies
|
||||||
user_count = await User.count(session, User.group_id == group_id)
|
user_count = await User.count(session, User.group_id == group_id)
|
||||||
response = GroupDetailResponse.from_group(group, user_count, policies)
|
return GroupDetailResponse.from_group(group, user_count, policies)
|
||||||
|
|
||||||
return ResponseBase(data=response.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
@admin_group_router.get(
|
@admin_group_router.get(
|
||||||
@@ -96,9 +91,7 @@ async def router_admin_get_group_members(
|
|||||||
:return: 分页成员列表
|
:return: 分页成员列表
|
||||||
"""
|
"""
|
||||||
# 验证组存在
|
# 验证组存在
|
||||||
group = await Group.get(session, Group.id == group_id)
|
await Group.get_exist_one(session, group_id)
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
|
||||||
|
|
||||||
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
|
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
|
||||||
|
|
||||||
@@ -140,10 +133,11 @@ async def router_admin_create_group(
|
|||||||
speed_limit=request.speed_limit,
|
speed_limit=request.speed_limit,
|
||||||
)
|
)
|
||||||
group = await group.save(session)
|
group = await group.save(session)
|
||||||
|
group_id_val: UUID = group.id
|
||||||
|
|
||||||
# 创建选项
|
# 创建选项
|
||||||
options = GroupOptions(
|
options = GroupOptions(
|
||||||
group_id=group.id,
|
group_id=group_id_val,
|
||||||
share_download=request.share_download,
|
share_download=request.share_download,
|
||||||
share_free=request.share_free,
|
share_free=request.share_free,
|
||||||
relocate=request.relocate,
|
relocate=request.relocate,
|
||||||
@@ -156,11 +150,11 @@ async def router_admin_create_group(
|
|||||||
aria2=request.aria2,
|
aria2=request.aria2,
|
||||||
redirected_source=request.redirected_source,
|
redirected_source=request.redirected_source,
|
||||||
)
|
)
|
||||||
await options.save(session)
|
options = await options.save(session)
|
||||||
|
|
||||||
# 关联存储策略
|
# 关联存储策略
|
||||||
for policy_id in request.policy_ids:
|
for policy_id in request.policy_ids:
|
||||||
link = GroupPolicyLink(group_id=group.id, policy_id=policy_id)
|
link = GroupPolicyLink(group_id=group_id_val, policy_id=policy_id)
|
||||||
session.add(link)
|
session.add(link)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -187,9 +181,7 @@ async def router_admin_update_group(
|
|||||||
:param request: 更新请求
|
:param request: 更新请求
|
||||||
:return: 更新结果
|
:return: 更新结果
|
||||||
"""
|
"""
|
||||||
group = await Group.get(session, Group.id == group_id, load=Group.options)
|
group = await Group.get_exist_one(session, group_id, load=Group.options)
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
|
||||||
|
|
||||||
# 检查名称唯一性(如果要更新名称)
|
# 检查名称唯一性(如果要更新名称)
|
||||||
if request.name and request.name != group.name:
|
if request.name and request.name != group.name:
|
||||||
@@ -219,7 +211,7 @@ async def router_admin_update_group(
|
|||||||
if options_data:
|
if options_data:
|
||||||
for key, value in options_data.items():
|
for key, value in options_data.items():
|
||||||
setattr(group.options, key, value)
|
setattr(group.options, key, value)
|
||||||
await group.options.save(session)
|
group.options = await group.options.save(session)
|
||||||
|
|
||||||
# 更新策略关联
|
# 更新策略关联
|
||||||
if request.policy_ids is not None:
|
if request.policy_ids is not None:
|
||||||
@@ -257,9 +249,7 @@ async def router_admin_delete_group(
|
|||||||
:param group_id: 用户组UUID
|
:param group_id: 用户组UUID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
group = await Group.get(session, Group.id == group_id)
|
group = await Group.get_exist_one(session, group_id)
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
|
||||||
|
|
||||||
# 检查是否有用户属于该组
|
# 检查是否有用户属于该组
|
||||||
user_count = await User.count(session, User.group_id == group_id)
|
user_count = await User.count(session, User.group_id == group_id)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
@@ -7,16 +8,95 @@ from sqlmodel import Field
|
|||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
|
||||||
ListResponse, Object, )
|
PolicyUpdateRequest, ResponseBase, ListResponse, Object,
|
||||||
from sqlmodels.base import SQLModelBase
|
)
|
||||||
from service.storage import DirectoryCreationError, LocalStorageService
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
|
||||||
|
|
||||||
admin_policy_router = APIRouter(
|
admin_policy_router = APIRouter(
|
||||||
prefix='/policy',
|
prefix='/policy',
|
||||||
tags=['admin', 'admin_policy']
|
tags=['admin', 'admin_policy']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PathTestResponse(SQLModelBase):
|
||||||
|
"""路径测试响应"""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
"""解析后的路径"""
|
||||||
|
|
||||||
|
is_exists: bool
|
||||||
|
"""路径是否存在"""
|
||||||
|
|
||||||
|
is_writable: bool
|
||||||
|
"""路径是否可写"""
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyGroupInfo(SQLModelBase):
|
||||||
|
"""策略关联的用户组信息"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""用户组UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""用户组名称"""
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyDetailResponse(SQLModelBase):
|
||||||
|
"""存储策略详情响应"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""策略UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""策略名称"""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
"""策略类型"""
|
||||||
|
|
||||||
|
server: str | None
|
||||||
|
"""服务器地址"""
|
||||||
|
|
||||||
|
bucket_name: str | None
|
||||||
|
"""存储桶名称"""
|
||||||
|
|
||||||
|
is_private: bool
|
||||||
|
"""是否私有"""
|
||||||
|
|
||||||
|
base_url: str | None
|
||||||
|
"""基础URL"""
|
||||||
|
|
||||||
|
access_key: str | None
|
||||||
|
"""Access Key"""
|
||||||
|
|
||||||
|
secret_key: str | None
|
||||||
|
"""Secret Key"""
|
||||||
|
|
||||||
|
max_size: int
|
||||||
|
"""最大文件尺寸"""
|
||||||
|
|
||||||
|
auto_rename: bool
|
||||||
|
"""是否自动重命名"""
|
||||||
|
|
||||||
|
dir_name_rule: str | None
|
||||||
|
"""目录命名规则"""
|
||||||
|
|
||||||
|
file_name_rule: str | None
|
||||||
|
"""文件命名规则"""
|
||||||
|
|
||||||
|
is_origin_link_enable: bool
|
||||||
|
"""是否启用外链"""
|
||||||
|
|
||||||
|
options: dict[str, Any] | None
|
||||||
|
"""策略选项"""
|
||||||
|
|
||||||
|
groups: list[PolicyGroupInfo]
|
||||||
|
"""关联的用户组"""
|
||||||
|
|
||||||
|
object_count: int
|
||||||
|
"""使用此策略的对象数量"""
|
||||||
|
|
||||||
class PolicyTestPathRequest(SQLModelBase):
|
class PolicyTestPathRequest(SQLModelBase):
|
||||||
"""测试本地路径请求 DTO"""
|
"""测试本地路径请求 DTO"""
|
||||||
|
|
||||||
@@ -33,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
|
|||||||
secret: str
|
secret: str
|
||||||
"""从机通信密钥"""
|
"""从机通信密钥"""
|
||||||
|
|
||||||
class PolicyCreateRequest(PolicyBase):
|
class PolicyTestS3Request(SQLModelBase):
|
||||||
"""创建存储策略请求 DTO,继承 PolicyBase 中的所有字段"""
|
"""测试 S3 连接请求 DTO"""
|
||||||
pass
|
|
||||||
|
server: str = Field(max_length=255)
|
||||||
|
"""S3 端点地址"""
|
||||||
|
|
||||||
|
bucket_name: str = Field(max_length=255)
|
||||||
|
"""存储桶名称"""
|
||||||
|
|
||||||
|
access_key: str
|
||||||
|
"""Access Key"""
|
||||||
|
|
||||||
|
secret_key: str
|
||||||
|
"""Secret Key"""
|
||||||
|
|
||||||
|
s3_region: str = Field(default='us-east-1', max_length=64)
|
||||||
|
"""S3 区域"""
|
||||||
|
|
||||||
|
s3_path_style: bool = False
|
||||||
|
"""是否使用路径风格"""
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyTestS3Response(SQLModelBase):
|
||||||
|
"""S3 连接测试响应"""
|
||||||
|
|
||||||
|
is_connected: bool
|
||||||
|
"""连接是否成功"""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
"""测试结果消息"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
|
||||||
|
|
||||||
|
_OPTIONS_FIELDS: set[str] = {
|
||||||
|
'token', 'file_type', 'mimetype', 'od_redirect',
|
||||||
|
'chunk_size', 's3_path_style', 's3_region',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.get(
|
@admin_policy_router.get(
|
||||||
path='/list',
|
path='/list',
|
||||||
@@ -70,7 +186,7 @@ async def router_policy_list(
|
|||||||
)
|
)
|
||||||
async def router_policy_test_path(
|
async def router_policy_test_path(
|
||||||
request: PolicyTestPathRequest,
|
request: PolicyTestPathRequest,
|
||||||
) -> ResponseBase:
|
) -> PathTestResponse:
|
||||||
"""
|
"""
|
||||||
测试本地存储路径是否可用。
|
测试本地存储路径是否可用。
|
||||||
|
|
||||||
@@ -97,22 +213,23 @@ async def router_policy_test_path(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ResponseBase(data={
|
return PathTestResponse(
|
||||||
"path": str(path),
|
path=str(path),
|
||||||
"exists": is_exists,
|
is_exists=is_exists,
|
||||||
"writable": is_writable,
|
is_writable=is_writable,
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.post(
|
@admin_policy_router.post(
|
||||||
path='/test/slave',
|
path='/test/slave',
|
||||||
summary='测试从机通信',
|
summary='测试从机通信',
|
||||||
description='Test slave node communication',
|
description='Test slave node communication',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_policy_test_slave(
|
async def router_policy_test_slave(
|
||||||
request: PolicyTestSlaveRequest,
|
request: PolicyTestSlaveRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
测试从机RPC通信。
|
测试从机RPC通信。
|
||||||
|
|
||||||
@@ -129,25 +246,28 @@ async def router_policy_test_slave(
|
|||||||
timeout=aiohttp.ClientTimeout(total=10)
|
timeout=aiohttp.ClientTimeout(total=10)
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
return ResponseBase(data={"connected": True})
|
return
|
||||||
else:
|
else:
|
||||||
return ResponseBase(
|
raise HTTPException(
|
||||||
code=400,
|
status_code=400,
|
||||||
msg=f"从机响应错误,HTTP {resp.status}"
|
detail=f"从机响应错误,HTTP {resp.status}",
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ResponseBase(code=400, msg=f"连接失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
|
||||||
|
|
||||||
@admin_policy_router.post(
|
@admin_policy_router.post(
|
||||||
path='/',
|
path='/',
|
||||||
summary='创建存储策略',
|
summary='创建存储策略',
|
||||||
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
|
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_policy_add_policy(
|
async def router_policy_add_policy(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
request: PolicyCreateRequest,
|
request: PolicyCreateRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
创建存储策略端点
|
创建存储策略端点
|
||||||
|
|
||||||
@@ -201,12 +321,18 @@ async def router_policy_add_policy(
|
|||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
policy = await policy.save(session)
|
policy = await policy.save(session)
|
||||||
|
|
||||||
return ResponseBase(data={
|
# 创建策略选项
|
||||||
"id": str(policy.id),
|
options = PolicyOptions(
|
||||||
"name": policy.name,
|
policy_id=policy.id,
|
||||||
"type": policy.type.value,
|
token=request.token,
|
||||||
"server": policy.server,
|
file_type=request.file_type,
|
||||||
})
|
mimetype=request.mimetype,
|
||||||
|
od_redirect=request.od_redirect,
|
||||||
|
chunk_size=request.chunk_size,
|
||||||
|
s3_path_style=request.s3_path_style,
|
||||||
|
s3_region=request.s3_region,
|
||||||
|
)
|
||||||
|
options = await options.save(session)
|
||||||
|
|
||||||
@admin_policy_router.post(
|
@admin_policy_router.post(
|
||||||
path='/cors',
|
path='/cors',
|
||||||
@@ -257,9 +383,7 @@ async def router_policy_onddrive_oauth(
|
|||||||
:param policy_id: 存储策略UUID
|
:param policy_id: 存储策略UUID
|
||||||
:return: OAuth URL
|
:return: OAuth URL
|
||||||
"""
|
"""
|
||||||
policy = await Policy.get(session, Policy.id == policy_id)
|
policy = await Policy.get_exist_one(session, policy_id)
|
||||||
if not policy:
|
|
||||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
|
||||||
|
|
||||||
# TODO: 实现OneDrive OAuth
|
# TODO: 实现OneDrive OAuth
|
||||||
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
|
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
|
||||||
@@ -274,7 +398,7 @@ async def router_policy_onddrive_oauth(
|
|||||||
async def router_policy_get_policy(
|
async def router_policy_get_policy(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
policy_id: UUID,
|
policy_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> PolicyDetailResponse:
|
||||||
"""
|
"""
|
||||||
获取存储策略详情。
|
获取存储策略详情。
|
||||||
|
|
||||||
@@ -282,9 +406,7 @@ async def router_policy_get_policy(
|
|||||||
:param policy_id: 存储策略UUID
|
:param policy_id: 存储策略UUID
|
||||||
:return: 策略详情
|
:return: 策略详情
|
||||||
"""
|
"""
|
||||||
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
|
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
||||||
if not policy:
|
|
||||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
|
||||||
|
|
||||||
# 获取使用此策略的用户组
|
# 获取使用此策略的用户组
|
||||||
groups = await policy.awaitable_attrs.groups
|
groups = await policy.awaitable_attrs.groups
|
||||||
@@ -292,35 +414,38 @@ async def router_policy_get_policy(
|
|||||||
# 统计使用此策略的对象数量
|
# 统计使用此策略的对象数量
|
||||||
object_count = await Object.count(session, Object.policy_id == policy_id)
|
object_count = await Object.count(session, Object.policy_id == policy_id)
|
||||||
|
|
||||||
return ResponseBase(data={
|
return PolicyDetailResponse(
|
||||||
"id": str(policy.id),
|
id=str(policy.id),
|
||||||
"name": policy.name,
|
name=policy.name,
|
||||||
"type": policy.type.value,
|
type=policy.type.value,
|
||||||
"server": policy.server,
|
server=policy.server,
|
||||||
"bucket_name": policy.bucket_name,
|
bucket_name=policy.bucket_name,
|
||||||
"is_private": policy.is_private,
|
is_private=policy.is_private,
|
||||||
"base_url": policy.base_url,
|
base_url=policy.base_url,
|
||||||
"max_size": policy.max_size,
|
access_key=policy.access_key,
|
||||||
"auto_rename": policy.auto_rename,
|
secret_key=policy.secret_key,
|
||||||
"dir_name_rule": policy.dir_name_rule,
|
max_size=policy.max_size,
|
||||||
"file_name_rule": policy.file_name_rule,
|
auto_rename=policy.auto_rename,
|
||||||
"is_origin_link_enable": policy.is_origin_link_enable,
|
dir_name_rule=policy.dir_name_rule,
|
||||||
"options": policy.options.model_dump() if policy.options else None,
|
file_name_rule=policy.file_name_rule,
|
||||||
"groups": [{"id": str(g.id), "name": g.name} for g in groups],
|
is_origin_link_enable=policy.is_origin_link_enable,
|
||||||
"object_count": object_count,
|
options=policy.options.model_dump() if policy.options else None,
|
||||||
})
|
groups=[PolicyGroupInfo(id=str(g.id), name=g.name) for g in groups],
|
||||||
|
object_count=object_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_policy_router.delete(
|
@admin_policy_router.delete(
|
||||||
path='/{policy_id}',
|
path='/{policy_id}',
|
||||||
summary='删除存储策略',
|
summary='删除存储策略',
|
||||||
description='Delete storage policy by ID',
|
description='Delete storage policy by ID',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_policy_delete_policy(
|
async def router_policy_delete_policy(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
policy_id: UUID,
|
policy_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
删除存储策略。
|
删除存储策略。
|
||||||
|
|
||||||
@@ -330,9 +455,7 @@ async def router_policy_delete_policy(
|
|||||||
:param policy_id: 存储策略UUID
|
:param policy_id: 存储策略UUID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
policy = await Policy.get(session, Policy.id == policy_id)
|
policy = await Policy.get_exist_one(session, policy_id)
|
||||||
if not policy:
|
|
||||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
|
||||||
|
|
||||||
# 检查是否有文件使用此策略
|
# 检查是否有文件使用此策略
|
||||||
file_count = await Object.count(session, Object.policy_id == policy_id)
|
file_count = await Object.count(session, Object.policy_id == policy_id)
|
||||||
@@ -346,4 +469,105 @@ async def router_policy_delete_policy(
|
|||||||
await Policy.delete(session, policy)
|
await Policy.delete(session, policy)
|
||||||
|
|
||||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||||
return ResponseBase(data={"deleted": True})
|
|
||||||
|
|
||||||
|
@admin_policy_router.patch(
|
||||||
|
path='/{policy_id}',
|
||||||
|
summary='更新存储策略',
|
||||||
|
description='更新存储策略配置。策略类型创建后不可更改。',
|
||||||
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_policy_update_policy(
|
||||||
|
session: SessionDep,
|
||||||
|
policy_id: UUID,
|
||||||
|
request: PolicyUpdateRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
更新存储策略端点
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 更新策略基础字段和扩展选项
|
||||||
|
- 策略类型(type)不可更改
|
||||||
|
|
||||||
|
认证:
|
||||||
|
- 需要管理员权限
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param policy_id: 存储策略UUID
|
||||||
|
:param request: 更新请求
|
||||||
|
"""
|
||||||
|
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
||||||
|
|
||||||
|
# 检查名称唯一性(如果要更新名称)
|
||||||
|
if request.name and request.name != policy.name:
|
||||||
|
existing = await Policy.get(session, Policy.name == request.name)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(status_code=409, detail="策略名称已存在")
|
||||||
|
|
||||||
|
# 分离 Policy 字段和 Options 字段
|
||||||
|
all_data = request.model_dump(exclude_unset=True)
|
||||||
|
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
|
||||||
|
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
|
||||||
|
|
||||||
|
# 更新 Policy 基础字段
|
||||||
|
if policy_data:
|
||||||
|
for key, value in policy_data.items():
|
||||||
|
setattr(policy, key, value)
|
||||||
|
policy = await policy.save(session)
|
||||||
|
|
||||||
|
# 更新或创建 PolicyOptions
|
||||||
|
if options_data:
|
||||||
|
if policy.options:
|
||||||
|
for key, value in options_data.items():
|
||||||
|
setattr(policy.options, key, value)
|
||||||
|
policy.options = await policy.options.save(session)
|
||||||
|
else:
|
||||||
|
options = PolicyOptions(policy_id=policy.id, **options_data)
|
||||||
|
options = await options.save(session)
|
||||||
|
|
||||||
|
l.info(f"管理员更新了存储策略: {policy_id}")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_policy_router.post(
|
||||||
|
path='/test/s3',
|
||||||
|
summary='测试 S3 连接',
|
||||||
|
description='测试 S3 存储端点的连通性和凭据有效性。',
|
||||||
|
dependencies=[Depends(admin_required)],
|
||||||
|
)
|
||||||
|
async def router_policy_test_s3(
|
||||||
|
request: PolicyTestS3Request,
|
||||||
|
) -> PolicyTestS3Response:
|
||||||
|
"""
|
||||||
|
测试 S3 连接端点
|
||||||
|
|
||||||
|
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
|
||||||
|
|
||||||
|
:param request: 测试请求
|
||||||
|
:return: 测试结果
|
||||||
|
"""
|
||||||
|
from service.storage import S3APIError
|
||||||
|
|
||||||
|
# 构造临时 Policy 对象用于创建 S3StorageService
|
||||||
|
temp_policy = Policy(
|
||||||
|
name="__test__",
|
||||||
|
type=PolicyType.S3,
|
||||||
|
server=request.server,
|
||||||
|
bucket_name=request.bucket_name,
|
||||||
|
access_key=request.access_key,
|
||||||
|
secret_key=request.secret_key,
|
||||||
|
)
|
||||||
|
s3_service = S3StorageService(
|
||||||
|
temp_policy,
|
||||||
|
region=request.s3_region,
|
||||||
|
is_path_style=request.s3_path_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用 file_exists 发送 HEAD 请求来验证连通性
|
||||||
|
await s3_service.file_exists("__connection_test__")
|
||||||
|
return PolicyTestS3Response(is_connected=True, message="连接成功")
|
||||||
|
except S3APIError as e:
|
||||||
|
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
@@ -6,8 +7,53 @@ from loguru import logger as l
|
|||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
ResponseBase, ListResponse,
|
ListResponse,
|
||||||
Share, AdminShareListItem, )
|
Share, AdminShareListItem,
|
||||||
|
)
|
||||||
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
|
||||||
|
class ShareDetailResponse(SQLModelBase):
|
||||||
|
"""分享详情响应"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""分享UUID"""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
"""分享码"""
|
||||||
|
|
||||||
|
views: int
|
||||||
|
"""浏览次数"""
|
||||||
|
|
||||||
|
downloads: int
|
||||||
|
"""下载次数"""
|
||||||
|
|
||||||
|
remain_downloads: int | None
|
||||||
|
"""剩余下载次数"""
|
||||||
|
|
||||||
|
expires: datetime | None
|
||||||
|
"""过期时间"""
|
||||||
|
|
||||||
|
preview_enabled: bool
|
||||||
|
"""是否启用预览"""
|
||||||
|
|
||||||
|
score: int
|
||||||
|
"""评分"""
|
||||||
|
|
||||||
|
has_password: bool
|
||||||
|
"""是否有密码"""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
"""用户UUID"""
|
||||||
|
|
||||||
|
username: str | None
|
||||||
|
"""用户名"""
|
||||||
|
|
||||||
|
object: dict | None
|
||||||
|
"""关联对象信息"""
|
||||||
|
|
||||||
|
created_at: str
|
||||||
|
"""创建时间"""
|
||||||
|
|
||||||
admin_share_router = APIRouter(
|
admin_share_router = APIRouter(
|
||||||
prefix='/share',
|
prefix='/share',
|
||||||
@@ -53,8 +99,8 @@ async def router_admin_get_share_list(
|
|||||||
)
|
)
|
||||||
async def router_admin_get_share(
|
async def router_admin_get_share(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
share_id: int,
|
share_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> ShareDetailResponse:
|
||||||
"""
|
"""
|
||||||
获取分享详情。
|
获取分享详情。
|
||||||
|
|
||||||
@@ -69,38 +115,39 @@ async def router_admin_get_share(
|
|||||||
obj = await share.awaitable_attrs.object
|
obj = await share.awaitable_attrs.object
|
||||||
user = await share.awaitable_attrs.user
|
user = await share.awaitable_attrs.user
|
||||||
|
|
||||||
return ResponseBase(data={
|
return ShareDetailResponse(
|
||||||
"id": share.id,
|
id=share.id,
|
||||||
"code": share.code,
|
code=share.code,
|
||||||
"views": share.views,
|
views=share.views,
|
||||||
"downloads": share.downloads,
|
downloads=share.downloads,
|
||||||
"remain_downloads": share.remain_downloads,
|
remain_downloads=share.remain_downloads,
|
||||||
"expires": share.expires.isoformat() if share.expires else None,
|
expires=share.expires,
|
||||||
"preview_enabled": share.preview_enabled,
|
preview_enabled=share.preview_enabled,
|
||||||
"score": share.score,
|
score=share.score,
|
||||||
"has_password": bool(share.password),
|
has_password=bool(share.password),
|
||||||
"user_id": str(share.user_id),
|
user_id=str(share.user_id),
|
||||||
"username": user.email if user else None,
|
username=user.email if user else None,
|
||||||
"object": {
|
object={
|
||||||
"id": str(obj.id),
|
"id": str(obj.id),
|
||||||
"name": obj.name,
|
"name": obj.name,
|
||||||
"type": obj.type.value,
|
"type": obj.type.value,
|
||||||
"size": obj.size,
|
"size": obj.size,
|
||||||
} if obj else None,
|
} if obj else None,
|
||||||
"created_at": share.created_at.isoformat(),
|
created_at=share.created_at.isoformat(),
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_share_router.delete(
|
@admin_share_router.delete(
|
||||||
path='/{share_id}',
|
path='/{share_id}',
|
||||||
summary='删除分享',
|
summary='删除分享',
|
||||||
description='Delete share by ID',
|
description='Delete share by ID',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_admin_delete_share(
|
async def router_admin_delete_share(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
share_id: int,
|
share_id: UUID,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
删除分享。
|
删除分享。
|
||||||
|
|
||||||
@@ -108,11 +155,8 @@ async def router_admin_delete_share(
|
|||||||
:param share_id: 分享ID
|
:param share_id: 分享ID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
share = await Share.get(session, Share.id == share_id)
|
share = await Share.get_exist_one(session, share_id)
|
||||||
if not share:
|
|
||||||
raise HTTPException(status_code=404, detail="分享不存在")
|
|
||||||
|
|
||||||
await Share.delete(session, share)
|
await Share.delete(session, share)
|
||||||
|
|
||||||
l.info(f"管理员删除了分享: {share.code}")
|
l.info(f"管理员删除了分享: {share.code}")
|
||||||
return ResponseBase(data={"deleted": True})
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
@@ -6,9 +7,44 @@ from loguru import logger as l
|
|||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
ResponseBase, ListResponse,
|
ListResponse,
|
||||||
Task, TaskSummary,
|
Task, TaskSummary, TaskStatus, TaskType,
|
||||||
)
|
)
|
||||||
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDetailResponse(SQLModelBase):
|
||||||
|
"""任务详情响应"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
"""任务ID"""
|
||||||
|
|
||||||
|
status: TaskStatus
|
||||||
|
"""任务状态"""
|
||||||
|
|
||||||
|
type: TaskType
|
||||||
|
"""任务类型"""
|
||||||
|
|
||||||
|
progress: int
|
||||||
|
"""任务进度"""
|
||||||
|
|
||||||
|
error: str | None
|
||||||
|
"""错误信息"""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
"""用户UUID"""
|
||||||
|
|
||||||
|
username: str | None
|
||||||
|
"""用户名"""
|
||||||
|
|
||||||
|
props: dict[str, Any] | None
|
||||||
|
"""任务属性"""
|
||||||
|
|
||||||
|
created_at: str
|
||||||
|
"""创建时间"""
|
||||||
|
|
||||||
|
updated_at: str
|
||||||
|
"""更新时间"""
|
||||||
|
|
||||||
admin_task_router = APIRouter(
|
admin_task_router = APIRouter(
|
||||||
prefix='/task',
|
prefix='/task',
|
||||||
@@ -67,7 +103,7 @@ async def router_admin_get_task_list(
|
|||||||
async def router_admin_get_task(
|
async def router_admin_get_task(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
) -> ResponseBase:
|
) -> TaskDetailResponse:
|
||||||
"""
|
"""
|
||||||
获取任务详情。
|
获取任务详情。
|
||||||
|
|
||||||
@@ -82,30 +118,31 @@ async def router_admin_get_task(
|
|||||||
user = await task.awaitable_attrs.user
|
user = await task.awaitable_attrs.user
|
||||||
props = await task.awaitable_attrs.props
|
props = await task.awaitable_attrs.props
|
||||||
|
|
||||||
return ResponseBase(data={
|
return TaskDetailResponse(
|
||||||
"id": task.id,
|
id=task.id,
|
||||||
"status": task.status,
|
status=task.status,
|
||||||
"type": task.type,
|
type=task.type,
|
||||||
"progress": task.progress,
|
progress=task.progress,
|
||||||
"error": task.error,
|
error=task.error,
|
||||||
"user_id": str(task.user_id),
|
user_id=str(task.user_id),
|
||||||
"username": user.email if user else None,
|
username=user.email if user else None,
|
||||||
"props": props.model_dump() if props else None,
|
props=props.model_dump() if props else None,
|
||||||
"created_at": task.created_at.isoformat(),
|
created_at=task.created_at.isoformat(),
|
||||||
"updated_at": task.updated_at.isoformat(),
|
updated_at=task.updated_at.isoformat(),
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_task_router.delete(
|
@admin_task_router.delete(
|
||||||
path='/{task_id}',
|
path='/{task_id}',
|
||||||
summary='删除任务',
|
summary='删除任务',
|
||||||
description='Delete task by ID',
|
description='Delete task by ID',
|
||||||
dependencies=[Depends(admin_required)]
|
dependencies=[Depends(admin_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_admin_delete_task(
|
async def router_admin_delete_task(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
删除任务。
|
删除任务。
|
||||||
|
|
||||||
@@ -113,11 +150,8 @@ async def router_admin_delete_task(
|
|||||||
:param task_id: 任务ID
|
:param task_id: 任务ID
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
task = await Task.get(session, Task.id == task_id)
|
task = await Task.get_exist_one(session, task_id)
|
||||||
if not task:
|
|
||||||
raise HTTPException(status_code=404, detail="任务不存在")
|
|
||||||
|
|
||||||
await Task.delete(session, task)
|
await Task.delete(session, task)
|
||||||
|
|
||||||
l.info(f"管理员删除了任务: {task_id}")
|
l.info(f"管理员删除了任务: {task_id}")
|
||||||
return ResponseBase(data={"deleted": True})
|
|
||||||
187
routers/api/v1/admin/theme/__init__.py
Normal file
187
routers/api/v1/admin/theme/__init__.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from loguru import logger as l
|
||||||
|
from sqlalchemy import update as sql_update
|
||||||
|
|
||||||
|
from middleware.auth import admin_required
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import (
|
||||||
|
ThemePreset,
|
||||||
|
ThemePresetCreateRequest,
|
||||||
|
ThemePresetUpdateRequest,
|
||||||
|
ThemePresetResponse,
|
||||||
|
ThemePresetListResponse,
|
||||||
|
)
|
||||||
|
from utils import http_exceptions
|
||||||
|
|
||||||
|
admin_theme_router = APIRouter(
|
||||||
|
prefix="/theme",
|
||||||
|
tags=["admin", "admin_theme"],
|
||||||
|
dependencies=[Depends(admin_required)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_theme_router.get(
|
||||||
|
path='/',
|
||||||
|
summary='获取主题预设列表',
|
||||||
|
)
|
||||||
|
async def router_admin_theme_list(session: SessionDep) -> ThemePresetListResponse:
|
||||||
|
"""
|
||||||
|
获取所有主题预设列表
|
||||||
|
|
||||||
|
认证:需要管理员权限
|
||||||
|
|
||||||
|
响应:
|
||||||
|
- ThemePresetListResponse: 包含所有主题预设的列表
|
||||||
|
"""
|
||||||
|
presets: list[ThemePreset] = await ThemePreset.get(session, fetch_mode="all")
|
||||||
|
return ThemePresetListResponse(
|
||||||
|
themes=[ThemePresetResponse.from_preset(p) for p in presets]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_theme_router.post(
|
||||||
|
path='/',
|
||||||
|
summary='创建主题预设',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_admin_theme_create(
|
||||||
|
session: SessionDep,
|
||||||
|
request: ThemePresetCreateRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
创建新的主题预设
|
||||||
|
|
||||||
|
认证:需要管理员权限
|
||||||
|
|
||||||
|
请求体:
|
||||||
|
- name: 预设名称(唯一)
|
||||||
|
- colors: 颜色配置对象
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 409: 名称已存在
|
||||||
|
"""
|
||||||
|
# 检查名称唯一性
|
||||||
|
existing = await ThemePreset.get(session, ThemePreset.name == request.name)
|
||||||
|
if existing:
|
||||||
|
http_exceptions.raise_conflict(f"主题预设名称 '{request.name}' 已存在")
|
||||||
|
|
||||||
|
preset = ThemePreset(
|
||||||
|
name=request.name,
|
||||||
|
**request.colors.model_dump(),
|
||||||
|
)
|
||||||
|
preset = await preset.save(session)
|
||||||
|
l.info(f"管理员创建了主题预设: {request.name}")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_theme_router.patch(
|
||||||
|
path='/{preset_id}',
|
||||||
|
summary='更新主题预设',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_admin_theme_update(
|
||||||
|
session: SessionDep,
|
||||||
|
preset_id: UUID,
|
||||||
|
request: ThemePresetUpdateRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
部分更新主题预设
|
||||||
|
|
||||||
|
认证:需要管理员权限
|
||||||
|
|
||||||
|
路径参数:
|
||||||
|
- preset_id: 预设UUID
|
||||||
|
|
||||||
|
请求体(均可选):
|
||||||
|
- name: 预设名称
|
||||||
|
- colors: 颜色配置对象
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 预设不存在
|
||||||
|
- 409: 名称已被其他预设使用
|
||||||
|
"""
|
||||||
|
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||||
|
|
||||||
|
# 检查名称唯一性(排除自身)
|
||||||
|
if request.name is not None and request.name != preset.name:
|
||||||
|
existing = await ThemePreset.get(session, ThemePreset.name == request.name)
|
||||||
|
if existing:
|
||||||
|
http_exceptions.raise_conflict(f"主题预设名称 '{request.name}' 已存在")
|
||||||
|
preset.name = request.name
|
||||||
|
|
||||||
|
# 更新颜色字段
|
||||||
|
if request.colors is not None:
|
||||||
|
color_data = request.colors.model_dump()
|
||||||
|
for key, value in color_data.items():
|
||||||
|
setattr(preset, key, value)
|
||||||
|
|
||||||
|
preset = await preset.save(session)
|
||||||
|
l.info(f"管理员更新了主题预设: {preset.name}")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_theme_router.delete(
|
||||||
|
path='/{preset_id}',
|
||||||
|
summary='删除主题预设',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_admin_theme_delete(
|
||||||
|
session: SessionDep,
|
||||||
|
preset_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
删除主题预设
|
||||||
|
|
||||||
|
认证:需要管理员权限
|
||||||
|
|
||||||
|
路径参数:
|
||||||
|
- preset_id: 预设UUID
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 预设不存在
|
||||||
|
|
||||||
|
副作用:
|
||||||
|
- 关联用户的 theme_preset_id 会被数据库 SET NULL
|
||||||
|
"""
|
||||||
|
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||||
|
|
||||||
|
await preset.delete(session)
|
||||||
|
l.info(f"管理员删除了主题预设: {preset.name}")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_theme_router.patch(
|
||||||
|
path='/{preset_id}/default',
|
||||||
|
summary='设为默认主题预设',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_admin_theme_set_default(
|
||||||
|
session: SessionDep,
|
||||||
|
preset_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
将指定预设设为默认主题
|
||||||
|
|
||||||
|
认证:需要管理员权限
|
||||||
|
|
||||||
|
路径参数:
|
||||||
|
- preset_id: 预设UUID
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 预设不存在
|
||||||
|
|
||||||
|
逻辑:
|
||||||
|
- 事务中先清除所有旧默认,再设新默认
|
||||||
|
"""
|
||||||
|
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||||
|
|
||||||
|
# 清除所有旧默认
|
||||||
|
await session.execute(
|
||||||
|
sql_update(ThemePreset)
|
||||||
|
.where(ThemePreset.is_default == True) # noqa: E712
|
||||||
|
.values(is_default=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设新默认
|
||||||
|
preset.is_default = True
|
||||||
|
preset = await preset.save(session)
|
||||||
|
l.info(f"管理员将主题预设 '{preset.name}' 设为默认")
|
||||||
@@ -12,6 +12,7 @@ from sqlmodels import (
|
|||||||
Group, Object, ObjectType, Setting, SettingsType,
|
Group, Object, ObjectType, Setting, SettingsType,
|
||||||
BatchDeleteRequest,
|
BatchDeleteRequest,
|
||||||
)
|
)
|
||||||
|
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||||
from sqlmodels.user import (
|
from sqlmodels.user import (
|
||||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
|
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
|
||||||
)
|
)
|
||||||
@@ -83,14 +84,27 @@ async def router_admin_create_user(
|
|||||||
"""
|
"""
|
||||||
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
||||||
|
|
||||||
|
管理员创建用户时,若提供了 email + password,
|
||||||
|
会同时创建 AuthIdentity(provider=email_password)。
|
||||||
|
|
||||||
:param session: 数据库会话
|
:param session: 数据库会话
|
||||||
:param request: 创建用户请求 DTO
|
:param request: 创建用户请求 DTO
|
||||||
:return: 创建结果
|
:return: 创建结果
|
||||||
"""
|
"""
|
||||||
|
# 如果提供了邮箱,检查唯一性(User 表和 AuthIdentity 表)
|
||||||
|
if request.email:
|
||||||
existing_user = await User.get(session, User.email == request.email)
|
existing_user = await User.get(session, User.email == request.email)
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||||
|
|
||||||
|
existing_identity = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||||
|
& (AuthIdentity.identifier == request.email),
|
||||||
|
)
|
||||||
|
if existing_identity:
|
||||||
|
raise HTTPException(status_code=409, detail="该邮箱已被绑定")
|
||||||
|
|
||||||
# 验证用户组存在
|
# 验证用户组存在
|
||||||
group = await Group.get(session, Group.id == request.group_id)
|
group = await Group.get(session, Group.id == request.group_id)
|
||||||
if not group:
|
if not group:
|
||||||
@@ -98,12 +112,25 @@ async def router_admin_create_user(
|
|||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
email=request.email,
|
email=request.email,
|
||||||
password=Password.hash(request.password),
|
|
||||||
nickname=request.nickname,
|
nickname=request.nickname,
|
||||||
group_id=request.group_id,
|
group_id=request.group_id,
|
||||||
status=request.status,
|
status=request.status,
|
||||||
)
|
)
|
||||||
user = await user.save(session)
|
user = await user.save(session)
|
||||||
|
|
||||||
|
# 如果提供了邮箱和密码,创建邮箱密码认证身份
|
||||||
|
if request.email and request.password:
|
||||||
|
identity = AuthIdentity(
|
||||||
|
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||||
|
identifier=request.email,
|
||||||
|
credential=Password.hash(request.password),
|
||||||
|
is_primary=True,
|
||||||
|
is_verified=True,
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
identity = await identity.save(session)
|
||||||
|
|
||||||
|
user = await User.get(session, User.id == user.id, load=User.group)
|
||||||
return user.to_public()
|
return user.to_public()
|
||||||
|
|
||||||
|
|
||||||
@@ -127,9 +154,7 @@ async def router_admin_update_user(
|
|||||||
:param request: 更新请求
|
:param request: 更新请求
|
||||||
:return: 更新结果
|
:return: 更新结果
|
||||||
"""
|
"""
|
||||||
user = await User.get(session, User.id == user_id)
|
user = await User.get_exist_one(session, user_id)
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="用户不存在")
|
|
||||||
|
|
||||||
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
||||||
default_admin_setting = await Setting.get(
|
default_admin_setting = await Setting.get(
|
||||||
@@ -148,17 +173,7 @@ async def router_admin_update_user(
|
|||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||||
|
|
||||||
# 如果更新密码,需要加密
|
|
||||||
update_data = request.model_dump(exclude_unset=True)
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
if 'password' in update_data and update_data['password']:
|
|
||||||
update_data['password'] = Password.hash(update_data['password'])
|
|
||||||
elif 'password' in update_data:
|
|
||||||
del update_data['password'] # 空密码不更新
|
|
||||||
|
|
||||||
# 验证两步验证密钥格式(如果提供了值且不为 None,长度必须为 32)
|
|
||||||
if 'two_factor' in update_data and update_data['two_factor'] is not None:
|
|
||||||
if len(update_data['two_factor']) != 32:
|
|
||||||
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
|
|
||||||
|
|
||||||
# 记录旧 status 以便检测变更
|
# 记录旧 status 以便检测变更
|
||||||
old_status = user.status
|
old_status = user.status
|
||||||
@@ -175,7 +190,7 @@ async def router_admin_update_user(
|
|||||||
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
|
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
|
||||||
await UserBanStore.unban(str(user_id))
|
await UserBanStore.unban(str(user_id))
|
||||||
|
|
||||||
l.info(f"管理员更新了用户: {request.email}")
|
l.info(f"管理员更新了用户: {user.email}")
|
||||||
|
|
||||||
|
|
||||||
@admin_user_router.delete(
|
@admin_user_router.delete(
|
||||||
@@ -198,9 +213,17 @@ async def router_admin_delete_users(
|
|||||||
:param request: 批量删除请求,包含待删除用户的 UUID 列表
|
:param request: 批量删除请求,包含待删除用户的 UUID 列表
|
||||||
:return: 删除结果(已删除数 / 总请求数)
|
:return: 删除结果(已删除数 / 总请求数)
|
||||||
"""
|
"""
|
||||||
deleted = 0
|
|
||||||
for uid in request.ids:
|
for uid in request.ids:
|
||||||
user = await User.get(session, User.id == uid)
|
user = await User.get(session, User.id == uid, load=User.group)
|
||||||
|
|
||||||
|
# 安全检查:默认管理员不允许被删除(通过 Setting 中的 default_admin_id 识别)
|
||||||
|
default_admin_setting = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||||
|
)
|
||||||
|
if user and default_admin_setting and default_admin_setting.value == str(uid):
|
||||||
|
raise HTTPException(status_code=403, detail=f"默认管理员不允许被删除")
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
await User.delete(session, user)
|
await User.delete(session, user)
|
||||||
l.info(f"管理员删除了用户: {user.email}")
|
l.info(f"管理员删除了用户: {user.email}")
|
||||||
@@ -228,13 +251,12 @@ async def router_admin_calibrate_storage(
|
|||||||
:param user_id: 用户UUID
|
:param user_id: 用户UUID
|
||||||
:return: 校准结果
|
:return: 校准结果
|
||||||
"""
|
"""
|
||||||
user = await User.get(session, User.id == user_id)
|
user = await User.get_exist_one(session, user_id)
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="用户不存在")
|
|
||||||
|
|
||||||
previous_storage = user.storage
|
previous_storage = user.storage
|
||||||
|
|
||||||
# 计算实际存储量 - 使用 SQL 聚合
|
# 计算实际存储量 - 使用 SQL 聚合
|
||||||
|
# [TODO] 不应这么计算,看看 SQLModel_Ext 库怎么解决
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(func.sum(Object.size), func.count(Object.id)).where(
|
select(func.sum(Object.size), func.count(Object.id)).where(
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
|
|
||||||
from middleware.auth import admin_required
|
|
||||||
from middleware.dependencies import SessionDep
|
|
||||||
from sqlmodels import (
|
|
||||||
ResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
admin_vas_router = APIRouter(
|
|
||||||
prefix='/vas',
|
|
||||||
tags=['admin', 'admin_vas']
|
|
||||||
)
|
|
||||||
|
|
||||||
@admin_vas_router.get(
|
|
||||||
path='/list',
|
|
||||||
summary='获取增值服务列表',
|
|
||||||
description='Get VAS list (orders and storage packs)',
|
|
||||||
dependencies=[Depends(admin_required)]
|
|
||||||
)
|
|
||||||
async def router_admin_get_vas_list(
|
|
||||||
session: SessionDep,
|
|
||||||
user_id: UUID | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 20,
|
|
||||||
) -> ResponseBase:
|
|
||||||
"""
|
|
||||||
获取增值服务列表(订单和存储包)。
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param user_id: 按用户筛选
|
|
||||||
:param page: 页码
|
|
||||||
:param page_size: 每页数量
|
|
||||||
:return: 增值服务列表
|
|
||||||
"""
|
|
||||||
# TODO: 实现增值服务列表
|
|
||||||
# 需要查询 Order 和 StoragePack 模型
|
|
||||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
|
||||||
|
|
||||||
|
|
||||||
@admin_vas_router.get(
|
|
||||||
path='/{vas_id}',
|
|
||||||
summary='获取增值服务详情',
|
|
||||||
description='Get VAS detail by ID',
|
|
||||||
dependencies=[Depends(admin_required)]
|
|
||||||
)
|
|
||||||
async def router_admin_get_vas(
|
|
||||||
session: SessionDep,
|
|
||||||
vas_id: UUID,
|
|
||||||
) -> ResponseBase:
|
|
||||||
"""
|
|
||||||
获取增值服务详情。
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param vas_id: 增值服务UUID
|
|
||||||
:return: 增值服务详情
|
|
||||||
"""
|
|
||||||
# TODO: 实现增值服务详情
|
|
||||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
|
||||||
|
|
||||||
|
|
||||||
@admin_vas_router.delete(
|
|
||||||
path='/{vas_id}',
|
|
||||||
summary='删除增值服务',
|
|
||||||
description='Delete VAS by ID',
|
|
||||||
dependencies=[Depends(admin_required)]
|
|
||||||
)
|
|
||||||
async def router_admin_delete_vas(
|
|
||||||
session: SessionDep,
|
|
||||||
vas_id: UUID,
|
|
||||||
) -> ResponseBase:
|
|
||||||
"""
|
|
||||||
删除增值服务。
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param vas_id: 增值服务UUID
|
|
||||||
:return: 删除结果
|
|
||||||
"""
|
|
||||||
# TODO: 实现增值服务删除
|
|
||||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
from sqlmodels import ResponseBase
|
from sqlmodels import ResponseBase
|
||||||
import service.oauth
|
import service.oauth
|
||||||
@@ -15,18 +16,12 @@ oauth_router = APIRouter(
|
|||||||
tags=["callback", "oauth"],
|
tags=["callback", "oauth"],
|
||||||
)
|
)
|
||||||
|
|
||||||
pay_router = APIRouter(
|
|
||||||
prefix='/callback/pay',
|
|
||||||
tags=["callback", "pay"],
|
|
||||||
)
|
|
||||||
|
|
||||||
upload_router = APIRouter(
|
upload_router = APIRouter(
|
||||||
prefix='/callback/upload',
|
prefix='/callback/upload',
|
||||||
tags=["callback", "upload"],
|
tags=["callback", "upload"],
|
||||||
)
|
)
|
||||||
|
|
||||||
callback_router.include_router(oauth_router)
|
callback_router.include_router(oauth_router)
|
||||||
callback_router.include_router(pay_router)
|
|
||||||
callback_router.include_router(upload_router)
|
callback_router.include_router(upload_router)
|
||||||
|
|
||||||
@oauth_router.post(
|
@oauth_router.post(
|
||||||
@@ -64,91 +59,17 @@ async def router_callback_github(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
access_token = await service.oauth.github.get_access_token(code)
|
access_token = await service.oauth.github.get_access_token(code)
|
||||||
# [TODO] 把access_token写数据库里
|
|
||||||
if not access_token:
|
if not access_token:
|
||||||
return PlainTextResponse("Failed to retrieve access token from GitHub.", status_code=400)
|
return PlainTextResponse("GitHub 认证失败", status_code=400)
|
||||||
|
|
||||||
user_data = await service.oauth.github.get_user_info(access_token.access_token)
|
user_data = await service.oauth.github.get_user_info(access_token.access_token)
|
||||||
# [TODO] 把user_data写数据库里
|
# [TODO] 把 access_token 和 user_data 写数据库,生成 JWT,重定向到前端
|
||||||
|
l.info(f"GitHub OAuth 回调成功: user={user_data.user_data.login}")
|
||||||
|
|
||||||
return PlainTextResponse(f"User information processed successfully, code: {code}, user_data: {user_data.json_dump()}", status_code=200)
|
return PlainTextResponse("认证成功,功能开发中", status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return PlainTextResponse(f"An error occurred: {str(e)}", status_code=500)
|
l.error(f"GitHub OAuth 回调异常: {e}")
|
||||||
|
return PlainTextResponse("认证过程中发生错误,请重试", status_code=500)
|
||||||
@pay_router.post(
|
|
||||||
path='/alipay',
|
|
||||||
summary='支付宝支付回调',
|
|
||||||
description='Handle Alipay payment callback and return payment status.',
|
|
||||||
)
|
|
||||||
def router_callback_alipay() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Handle Alipay payment callback and return payment status.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the Alipay payment callback.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@pay_router.post(
|
|
||||||
path='/wechat',
|
|
||||||
summary='微信支付回调',
|
|
||||||
description='Handle WeChat Pay payment callback and return payment status.',
|
|
||||||
)
|
|
||||||
def router_callback_wechat() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Handle WeChat Pay payment callback and return payment status.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@pay_router.post(
|
|
||||||
path='/stripe',
|
|
||||||
summary='Stripe支付回调',
|
|
||||||
description='Handle Stripe payment callback and return payment status.',
|
|
||||||
)
|
|
||||||
def router_callback_stripe() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Handle Stripe payment callback and return payment status.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the Stripe payment callback.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@pay_router.get(
|
|
||||||
path='/easypay',
|
|
||||||
summary='易支付回调',
|
|
||||||
description='Handle EasyPay payment callback and return payment status.',
|
|
||||||
)
|
|
||||||
def router_callback_easypay() -> PlainTextResponse:
|
|
||||||
"""
|
|
||||||
Handle EasyPay payment callback and return payment status.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
# return PlainTextResponse("success", status_code=200)
|
|
||||||
|
|
||||||
@pay_router.get(
|
|
||||||
path='/custom/{order_no}/{id}',
|
|
||||||
summary='自定义支付回调',
|
|
||||||
description='Handle custom payment callback and return payment status.',
|
|
||||||
)
|
|
||||||
def router_callback_custom(order_no: str, id: str) -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Handle custom payment callback and return payment status.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
order_no (str): The order number for the payment.
|
|
||||||
id (str): The ID associated with the payment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the custom payment callback.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@upload_router.post(
|
@upload_router.post(
|
||||||
path='/remote/{session_id}/{key}',
|
path='/remote/{session_id}/{key}',
|
||||||
|
|||||||
100
routers/api/v1/category/__init__.py
Normal file
100
routers/api/v1/category/__init__.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""
|
||||||
|
文件分类筛选端点
|
||||||
|
|
||||||
|
按文件类型分类(图片/视频/音频/文档)查询用户的所有文件,
|
||||||
|
跨目录搜索,支持分页。扩展名映射从数据库 Setting 表读取。
|
||||||
|
"""
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from middleware.auth import auth_required
|
||||||
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
|
from sqlmodels import (
|
||||||
|
FileCategory,
|
||||||
|
ListResponse,
|
||||||
|
Object,
|
||||||
|
ObjectResponse,
|
||||||
|
ObjectType,
|
||||||
|
Setting,
|
||||||
|
SettingsType,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
category_router = APIRouter(
|
||||||
|
prefix="/category",
|
||||||
|
tags=["category"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@category_router.get(
|
||||||
|
path="/{category}",
|
||||||
|
summary="按分类获取文件列表",
|
||||||
|
)
|
||||||
|
async def router_category_list(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
category: FileCategory,
|
||||||
|
table_view: TableViewRequestDep,
|
||||||
|
) -> ListResponse[ObjectResponse]:
|
||||||
|
"""
|
||||||
|
按文件类型分类查询用户的所有文件
|
||||||
|
|
||||||
|
跨所有目录搜索,返回分页结果。
|
||||||
|
扩展名配置从数据库 Setting 表读取(type=file_category)。
|
||||||
|
|
||||||
|
认证:
|
||||||
|
- JWT token in Authorization header
|
||||||
|
|
||||||
|
路径参数:
|
||||||
|
- category: 文件分类(image / video / audio / document)
|
||||||
|
|
||||||
|
查询参数:
|
||||||
|
- offset: 分页偏移量(默认0)
|
||||||
|
- limit: 每页数量(默认20,最大100)
|
||||||
|
- desc: 是否降序(默认true)
|
||||||
|
- order: 排序字段(created_at / updated_at)
|
||||||
|
|
||||||
|
响应:
|
||||||
|
- ListResponse[ObjectResponse]: 分页文件列表
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- HTTPException 422: category 参数无效
|
||||||
|
- HTTPException 404: 该分类未配置扩展名
|
||||||
|
"""
|
||||||
|
# 从数据库读取该分类的扩展名配置
|
||||||
|
setting = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.FILE_CATEGORY) & (Setting.name == category.value),
|
||||||
|
)
|
||||||
|
if not setting or not setting.value:
|
||||||
|
raise HTTPException(status_code=404, detail=f"分类 {category.value} 未配置扩展名")
|
||||||
|
|
||||||
|
extensions = [ext.strip() for ext in setting.value.split(",") if ext.strip()]
|
||||||
|
if not extensions:
|
||||||
|
raise HTTPException(status_code=404, detail=f"分类 {category.value} 扩展名列表为空")
|
||||||
|
|
||||||
|
result = await Object.get_by_category(
|
||||||
|
session,
|
||||||
|
user.id,
|
||||||
|
extensions,
|
||||||
|
table_view=table_view,
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
ObjectResponse(
|
||||||
|
id=obj.id,
|
||||||
|
name=obj.name,
|
||||||
|
type=ObjectType.FILE,
|
||||||
|
size=obj.size,
|
||||||
|
mime_type=obj.mime_type,
|
||||||
|
thumb=False,
|
||||||
|
created_at=obj.created_at,
|
||||||
|
updated_at=obj.updated_at,
|
||||||
|
source_enabled=False,
|
||||||
|
)
|
||||||
|
for obj in result.items
|
||||||
|
]
|
||||||
|
|
||||||
|
return ListResponse(count=result.count, items=items)
|
||||||
@@ -57,7 +57,7 @@ async def _get_directory_response(
|
|||||||
policy_response = PolicyResponse(
|
policy_response = PolicyResponse(
|
||||||
id=policy.id,
|
id=policy.id,
|
||||||
name=policy.name,
|
name=policy.name,
|
||||||
type=policy.type.value,
|
type=policy.type,
|
||||||
max_size=policy.max_size,
|
max_size=policy.max_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -139,12 +139,13 @@ async def router_directory_get(
|
|||||||
@directory_router.post(
|
@directory_router.post(
|
||||||
path="/",
|
path="/",
|
||||||
summary="创建目录",
|
summary="创建目录",
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_directory_create(
|
async def router_directory_create(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: DirectoryCreateRequest
|
request: DirectoryCreateRequest
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
创建目录
|
创建目录
|
||||||
|
|
||||||
@@ -162,8 +163,11 @@ async def router_directory_create(
|
|||||||
if "/" in name or "\\" in name:
|
if "/" in name or "\\" in name:
|
||||||
raise HTTPException(status_code=400, detail="目录名称不能包含斜杠")
|
raise HTTPException(status_code=400, detail="目录名称不能包含斜杠")
|
||||||
|
|
||||||
# 通过 UUID 获取父目录
|
# 通过 UUID 获取父目录(排除已删除的)
|
||||||
parent = await Object.get(session, Object.id == request.parent_id)
|
parent = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == request.parent_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not parent or parent.owner_id != user.id:
|
if not parent or parent.owner_id != user.id:
|
||||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||||
|
|
||||||
@@ -173,17 +177,26 @@ async def router_directory_create(
|
|||||||
if parent.is_banned:
|
if parent.is_banned:
|
||||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||||
|
|
||||||
# 检查是否已存在同名对象
|
# 检查是否已存在同名对象(仅检查未删除的)
|
||||||
existing = await Object.get(
|
existing = await Object.get(
|
||||||
session,
|
session,
|
||||||
(Object.owner_id == user.id) &
|
(Object.owner_id == user.id) &
|
||||||
(Object.parent_id == parent.id) &
|
(Object.parent_id == parent.id) &
|
||||||
(Object.name == name)
|
(Object.name == name) &
|
||||||
|
(Object.deleted_at == None)
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
||||||
|
|
||||||
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
||||||
|
|
||||||
|
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
||||||
|
if request.policy_id:
|
||||||
|
group = await user.awaitable_attrs.group
|
||||||
|
await session.refresh(group, ['policies'])
|
||||||
|
if request.policy_id not in {p.id for p in group.policies}:
|
||||||
|
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
||||||
|
|
||||||
parent_id = parent.id # 在 save 前保存
|
parent_id = parent.id # 在 save 前保存
|
||||||
|
|
||||||
new_folder = Object(
|
new_folder = Object(
|
||||||
@@ -193,14 +206,4 @@ async def router_directory_create(
|
|||||||
parent_id=parent_id,
|
parent_id=parent_id,
|
||||||
policy_id=policy_id,
|
policy_id=policy_id,
|
||||||
)
|
)
|
||||||
new_folder_id = new_folder.id # 在 save 前保存 UUID
|
new_folder = await new_folder.save(session)
|
||||||
new_folder_name = new_folder.name
|
|
||||||
await new_folder.save(session)
|
|
||||||
|
|
||||||
return ResponseBase(
|
|
||||||
data={
|
|
||||||
"id": new_folder_id,
|
|
||||||
"name": new_folder_name,
|
|
||||||
"parent_id": parent_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
106
routers/api/v1/file/viewers/__init__.py
Normal file
106
routers/api/v1/file/viewers/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
文件查看器查询端点
|
||||||
|
|
||||||
|
提供按文件扩展名查询可用查看器的功能,包含用户组访问控制过滤。
|
||||||
|
"""
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from middleware.auth import auth_required
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import (
|
||||||
|
FileApp,
|
||||||
|
FileAppExtension,
|
||||||
|
FileAppGroupLink,
|
||||||
|
FileAppSummary,
|
||||||
|
FileViewersResponse,
|
||||||
|
User,
|
||||||
|
UserFileAppDefault,
|
||||||
|
)
|
||||||
|
|
||||||
|
viewers_router = APIRouter(prefix="/viewers", tags=["file", "viewers"])
|
||||||
|
|
||||||
|
|
||||||
|
@viewers_router.get(
|
||||||
|
path='',
|
||||||
|
summary='查询可用文件查看器',
|
||||||
|
description='根据文件扩展名查询可用的查看器应用列表。',
|
||||||
|
)
|
||||||
|
async def get_viewers(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
ext: Annotated[str, Query(max_length=20, description="文件扩展名")],
|
||||||
|
) -> FileViewersResponse:
|
||||||
|
"""
|
||||||
|
查询可用文件查看器端点
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 规范化扩展名(小写,去点号)
|
||||||
|
2. 查询匹配的已启用应用
|
||||||
|
3. 按用户组权限过滤
|
||||||
|
4. 按 priority 排序
|
||||||
|
5. 查询用户默认偏好
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 401: 未授权
|
||||||
|
"""
|
||||||
|
# 规范化扩展名
|
||||||
|
normalized_ext = ext.lower().strip().lstrip('.')
|
||||||
|
|
||||||
|
# 查询匹配扩展名的应用(已启用的)
|
||||||
|
ext_records: list[FileAppExtension] = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
and_(
|
||||||
|
FileAppExtension.extension == normalized_ext,
|
||||||
|
),
|
||||||
|
fetch_mode="all",
|
||||||
|
load=FileAppExtension.app,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤和收集可用应用
|
||||||
|
user_group_id = user.group_id
|
||||||
|
viewers: list[tuple[FileAppSummary, int]] = []
|
||||||
|
|
||||||
|
for ext_record in ext_records:
|
||||||
|
app: FileApp = ext_record.app
|
||||||
|
if not app.is_enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if app.is_restricted:
|
||||||
|
# 检查用户组权限(FileAppGroupLink 是纯关联表,使用 session 查询)
|
||||||
|
stmt = select(FileAppGroupLink).where(
|
||||||
|
and_(
|
||||||
|
FileAppGroupLink.app_id == app.id,
|
||||||
|
FileAppGroupLink.group_id == user_group_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await session.exec(stmt)
|
||||||
|
group_link = result.first()
|
||||||
|
if not group_link:
|
||||||
|
continue
|
||||||
|
|
||||||
|
viewers.append((app.to_summary(), ext_record.priority))
|
||||||
|
|
||||||
|
# 按 priority 排序
|
||||||
|
viewers.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
|
# 查询用户默认偏好
|
||||||
|
user_default: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||||
|
session,
|
||||||
|
and_(
|
||||||
|
UserFileAppDefault.user_id == user.id,
|
||||||
|
UserFileAppDefault.extension == normalized_ext,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileViewersResponse(
|
||||||
|
viewers=[v[0] for v in viewers],
|
||||||
|
default_viewer_id=user_default.app_id if user_default else None,
|
||||||
|
)
|
||||||
@@ -8,14 +8,14 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from middleware.auth import auth_required
|
from middleware.auth import auth_required
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
from sqlmodels import (
|
from sqlmodels import (
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
|
Group,
|
||||||
Object,
|
Object,
|
||||||
ObjectCopyRequest,
|
ObjectCopyRequest,
|
||||||
ObjectDeleteRequest,
|
ObjectDeleteRequest,
|
||||||
@@ -23,170 +23,54 @@ from sqlmodels import (
|
|||||||
ObjectPropertyDetailResponse,
|
ObjectPropertyDetailResponse,
|
||||||
ObjectPropertyResponse,
|
ObjectPropertyResponse,
|
||||||
ObjectRenameRequest,
|
ObjectRenameRequest,
|
||||||
|
ObjectSwitchPolicyRequest,
|
||||||
ObjectType,
|
ObjectType,
|
||||||
PhysicalFile,
|
PhysicalFile,
|
||||||
Policy,
|
Policy,
|
||||||
PolicyType,
|
PolicyType,
|
||||||
ResponseBase,
|
Task,
|
||||||
|
TaskProps,
|
||||||
|
TaskStatus,
|
||||||
|
TaskSummaryBase,
|
||||||
|
TaskType,
|
||||||
User,
|
User,
|
||||||
|
# 元数据相关
|
||||||
|
ObjectMetadata,
|
||||||
|
MetadataResponse,
|
||||||
|
MetadataPatchRequest,
|
||||||
|
INTERNAL_NAMESPACES,
|
||||||
|
USER_WRITABLE_NAMESPACES,
|
||||||
)
|
)
|
||||||
from service.storage import LocalStorageService
|
from service.storage import (
|
||||||
|
LocalStorageService,
|
||||||
|
adjust_user_storage,
|
||||||
|
copy_object_recursive,
|
||||||
|
migrate_file_with_task,
|
||||||
|
migrate_directory_files,
|
||||||
|
)
|
||||||
|
from service.storage.object import soft_delete_objects
|
||||||
|
from sqlmodels.database_connection import DatabaseManager
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
|
|
||||||
|
from .custom_property import router as custom_property_router
|
||||||
|
|
||||||
object_router = APIRouter(
|
object_router = APIRouter(
|
||||||
prefix="/object",
|
prefix="/object",
|
||||||
tags=["object"]
|
tags=["object"]
|
||||||
)
|
)
|
||||||
|
object_router.include_router(custom_property_router)
|
||||||
|
|
||||||
async def _delete_object_recursive(
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: Object,
|
|
||||||
user_id: UUID,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
递归删除对象(软删除)
|
|
||||||
|
|
||||||
对于文件:
|
|
||||||
- 减少 PhysicalFile 引用计数
|
|
||||||
- 只有引用计数为0时才移动物理文件到回收站
|
|
||||||
|
|
||||||
对于目录:
|
|
||||||
- 递归处理所有子对象
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param obj: 要删除的对象
|
|
||||||
:param user_id: 用户UUID
|
|
||||||
:return: 删除的对象数量
|
|
||||||
"""
|
|
||||||
deleted_count = 0
|
|
||||||
|
|
||||||
# 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败
|
|
||||||
obj_id = obj.id
|
|
||||||
obj_name = obj.name
|
|
||||||
obj_is_folder = obj.is_folder
|
|
||||||
obj_is_file = obj.is_file
|
|
||||||
obj_physical_file_id = obj.physical_file_id
|
|
||||||
|
|
||||||
if obj_is_folder:
|
|
||||||
# 递归删除子对象
|
|
||||||
children = await Object.get_children(session, user_id, obj_id)
|
|
||||||
for child in children:
|
|
||||||
deleted_count += await _delete_object_recursive(session, child, user_id)
|
|
||||||
|
|
||||||
# 如果是文件,处理物理文件引用
|
|
||||||
if obj_is_file and obj_physical_file_id:
|
|
||||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id)
|
|
||||||
if physical_file:
|
|
||||||
# 减少引用计数
|
|
||||||
new_count = physical_file.decrement_reference()
|
|
||||||
|
|
||||||
if physical_file.can_be_deleted:
|
|
||||||
# 引用计数为0,移动物理文件到回收站
|
|
||||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
|
||||||
if policy and policy.type == PolicyType.LOCAL:
|
|
||||||
try:
|
|
||||||
storage_service = LocalStorageService(policy)
|
|
||||||
await storage_service.move_to_trash(
|
|
||||||
source_path=physical_file.storage_path,
|
|
||||||
user_id=user_id,
|
|
||||||
object_id=obj_id,
|
|
||||||
)
|
|
||||||
l.debug(f"物理文件已移动到回收站: {obj_name}")
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}")
|
|
||||||
|
|
||||||
# 删除 PhysicalFile 记录
|
|
||||||
await PhysicalFile.delete(session, physical_file)
|
|
||||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
|
||||||
else:
|
|
||||||
# 还有其他引用,只更新引用计数
|
|
||||||
await physical_file.save(session)
|
|
||||||
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
|
|
||||||
|
|
||||||
# 使用条件删除,避免访问过期的 obj 实例
|
|
||||||
await Object.delete(session, condition=Object.id == obj_id)
|
|
||||||
deleted_count += 1
|
|
||||||
|
|
||||||
return deleted_count
|
|
||||||
|
|
||||||
|
|
||||||
async def _copy_object_recursive(
|
|
||||||
session: AsyncSession,
|
|
||||||
src: Object,
|
|
||||||
dst_parent_id: UUID,
|
|
||||||
user_id: UUID,
|
|
||||||
) -> tuple[int, list[UUID]]:
|
|
||||||
"""
|
|
||||||
递归复制对象
|
|
||||||
|
|
||||||
对于文件:
|
|
||||||
- 增加 PhysicalFile 引用计数
|
|
||||||
- 创建新的 Object 记录指向同一 PhysicalFile
|
|
||||||
|
|
||||||
对于目录:
|
|
||||||
- 创建新目录
|
|
||||||
- 递归复制所有子对象
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
|
||||||
:param src: 源对象
|
|
||||||
:param dst_parent_id: 目标父目录UUID
|
|
||||||
:param user_id: 用户UUID
|
|
||||||
:return: (复制数量, 新对象UUID列表)
|
|
||||||
"""
|
|
||||||
copied_count = 0
|
|
||||||
new_ids: list[UUID] = []
|
|
||||||
|
|
||||||
# 在 save() 之前保存需要的属性值,避免 commit 后对象过期导致懒加载失败
|
|
||||||
src_is_folder = src.is_folder
|
|
||||||
src_id = src.id
|
|
||||||
|
|
||||||
# 创建新的 Object 记录
|
|
||||||
new_obj = Object(
|
|
||||||
name=src.name,
|
|
||||||
type=src.type,
|
|
||||||
size=src.size,
|
|
||||||
password=src.password,
|
|
||||||
parent_id=dst_parent_id,
|
|
||||||
owner_id=user_id,
|
|
||||||
policy_id=src.policy_id,
|
|
||||||
physical_file_id=src.physical_file_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果是文件,增加物理文件引用计数
|
|
||||||
if src.is_file and src.physical_file_id:
|
|
||||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src.physical_file_id)
|
|
||||||
if physical_file:
|
|
||||||
physical_file.increment_reference()
|
|
||||||
await physical_file.save(session)
|
|
||||||
|
|
||||||
new_obj = await new_obj.save(session)
|
|
||||||
copied_count += 1
|
|
||||||
new_ids.append(new_obj.id)
|
|
||||||
|
|
||||||
# 如果是目录,递归复制子对象
|
|
||||||
if src_is_folder:
|
|
||||||
children = await Object.get_children(session, user_id, src_id)
|
|
||||||
for child in children:
|
|
||||||
child_count, child_ids = await _copy_object_recursive(
|
|
||||||
session, child, new_obj.id, user_id
|
|
||||||
)
|
|
||||||
copied_count += child_count
|
|
||||||
new_ids.extend(child_ids)
|
|
||||||
|
|
||||||
return copied_count, new_ids
|
|
||||||
|
|
||||||
|
|
||||||
@object_router.post(
|
@object_router.post(
|
||||||
path='/',
|
path='/',
|
||||||
summary='创建空白文件',
|
summary='创建空白文件',
|
||||||
description='在指定目录下创建空白文件。',
|
description='在指定目录下创建空白文件。',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_object_create(
|
async def router_object_create(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: CreateFileRequest,
|
request: CreateFileRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
创建空白文件端点
|
创建空白文件端点
|
||||||
|
|
||||||
@@ -201,8 +85,11 @@ async def router_object_create(
|
|||||||
if not request.name or '/' in request.name or '\\' in request.name:
|
if not request.name or '/' in request.name or '\\' in request.name:
|
||||||
raise HTTPException(status_code=400, detail="无效的文件名")
|
raise HTTPException(status_code=400, detail="无效的文件名")
|
||||||
|
|
||||||
# 验证父目录
|
# 验证父目录(排除已删除的)
|
||||||
parent = await Object.get(session, Object.id == request.parent_id)
|
parent = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == request.parent_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not parent or parent.owner_id != user_id:
|
if not parent or parent.owner_id != user_id:
|
||||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||||
|
|
||||||
@@ -212,21 +99,20 @@ async def router_object_create(
|
|||||||
if parent.is_banned:
|
if parent.is_banned:
|
||||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||||
|
|
||||||
# 检查是否已存在同名文件
|
# 检查是否已存在同名文件(仅检查未删除的)
|
||||||
existing = await Object.get(
|
existing = await Object.get(
|
||||||
session,
|
session,
|
||||||
(Object.owner_id == user_id) &
|
(Object.owner_id == user_id) &
|
||||||
(Object.parent_id == parent.id) &
|
(Object.parent_id == parent.id) &
|
||||||
(Object.name == request.name)
|
(Object.name == request.name) &
|
||||||
|
(Object.deleted_at == None)
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
raise HTTPException(status_code=409, detail="同名文件已存在")
|
raise HTTPException(status_code=409, detail="同名文件已存在")
|
||||||
|
|
||||||
# 确定存储策略
|
# 确定存储策略
|
||||||
policy_id = request.policy_id or parent.policy_id
|
policy_id = request.policy_id or parent.policy_id
|
||||||
policy = await Policy.get(session, Policy.id == policy_id)
|
policy = await Policy.get_exist_one(session, policy_id)
|
||||||
if not policy:
|
|
||||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
|
||||||
|
|
||||||
parent_id = parent.id
|
parent_id = parent.id
|
||||||
|
|
||||||
@@ -261,44 +147,45 @@ async def router_object_create(
|
|||||||
owner_id=user_id,
|
owner_id=user_id,
|
||||||
policy_id=policy_id,
|
policy_id=policy_id,
|
||||||
)
|
)
|
||||||
await file_object.save(session)
|
file_object = await file_object.save(session)
|
||||||
|
|
||||||
l.info(f"创建空白文件: {request.name}")
|
l.info(f"创建空白文件: {request.name}")
|
||||||
|
|
||||||
return ResponseBase()
|
|
||||||
|
|
||||||
|
|
||||||
@object_router.delete(
|
@object_router.delete(
|
||||||
path='/',
|
path='/',
|
||||||
summary='删除对象',
|
summary='删除对象',
|
||||||
description='删除一个或多个对象(文件或目录),文件会移动到用户回收站。',
|
description='删除一个或多个对象(文件或目录),文件会移动到用户回收站。',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_object_delete(
|
async def router_object_delete(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: ObjectDeleteRequest,
|
request: ObjectDeleteRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
删除对象端点(软删除)
|
删除对象端点(软删除到回收站)
|
||||||
|
|
||||||
流程:
|
流程:
|
||||||
1. 验证对象存在且属于当前用户
|
1. 验证对象存在且属于当前用户
|
||||||
2. 对于文件,减少物理文件引用计数
|
2. 设置 deleted_at 时间戳
|
||||||
3. 如果引用计数为0,移动物理文件到 .trash 目录
|
3. 保存原 parent_id 到 deleted_original_parent_id
|
||||||
4. 对于目录,递归处理子对象
|
4. 将 parent_id 置 NULL 脱离文件树
|
||||||
5. 从数据库中删除记录
|
5. 子对象和物理文件不做任何变更
|
||||||
|
|
||||||
:param session: 数据库会话
|
:param session: 数据库会话
|
||||||
:param user: 当前登录用户
|
:param user: 当前登录用户
|
||||||
:param request: 删除请求(包含待删除对象的UUID列表)
|
:param request: 删除请求(包含待删除对象的UUID列表)
|
||||||
:return: 删除结果
|
:return: 删除结果
|
||||||
"""
|
"""
|
||||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
deleted_count = 0
|
objects_to_delete: list[Object] = []
|
||||||
|
|
||||||
for obj_id in request.ids:
|
for obj_id in request.ids:
|
||||||
obj = await Object.get(session, Object.id == obj_id)
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == obj_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not obj or obj.owner_id != user_id:
|
if not obj or obj.owner_id != user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -307,30 +194,24 @@ async def router_object_delete(
|
|||||||
l.warning(f"尝试删除根目录被阻止: {obj.name}")
|
l.warning(f"尝试删除根目录被阻止: {obj.name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 递归删除(包含引用计数逻辑)
|
objects_to_delete.append(obj)
|
||||||
count = await _delete_object_recursive(session, obj, user_id)
|
|
||||||
deleted_count += count
|
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 删除了 {deleted_count} 个对象")
|
if objects_to_delete:
|
||||||
|
deleted_count = await soft_delete_objects(session, objects_to_delete)
|
||||||
return ResponseBase(
|
l.info(f"用户 {user_id} 软删除了 {deleted_count} 个对象到回收站")
|
||||||
data={
|
|
||||||
"deleted": deleted_count,
|
|
||||||
"total": len(request.ids),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@object_router.patch(
|
@object_router.patch(
|
||||||
path='/',
|
path='/',
|
||||||
summary='移动对象',
|
summary='移动对象',
|
||||||
description='移动一个或多个对象到目标目录',
|
description='移动一个或多个对象到目标目录',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_object_move(
|
async def router_object_move(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: ObjectMoveRequest,
|
request: ObjectMoveRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
移动对象端点
|
移动对象端点
|
||||||
|
|
||||||
@@ -342,8 +223,11 @@ async def router_object_move(
|
|||||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
|
|
||||||
# 验证目标目录
|
# 验证目标目录(排除已删除的)
|
||||||
dst = await Object.get(session, Object.id == request.dst_id)
|
dst = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == request.dst_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not dst or dst.owner_id != user_id:
|
if not dst or dst.owner_id != user_id:
|
||||||
raise HTTPException(status_code=404, detail="目标目录不存在")
|
raise HTTPException(status_code=404, detail="目标目录不存在")
|
||||||
|
|
||||||
@@ -360,7 +244,10 @@ async def router_object_move(
|
|||||||
moved_count = 0
|
moved_count = 0
|
||||||
|
|
||||||
for src_id in request.src_ids:
|
for src_id in request.src_ids:
|
||||||
src = await Object.get(session, Object.id == src_id)
|
src = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == src_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not src or src.owner_id != user_id:
|
if not src or src.owner_id != user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -388,12 +275,13 @@ async def router_object_move(
|
|||||||
if is_cycle:
|
if is_cycle:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查目标目录下是否存在同名对象
|
# 检查目标目录下是否存在同名对象(仅检查未删除的)
|
||||||
existing = await Object.get(
|
existing = await Object.get(
|
||||||
session,
|
session,
|
||||||
(Object.owner_id == user_id) &
|
(Object.owner_id == user_id) &
|
||||||
(Object.parent_id == dst_id) &
|
(Object.parent_id == dst_id) &
|
||||||
(Object.name == src.name)
|
(Object.name == src.name) &
|
||||||
|
(Object.deleted_at == None)
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
continue # 跳过重名对象
|
continue # 跳过重名对象
|
||||||
@@ -405,24 +293,18 @@ async def router_object_move(
|
|||||||
# 统一提交所有更改
|
# 统一提交所有更改
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return ResponseBase(
|
|
||||||
data={
|
|
||||||
"moved": moved_count,
|
|
||||||
"total": len(request.src_ids),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@object_router.post(
|
@object_router.post(
|
||||||
path='/copy',
|
path='/copy',
|
||||||
summary='复制对象',
|
summary='复制对象',
|
||||||
description='复制一个或多个对象到目标目录。文件复制仅增加物理文件引用计数,不复制物理文件。',
|
description='复制一个或多个对象到目标目录。文件复制仅增加物理文件引用计数,不复制物理文件。',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_object_copy(
|
async def router_object_copy(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: ObjectCopyRequest,
|
request: ObjectCopyRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
复制对象端点
|
复制对象端点
|
||||||
|
|
||||||
@@ -443,8 +325,11 @@ async def router_object_copy(
|
|||||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
|
|
||||||
# 验证目标目录
|
# 验证目标目录(排除已删除的)
|
||||||
dst = await Object.get(session, Object.id == request.dst_id)
|
dst = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == request.dst_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not dst or dst.owner_id != user_id:
|
if not dst or dst.owner_id != user_id:
|
||||||
raise HTTPException(status_code=404, detail="目标目录不存在")
|
raise HTTPException(status_code=404, detail="目标目录不存在")
|
||||||
|
|
||||||
@@ -456,20 +341,25 @@ async def router_object_copy(
|
|||||||
|
|
||||||
copied_count = 0
|
copied_count = 0
|
||||||
new_ids: list[UUID] = []
|
new_ids: list[UUID] = []
|
||||||
|
total_copied_size = 0
|
||||||
|
|
||||||
for src_id in request.src_ids:
|
for src_id in request.src_ids:
|
||||||
src = await Object.get(session, Object.id == src_id)
|
src = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == src_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not src or src.owner_id != user_id:
|
if not src or src.owner_id != user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if src.is_banned:
|
if src.is_banned:
|
||||||
continue
|
http_exceptions.raise_banned("源对象已被封禁,无法执行此操作")
|
||||||
|
|
||||||
# 不能复制根目录
|
# 不能复制根目录
|
||||||
if src.parent_id is None:
|
if src.parent_id is None:
|
||||||
continue
|
http_exceptions.raise_banned("无法复制根目录")
|
||||||
|
|
||||||
# 不能复制到自身
|
# 不能复制到自身
|
||||||
|
# [TODO] 视为创建副本
|
||||||
if src.id == dst.id:
|
if src.id == dst.id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -485,42 +375,42 @@ async def router_object_copy(
|
|||||||
if is_cycle:
|
if is_cycle:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查目标目录下是否存在同名对象
|
# 检查目标目录下是否存在同名对象(仅检查未删除的)
|
||||||
existing = await Object.get(
|
existing = await Object.get(
|
||||||
session,
|
session,
|
||||||
(Object.owner_id == user_id) &
|
(Object.owner_id == user_id) &
|
||||||
(Object.parent_id == dst.id) &
|
(Object.parent_id == dst.id) &
|
||||||
(Object.name == src.name)
|
(Object.name == src.name) &
|
||||||
|
(Object.deleted_at == None)
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
continue # 跳过重名对象
|
# [TODO] 应当询问用户是否覆盖、跳过或创建副本
|
||||||
|
continue
|
||||||
|
|
||||||
# 递归复制
|
# 递归复制
|
||||||
count, ids = await _copy_object_recursive(session, src, dst.id, user_id)
|
count, ids, copied_size = await copy_object_recursive(session, src, dst.id, user_id)
|
||||||
copied_count += count
|
copied_count += count
|
||||||
new_ids.extend(ids)
|
new_ids.extend(ids)
|
||||||
|
total_copied_size += copied_size
|
||||||
|
|
||||||
|
# 更新用户存储配额
|
||||||
|
if total_copied_size > 0:
|
||||||
|
await adjust_user_storage(session, user_id, total_copied_size)
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 复制了 {copied_count} 个对象")
|
l.info(f"用户 {user_id} 复制了 {copied_count} 个对象")
|
||||||
|
|
||||||
return ResponseBase(
|
|
||||||
data={
|
|
||||||
"copied": copied_count,
|
|
||||||
"total": len(request.src_ids),
|
|
||||||
"new_ids": new_ids,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@object_router.post(
|
@object_router.post(
|
||||||
path='/rename',
|
path='/rename',
|
||||||
summary='重命名对象',
|
summary='重命名对象',
|
||||||
description='重命名对象(文件或目录)。',
|
description='重命名对象(文件或目录)。',
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_object_rename(
|
async def router_object_rename(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: ObjectRenameRequest,
|
request: ObjectRenameRequest,
|
||||||
) -> ResponseBase:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
重命名对象端点
|
重命名对象端点
|
||||||
|
|
||||||
@@ -539,8 +429,11 @@ async def router_object_rename(
|
|||||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
|
|
||||||
# 验证对象存在
|
# 验证对象存在(排除已删除的)
|
||||||
obj = await Object.get(session, Object.id == request.id)
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == request.id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail="对象不存在")
|
raise HTTPException(status_code=404, detail="对象不存在")
|
||||||
|
|
||||||
@@ -562,28 +455,27 @@ async def router_object_rename(
|
|||||||
if '/' in new_name or '\\' in new_name:
|
if '/' in new_name or '\\' in new_name:
|
||||||
raise HTTPException(status_code=400, detail="名称不能包含斜杠")
|
raise HTTPException(status_code=400, detail="名称不能包含斜杠")
|
||||||
|
|
||||||
# 如果名称没有变化,直接返回成功
|
# 如果名称没有变化,直接返回
|
||||||
if obj.name == new_name:
|
if obj.name == new_name:
|
||||||
return ResponseBase(data={"success": True})
|
return # noqa: already 204
|
||||||
|
|
||||||
# 检查同目录下是否存在同名对象
|
# 检查同目录下是否存在同名对象(仅检查未删除的)
|
||||||
existing = await Object.get(
|
existing = await Object.get(
|
||||||
session,
|
session,
|
||||||
(Object.owner_id == user_id) &
|
(Object.owner_id == user_id) &
|
||||||
(Object.parent_id == obj.parent_id) &
|
(Object.parent_id == obj.parent_id) &
|
||||||
(Object.name == new_name)
|
(Object.name == new_name) &
|
||||||
|
(Object.deleted_at == None)
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
raise HTTPException(status_code=409, detail="同名对象已存在")
|
raise HTTPException(status_code=409, detail="同名对象已存在")
|
||||||
|
|
||||||
# 更新名称
|
# 更新名称
|
||||||
obj.name = new_name
|
obj.name = new_name
|
||||||
await obj.save(session)
|
obj = await obj.save(session)
|
||||||
|
|
||||||
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
|
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
|
||||||
|
|
||||||
return ResponseBase(data={"success": True})
|
|
||||||
|
|
||||||
|
|
||||||
@object_router.get(
|
@object_router.get(
|
||||||
path='/property/{id}',
|
path='/property/{id}',
|
||||||
@@ -603,7 +495,10 @@ async def router_object_property(
|
|||||||
:param id: 对象UUID
|
:param id: 对象UUID
|
||||||
:return: 对象基本属性
|
:return: 对象基本属性
|
||||||
"""
|
"""
|
||||||
obj = await Object.get(session, Object.id == id)
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail="对象不存在")
|
raise HTTPException(status_code=404, detail="对象不存在")
|
||||||
|
|
||||||
@@ -615,6 +510,7 @@ async def router_object_property(
|
|||||||
name=obj.name,
|
name=obj.name,
|
||||||
type=obj.type,
|
type=obj.type,
|
||||||
size=obj.size,
|
size=obj.size,
|
||||||
|
mime_type=obj.mime_type,
|
||||||
created_at=obj.created_at,
|
created_at=obj.created_at,
|
||||||
updated_at=obj.updated_at,
|
updated_at=obj.updated_at,
|
||||||
parent_id=obj.parent_id,
|
parent_id=obj.parent_id,
|
||||||
@@ -641,8 +537,8 @@ async def router_object_property_detail(
|
|||||||
"""
|
"""
|
||||||
obj = await Object.get(
|
obj = await Object.get(
|
||||||
session,
|
session,
|
||||||
Object.id == id,
|
(Object.id == id) & (Object.deleted_at == None),
|
||||||
load=Object.file_metadata,
|
load=Object.metadata_entries,
|
||||||
)
|
)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail="对象不存在")
|
raise HTTPException(status_code=404, detail="对象不存在")
|
||||||
@@ -665,35 +561,301 @@ async def router_object_property_detail(
|
|||||||
total_views = sum(s.views for s in shares)
|
total_views = sum(s.views for s in shares)
|
||||||
total_downloads = sum(s.downloads for s in shares)
|
total_downloads = sum(s.downloads for s in shares)
|
||||||
|
|
||||||
# 获取物理文件引用计数
|
# 获取物理文件信息(引用计数、校验和)
|
||||||
reference_count = 1
|
reference_count = 1
|
||||||
|
checksum_md5: str | None = None
|
||||||
|
checksum_sha256: str | None = None
|
||||||
if obj.physical_file_id:
|
if obj.physical_file_id:
|
||||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||||
if physical_file:
|
if physical_file:
|
||||||
reference_count = physical_file.reference_count
|
reference_count = physical_file.reference_count
|
||||||
|
checksum_md5 = physical_file.checksum_md5
|
||||||
|
checksum_sha256 = physical_file.checksum_sha256
|
||||||
|
|
||||||
# 构建响应
|
# 构建元数据字典(排除内部命名空间)
|
||||||
response = ObjectPropertyDetailResponse(
|
metadata: dict[str, str] = {}
|
||||||
|
for entry in obj.metadata_entries:
|
||||||
|
ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
||||||
|
if ns not in INTERNAL_NAMESPACES:
|
||||||
|
metadata[entry.name] = entry.value
|
||||||
|
|
||||||
|
return ObjectPropertyDetailResponse(
|
||||||
id=obj.id,
|
id=obj.id,
|
||||||
name=obj.name,
|
name=obj.name,
|
||||||
type=obj.type,
|
type=obj.type,
|
||||||
size=obj.size,
|
size=obj.size,
|
||||||
|
mime_type=obj.mime_type,
|
||||||
created_at=obj.created_at,
|
created_at=obj.created_at,
|
||||||
updated_at=obj.updated_at,
|
updated_at=obj.updated_at,
|
||||||
parent_id=obj.parent_id,
|
parent_id=obj.parent_id,
|
||||||
|
checksum_md5=checksum_md5,
|
||||||
|
checksum_sha256=checksum_sha256,
|
||||||
policy_name=policy_name,
|
policy_name=policy_name,
|
||||||
share_count=share_count,
|
share_count=share_count,
|
||||||
total_views=total_views,
|
total_views=total_views,
|
||||||
total_downloads=total_downloads,
|
total_downloads=total_downloads,
|
||||||
reference_count=reference_count,
|
reference_count=reference_count,
|
||||||
|
metadatas=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加文件元数据
|
|
||||||
if obj.file_metadata:
|
|
||||||
response.mime_type = obj.file_metadata.mime_type
|
|
||||||
response.width = obj.file_metadata.width
|
|
||||||
response.height = obj.file_metadata.height
|
|
||||||
response.duration = obj.file_metadata.duration
|
|
||||||
response.checksum_md5 = obj.file_metadata.checksum_md5
|
|
||||||
|
|
||||||
return response
|
@object_router.patch(
|
||||||
|
path='/{object_id}/policy',
|
||||||
|
summary='切换对象存储策略',
|
||||||
|
)
|
||||||
|
async def router_object_switch_policy(
|
||||||
|
session: SessionDep,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
object_id: UUID,
|
||||||
|
request: ObjectSwitchPolicyRequest,
|
||||||
|
) -> TaskSummaryBase:
|
||||||
|
"""
|
||||||
|
切换对象的存储策略
|
||||||
|
|
||||||
|
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
|
||||||
|
目录:更新目录 policy_id(新文件使用新策略);
|
||||||
|
若 is_migrate_existing=True,额外创建后台任务迁移所有已有文件。
|
||||||
|
|
||||||
|
认证:JWT Bearer Token
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 对象不存在
|
||||||
|
- 403: 无权操作此对象 / 用户组无权使用目标策略
|
||||||
|
- 400: 目标策略与当前相同 / 不能对根目录操作
|
||||||
|
"""
|
||||||
|
user_id = user.id
|
||||||
|
|
||||||
|
# 查找对象
|
||||||
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == object_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
|
if not obj:
|
||||||
|
http_exceptions.raise_not_found("对象不存在")
|
||||||
|
if obj.owner_id != user_id:
|
||||||
|
http_exceptions.raise_forbidden("无权操作此对象")
|
||||||
|
if obj.is_banned:
|
||||||
|
http_exceptions.raise_banned()
|
||||||
|
|
||||||
|
# 根目录不能直接切换策略(应通过子对象或子目录操作)
|
||||||
|
if obj.parent_id is None:
|
||||||
|
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
|
||||||
|
|
||||||
|
# 校验目标策略存在
|
||||||
|
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
|
||||||
|
if not dest_policy:
|
||||||
|
http_exceptions.raise_not_found("目标存储策略不存在")
|
||||||
|
|
||||||
|
# 校验用户组权限
|
||||||
|
group: Group = await user.awaitable_attrs.group
|
||||||
|
await session.refresh(group, ['policies'])
|
||||||
|
allowed_ids = {p.id for p in group.policies}
|
||||||
|
if request.policy_id not in allowed_ids:
|
||||||
|
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
|
||||||
|
|
||||||
|
# 不能切换到相同策略
|
||||||
|
if obj.policy_id == request.policy_id:
|
||||||
|
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
|
||||||
|
|
||||||
|
# 保存必要的属性,避免 save 后对象过期
|
||||||
|
src_policy_id = obj.policy_id
|
||||||
|
obj_id = obj.id
|
||||||
|
obj_is_file = obj.type == ObjectType.FILE
|
||||||
|
dest_policy_id = request.policy_id
|
||||||
|
dest_policy_name = dest_policy.name
|
||||||
|
|
||||||
|
# 创建任务记录
|
||||||
|
task = Task(
|
||||||
|
type=TaskType.POLICY_MIGRATE,
|
||||||
|
status=TaskStatus.QUEUED,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
task = await task.save(session)
|
||||||
|
task_id = task.id
|
||||||
|
|
||||||
|
task_props = TaskProps(
|
||||||
|
task_id=task_id,
|
||||||
|
source_policy_id=src_policy_id,
|
||||||
|
dest_policy_id=dest_policy_id,
|
||||||
|
object_id=obj_id,
|
||||||
|
)
|
||||||
|
task_props = await task_props.save(session)
|
||||||
|
|
||||||
|
if obj_is_file:
|
||||||
|
# 文件:后台迁移
|
||||||
|
async def _run_file_migration() -> None:
|
||||||
|
async with DatabaseManager.session() as bg_session:
|
||||||
|
bg_obj = await Object.get(bg_session, Object.id == obj_id)
|
||||||
|
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||||
|
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||||
|
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
|
||||||
|
|
||||||
|
background_tasks.add_task(_run_file_migration)
|
||||||
|
else:
|
||||||
|
# 目录:先更新目录自身的 policy_id
|
||||||
|
obj = await Object.get(session, Object.id == obj_id)
|
||||||
|
obj.policy_id = dest_policy_id
|
||||||
|
obj = await obj.save(session)
|
||||||
|
|
||||||
|
if request.is_migrate_existing:
|
||||||
|
# 后台迁移所有已有文件
|
||||||
|
async def _run_dir_migration() -> None:
|
||||||
|
async with DatabaseManager.session() as bg_session:
|
||||||
|
bg_folder = await Object.get(bg_session, Object.id == obj_id)
|
||||||
|
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||||
|
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||||
|
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
|
||||||
|
|
||||||
|
background_tasks.add_task(_run_dir_migration)
|
||||||
|
else:
|
||||||
|
# 不迁移已有文件,直接完成任务
|
||||||
|
task = await Task.get(session, Task.id == task_id)
|
||||||
|
task.status = TaskStatus.COMPLETED
|
||||||
|
task.progress = 100
|
||||||
|
task = await task.save(session)
|
||||||
|
|
||||||
|
# 重新获取 task 以读取最新状态
|
||||||
|
task = await Task.get(session, Task.id == task_id)
|
||||||
|
|
||||||
|
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
|
||||||
|
|
||||||
|
return TaskSummaryBase(
|
||||||
|
id=task.id,
|
||||||
|
type=task.type,
|
||||||
|
status=task.status,
|
||||||
|
progress=task.progress,
|
||||||
|
error=task.error,
|
||||||
|
user_id=task.user_id,
|
||||||
|
created_at=task.created_at,
|
||||||
|
updated_at=task.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 元数据端点 ====================
|
||||||
|
|
||||||
|
@object_router.get(
|
||||||
|
path='/{object_id}/metadata',
|
||||||
|
summary='获取对象元数据',
|
||||||
|
description='获取对象的元数据键值对,可按命名空间过滤。',
|
||||||
|
)
|
||||||
|
async def router_get_object_metadata(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
object_id: UUID,
|
||||||
|
ns: str | None = None,
|
||||||
|
) -> MetadataResponse:
|
||||||
|
"""
|
||||||
|
获取对象元数据端点
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
查询参数:
|
||||||
|
- ns: 逗号分隔的命名空间列表(如 exif,stream),不传返回所有非内部命名空间
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 对象不存在
|
||||||
|
- 403: 无权查看此对象
|
||||||
|
"""
|
||||||
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == object_id) & (Object.deleted_at == None),
|
||||||
|
load=Object.metadata_entries,
|
||||||
|
)
|
||||||
|
if not obj:
|
||||||
|
raise HTTPException(status_code=404, detail="对象不存在")
|
||||||
|
|
||||||
|
if obj.owner_id != user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权查看此对象")
|
||||||
|
|
||||||
|
# 解析命名空间过滤
|
||||||
|
ns_filter: set[str] | None = None
|
||||||
|
if ns:
|
||||||
|
ns_filter = {n.strip() for n in ns.split(",") if n.strip()}
|
||||||
|
# 不允许查看内部命名空间
|
||||||
|
ns_filter -= INTERNAL_NAMESPACES
|
||||||
|
|
||||||
|
# 构建元数据字典
|
||||||
|
metadata: dict[str, str] = {}
|
||||||
|
for entry in obj.metadata_entries:
|
||||||
|
entry_ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
||||||
|
if entry_ns in INTERNAL_NAMESPACES:
|
||||||
|
continue
|
||||||
|
if ns_filter is not None and entry_ns not in ns_filter:
|
||||||
|
continue
|
||||||
|
metadata[entry.name] = entry.value
|
||||||
|
|
||||||
|
return MetadataResponse(metadatas=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@object_router.patch(
|
||||||
|
path='/{object_id}/metadata',
|
||||||
|
summary='批量更新对象元数据',
|
||||||
|
description='批量设置或删除对象的元数据条目。仅允许修改 custom: 命名空间。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_patch_object_metadata(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
object_id: UUID,
|
||||||
|
request: MetadataPatchRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
批量更新对象元数据端点
|
||||||
|
|
||||||
|
请求体中值为 None 的键将被删除,其余键将被设置/更新。
|
||||||
|
用户只能修改 custom: 命名空间的条目。
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: 尝试修改非 custom: 命名空间的条目
|
||||||
|
- 404: 对象不存在
|
||||||
|
- 403: 无权操作此对象
|
||||||
|
"""
|
||||||
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == object_id) & (Object.deleted_at == None),
|
||||||
|
)
|
||||||
|
if not obj:
|
||||||
|
raise HTTPException(status_code=404, detail="对象不存在")
|
||||||
|
|
||||||
|
if obj.owner_id != user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权操作此对象")
|
||||||
|
|
||||||
|
for patch in request.patches:
|
||||||
|
# 验证命名空间
|
||||||
|
patch_ns = patch.key.split(":")[0] if ":" in patch.key else ""
|
||||||
|
if patch_ns not in USER_WRITABLE_NAMESPACES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"不允许修改命名空间 '{patch_ns}' 的元数据,仅允许 custom: 命名空间",
|
||||||
|
)
|
||||||
|
|
||||||
|
if patch.value is None:
|
||||||
|
# 删除元数据条目
|
||||||
|
existing = await ObjectMetadata.get(
|
||||||
|
session,
|
||||||
|
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
await ObjectMetadata.delete(session, instances=existing)
|
||||||
|
else:
|
||||||
|
# 设置/更新元数据条目
|
||||||
|
existing = await ObjectMetadata.get(
|
||||||
|
session,
|
||||||
|
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
existing.value = patch.value
|
||||||
|
existing = await existing.save(session)
|
||||||
|
else:
|
||||||
|
entry = ObjectMetadata(
|
||||||
|
object_id=object_id,
|
||||||
|
name=patch.key,
|
||||||
|
value=patch.value,
|
||||||
|
is_public=True,
|
||||||
|
)
|
||||||
|
entry = await entry.save(session)
|
||||||
|
|
||||||
|
l.info(f"用户 {user.id} 更新了对象 {object_id} 的 {len(request.patches)} 条元数据")
|
||||||
|
|||||||
168
routers/api/v1/object/custom_property/__init__.py
Normal file
168
routers/api/v1/object/custom_property/__init__.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
用户自定义属性定义路由
|
||||||
|
|
||||||
|
提供自定义属性模板的增删改查功能。
|
||||||
|
用户可以定义类型化的属性模板(如标签、评分、分类等),
|
||||||
|
然后通过元数据 PATCH 端点为对象设置属性值。
|
||||||
|
|
||||||
|
路由前缀:/custom_property
|
||||||
|
"""
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from middleware.auth import auth_required
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import (
|
||||||
|
CustomPropertyDefinition,
|
||||||
|
CustomPropertyCreateRequest,
|
||||||
|
CustomPropertyUpdateRequest,
|
||||||
|
CustomPropertyResponse,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/custom_property",
|
||||||
|
tags=["custom_property"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
path='',
|
||||||
|
summary='获取自定义属性定义列表',
|
||||||
|
description='获取当前用户的所有自定义属性定义,按 sort_order 排序。',
|
||||||
|
)
|
||||||
|
async def router_list_custom_properties(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
) -> list[CustomPropertyResponse]:
|
||||||
|
"""
|
||||||
|
获取自定义属性定义列表端点
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
返回当前用户定义的所有自定义属性模板。
|
||||||
|
"""
|
||||||
|
definitions = await CustomPropertyDefinition.get(
|
||||||
|
session,
|
||||||
|
CustomPropertyDefinition.owner_id == user.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
CustomPropertyResponse(
|
||||||
|
id=d.id,
|
||||||
|
name=d.name,
|
||||||
|
type=d.type,
|
||||||
|
icon=d.icon,
|
||||||
|
options=d.options,
|
||||||
|
default_value=d.default_value,
|
||||||
|
sort_order=d.sort_order,
|
||||||
|
)
|
||||||
|
for d in sorted(definitions, key=lambda x: x.sort_order)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
path='',
|
||||||
|
summary='创建自定义属性定义',
|
||||||
|
description='创建一个新的自定义属性模板。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_create_custom_property(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
request: CustomPropertyCreateRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
创建自定义属性定义端点
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: 请求数据无效
|
||||||
|
- 409: 同名属性已存在
|
||||||
|
"""
|
||||||
|
# 检查同名属性
|
||||||
|
existing = await CustomPropertyDefinition.get(
|
||||||
|
session,
|
||||||
|
(CustomPropertyDefinition.owner_id == user.id) &
|
||||||
|
(CustomPropertyDefinition.name == request.name),
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(status_code=409, detail="同名自定义属性已存在")
|
||||||
|
|
||||||
|
definition = CustomPropertyDefinition(
|
||||||
|
owner_id=user.id,
|
||||||
|
name=request.name,
|
||||||
|
type=request.type,
|
||||||
|
icon=request.icon,
|
||||||
|
options=request.options,
|
||||||
|
default_value=request.default_value,
|
||||||
|
)
|
||||||
|
definition = await definition.save(session)
|
||||||
|
|
||||||
|
l.info(f"用户 {user.id} 创建了自定义属性: {request.name}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
path='/{id}',
|
||||||
|
summary='更新自定义属性定义',
|
||||||
|
description='更新自定义属性模板的名称、图标、选项等。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_update_custom_property(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
id: UUID,
|
||||||
|
request: CustomPropertyUpdateRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
更新自定义属性定义端点
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 属性定义不存在
|
||||||
|
- 403: 无权操作此属性
|
||||||
|
"""
|
||||||
|
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
||||||
|
|
||||||
|
if definition.owner_id != user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权操作此属性")
|
||||||
|
|
||||||
|
definition = await definition.update(session, request)
|
||||||
|
|
||||||
|
l.info(f"用户 {user.id} 更新了自定义属性: {id}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
path='/{id}',
|
||||||
|
summary='删除自定义属性定义',
|
||||||
|
description='删除自定义属性模板。注意:不会自动清理已使用该属性的元数据条目。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_delete_custom_property(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
删除自定义属性定义端点
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 属性定义不存在
|
||||||
|
- 403: 无权操作此属性
|
||||||
|
"""
|
||||||
|
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
||||||
|
|
||||||
|
if definition.owner_id != user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权操作此属性")
|
||||||
|
|
||||||
|
await CustomPropertyDefinition.delete(session, instances=definition)
|
||||||
|
|
||||||
|
l.info(f"用户 {user.id} 删除了自定义属性: {id}")
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
from uuid import uuid4
|
from uuid import UUID, uuid4
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||||
@@ -9,11 +9,14 @@ from middleware.auth import auth_required
|
|||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
from sqlmodels import ResponseBase
|
from sqlmodels import ResponseBase
|
||||||
from sqlmodels.user import User
|
from sqlmodels.user import User
|
||||||
from sqlmodels.share import Share, ShareCreateRequest, ShareResponse
|
from sqlmodels.share import (
|
||||||
from sqlmodels.object import Object
|
Share, ShareCreateRequest, CreateShareResponse, ShareResponse,
|
||||||
from sqlmodels.mixin import ListResponse, TableViewRequest
|
ShareDetailResponse, ShareOwnerInfo, ShareObjectItem,
|
||||||
|
)
|
||||||
|
from sqlmodels.object import Object, ObjectType
|
||||||
|
from sqlmodel_ext import ListResponse, TableViewRequest
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
from utils.password.pwd import Password
|
from utils.password.pwd import Password, PasswordStatus
|
||||||
|
|
||||||
share_router = APIRouter(
|
share_router = APIRouter(
|
||||||
prefix='/share',
|
prefix='/share',
|
||||||
@@ -22,21 +25,92 @@ share_router = APIRouter(
|
|||||||
|
|
||||||
@share_router.get(
|
@share_router.get(
|
||||||
path='/{id}',
|
path='/{id}',
|
||||||
summary='获取分享',
|
summary='获取分享详情',
|
||||||
description='Get shared content by info type and ID.',
|
description='Get share detail by share ID. No authentication required.',
|
||||||
)
|
)
|
||||||
def router_share_get(info: str, id: str) -> ResponseBase:
|
async def router_share_get(
|
||||||
|
session: SessionDep,
|
||||||
|
id: UUID,
|
||||||
|
password: str | None = Query(default=None),
|
||||||
|
) -> ShareDetailResponse:
|
||||||
"""
|
"""
|
||||||
Get shared content by info type and ID.
|
获取分享详情
|
||||||
|
|
||||||
Args:
|
认证:无需登录
|
||||||
info (str): The type of information being shared.
|
|
||||||
id (str): The ID of the shared content.
|
|
||||||
|
|
||||||
Returns:
|
流程:
|
||||||
dict: A dictionary containing shared content information.
|
1. 通过分享ID查找分享
|
||||||
|
2. 检查过期、封禁状态
|
||||||
|
3. 验证提取码(如果有)
|
||||||
|
4. 返回分享详情(含文件树和分享者信息)
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
# 1. 查询分享(预加载 user 和 object)
|
||||||
|
share = await Share.get_exist_one(session, id, load=[Share.user, Share.object])
|
||||||
|
|
||||||
|
# 2. 检查过期
|
||||||
|
now = datetime.now()
|
||||||
|
if share.expires and share.expires < now:
|
||||||
|
http_exceptions.raise_not_found(detail="分享已过期")
|
||||||
|
|
||||||
|
# 3. 获取关联对象
|
||||||
|
obj = await share.awaitable_attrs.object
|
||||||
|
user = await share.awaitable_attrs.user
|
||||||
|
|
||||||
|
# 4. 检查封禁和软删除
|
||||||
|
if obj and obj.is_banned:
|
||||||
|
http_exceptions.raise_banned()
|
||||||
|
if obj and obj.deleted_at:
|
||||||
|
http_exceptions.raise_not_found(detail="分享关联的文件已被删除")
|
||||||
|
|
||||||
|
# 5. 检查密码
|
||||||
|
if share.password:
|
||||||
|
if not password:
|
||||||
|
http_exceptions.raise_precondition_required(detail="请输入提取码")
|
||||||
|
if Password.verify(share.password, password) != PasswordStatus.VALID:
|
||||||
|
http_exceptions.raise_forbidden(detail="提取码错误")
|
||||||
|
|
||||||
|
# 6. 加载子对象(目录分享)
|
||||||
|
children_items: list[ShareObjectItem] = []
|
||||||
|
if obj and obj.type == ObjectType.FOLDER:
|
||||||
|
children = await Object.get_children(session, obj.owner_id, obj.id)
|
||||||
|
children_items = [
|
||||||
|
ShareObjectItem(
|
||||||
|
id=child.id,
|
||||||
|
name=child.name,
|
||||||
|
type=child.type,
|
||||||
|
size=child.size,
|
||||||
|
created_at=child.created_at,
|
||||||
|
updated_at=child.updated_at,
|
||||||
|
)
|
||||||
|
for child in children
|
||||||
|
]
|
||||||
|
|
||||||
|
# 7. 构建响应(在 save 之前,避免 MissingGreenlet)
|
||||||
|
response = ShareDetailResponse(
|
||||||
|
expires=share.expires,
|
||||||
|
preview_enabled=share.preview_enabled,
|
||||||
|
score=share.score,
|
||||||
|
created_at=share.created_at,
|
||||||
|
owner=ShareOwnerInfo(
|
||||||
|
nickname=user.nickname if user else None,
|
||||||
|
avatar=user.avatar if user else "default",
|
||||||
|
),
|
||||||
|
object=ShareObjectItem(
|
||||||
|
id=obj.id,
|
||||||
|
name=obj.name,
|
||||||
|
type=obj.type,
|
||||||
|
size=obj.size,
|
||||||
|
created_at=obj.created_at,
|
||||||
|
updated_at=obj.updated_at,
|
||||||
|
),
|
||||||
|
children=children_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. 递增浏览次数(最后执行,避免 MissingGreenlet)
|
||||||
|
share.views += 1
|
||||||
|
await share.save(session, refresh=False)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
@share_router.put(
|
@share_router.put(
|
||||||
path='/download/{id}',
|
path='/download/{id}',
|
||||||
@@ -226,7 +300,7 @@ async def router_share_create(
|
|||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[User, Depends(auth_required)],
|
user: Annotated[User, Depends(auth_required)],
|
||||||
request: ShareCreateRequest,
|
request: ShareCreateRequest,
|
||||||
) -> ShareResponse:
|
) -> CreateShareResponse:
|
||||||
"""
|
"""
|
||||||
创建新分享
|
创建新分享
|
||||||
|
|
||||||
@@ -237,10 +311,13 @@ async def router_share_create(
|
|||||||
2. 生成随机分享码(uuid4)
|
2. 生成随机分享码(uuid4)
|
||||||
3. 如果有密码则加密存储
|
3. 如果有密码则加密存储
|
||||||
4. 创建 Share 记录并保存
|
4. 创建 Share 记录并保存
|
||||||
5. 返回分享信息
|
5. 返回分享 ID
|
||||||
"""
|
"""
|
||||||
# 验证对象存在且属于当前用户
|
# 验证对象存在且属于当前用户(排除已删除的)
|
||||||
obj = await Object.get(session, Object.id == request.object_id)
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == request.object_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
if not obj or obj.owner_id != user.id:
|
if not obj or obj.owner_id != user.id:
|
||||||
raise HTTPException(status_code=404, detail="对象不存在或无权限")
|
raise HTTPException(status_code=404, detail="对象不存在或无权限")
|
||||||
|
|
||||||
@@ -256,11 +333,12 @@ async def router_share_create(
|
|||||||
hashed_password = Password.hash(request.password)
|
hashed_password = Password.hash(request.password)
|
||||||
|
|
||||||
# 创建分享记录
|
# 创建分享记录
|
||||||
|
user_id = user.id
|
||||||
share = Share(
|
share = Share(
|
||||||
code=code,
|
code=code,
|
||||||
password=hashed_password,
|
password=hashed_password,
|
||||||
object_id=request.object_id,
|
object_id=request.object_id,
|
||||||
user_id=user.id,
|
user_id=user_id,
|
||||||
expires=request.expires,
|
expires=request.expires,
|
||||||
remain_downloads=request.remain_downloads,
|
remain_downloads=request.remain_downloads,
|
||||||
preview_enabled=request.preview_enabled,
|
preview_enabled=request.preview_enabled,
|
||||||
@@ -269,24 +347,9 @@ async def router_share_create(
|
|||||||
)
|
)
|
||||||
share = await share.save(session)
|
share = await share.save(session)
|
||||||
|
|
||||||
l.info(f"用户 {user.id} 创建分享: {share.code}")
|
l.info(f"用户 {user_id} 创建分享: {share.code}")
|
||||||
|
|
||||||
# 返回响应
|
return CreateShareResponse(share_id=share.id)
|
||||||
return ShareResponse(
|
|
||||||
id=share.id,
|
|
||||||
code=share.code,
|
|
||||||
object_id=share.object_id,
|
|
||||||
source_name=share.source_name,
|
|
||||||
views=share.views,
|
|
||||||
downloads=share.downloads,
|
|
||||||
remain_downloads=share.remain_downloads,
|
|
||||||
expires=share.expires,
|
|
||||||
preview_enabled=share.preview_enabled,
|
|
||||||
score=share.score,
|
|
||||||
created_at=share.created_at,
|
|
||||||
is_expired=share.expires is not None and share.expires < datetime.now(),
|
|
||||||
has_password=share.password is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@share_router.get(
|
@share_router.get(
|
||||||
path='/',
|
path='/',
|
||||||
@@ -406,16 +469,29 @@ def router_share_update(id: str) -> ResponseBase:
|
|||||||
path='/{id}',
|
path='/{id}',
|
||||||
summary='删除分享',
|
summary='删除分享',
|
||||||
description='Delete a share by ID.',
|
description='Delete a share by ID.',
|
||||||
dependencies=[Depends(auth_required)]
|
status_code=204,
|
||||||
)
|
)
|
||||||
def router_share_delete(id: str) -> ResponseBase:
|
async def router_share_delete(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
id: UUID,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Delete a share by ID.
|
删除分享
|
||||||
|
|
||||||
Args:
|
认证:需要 JWT token
|
||||||
id (str): The ID of the share to be deleted.
|
|
||||||
|
|
||||||
Returns:
|
流程:
|
||||||
ResponseBase: A model containing the response data for the deleted share.
|
1. 通过分享ID查找分享
|
||||||
|
2. 验证分享属于当前用户
|
||||||
|
3. 删除分享记录
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
share = await Share.get_exist_one(session, id)
|
||||||
|
if share.user_id != user.id:
|
||||||
|
http_exceptions.raise_forbidden(detail="无权删除此分享")
|
||||||
|
|
||||||
|
user_id = user.id
|
||||||
|
share_code = share.code
|
||||||
|
await Share.delete(session, share)
|
||||||
|
|
||||||
|
l.info(f"用户 {user_id} 删除了分享: {share_code}")
|
||||||
@@ -1,7 +1,12 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
from sqlmodels import (
|
||||||
|
ResponseBase, Setting, SettingsType, SiteConfigResponse,
|
||||||
|
ThemePreset, ThemePresetResponse, ThemePresetListResponse,
|
||||||
|
AuthMethodConfig,
|
||||||
|
)
|
||||||
|
from sqlmodels.auth_identity import AuthProviderType
|
||||||
from sqlmodels.setting import CaptchaType
|
from sqlmodels.setting import CaptchaType
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
|
|
||||||
@@ -41,6 +46,22 @@ def router_site_captcha():
|
|||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
|
@site_router.get(
|
||||||
|
path='/themes',
|
||||||
|
summary='获取主题预设列表',
|
||||||
|
)
|
||||||
|
async def router_site_themes(session: SessionDep) -> ThemePresetListResponse:
|
||||||
|
"""
|
||||||
|
获取所有主题预设列表
|
||||||
|
|
||||||
|
无需认证,前端初始化时调用。
|
||||||
|
"""
|
||||||
|
presets: list[ThemePreset] = await ThemePreset.get(session, fetch_mode="all")
|
||||||
|
return ThemePresetListResponse(
|
||||||
|
themes=[ThemePresetResponse.from_preset(p) for p in presets]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@site_router.get(
|
@site_router.get(
|
||||||
path='/config',
|
path='/config',
|
||||||
summary='站点全局配置',
|
summary='站点全局配置',
|
||||||
@@ -51,7 +72,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
|||||||
获取站点全局配置
|
获取站点全局配置
|
||||||
|
|
||||||
无需认证。前端在初始化时调用此端点获取验证码类型、
|
无需认证。前端在初始化时调用此端点获取验证码类型、
|
||||||
登录/注册/找回密码是否需要验证码等配置。
|
登录/注册/找回密码是否需要验证码、可用的认证方式等配置。
|
||||||
"""
|
"""
|
||||||
# 批量查询所需设置
|
# 批量查询所需设置
|
||||||
settings: list[Setting] = await Setting.get(
|
settings: list[Setting] = await Setting.get(
|
||||||
@@ -59,7 +80,10 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
|||||||
(Setting.type == SettingsType.BASIC) |
|
(Setting.type == SettingsType.BASIC) |
|
||||||
(Setting.type == SettingsType.LOGIN) |
|
(Setting.type == SettingsType.LOGIN) |
|
||||||
(Setting.type == SettingsType.REGISTER) |
|
(Setting.type == SettingsType.REGISTER) |
|
||||||
(Setting.type == SettingsType.CAPTCHA),
|
(Setting.type == SettingsType.CAPTCHA) |
|
||||||
|
(Setting.type == SettingsType.AUTH) |
|
||||||
|
(Setting.type == SettingsType.OAUTH) |
|
||||||
|
(Setting.type == SettingsType.AVATAR),
|
||||||
fetch_mode="all",
|
fetch_mode="all",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -75,12 +99,32 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
|||||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||||
captcha_key = s.get("captcha_CloudflareKey") or None
|
captcha_key = s.get("captcha_CloudflareKey") or None
|
||||||
|
|
||||||
|
# 构建认证方式列表
|
||||||
|
auth_methods: list[AuthMethodConfig] = [
|
||||||
|
AuthMethodConfig(provider=AuthProviderType.EMAIL_PASSWORD, is_enabled=s.get("auth_email_password_enabled") == "1"),
|
||||||
|
AuthMethodConfig(provider=AuthProviderType.PHONE_SMS, is_enabled=s.get("auth_phone_sms_enabled") == "1"),
|
||||||
|
AuthMethodConfig(provider=AuthProviderType.GITHUB, is_enabled=s.get("github_enabled") == "1"),
|
||||||
|
AuthMethodConfig(provider=AuthProviderType.QQ, is_enabled=s.get("qq_enabled") == "1"),
|
||||||
|
AuthMethodConfig(provider=AuthProviderType.PASSKEY, is_enabled=s.get("auth_passkey_enabled") == "1"),
|
||||||
|
AuthMethodConfig(provider=AuthProviderType.MAGIC_LINK, is_enabled=s.get("auth_magic_link_enabled") == "1"),
|
||||||
|
]
|
||||||
|
|
||||||
return SiteConfigResponse(
|
return SiteConfigResponse(
|
||||||
title=s.get("siteName") or "DiskNext",
|
title=s.get("siteName") or "DiskNext",
|
||||||
|
logo_light=s.get("logo_light") or None,
|
||||||
|
logo_dark=s.get("logo_dark") or None,
|
||||||
register_enabled=s.get("register_enabled") == "1",
|
register_enabled=s.get("register_enabled") == "1",
|
||||||
login_captcha=s.get("login_captcha") == "1",
|
login_captcha=s.get("login_captcha") == "1",
|
||||||
reg_captcha=s.get("reg_captcha") == "1",
|
reg_captcha=s.get("reg_captcha") == "1",
|
||||||
forget_captcha=s.get("forget_captcha") == "1",
|
forget_captcha=s.get("forget_captcha") == "1",
|
||||||
captcha_type=captcha_type,
|
captcha_type=captcha_type,
|
||||||
captcha_key=captcha_key,
|
captcha_key=captcha_key,
|
||||||
|
auth_methods=auth_methods,
|
||||||
|
password_required=s.get("auth_password_required") == "1",
|
||||||
|
phone_binding_required=s.get("auth_phone_binding_required") == "1",
|
||||||
|
email_binding_required=s.get("auth_email_binding_required") == "1",
|
||||||
|
avatar_max_size=int(s["avatar_size"]),
|
||||||
|
footer_code=s.get("footer_code"),
|
||||||
|
tos_url=s.get("tos_url"),
|
||||||
|
privacy_url=s.get("privacy_url"),
|
||||||
)
|
)
|
||||||
@@ -20,15 +20,15 @@ slave_aria2_router = APIRouter(
|
|||||||
summary='测试用路由',
|
summary='测试用路由',
|
||||||
description='Test route for checking connectivity.',
|
description='Test route for checking connectivity.',
|
||||||
)
|
)
|
||||||
def router_slave_ping() -> ResponseBase:
|
def router_slave_ping() -> str:
|
||||||
"""
|
"""
|
||||||
Test route for checking connectivity.
|
Test route for checking connectivity.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ResponseBase: A response model indicating success.
|
str: 后端版本号
|
||||||
"""
|
"""
|
||||||
from utils.conf.appmeta import BackendVersion
|
from utils.conf.appmeta import BackendVersion
|
||||||
return ResponseBase(data=BackendVersion)
|
return BackendVersion
|
||||||
|
|
||||||
@slave_router.post(
|
@slave_router.post(
|
||||||
path='/post',
|
path='/post',
|
||||||
|
|||||||
161
routers/api/v1/trash/__init__.py
Normal file
161
routers/api/v1/trash/__init__.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""
|
||||||
|
回收站路由
|
||||||
|
|
||||||
|
提供回收站管理功能:列出、恢复、永久删除、清空。
|
||||||
|
|
||||||
|
路由前缀:/trash
|
||||||
|
"""
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from middleware.auth import auth_required
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import Object, User
|
||||||
|
from sqlmodels.object import TrashDeleteRequest, TrashItemResponse, TrashRestoreRequest
|
||||||
|
from service.storage.object import (
|
||||||
|
permanently_delete_objects,
|
||||||
|
restore_objects,
|
||||||
|
soft_delete_objects,
|
||||||
|
)
|
||||||
|
|
||||||
|
trash_router = APIRouter(
|
||||||
|
prefix="/trash",
|
||||||
|
tags=["trash"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@trash_router.get(
|
||||||
|
path='/',
|
||||||
|
summary='列出回收站内容',
|
||||||
|
description='获取当前用户回收站中的所有顶层对象。',
|
||||||
|
)
|
||||||
|
async def router_trash_list(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
) -> list[TrashItemResponse]:
|
||||||
|
"""
|
||||||
|
列出回收站内容
|
||||||
|
|
||||||
|
认证:需要 JWT token
|
||||||
|
|
||||||
|
返回回收站中被直接删除的顶层对象列表,
|
||||||
|
不包含其子对象(子对象在恢复/永久删除时会随顶层对象一起处理)。
|
||||||
|
"""
|
||||||
|
items = await Object.get_trash_items(session, user.id)
|
||||||
|
|
||||||
|
return [
|
||||||
|
TrashItemResponse(
|
||||||
|
id=item.id,
|
||||||
|
name=item.name,
|
||||||
|
type=item.type,
|
||||||
|
size=item.size,
|
||||||
|
deleted_at=item.deleted_at,
|
||||||
|
original_parent_id=item.deleted_original_parent_id,
|
||||||
|
)
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@trash_router.patch(
|
||||||
|
path='/restore',
|
||||||
|
summary='恢复对象',
|
||||||
|
description='从回收站恢复一个或多个对象到原位置。如果原位置不存在则恢复到根目录。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_trash_restore(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
request: TrashRestoreRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
从回收站恢复对象
|
||||||
|
|
||||||
|
认证:需要 JWT token
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 验证对象存在且在回收站中(deleted_at IS NOT NULL)
|
||||||
|
2. 检查原父目录是否存在且未删除
|
||||||
|
3. 存在 → 恢复到原位置;不存在 → 恢复到根目录
|
||||||
|
4. 处理同名冲突(自动重命名)
|
||||||
|
5. 清除 deleted_at 和 deleted_original_parent_id
|
||||||
|
"""
|
||||||
|
user_id = user.id
|
||||||
|
objects_to_restore: list[Object] = []
|
||||||
|
|
||||||
|
for obj_id in request.ids:
|
||||||
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == obj_id) & (Object.owner_id == user_id) & (Object.deleted_at != None)
|
||||||
|
)
|
||||||
|
if obj:
|
||||||
|
objects_to_restore.append(obj)
|
||||||
|
|
||||||
|
if objects_to_restore:
|
||||||
|
restored_count = await restore_objects(session, objects_to_restore, user_id)
|
||||||
|
l.info(f"用户 {user_id} 从回收站恢复了 {restored_count} 个对象")
|
||||||
|
|
||||||
|
|
||||||
|
@trash_router.delete(
|
||||||
|
path='/',
|
||||||
|
summary='永久删除对象',
|
||||||
|
description='永久删除回收站中的指定对象,包括物理文件和数据库记录。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_trash_delete(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
request: TrashDeleteRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
永久删除回收站中的对象
|
||||||
|
|
||||||
|
认证:需要 JWT token
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 验证对象存在且在回收站中
|
||||||
|
2. BFS 收集所有子文件的 PhysicalFile
|
||||||
|
3. 处理引用计数,引用为 0 时物理删除文件
|
||||||
|
4. 硬删除根 Object(CASCADE 清理子对象)
|
||||||
|
5. 更新用户存储配额
|
||||||
|
"""
|
||||||
|
user_id = user.id
|
||||||
|
objects_to_delete: list[Object] = []
|
||||||
|
|
||||||
|
for obj_id in request.ids:
|
||||||
|
obj = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == obj_id) & (Object.owner_id == user_id) & (Object.deleted_at != None)
|
||||||
|
)
|
||||||
|
if obj:
|
||||||
|
objects_to_delete.append(obj)
|
||||||
|
|
||||||
|
if objects_to_delete:
|
||||||
|
deleted_count = await permanently_delete_objects(session, objects_to_delete, user_id)
|
||||||
|
l.info(f"用户 {user_id} 永久删除了 {deleted_count} 个对象")
|
||||||
|
|
||||||
|
|
||||||
|
@trash_router.delete(
|
||||||
|
path='/empty',
|
||||||
|
summary='清空回收站',
|
||||||
|
description='永久删除回收站中的所有对象。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_trash_empty(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
清空回收站
|
||||||
|
|
||||||
|
认证:需要 JWT token
|
||||||
|
|
||||||
|
获取回收站中所有顶层对象,逐个执行永久删除。
|
||||||
|
"""
|
||||||
|
user_id = user.id
|
||||||
|
trash_items = await Object.get_trash_items(session, user_id)
|
||||||
|
|
||||||
|
if trash_items:
|
||||||
|
deleted_count = await permanently_delete_objects(session, trash_items, user_id)
|
||||||
|
l.info(f"用户 {user_id} 清空回收站,共删除 {deleted_count} 个对象")
|
||||||
@@ -1,18 +1,31 @@
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
|
from itsdangerous import URLSafeTimedSerializer
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from webauthn import generate_registration_options
|
from webauthn import (
|
||||||
from webauthn.helpers import options_to_json_dict
|
generate_authentication_options,
|
||||||
|
generate_registration_options,
|
||||||
|
verify_registration_response,
|
||||||
|
)
|
||||||
|
from webauthn.helpers import bytes_to_base64url, options_to_json
|
||||||
|
from webauthn.helpers.structs import PublicKeyCredentialDescriptor
|
||||||
|
|
||||||
import service
|
import service
|
||||||
import sqlmodels
|
import sqlmodels
|
||||||
from middleware.auth import auth_required
|
from middleware.auth import auth_required
|
||||||
from middleware.dependencies import SessionDep, require_captcha
|
from middleware.dependencies import SessionDep, require_captcha
|
||||||
from service.captcha import CaptchaScene
|
from service.captcha import CaptchaScene
|
||||||
|
from service.redis.challenge_store import ChallengeStore
|
||||||
|
from service.webauthn import get_rp_config
|
||||||
|
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||||
from sqlmodels.user import UserStatus
|
from sqlmodels.user import UserStatus
|
||||||
|
from sqlmodels.user_authn import UserAuthn
|
||||||
from utils import JWT, Password, http_exceptions
|
from utils import JWT, Password, http_exceptions
|
||||||
from .settings import user_settings_router
|
from .settings import user_settings_router
|
||||||
|
|
||||||
@@ -23,59 +36,36 @@ user_router = APIRouter(
|
|||||||
|
|
||||||
user_router.include_router(user_settings_router)
|
user_router.include_router(user_settings_router)
|
||||||
|
|
||||||
class OAuth2PasswordWithExtrasForm:
|
|
||||||
"""
|
|
||||||
扩展 OAuth2 密码表单。
|
|
||||||
|
|
||||||
在标准 username/password 基础上添加 otp_code 字段。
|
|
||||||
captcha_code 由 require_captcha 依赖注入单独处理。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
username: Annotated[str, Form()],
|
|
||||||
password: Annotated[str, Form()],
|
|
||||||
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
|
|
||||||
):
|
|
||||||
self.username = username
|
|
||||||
self.password = password
|
|
||||||
self.otp_code = otp_code
|
|
||||||
|
|
||||||
|
|
||||||
@user_router.post(
|
@user_router.post(
|
||||||
path='/session',
|
path='/session',
|
||||||
summary='用户登录',
|
summary='用户登录(统一入口)',
|
||||||
description='用户登录端点,支持验证码校验和两步验证。',
|
description='统一登录端点,支持多种认证方式。',
|
||||||
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
|
|
||||||
)
|
)
|
||||||
async def router_user_session(
|
async def router_user_session(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
|
request: sqlmodels.UnifiedLoginRequest,
|
||||||
) -> sqlmodels.TokenResponse:
|
) -> sqlmodels.TokenResponse:
|
||||||
"""
|
"""
|
||||||
用户登录端点
|
统一登录端点
|
||||||
|
|
||||||
表单字段:
|
请求体:
|
||||||
- username: 用户邮箱
|
- provider: 登录方式(email_password / github / qq / passkey / magic_link)
|
||||||
- password: 用户密码
|
- identifier: 标识符(邮箱 / OAuth code / credential_id / magic link token)
|
||||||
- captcha_code: 验证码 token(可选,由 require_captcha 依赖校验)
|
- credential: 凭证(密码 / WebAuthn assertion 等)
|
||||||
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
|
- two_fa_code: 两步验证码(可选)
|
||||||
|
- redirect_uri: OAuth 回调地址(可选)
|
||||||
|
- captcha: 验证码(可选)
|
||||||
|
|
||||||
错误处理:
|
错误处理:
|
||||||
- 400: 需要验证码但未提供
|
- 400: 登录方式未启用 / 参数错误
|
||||||
- 401: 邮箱/密码错误,或 2FA 验证码错误
|
- 401: 凭证错误
|
||||||
- 403: 账户已禁用 / 验证码验证失败
|
- 403: 账户已禁用
|
||||||
- 428: 需要两步验证但未提供 otp_code
|
- 428: 需要两步验证
|
||||||
|
- 501: 暂未实现的登录方式
|
||||||
"""
|
"""
|
||||||
return await service.user.login(
|
return await service.user.unified_login(session, request)
|
||||||
session,
|
|
||||||
sqlmodels.LoginRequest(
|
|
||||||
email=form_data.username,
|
|
||||||
password=form_data.password,
|
|
||||||
two_fa_code=form_data.otp_code,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@user_router.post(
|
@user_router.post(
|
||||||
path='/session/refresh',
|
path='/session/refresh',
|
||||||
@@ -150,41 +140,82 @@ async def router_user_session_refresh(
|
|||||||
|
|
||||||
@user_router.post(
|
@user_router.post(
|
||||||
path='/',
|
path='/',
|
||||||
summary='用户注册',
|
summary='用户注册(统一入口)',
|
||||||
description='User registration endpoint.',
|
description='User registration endpoint.',
|
||||||
status_code=204,
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_user_register(
|
async def router_user_register(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
request: sqlmodels.RegisterRequest,
|
request: sqlmodels.UnifiedRegisterRequest,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
用户注册端点
|
统一注册端点
|
||||||
|
|
||||||
流程:
|
流程:
|
||||||
1. 验证用户名唯一性
|
1. 检查注册开关
|
||||||
2. 获取默认用户组
|
2. 检查 provider 启用
|
||||||
3. 创建用户记录
|
3. 验证 identifier 唯一性(AuthIdentity 表)
|
||||||
4. 创建用户根目录(name="/")
|
4. 创建 User + AuthIdentity + 根目录
|
||||||
|
|
||||||
:param session: 数据库会话
|
请求体:
|
||||||
:param request: 注册请求
|
- provider: 注册方式(email_password / phone_sms)
|
||||||
:return: 注册结果
|
- identifier: 标识符(邮箱 / 手机号)
|
||||||
:raises HTTPException 400: 用户名已存在
|
- credential: 凭证(密码 / 短信验证码)
|
||||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
- nickname: 昵称(可选)
|
||||||
|
- captcha: 验证码(可选)
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: 注册未开放 / 参数错误
|
||||||
|
- 409: 邮箱或手机号已存在
|
||||||
|
- 501: 暂未实现的注册方式
|
||||||
"""
|
"""
|
||||||
# 1. 验证邮箱唯一性
|
# 1. 检查注册开关
|
||||||
|
register_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||||
|
& (sqlmodels.Setting.name == "register_enabled"),
|
||||||
|
)
|
||||||
|
if not register_setting or register_setting.value != "1":
|
||||||
|
http_exceptions.raise_bad_request("注册功能未开放")
|
||||||
|
|
||||||
|
# 2. 目前只支持 email_password 注册
|
||||||
|
if request.provider == AuthProviderType.PHONE_SMS:
|
||||||
|
http_exceptions.raise_not_implemented("短信注册暂未开放")
|
||||||
|
elif request.provider != AuthProviderType.EMAIL_PASSWORD:
|
||||||
|
http_exceptions.raise_bad_request("不支持的注册方式")
|
||||||
|
|
||||||
|
# 3. 检查密码是否必填
|
||||||
|
password_required_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||||
|
& (sqlmodels.Setting.name == "auth_password_required"),
|
||||||
|
)
|
||||||
|
is_password_required = not password_required_setting or password_required_setting.value != "0"
|
||||||
|
if is_password_required and not request.credential:
|
||||||
|
http_exceptions.raise_bad_request("密码不能为空")
|
||||||
|
|
||||||
|
# 4. 验证 identifier 唯一性(AuthIdentity 表)
|
||||||
|
existing_identity = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.provider == request.provider)
|
||||||
|
& (AuthIdentity.identifier == request.identifier),
|
||||||
|
)
|
||||||
|
if existing_identity:
|
||||||
|
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||||
|
|
||||||
|
# 同时检查 User.email 唯一性(防止旧数据冲突)
|
||||||
existing_user = await sqlmodels.User.get(
|
existing_user = await sqlmodels.User.get(
|
||||||
session,
|
session,
|
||||||
sqlmodels.User.email == request.email
|
sqlmodels.User.email == request.identifier,
|
||||||
)
|
)
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||||
|
|
||||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
# 5. 获取默认用户组
|
||||||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
default_group_setting = await sqlmodels.Setting.get(
|
||||||
session,
|
session,
|
||||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
|
||||||
|
& (sqlmodels.Setting.name == "default_group"),
|
||||||
)
|
)
|
||||||
if default_group_setting is None or not default_group_setting.value:
|
if default_group_setting is None or not default_group_setting.value:
|
||||||
logger.error("默认用户组不存在")
|
logger.error("默认用户组不存在")
|
||||||
@@ -196,21 +227,33 @@ async def router_user_register(
|
|||||||
logger.error("默认用户组不存在")
|
logger.error("默认用户组不存在")
|
||||||
http_exceptions.raise_internal_error()
|
http_exceptions.raise_internal_error()
|
||||||
|
|
||||||
# 3. 创建用户
|
# 6. 创建用户
|
||||||
hashed_password = Password.hash(request.password)
|
|
||||||
new_user = sqlmodels.User(
|
new_user = sqlmodels.User(
|
||||||
email=request.email,
|
email=request.identifier,
|
||||||
password=hashed_password,
|
nickname=request.nickname,
|
||||||
group_id=default_group.id,
|
group_id=default_group.id,
|
||||||
)
|
)
|
||||||
new_user_id = new_user.id
|
new_user_id = new_user.id
|
||||||
await new_user.save(session)
|
new_user = await new_user.save(session)
|
||||||
|
|
||||||
# 4. 创建用户根目录
|
# 7. 创建 AuthIdentity
|
||||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
hashed_password = Password.hash(request.credential) if request.credential else None
|
||||||
if not default_policy:
|
identity = AuthIdentity(
|
||||||
logger.error("默认存储策略不存在")
|
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||||
|
identifier=request.identifier,
|
||||||
|
credential=hashed_password,
|
||||||
|
is_primary=True,
|
||||||
|
is_verified=False,
|
||||||
|
user_id=new_user_id,
|
||||||
|
)
|
||||||
|
identity = await identity.save(session)
|
||||||
|
|
||||||
|
# 8. 创建用户根目录(使用用户组关联的第一个存储策略)
|
||||||
|
await session.refresh(default_group, ['policies'])
|
||||||
|
if not default_group.policies:
|
||||||
|
logger.error("默认用户组未关联任何存储策略")
|
||||||
http_exceptions.raise_internal_error()
|
http_exceptions.raise_internal_error()
|
||||||
|
default_policy = default_group.policies[0]
|
||||||
|
|
||||||
await sqlmodels.Object(
|
await sqlmodels.Object(
|
||||||
name="/",
|
name="/",
|
||||||
@@ -220,6 +263,66 @@ async def router_user_register(
|
|||||||
policy_id=default_policy.id,
|
policy_id=default_policy.id,
|
||||||
).save(session)
|
).save(session)
|
||||||
|
|
||||||
|
|
||||||
|
@user_router.post(
|
||||||
|
path='/magic-link',
|
||||||
|
summary='发送 Magic Link 邮件',
|
||||||
|
description='生成 Magic Link token 并发送到指定邮箱。',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_user_magic_link(
|
||||||
|
session: SessionDep,
|
||||||
|
request: sqlmodels.MagicLinkRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
发送 Magic Link 邮件
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 验证邮箱对应的 AuthIdentity 存在
|
||||||
|
2. 生成签名 token
|
||||||
|
3. 发送邮件(包含带 token 的链接)
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: Magic Link 未启用
|
||||||
|
- 404: 邮箱未注册
|
||||||
|
"""
|
||||||
|
# 检查 magic_link 是否启用
|
||||||
|
magic_link_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||||
|
& (sqlmodels.Setting.name == "auth_magic_link_enabled"),
|
||||||
|
)
|
||||||
|
if not magic_link_setting or magic_link_setting.value != "1":
|
||||||
|
http_exceptions.raise_bad_request("Magic Link 登录未启用")
|
||||||
|
|
||||||
|
# 验证邮箱存在
|
||||||
|
identity = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.identifier == request.email)
|
||||||
|
& (
|
||||||
|
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||||
|
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not identity:
|
||||||
|
http_exceptions.raise_not_found("该邮箱未注册")
|
||||||
|
|
||||||
|
# 生成签名 token
|
||||||
|
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||||
|
token = serializer.dumps(request.email, salt="magic-link-salt")
|
||||||
|
|
||||||
|
# 获取站点 URL
|
||||||
|
site_url_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.BASIC)
|
||||||
|
& (sqlmodels.Setting.name == "siteURL"),
|
||||||
|
)
|
||||||
|
site_url = site_url_setting.value if site_url_setting else "http://localhost"
|
||||||
|
|
||||||
|
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token})
|
||||||
|
logger.info(f"Magic Link token 已为 {request.email} 生成 (邮件发送待实现)")
|
||||||
|
|
||||||
|
|
||||||
@user_router.post(
|
@user_router.post(
|
||||||
path='/code',
|
path='/code',
|
||||||
summary='发送验证码邮件',
|
summary='发送验证码邮件',
|
||||||
@@ -236,46 +339,6 @@ def router_user_email_code(
|
|||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
http_exceptions.raise_not_implemented()
|
||||||
|
|
||||||
@user_router.get(
|
|
||||||
path='/qq',
|
|
||||||
summary='初始化QQ登录',
|
|
||||||
description='Initialize QQ login for a user.',
|
|
||||||
)
|
|
||||||
def router_user_qq() -> sqlmodels.ResponseBase:
|
|
||||||
"""
|
|
||||||
Initialize QQ login for a user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary containing QQ login initialization information.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@user_router.get(
|
|
||||||
path='authn/{username}',
|
|
||||||
summary='WebAuthn登录初始化',
|
|
||||||
description='Initialize WebAuthn login for a user.',
|
|
||||||
)
|
|
||||||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
|
||||||
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@user_router.post(
|
|
||||||
path='authn/finish/{username}',
|
|
||||||
summary='WebAuthn登录',
|
|
||||||
description='Finish WebAuthn login for a user.',
|
|
||||||
)
|
|
||||||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
|
||||||
"""
|
|
||||||
Finish WebAuthn login for a user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
username (str): The username of the user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary containing WebAuthn login information.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@user_router.get(
|
@user_router.get(
|
||||||
path='/profile/{id}',
|
path='/profile/{id}',
|
||||||
summary='获取用户主页展示用分享',
|
summary='获取用户主页展示用分享',
|
||||||
@@ -296,20 +359,78 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
|||||||
@user_router.get(
|
@user_router.get(
|
||||||
path='/avatar/{id}/{size}',
|
path='/avatar/{id}/{size}',
|
||||||
summary='获取用户头像',
|
summary='获取用户头像',
|
||||||
description='Get user avatar by ID and size.',
|
response_model=None,
|
||||||
)
|
)
|
||||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
async def router_user_avatar(
|
||||||
|
session: SessionDep,
|
||||||
|
id: UUID,
|
||||||
|
size: int = 128,
|
||||||
|
) -> FileResponse | RedirectResponse:
|
||||||
"""
|
"""
|
||||||
Get user avatar by ID and size.
|
获取指定用户指定尺寸的头像(公开端点,无需认证)
|
||||||
|
|
||||||
Args:
|
路径参数:
|
||||||
id (str): The user ID.
|
- id: 用户 UUID
|
||||||
size (int): The size of the avatar image.
|
- size: 请求的头像尺寸(px),默认 128
|
||||||
|
|
||||||
Returns:
|
行为:
|
||||||
str: A Base64 encoded string of the user avatar image.
|
- default: 302 重定向到 Gravatar identicon
|
||||||
|
- gravatar: 302 重定向到 Gravatar(使用用户邮箱 MD5)
|
||||||
|
- file: 返回本地 WebP 文件
|
||||||
|
|
||||||
|
响应:
|
||||||
|
- 200: image/webp(file 模式)
|
||||||
|
- 302: 重定向到外部 URL(default/gravatar 模式)
|
||||||
|
- 404: 用户不存在
|
||||||
|
|
||||||
|
缓存:Cache-Control: public, max-age=3600
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
import aiofiles.os
|
||||||
|
|
||||||
|
from service.avatar import (
|
||||||
|
get_avatar_file_path,
|
||||||
|
get_avatar_settings,
|
||||||
|
gravatar_url,
|
||||||
|
resolve_avatar_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await sqlmodels.User.get(session, sqlmodels.User.id == id)
|
||||||
|
if not user:
|
||||||
|
http_exceptions.raise_not_found("用户不存在")
|
||||||
|
|
||||||
|
avatar_path, _, size_l, size_m, size_s = await get_avatar_settings(session)
|
||||||
|
|
||||||
|
if user.avatar == "file":
|
||||||
|
size_label = resolve_avatar_size(size, size_l, size_m, size_s)
|
||||||
|
file_path = get_avatar_file_path(avatar_path, user.id, size_label)
|
||||||
|
|
||||||
|
if not await aiofiles.os.path.exists(file_path):
|
||||||
|
# 文件丢失,降级为 identicon
|
||||||
|
fallback_url = gravatar_url(str(user.id), size, "https://www.gravatar.com/")
|
||||||
|
return RedirectResponse(url=fallback_url, status_code=302)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=file_path,
|
||||||
|
media_type="image/webp",
|
||||||
|
headers={"Cache-Control": "public, max-age=3600"},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif user.avatar == "gravatar":
|
||||||
|
gravatar_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.AVATAR)
|
||||||
|
& (sqlmodels.Setting.name == "gravatar_server"),
|
||||||
|
)
|
||||||
|
server = gravatar_setting.value if gravatar_setting else "https://www.gravatar.com/"
|
||||||
|
email = user.email or str(user.id)
|
||||||
|
url = gravatar_url(email, size, server)
|
||||||
|
return RedirectResponse(url=url, status_code=302)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# default: identicon
|
||||||
|
email_or_id = user.email or str(user.id)
|
||||||
|
url = gravatar_url(email_or_id, size, "https://www.gravatar.com/")
|
||||||
|
return RedirectResponse(url=url, status_code=302)
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# 需要登录的接口
|
# 需要登录的接口
|
||||||
@@ -348,8 +469,6 @@ async def router_user_me(
|
|||||||
return sqlmodels.UserResponse(
|
return sqlmodels.UserResponse(
|
||||||
id=user.id,
|
id=user.id,
|
||||||
email=user.email,
|
email=user.email,
|
||||||
status=user.status,
|
|
||||||
score=user.score,
|
|
||||||
nickname=user.nickname,
|
nickname=user.nickname,
|
||||||
avatar=user.avatar,
|
avatar=user.avatar,
|
||||||
created_at=user.created_at,
|
created_at=user.created_at,
|
||||||
@@ -375,9 +494,24 @@ async def router_user_storage(
|
|||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||||
|
|
||||||
# [TODO] 总空间加上用户购买的额外空间
|
# 查询用户所有未过期容量包的 size 总和
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import func, select, and_, or_
|
||||||
|
|
||||||
total: int = group.max_storage
|
now = datetime.now()
|
||||||
|
stmt = select(func.coalesce(func.sum(sqlmodels.StoragePack.size), 0)).where(
|
||||||
|
and_(
|
||||||
|
sqlmodels.StoragePack.user_id == user.id,
|
||||||
|
or_(
|
||||||
|
sqlmodels.StoragePack.expired_time.is_(None),
|
||||||
|
sqlmodels.StoragePack.expired_time > now,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await session.exec(stmt)
|
||||||
|
active_packs_total: int = result.scalar_one()
|
||||||
|
|
||||||
|
total: int = group.max_storage + active_packs_total
|
||||||
used: int = user.storage
|
used: int = user.storage
|
||||||
free: int = max(0, total - used)
|
free: int = max(0, total - used)
|
||||||
|
|
||||||
@@ -389,57 +523,177 @@ async def router_user_storage(
|
|||||||
|
|
||||||
@user_router.put(
|
@user_router.put(
|
||||||
path='/authn/start',
|
path='/authn/start',
|
||||||
summary='WebAuthn登录初始化',
|
summary='注册 Passkey 凭证(初始化)',
|
||||||
description='Initialize WebAuthn login for a user.',
|
description='Initialize Passkey registration for a user.',
|
||||||
dependencies=[Depends(auth_required)],
|
dependencies=[Depends(auth_required)],
|
||||||
)
|
)
|
||||||
async def router_user_authn_start(
|
async def router_user_authn_start(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
) -> sqlmodels.ResponseBase:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Initialize WebAuthn login for a user.
|
Passkey 注册初始化(需要登录)
|
||||||
|
|
||||||
Returns:
|
返回 WebAuthn registration options,前端使用 navigator.credentials.create() 处理。
|
||||||
dict: A dictionary containing WebAuthn initialization information.
|
|
||||||
|
错误处理:
|
||||||
|
- 400: Passkey 未启用
|
||||||
"""
|
"""
|
||||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
|
||||||
authn_setting = await sqlmodels.Setting.get(
|
authn_setting = await sqlmodels.Setting.get(
|
||||||
session,
|
session,
|
||||||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTHN)
|
||||||
|
& (sqlmodels.Setting.name == "authn_enabled"),
|
||||||
)
|
)
|
||||||
if not authn_setting or authn_setting.value != "1":
|
if not authn_setting or authn_setting.value != "1":
|
||||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
http_exceptions.raise_bad_request("Passkey 未启用")
|
||||||
|
|
||||||
site_url_setting = await sqlmodels.Setting.get(
|
rp_id, rp_name, _origin = await get_rp_config(session)
|
||||||
|
|
||||||
|
# 查询用户已注册凭证,用于 exclude_credentials
|
||||||
|
existing_authns: list[UserAuthn] = await UserAuthn.get(
|
||||||
session,
|
session,
|
||||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL")
|
UserAuthn.user_id == user.id,
|
||||||
|
fetch_mode="all",
|
||||||
)
|
)
|
||||||
site_title_setting = await sqlmodels.Setting.get(
|
exclude_credentials: list[PublicKeyCredentialDescriptor] = [
|
||||||
session,
|
PublicKeyCredentialDescriptor(
|
||||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle")
|
id=authn.credential_id,
|
||||||
|
transports=authn.transports.split(",") if authn.transports else [],
|
||||||
)
|
)
|
||||||
|
for authn in existing_authns
|
||||||
|
]
|
||||||
|
|
||||||
options = generate_registration_options(
|
options = generate_registration_options(
|
||||||
rp_id=site_url_setting.value if site_url_setting else "",
|
rp_id=rp_id,
|
||||||
rp_name=site_title_setting.value if site_title_setting else "",
|
rp_name=rp_name,
|
||||||
user_name=user.email,
|
user_id=user.id.bytes,
|
||||||
user_display_name=user.nickname or user.email,
|
user_name=user.email or str(user.id),
|
||||||
|
user_display_name=user.nickname or user.email or str(user.id),
|
||||||
|
exclude_credentials=exclude_credentials if exclude_credentials else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
# 存储 challenge
|
||||||
|
await ChallengeStore.store(f"reg:{user.id}", options.challenge)
|
||||||
|
|
||||||
|
return json.loads(options_to_json(options))
|
||||||
|
|
||||||
|
|
||||||
@user_router.put(
|
@user_router.put(
|
||||||
path='/authn/finish',
|
path='/authn/finish',
|
||||||
summary='WebAuthn登录',
|
summary='注册 Passkey 凭证(完成)',
|
||||||
description='Finish WebAuthn login for a user.',
|
description='Finish Passkey registration for a user.',
|
||||||
dependencies=[Depends(auth_required)],
|
dependencies=[Depends(auth_required)],
|
||||||
|
status_code=201,
|
||||||
)
|
)
|
||||||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
async def router_user_authn_finish(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
request: sqlmodels.AuthnFinishRequest,
|
||||||
|
) -> sqlmodels.AuthnDetailResponse:
|
||||||
"""
|
"""
|
||||||
Finish WebAuthn login for a user.
|
Passkey 注册完成(需要登录)
|
||||||
|
|
||||||
Returns:
|
接收前端 navigator.credentials.create() 返回的凭证数据,
|
||||||
dict: A dictionary containing WebAuthn login information.
|
验证后创建 UserAuthn 行 + AuthIdentity(provider=passkey)。
|
||||||
|
|
||||||
|
请求体:
|
||||||
|
- credential: navigator.credentials.create() 返回的 JSON 字符串
|
||||||
|
- name: 凭证名称(可选)
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: challenge 已过期或无效 / 验证失败
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
# 取出 challenge(一次性)
|
||||||
|
challenge: bytes | None = await ChallengeStore.retrieve_and_delete(f"reg:{user.id}")
|
||||||
|
if challenge is None:
|
||||||
|
http_exceptions.raise_bad_request("注册会话已过期,请重新开始")
|
||||||
|
|
||||||
|
rp_id, _rp_name, origin = await get_rp_config(session)
|
||||||
|
|
||||||
|
# 验证注册响应
|
||||||
|
try:
|
||||||
|
verification = verify_registration_response(
|
||||||
|
credential=request.credential,
|
||||||
|
expected_challenge=challenge,
|
||||||
|
expected_rp_id=rp_id,
|
||||||
|
expected_origin=origin,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"WebAuthn 注册验证失败: {e}")
|
||||||
|
http_exceptions.raise_bad_request("Passkey 验证失败")
|
||||||
|
|
||||||
|
# 编码为 base64url 存储
|
||||||
|
credential_id_b64: str = bytes_to_base64url(verification.credential_id)
|
||||||
|
credential_public_key_b64: str = bytes_to_base64url(verification.credential_public_key)
|
||||||
|
|
||||||
|
# 提取 transports
|
||||||
|
credential_dict: dict = json.loads(request.credential)
|
||||||
|
response_dict: dict = credential_dict.get("response", {})
|
||||||
|
transports_list: list[str] = response_dict.get("transports", [])
|
||||||
|
transports_str: str | None = ",".join(transports_list) if transports_list else None
|
||||||
|
|
||||||
|
# 创建 UserAuthn 记录
|
||||||
|
authn = UserAuthn(
|
||||||
|
credential_id=credential_id_b64,
|
||||||
|
credential_public_key=credential_public_key_b64,
|
||||||
|
sign_count=verification.sign_count,
|
||||||
|
credential_device_type=verification.credential_device_type,
|
||||||
|
credential_backed_up=verification.credential_backed_up,
|
||||||
|
transports=transports_str,
|
||||||
|
name=request.name,
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
authn = await authn.save(session)
|
||||||
|
|
||||||
|
# 创建 AuthIdentity(provider=passkey,identifier=credential_id_b64)
|
||||||
|
identity = AuthIdentity(
|
||||||
|
provider=AuthProviderType.PASSKEY,
|
||||||
|
identifier=credential_id_b64,
|
||||||
|
is_primary=False,
|
||||||
|
is_verified=True,
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
identity = await identity.save(session)
|
||||||
|
|
||||||
|
return authn.to_detail_response()
|
||||||
|
|
||||||
|
|
||||||
|
@user_router.post(
|
||||||
|
path='/authn/options',
|
||||||
|
summary='获取 Passkey 登录 options(无需登录)',
|
||||||
|
description='Generate authentication options for Passkey login.',
|
||||||
|
)
|
||||||
|
async def router_user_authn_options(
|
||||||
|
session: SessionDep,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
获取 Passkey 登录的 authentication options(无需登录)
|
||||||
|
|
||||||
|
前端调用此端点获取 options 后使用 navigator.credentials.get() 处理。
|
||||||
|
使用 Discoverable Credentials 模式(空 allow_credentials),
|
||||||
|
由浏览器/平台决定展示哪些凭证。
|
||||||
|
|
||||||
|
返回值包含 ``challenge_token`` 字段,前端在登录请求中作为 ``identifier`` 传入。
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: Passkey 未启用
|
||||||
|
"""
|
||||||
|
authn_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTHN)
|
||||||
|
& (sqlmodels.Setting.name == "authn_enabled"),
|
||||||
|
)
|
||||||
|
if not authn_setting or authn_setting.value != "1":
|
||||||
|
http_exceptions.raise_bad_request("Passkey 未启用")
|
||||||
|
|
||||||
|
rp_id, _rp_name, _origin = await get_rp_config(session)
|
||||||
|
|
||||||
|
options = generate_authentication_options(rp_id=rp_id)
|
||||||
|
|
||||||
|
# 生成 challenge_token 用于关联 challenge
|
||||||
|
challenge_token: str = str(uuid4())
|
||||||
|
await ChallengeStore.store(f"auth:{challenge_token}", options.challenge)
|
||||||
|
|
||||||
|
result: dict = json.loads(options_to_json(options))
|
||||||
|
result["challenge_token"] = challenge_token
|
||||||
|
return result
|
||||||
|
|||||||
@@ -1,33 +1,56 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
|
|
||||||
import sqlmodels
|
import sqlmodels
|
||||||
from middleware.auth import auth_required
|
from middleware.auth import auth_required
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import (
|
||||||
|
BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest,
|
||||||
|
SettingOption, UserSettingUpdateRequest,
|
||||||
|
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
||||||
|
ChangePasswordRequest,
|
||||||
|
AuthnDetailResponse, AuthnRenameRequest,
|
||||||
|
PolicySummary,
|
||||||
|
)
|
||||||
|
from sqlmodels.color import ThemeColorsBase
|
||||||
|
from sqlmodels.user_authn import UserAuthn
|
||||||
from utils import JWT, Password, http_exceptions
|
from utils import JWT, Password, http_exceptions
|
||||||
|
from utils.password.pwd import PasswordStatus, TwoFactorResponse, TwoFactorVerifyRequest
|
||||||
|
from .file_viewers import file_viewers_router
|
||||||
|
|
||||||
user_settings_router = APIRouter(
|
user_settings_router = APIRouter(
|
||||||
prefix='/settings',
|
prefix='/settings',
|
||||||
tags=["user", "user_settings"],
|
tags=["user", "user_settings"],
|
||||||
dependencies=[Depends(auth_required)],
|
dependencies=[Depends(auth_required)],
|
||||||
)
|
)
|
||||||
|
user_settings_router.include_router(file_viewers_router)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.get(
|
@user_settings_router.get(
|
||||||
path='/policies',
|
path='/policies',
|
||||||
summary='获取用户可选存储策略',
|
summary='获取用户可选存储策略',
|
||||||
description='Get user selectable storage policies.',
|
|
||||||
)
|
)
|
||||||
def router_user_settings_policies() -> sqlmodels.ResponseBase:
|
async def router_user_settings_policies(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
) -> list[PolicySummary]:
|
||||||
"""
|
"""
|
||||||
Get user selectable storage policies.
|
获取当前用户所在组可选的存储策略列表
|
||||||
|
|
||||||
Returns:
|
返回用户组关联的所有存储策略的摘要信息。
|
||||||
dict: A dictionary containing available storage policies for the user.
|
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
group = await user.awaitable_attrs.group
|
||||||
|
await session.refresh(group, ['policies'])
|
||||||
|
return [
|
||||||
|
PolicySummary(
|
||||||
|
id=p.id, name=p.name, type=p.type,
|
||||||
|
server=p.server, max_size=p.max_size, is_private=p.is_private,
|
||||||
|
)
|
||||||
|
for p in group.policies
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.get(
|
@user_settings_router.get(
|
||||||
@@ -67,77 +90,309 @@ def router_user_settings_tasks() -> sqlmodels.ResponseBase:
|
|||||||
summary='获取当前用户设定',
|
summary='获取当前用户设定',
|
||||||
description='Get current user settings.',
|
description='Get current user settings.',
|
||||||
)
|
)
|
||||||
def router_user_settings(
|
async def router_user_settings(
|
||||||
|
session: SessionDep,
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
) -> sqlmodels.UserSettingResponse:
|
) -> sqlmodels.UserSettingResponse:
|
||||||
"""
|
"""
|
||||||
Get current user settings.
|
获取当前用户设定
|
||||||
|
|
||||||
Returns:
|
主题颜色合并策略:
|
||||||
dict: A dictionary containing the current user settings.
|
1. 用户有颜色快照(7个字段均有值)→ 直接使用快照
|
||||||
|
2. 否则查找默认预设 → 使用默认预设颜色
|
||||||
|
3. 无默认预设 → 使用内置默认值
|
||||||
"""
|
"""
|
||||||
|
# 计算主题颜色
|
||||||
|
has_snapshot = all([
|
||||||
|
user.color_primary, user.color_secondary, user.color_success,
|
||||||
|
user.color_info, user.color_warning, user.color_error, user.color_neutral,
|
||||||
|
])
|
||||||
|
if has_snapshot:
|
||||||
|
theme_colors = ThemeColorsBase(
|
||||||
|
primary=user.color_primary,
|
||||||
|
secondary=user.color_secondary,
|
||||||
|
success=user.color_success,
|
||||||
|
info=user.color_info,
|
||||||
|
warning=user.color_warning,
|
||||||
|
error=user.color_error,
|
||||||
|
neutral=user.color_neutral,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
default_preset: ThemePreset | None = await ThemePreset.get(
|
||||||
|
session, ThemePreset.is_default == True # noqa: E712
|
||||||
|
)
|
||||||
|
if default_preset:
|
||||||
|
theme_colors = ThemeColorsBase(
|
||||||
|
primary=default_preset.primary,
|
||||||
|
secondary=default_preset.secondary,
|
||||||
|
success=default_preset.success,
|
||||||
|
info=default_preset.info,
|
||||||
|
warning=default_preset.warning,
|
||||||
|
error=default_preset.error,
|
||||||
|
neutral=default_preset.neutral,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
theme_colors = BUILTIN_DEFAULT_COLORS
|
||||||
|
|
||||||
|
# 检查是否启用了两步验证(从 email_password AuthIdentity 的 extra_data 中读取)
|
||||||
|
has_two_factor = False
|
||||||
|
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.user_id == user.id)
|
||||||
|
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||||
|
)
|
||||||
|
if email_identity and email_identity.extra_data:
|
||||||
|
import orjson
|
||||||
|
extra: dict = orjson.loads(email_identity.extra_data)
|
||||||
|
has_two_factor = bool(extra.get("two_factor"))
|
||||||
|
|
||||||
return sqlmodels.UserSettingResponse(
|
return sqlmodels.UserSettingResponse(
|
||||||
id=user.id,
|
id=user.id,
|
||||||
email=user.email,
|
email=user.email,
|
||||||
|
phone=user.phone,
|
||||||
nickname=user.nickname,
|
nickname=user.nickname,
|
||||||
created_at=user.created_at,
|
created_at=user.created_at,
|
||||||
group_name=user.group.name,
|
group_name=user.group.name,
|
||||||
language=user.language,
|
language=user.language,
|
||||||
timezone=user.timezone,
|
timezone=user.timezone,
|
||||||
group_expires=user.group_expires,
|
group_expires=user.group_expires,
|
||||||
two_factor=user.two_factor is not None,
|
two_factor=has_two_factor,
|
||||||
|
theme_preset_id=user.theme_preset_id,
|
||||||
|
theme_colors=theme_colors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.post(
|
@user_settings_router.post(
|
||||||
path='/avatar',
|
path='/avatar',
|
||||||
summary='从文件上传头像',
|
summary='从文件上传头像',
|
||||||
description='Upload user avatar from file.',
|
status_code=204,
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
)
|
||||||
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
|
async def router_user_settings_avatar(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upload user avatar from file.
|
上传头像文件
|
||||||
|
|
||||||
Returns:
|
认证:JWT token
|
||||||
dict: A dictionary containing the result of the avatar upload.
|
请求体:multipart/form-data,file 字段
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 验证文件 MIME 类型(JPEG/PNG/GIF/WebP)
|
||||||
|
2. 验证文件大小 <= avatar_size 设置(默认 2MB)
|
||||||
|
3. 调用 Pillow 验证图片有效性并处理(居中裁剪、缩放 L/M/S)
|
||||||
|
4. 保存三种尺寸的 WebP 文件
|
||||||
|
5. 更新 User.avatar = "file"
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: 文件类型不支持 / 图片无法解析
|
||||||
|
- 413: 文件过大
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
from service.avatar import (
|
||||||
|
ALLOWED_CONTENT_TYPES,
|
||||||
|
get_avatar_settings,
|
||||||
|
process_and_save_avatar,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证 MIME 类型
|
||||||
|
if file.content_type not in ALLOWED_CONTENT_TYPES:
|
||||||
|
http_exceptions.raise_bad_request(
|
||||||
|
f"不支持的图片格式,允许: {', '.join(ALLOWED_CONTENT_TYPES)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 读取并验证大小
|
||||||
|
_, max_upload_size, _, _, _ = await get_avatar_settings(session)
|
||||||
|
raw_bytes = await file.read()
|
||||||
|
if len(raw_bytes) > max_upload_size:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail=f"文件过大,最大允许 {max_upload_size} 字节",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理并保存(内部会验证图片有效性,无效抛出 ValueError)
|
||||||
|
try:
|
||||||
|
await process_and_save_avatar(session, user.id, raw_bytes)
|
||||||
|
except ValueError as e:
|
||||||
|
http_exceptions.raise_bad_request(str(e))
|
||||||
|
|
||||||
|
# 更新用户头像字段
|
||||||
|
user.avatar = "file"
|
||||||
|
user = await user.save(session)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.put(
|
@user_settings_router.put(
|
||||||
path='/avatar',
|
path='/avatar',
|
||||||
summary='设定为Gravatar头像',
|
summary='设定为 Gravatar 头像',
|
||||||
description='Set user avatar to Gravatar.',
|
status_code=204,
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
)
|
||||||
def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase:
|
async def router_user_settings_avatar_gravatar(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Set user avatar to Gravatar.
|
将头像切换为 Gravatar
|
||||||
|
|
||||||
Returns:
|
认证:JWT token
|
||||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
|
||||||
|
流程:
|
||||||
|
1. 验证用户有邮箱(Gravatar 基于邮箱 MD5)
|
||||||
|
2. 如果当前是 FILE 头像,删除本地文件
|
||||||
|
3. 更新 User.avatar = "gravatar"
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: 用户没有邮箱
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
from service.avatar import delete_avatar_files
|
||||||
|
|
||||||
|
if not user.email:
|
||||||
|
http_exceptions.raise_bad_request("Gravatar 需要邮箱,请先绑定邮箱")
|
||||||
|
|
||||||
|
if user.avatar == "file":
|
||||||
|
await delete_avatar_files(session, user.id)
|
||||||
|
|
||||||
|
user.avatar = "gravatar"
|
||||||
|
user = await user.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
@user_settings_router.delete(
|
||||||
|
path='/avatar',
|
||||||
|
summary='重置头像为默认',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def router_user_settings_avatar_delete(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
重置头像为默认
|
||||||
|
|
||||||
|
认证:JWT token
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 如果当前是 FILE 头像,删除本地文件
|
||||||
|
2. 更新 User.avatar = "default"
|
||||||
|
"""
|
||||||
|
from service.avatar import delete_avatar_files
|
||||||
|
|
||||||
|
if user.avatar == "file":
|
||||||
|
await delete_avatar_files(session, user.id)
|
||||||
|
|
||||||
|
user.avatar = "default"
|
||||||
|
user = await user.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
@user_settings_router.patch(
|
||||||
|
path='/theme',
|
||||||
|
summary='更新用户主题设置',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_user_settings_theme(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
request: UserThemeUpdateRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
更新用户主题设置
|
||||||
|
|
||||||
|
请求体(均可选):
|
||||||
|
- theme_preset_id: 主题预设UUID
|
||||||
|
- theme_colors: 颜色配置对象(写入颜色快照)
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 指定的主题预设不存在
|
||||||
|
"""
|
||||||
|
# 验证 preset_id 存在性
|
||||||
|
if request.theme_preset_id is not None:
|
||||||
|
preset: ThemePreset | None = await ThemePreset.get(
|
||||||
|
session, ThemePreset.id == request.theme_preset_id
|
||||||
|
)
|
||||||
|
if not preset:
|
||||||
|
http_exceptions.raise_not_found("主题预设不存在")
|
||||||
|
user.theme_preset_id = request.theme_preset_id
|
||||||
|
|
||||||
|
# 将颜色解构到快照列
|
||||||
|
if request.theme_colors is not None:
|
||||||
|
user.color_primary = request.theme_colors.primary
|
||||||
|
user.color_secondary = request.theme_colors.secondary
|
||||||
|
user.color_success = request.theme_colors.success
|
||||||
|
user.color_info = request.theme_colors.info
|
||||||
|
user.color_warning = request.theme_colors.warning
|
||||||
|
user.color_error = request.theme_colors.error
|
||||||
|
user.color_neutral = request.theme_colors.neutral
|
||||||
|
|
||||||
|
user = await user.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
@user_settings_router.patch(
|
||||||
|
path='/password',
|
||||||
|
summary='修改密码',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_user_settings_change_password(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
request: ChangePasswordRequest,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
修改当前用户密码
|
||||||
|
|
||||||
|
请求体:
|
||||||
|
- old_password: 当前密码
|
||||||
|
- new_password: 新密码(至少 8 位)
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: 用户没有邮箱密码认证身份
|
||||||
|
- 403: 当前密码错误
|
||||||
|
"""
|
||||||
|
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.user_id == user.id)
|
||||||
|
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||||
|
)
|
||||||
|
if not email_identity or not email_identity.credential:
|
||||||
|
http_exceptions.raise_bad_request("未找到邮箱密码认证身份")
|
||||||
|
|
||||||
|
verify_result = Password.verify(email_identity.credential, request.old_password)
|
||||||
|
if verify_result == PasswordStatus.INVALID:
|
||||||
|
http_exceptions.raise_forbidden("当前密码错误")
|
||||||
|
|
||||||
|
email_identity.credential = Password.hash(request.new_password)
|
||||||
|
email_identity = await email_identity.save(session)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.patch(
|
@user_settings_router.patch(
|
||||||
path='/{option}',
|
path='/{option}',
|
||||||
summary='更新用户设定',
|
summary='更新用户设定',
|
||||||
description='Update user settings.',
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
)
|
||||||
def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
|
async def router_user_settings_patch(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
option: SettingOption,
|
||||||
|
request: UserSettingUpdateRequest,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update user settings.
|
更新单个用户设置项
|
||||||
|
|
||||||
Args:
|
路径参数:
|
||||||
option (str): The setting option to update.
|
- option: 设置项名称(nickname / language / timezone)
|
||||||
|
|
||||||
Returns:
|
请求体:
|
||||||
dict: A dictionary containing the result of the settings update.
|
- 包含与 option 同名的字段及其新值
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 422: 无效的 option 或字段值不符合约束
|
||||||
|
- 400: 必填字段值缺失
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
value = getattr(request, option.value)
|
||||||
|
|
||||||
|
# language / timezone 不允许设为 null
|
||||||
|
if value is None and option != SettingOption.NICKNAME:
|
||||||
|
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
|
||||||
|
|
||||||
|
setattr(user, option.value, value)
|
||||||
|
user = await user.save(session)
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.get(
|
@user_settings_router.get(
|
||||||
@@ -148,17 +403,13 @@ def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
|
|||||||
)
|
)
|
||||||
async def router_user_settings_2fa(
|
async def router_user_settings_2fa(
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
) -> sqlmodels.ResponseBase:
|
) -> TwoFactorResponse:
|
||||||
"""
|
"""
|
||||||
Get two-factor authentication initialization information.
|
获取两步验证初始化信息
|
||||||
|
|
||||||
Returns:
|
返回 setup_token(用于后续验证请求)和 uri(用于生成二维码)。
|
||||||
dict: A dictionary containing two-factor authentication setup information.
|
|
||||||
"""
|
"""
|
||||||
|
return await Password.generate_totp(name=user.email or str(user.id))
|
||||||
return sqlmodels.ResponseBase(
|
|
||||||
data=await Password.generate_totp(user.email)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@user_settings_router.post(
|
@user_settings_router.post(
|
||||||
@@ -166,38 +417,276 @@ async def router_user_settings_2fa(
|
|||||||
summary='启用两步验证',
|
summary='启用两步验证',
|
||||||
description='Enable two-factor authentication.',
|
description='Enable two-factor authentication.',
|
||||||
dependencies=[Depends(auth_required)],
|
dependencies=[Depends(auth_required)],
|
||||||
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def router_user_settings_2fa_enable(
|
async def router_user_settings_2fa_enable(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
setup_token: str,
|
request: TwoFactorVerifyRequest,
|
||||||
code: str,
|
) -> None:
|
||||||
) -> sqlmodels.ResponseBase:
|
|
||||||
"""
|
"""
|
||||||
Enable two-factor authentication for the user.
|
启用两步验证
|
||||||
|
|
||||||
Returns:
|
将 2FA secret 存储到 email_password AuthIdentity 的 extra_data 中。
|
||||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
secret = serializer.loads(request.setup_token, salt="2fa-setup-salt", max_age=600)
|
||||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
|
||||||
except SignatureExpired:
|
except SignatureExpired:
|
||||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||||
except BadSignature:
|
except BadSignature:
|
||||||
raise HTTPException(status_code=400, detail="Invalid token")
|
raise HTTPException(status_code=400, detail="Invalid token")
|
||||||
|
|
||||||
# 2. 验证用户输入的 6 位验证码
|
if Password.verify_totp(secret, request.code) != PasswordStatus.VALID:
|
||||||
if not Password.verify_totp(secret, code):
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||||
|
|
||||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
# 将 secret 存储到 AuthIdentity.extra_data 中
|
||||||
user.two_factor = secret
|
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
user = await user.save(session)
|
session,
|
||||||
|
(AuthIdentity.user_id == user.id)
|
||||||
return sqlmodels.ResponseBase(
|
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||||
data={"message": "Two-factor authentication enabled successfully"}
|
|
||||||
)
|
)
|
||||||
|
if not email_identity:
|
||||||
|
raise HTTPException(status_code=400, detail="未找到邮箱密码认证身份")
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
|
||||||
|
extra["two_factor"] = secret
|
||||||
|
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
|
||||||
|
email_identity = await email_identity.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 认证身份管理 ====================
|
||||||
|
|
||||||
|
@user_settings_router.get(
|
||||||
|
path='/identities',
|
||||||
|
summary='列出已绑定的认证身份',
|
||||||
|
)
|
||||||
|
async def router_user_settings_identities(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
) -> list[AuthIdentityResponse]:
|
||||||
|
"""
|
||||||
|
列出当前用户已绑定的所有认证身份
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 认证身份列表,包含 provider、identifier、display_name 等
|
||||||
|
"""
|
||||||
|
identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
AuthIdentity.user_id == user.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
return [identity.to_response() for identity in identities]
|
||||||
|
|
||||||
|
|
||||||
|
@user_settings_router.post(
|
||||||
|
path='/identity',
|
||||||
|
summary='绑定新的认证身份',
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
)
|
||||||
|
async def router_user_settings_bind_identity(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
request: BindIdentityRequest,
|
||||||
|
) -> AuthIdentityResponse:
|
||||||
|
"""
|
||||||
|
绑定新的登录方式
|
||||||
|
|
||||||
|
请求体:
|
||||||
|
- provider: 提供者类型
|
||||||
|
- identifier: 标识符(邮箱 / 手机号 / OAuth code)
|
||||||
|
- credential: 凭证(密码、验证码等)
|
||||||
|
- redirect_uri: OAuth 回调地址(可选)
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 400: provider 未启用
|
||||||
|
- 409: 该身份已被其他用户绑定
|
||||||
|
"""
|
||||||
|
# 检查是否已被绑定
|
||||||
|
existing = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.provider == request.provider)
|
||||||
|
& (AuthIdentity.identifier == request.identifier),
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(status_code=409, detail="该身份已被绑定")
|
||||||
|
|
||||||
|
# 处理密码类型的凭证
|
||||||
|
credential: str | None = None
|
||||||
|
if request.provider == AuthProviderType.EMAIL_PASSWORD and request.credential:
|
||||||
|
credential = Password.hash(request.credential)
|
||||||
|
|
||||||
|
identity = AuthIdentity(
|
||||||
|
provider=request.provider,
|
||||||
|
identifier=request.identifier,
|
||||||
|
credential=credential,
|
||||||
|
is_primary=False,
|
||||||
|
is_verified=False,
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
identity = await identity.save(session)
|
||||||
|
return identity.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
@user_settings_router.delete(
|
||||||
|
path='/identity/{identity_id}',
|
||||||
|
summary='解绑认证身份',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_user_settings_unbind_identity(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
identity_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
解绑一个认证身份
|
||||||
|
|
||||||
|
约束:
|
||||||
|
- 不能解绑最后一个身份
|
||||||
|
- 站长配置强制绑定邮箱/手机号时,不能解绑对应身份
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 身份不存在或不属于当前用户
|
||||||
|
- 400: 不能解绑最后一个身份 / 不能解绑强制绑定的身份
|
||||||
|
"""
|
||||||
|
# 查找目标身份
|
||||||
|
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.id == identity_id) & (AuthIdentity.user_id == user.id),
|
||||||
|
)
|
||||||
|
if not identity:
|
||||||
|
http_exceptions.raise_not_found("认证身份不存在")
|
||||||
|
|
||||||
|
# 检查是否为最后一个身份
|
||||||
|
all_identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
AuthIdentity.user_id == user.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
if len(all_identities) <= 1:
|
||||||
|
http_exceptions.raise_bad_request("不能解绑最后一个认证身份")
|
||||||
|
|
||||||
|
# 检查强制绑定约束
|
||||||
|
if identity.provider == AuthProviderType.EMAIL_PASSWORD:
|
||||||
|
email_required_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||||
|
& (sqlmodels.Setting.name == "auth_email_binding_required"),
|
||||||
|
)
|
||||||
|
if email_required_setting and email_required_setting.value == "1":
|
||||||
|
http_exceptions.raise_bad_request("站长要求必须绑定邮箱,不能解绑")
|
||||||
|
|
||||||
|
if identity.provider == AuthProviderType.PHONE_SMS:
|
||||||
|
phone_required_setting = await sqlmodels.Setting.get(
|
||||||
|
session,
|
||||||
|
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||||
|
& (sqlmodels.Setting.name == "auth_phone_binding_required"),
|
||||||
|
)
|
||||||
|
if phone_required_setting and phone_required_setting.value == "1":
|
||||||
|
http_exceptions.raise_bad_request("站长要求必须绑定手机号,不能解绑")
|
||||||
|
|
||||||
|
await AuthIdentity.delete(session, identity)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== WebAuthn 凭证管理 ====================
|
||||||
|
|
||||||
|
@user_settings_router.get(
|
||||||
|
path='/authns',
|
||||||
|
summary='列出用户所有 WebAuthn 凭证',
|
||||||
|
)
|
||||||
|
async def router_user_settings_authns(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
) -> list[AuthnDetailResponse]:
|
||||||
|
"""
|
||||||
|
列出当前用户所有已注册的 WebAuthn 凭证
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 凭证列表,包含 credential_id、name、device_type 等
|
||||||
|
"""
|
||||||
|
authns: list[UserAuthn] = await UserAuthn.get(
|
||||||
|
session,
|
||||||
|
UserAuthn.user_id == user.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
return [authn.to_detail_response() for authn in authns]
|
||||||
|
|
||||||
|
|
||||||
|
@user_settings_router.patch(
|
||||||
|
path='/authn/{authn_id}',
|
||||||
|
summary='重命名 WebAuthn 凭证',
|
||||||
|
)
|
||||||
|
async def router_user_settings_rename_authn(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
authn_id: int,
|
||||||
|
request: AuthnRenameRequest,
|
||||||
|
) -> AuthnDetailResponse:
|
||||||
|
"""
|
||||||
|
重命名一个 WebAuthn 凭证
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 凭证不存在或不属于当前用户
|
||||||
|
"""
|
||||||
|
authn: UserAuthn | None = await UserAuthn.get(
|
||||||
|
session,
|
||||||
|
(UserAuthn.id == authn_id) & (UserAuthn.user_id == user.id),
|
||||||
|
)
|
||||||
|
if not authn:
|
||||||
|
http_exceptions.raise_not_found("WebAuthn 凭证不存在")
|
||||||
|
|
||||||
|
authn.name = request.name
|
||||||
|
authn = await authn.save(session)
|
||||||
|
return authn.to_detail_response()
|
||||||
|
|
||||||
|
|
||||||
|
@user_settings_router.delete(
|
||||||
|
path='/authn/{authn_id}',
|
||||||
|
summary='删除 WebAuthn 凭证',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def router_user_settings_delete_authn(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||||
|
authn_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
删除一个 WebAuthn 凭证
|
||||||
|
|
||||||
|
同时删除对应的 AuthIdentity(provider=passkey) 记录。
|
||||||
|
如果这是用户最后一个认证身份,拒绝删除。
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 凭证不存在或不属于当前用户
|
||||||
|
- 400: 不能删除最后一个认证身份
|
||||||
|
"""
|
||||||
|
authn: UserAuthn | None = await UserAuthn.get(
|
||||||
|
session,
|
||||||
|
(UserAuthn.id == authn_id) & (UserAuthn.user_id == user.id),
|
||||||
|
)
|
||||||
|
if not authn:
|
||||||
|
http_exceptions.raise_not_found("WebAuthn 凭证不存在")
|
||||||
|
|
||||||
|
# 检查是否为最后一个认证身份
|
||||||
|
all_identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
AuthIdentity.user_id == user.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
if len(all_identities) <= 1:
|
||||||
|
http_exceptions.raise_bad_request("不能删除最后一个认证身份")
|
||||||
|
|
||||||
|
# 删除对应的 AuthIdentity
|
||||||
|
passkey_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.provider == AuthProviderType.PASSKEY)
|
||||||
|
& (AuthIdentity.identifier == authn.credential_id)
|
||||||
|
& (AuthIdentity.user_id == user.id),
|
||||||
|
)
|
||||||
|
if passkey_identity:
|
||||||
|
await AuthIdentity.delete(session, passkey_identity, commit=False)
|
||||||
|
|
||||||
|
# 删除 UserAuthn
|
||||||
|
await UserAuthn.delete(session, authn)
|
||||||
|
|||||||
146
routers/api/v1/user/settings/file_viewers/__init__.py
Normal file
146
routers/api/v1/user/settings/file_viewers/__init__.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
用户文件查看器偏好设置端点
|
||||||
|
|
||||||
|
提供用户"始终使用"默认查看器的增删查功能。
|
||||||
|
"""
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
from middleware.auth import auth_required
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import (
|
||||||
|
FileApp,
|
||||||
|
FileAppExtension,
|
||||||
|
SetDefaultViewerRequest,
|
||||||
|
User,
|
||||||
|
UserFileAppDefault,
|
||||||
|
UserFileAppDefaultResponse,
|
||||||
|
)
|
||||||
|
from utils import http_exceptions
|
||||||
|
|
||||||
|
file_viewers_router = APIRouter(
|
||||||
|
prefix='/file-viewers',
|
||||||
|
tags=["user", "user_settings", "file_viewers"],
|
||||||
|
dependencies=[Depends(auth_required)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@file_viewers_router.put(
|
||||||
|
path='/default',
|
||||||
|
summary='设置默认查看器',
|
||||||
|
description='为指定扩展名设置"始终使用"的查看器。',
|
||||||
|
)
|
||||||
|
async def set_default_viewer(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
request: SetDefaultViewerRequest,
|
||||||
|
) -> UserFileAppDefaultResponse:
|
||||||
|
"""
|
||||||
|
设置默认查看器端点
|
||||||
|
|
||||||
|
如果用户已有该扩展名的默认设置,则更新;否则创建新记录。
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 应用不存在
|
||||||
|
- 400: 应用不支持该扩展名
|
||||||
|
"""
|
||||||
|
# 规范化扩展名
|
||||||
|
normalized_ext = request.extension.lower().strip().lstrip('.')
|
||||||
|
|
||||||
|
# 验证应用存在
|
||||||
|
app: FileApp | None = await FileApp.get(session, FileApp.id == request.app_id)
|
||||||
|
if not app:
|
||||||
|
http_exceptions.raise_not_found("应用不存在")
|
||||||
|
|
||||||
|
# 验证应用支持该扩展名
|
||||||
|
ext_record: FileAppExtension | None = await FileAppExtension.get(
|
||||||
|
session,
|
||||||
|
and_(
|
||||||
|
FileAppExtension.app_id == app.id,
|
||||||
|
FileAppExtension.extension == normalized_ext,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not ext_record:
|
||||||
|
http_exceptions.raise_bad_request("该应用不支持此扩展名")
|
||||||
|
|
||||||
|
# 查找已有记录
|
||||||
|
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||||
|
session,
|
||||||
|
and_(
|
||||||
|
UserFileAppDefault.user_id == user.id,
|
||||||
|
UserFileAppDefault.extension == normalized_ext,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
existing.app_id = request.app_id
|
||||||
|
existing = await existing.save(session, load=UserFileAppDefault.app)
|
||||||
|
return existing.to_response()
|
||||||
|
else:
|
||||||
|
new_default = UserFileAppDefault(
|
||||||
|
user_id=user.id,
|
||||||
|
extension=normalized_ext,
|
||||||
|
app_id=request.app_id,
|
||||||
|
)
|
||||||
|
new_default = await new_default.save(session, load=UserFileAppDefault.app)
|
||||||
|
return new_default.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
@file_viewers_router.get(
|
||||||
|
path='/defaults',
|
||||||
|
summary='列出所有默认查看器设置',
|
||||||
|
description='获取当前用户所有"始终使用"的查看器偏好。',
|
||||||
|
)
|
||||||
|
async def list_default_viewers(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
) -> list[UserFileAppDefaultResponse]:
|
||||||
|
"""
|
||||||
|
列出所有默认查看器设置端点
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
"""
|
||||||
|
defaults: list[UserFileAppDefault] = await UserFileAppDefault.get(
|
||||||
|
session,
|
||||||
|
UserFileAppDefault.user_id == user.id,
|
||||||
|
fetch_mode="all",
|
||||||
|
load=UserFileAppDefault.app,
|
||||||
|
)
|
||||||
|
return [d.to_response() for d in defaults]
|
||||||
|
|
||||||
|
|
||||||
|
@file_viewers_router.delete(
|
||||||
|
path='/default/{default_id}',
|
||||||
|
summary='撤销默认查看器设置',
|
||||||
|
description='删除指定的"始终使用"偏好。',
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
)
|
||||||
|
async def delete_default_viewer(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
default_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
撤销默认查看器设置端点
|
||||||
|
|
||||||
|
认证:JWT token 必填
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 404: 记录不存在或不属于当前用户
|
||||||
|
"""
|
||||||
|
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||||
|
session,
|
||||||
|
and_(
|
||||||
|
UserFileAppDefault.id == default_id,
|
||||||
|
UserFileAppDefault.user_id == user.id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
http_exceptions.raise_not_found("默认设置不存在")
|
||||||
|
|
||||||
|
await UserFileAppDefault.delete(session, existing)
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
from fastapi import APIRouter, Depends
|
|
||||||
|
|
||||||
from middleware.auth import auth_required
|
|
||||||
from sqlmodels import ResponseBase
|
|
||||||
from utils import http_exceptions
|
|
||||||
|
|
||||||
vas_router = APIRouter(
|
|
||||||
prefix="/vas",
|
|
||||||
tags=["vas"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@vas_router.get(
|
|
||||||
path='/pack',
|
|
||||||
summary='获取容量包及配额信息',
|
|
||||||
description='Get information about storage packs and quotas.',
|
|
||||||
dependencies=[Depends(auth_required)]
|
|
||||||
)
|
|
||||||
def router_vas_pack() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Get information about storage packs and quotas.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for storage packs and quotas.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@vas_router.get(
|
|
||||||
path='/product',
|
|
||||||
summary='获取商品信息,同时返回支付信息',
|
|
||||||
description='Get product information along with payment details.',
|
|
||||||
dependencies=[Depends(auth_required)]
|
|
||||||
)
|
|
||||||
def router_vas_product() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Get product information along with payment details.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for products and payment information.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@vas_router.post(
|
|
||||||
path='/order',
|
|
||||||
summary='新建支付订单',
|
|
||||||
description='Create an order for a product.',
|
|
||||||
dependencies=[Depends(auth_required)]
|
|
||||||
)
|
|
||||||
def router_vas_order() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Create an order for a product.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the created order.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@vas_router.get(
|
|
||||||
path='/order/{id}',
|
|
||||||
summary='查询订单状态',
|
|
||||||
description='Get information about a specific payment order by ID.',
|
|
||||||
dependencies=[Depends(auth_required)]
|
|
||||||
)
|
|
||||||
def router_vas_order_get(id: str) -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Get information about a specific payment order by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
id (str): The ID of the order to retrieve information for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the specified order.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@vas_router.get(
|
|
||||||
path='/redeem',
|
|
||||||
summary='获取兑换码信息',
|
|
||||||
description='Get information about a specific redemption code.',
|
|
||||||
dependencies=[Depends(auth_required)]
|
|
||||||
)
|
|
||||||
def router_vas_redeem(code: str) -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Get information about a specific redemption code.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): The redemption code to retrieve information for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the specified redemption code.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@vas_router.post(
|
|
||||||
path='/redeem',
|
|
||||||
summary='执行兑换',
|
|
||||||
description='Redeem a redemption code for a product or service.',
|
|
||||||
dependencies=[Depends(auth_required)]
|
|
||||||
)
|
|
||||||
def router_vas_redeem_post() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Redeem a redemption code for a product or service.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the redeemed code.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
@@ -1,110 +1,207 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
from middleware.auth import auth_required
|
from middleware.auth import auth_required
|
||||||
from sqlmodels import ResponseBase
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import (
|
||||||
|
Object,
|
||||||
|
User,
|
||||||
|
WebDAV,
|
||||||
|
WebDAVAccountResponse,
|
||||||
|
WebDAVCreateRequest,
|
||||||
|
WebDAVUpdateRequest,
|
||||||
|
)
|
||||||
|
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||||
from utils import http_exceptions
|
from utils import http_exceptions
|
||||||
|
from utils.password.pwd import Password
|
||||||
|
|
||||||
# WebDAV 管理路由
|
|
||||||
webdav_router = APIRouter(
|
webdav_router = APIRouter(
|
||||||
prefix='/webdav',
|
prefix='/webdav',
|
||||||
tags=["webdav"],
|
tags=["webdav"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_webdav_enabled(user: User) -> None:
|
||||||
|
"""检查用户组是否启用了 WebDAV 功能"""
|
||||||
|
if not user.group.web_dav_enabled:
|
||||||
|
http_exceptions.raise_forbidden("WebDAV 功能未启用")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
|
||||||
|
"""将 WebDAV 数据库模型转换为响应 DTO"""
|
||||||
|
return WebDAVAccountResponse(
|
||||||
|
id=account.id,
|
||||||
|
name=account.name,
|
||||||
|
root=account.root,
|
||||||
|
readonly=account.readonly,
|
||||||
|
use_proxy=account.use_proxy,
|
||||||
|
created_at=str(account.created_at),
|
||||||
|
updated_at=str(account.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@webdav_router.get(
|
@webdav_router.get(
|
||||||
path='/accounts',
|
path='/accounts',
|
||||||
summary='获取账号信息',
|
summary='获取账号列表',
|
||||||
description='Get account information for WebDAV.',
|
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
)
|
||||||
def router_webdav_accounts() -> ResponseBase:
|
async def list_accounts(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
) -> list[WebDAVAccountResponse]:
|
||||||
"""
|
"""
|
||||||
Get account information for WebDAV.
|
列出当前用户所有 WebDAV 账户
|
||||||
|
|
||||||
Returns:
|
认证:JWT Bearer Token
|
||||||
ResponseBase: A model containing the response data for the account information.
|
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
_check_webdav_enabled(user)
|
||||||
|
user_id: UUID = user.id
|
||||||
|
|
||||||
|
accounts: list[WebDAV] = await WebDAV.get(
|
||||||
|
session,
|
||||||
|
WebDAV.user_id == user_id,
|
||||||
|
fetch_mode="all",
|
||||||
|
)
|
||||||
|
return [_to_response(a) for a in accounts]
|
||||||
|
|
||||||
|
|
||||||
@webdav_router.post(
|
@webdav_router.post(
|
||||||
path='/accounts',
|
path='/accounts',
|
||||||
summary='新建账号',
|
summary='创建账号',
|
||||||
description='Create a new WebDAV account.',
|
status_code=201,
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
)
|
||||||
def router_webdav_create_account() -> ResponseBase:
|
async def create_account(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
request: WebDAVCreateRequest,
|
||||||
|
) -> WebDAVAccountResponse:
|
||||||
"""
|
"""
|
||||||
Create a new WebDAV account.
|
创建 WebDAV 账户
|
||||||
|
|
||||||
Returns:
|
认证:JWT Bearer Token
|
||||||
ResponseBase: A model containing the response data for the created account.
|
|
||||||
|
错误处理:
|
||||||
|
- 403: WebDAV 功能未启用
|
||||||
|
- 400: 根目录路径不存在或不是目录
|
||||||
|
- 409: 账户名已存在
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
_check_webdav_enabled(user)
|
||||||
|
user_id: UUID = user.id
|
||||||
|
|
||||||
@webdav_router.delete(
|
# 验证账户名唯一
|
||||||
path='/accounts/{id}',
|
existing = await WebDAV.get(
|
||||||
summary='删除账号',
|
session,
|
||||||
description='Delete a WebDAV account by its ID.',
|
(WebDAV.name == request.name) & (WebDAV.user_id == user_id),
|
||||||
dependencies=[Depends(auth_required)],
|
)
|
||||||
)
|
if existing:
|
||||||
def router_webdav_delete_account(id: str) -> ResponseBase:
|
http_exceptions.raise_conflict("账户名已存在")
|
||||||
"""
|
|
||||||
Delete a WebDAV account by its ID.
|
|
||||||
|
|
||||||
Args:
|
# 验证 root 路径存在且为目录
|
||||||
id (str): The ID of the account to be deleted.
|
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||||
|
if not root_obj or not root_obj.is_folder:
|
||||||
|
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||||
|
|
||||||
Returns:
|
# 创建账户
|
||||||
ResponseBase: A model containing the response data for the deletion operation.
|
account = WebDAV(
|
||||||
"""
|
name=request.name,
|
||||||
http_exceptions.raise_not_implemented()
|
password=Password.hash(request.password),
|
||||||
|
root=request.root,
|
||||||
|
readonly=request.readonly,
|
||||||
|
use_proxy=request.use_proxy,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
account = await account.save(session)
|
||||||
|
|
||||||
@webdav_router.post(
|
l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
|
||||||
path='/mount',
|
return _to_response(account)
|
||||||
summary='新建目录挂载',
|
|
||||||
description='Create a new WebDAV mount point.',
|
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
|
||||||
def router_webdav_create_mount() -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Create a new WebDAV mount point.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the created mount point.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@webdav_router.delete(
|
|
||||||
path='/mount/{id}',
|
|
||||||
summary='删除目录挂载',
|
|
||||||
description='Delete a WebDAV mount point by its ID.',
|
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
|
||||||
def router_webdav_delete_mount(id: str) -> ResponseBase:
|
|
||||||
"""
|
|
||||||
Delete a WebDAV mount point by its ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
id (str): The ID of the mount point to be deleted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ResponseBase: A model containing the response data for the deletion operation.
|
|
||||||
"""
|
|
||||||
http_exceptions.raise_not_implemented()
|
|
||||||
|
|
||||||
@webdav_router.patch(
|
@webdav_router.patch(
|
||||||
path='accounts/{id}',
|
path='/accounts/{account_id}',
|
||||||
summary='更新账号信息',
|
summary='更新账号',
|
||||||
description='Update WebDAV account information by ID.',
|
|
||||||
dependencies=[Depends(auth_required)],
|
|
||||||
)
|
)
|
||||||
def router_webdav_update_account(id: str) -> ResponseBase:
|
async def update_account(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
account_id: int,
|
||||||
|
request: WebDAVUpdateRequest,
|
||||||
|
) -> WebDAVAccountResponse:
|
||||||
"""
|
"""
|
||||||
Update WebDAV account information by ID.
|
更新 WebDAV 账户
|
||||||
|
|
||||||
Args:
|
认证:JWT Bearer Token
|
||||||
id (str): The ID of the account to be updated.
|
|
||||||
|
|
||||||
Returns:
|
错误处理:
|
||||||
ResponseBase: A model containing the response data for the updated account.
|
- 403: WebDAV 功能未启用
|
||||||
|
- 404: 账户不存在
|
||||||
|
- 400: 根目录路径不存在或不是目录
|
||||||
"""
|
"""
|
||||||
http_exceptions.raise_not_implemented()
|
_check_webdav_enabled(user)
|
||||||
|
user_id: UUID = user.id
|
||||||
|
|
||||||
|
account = await WebDAV.get(
|
||||||
|
session,
|
||||||
|
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||||
|
)
|
||||||
|
if not account:
|
||||||
|
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||||
|
|
||||||
|
# 验证 root 路径
|
||||||
|
if request.root is not None:
|
||||||
|
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||||
|
if not root_obj or not root_obj.is_folder:
|
||||||
|
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||||
|
|
||||||
|
# 密码哈希后原地替换,update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
|
||||||
|
is_password_changed = request.password is not None
|
||||||
|
if is_password_changed:
|
||||||
|
request.password = Password.hash(request.password)
|
||||||
|
|
||||||
|
account = await account.update(session, request)
|
||||||
|
|
||||||
|
# 密码变更时清除认证缓存
|
||||||
|
if is_password_changed:
|
||||||
|
await WebDAVAuthCache.invalidate_account(user_id, account.name)
|
||||||
|
|
||||||
|
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
|
||||||
|
return _to_response(account)
|
||||||
|
|
||||||
|
|
||||||
|
@webdav_router.delete(
|
||||||
|
path='/accounts/{account_id}',
|
||||||
|
summary='删除账号',
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def delete_account(
|
||||||
|
session: SessionDep,
|
||||||
|
user: Annotated[User, Depends(auth_required)],
|
||||||
|
account_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
删除 WebDAV 账户
|
||||||
|
|
||||||
|
认证:JWT Bearer Token
|
||||||
|
|
||||||
|
错误处理:
|
||||||
|
- 403: WebDAV 功能未启用
|
||||||
|
- 404: 账户不存在
|
||||||
|
"""
|
||||||
|
_check_webdav_enabled(user)
|
||||||
|
user_id: UUID = user.id
|
||||||
|
|
||||||
|
account = await WebDAV.get(
|
||||||
|
session,
|
||||||
|
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||||
|
)
|
||||||
|
if not account:
|
||||||
|
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||||
|
|
||||||
|
account_name = account.name
|
||||||
|
await WebDAV.delete(session, account)
|
||||||
|
|
||||||
|
# 清除认证缓存
|
||||||
|
await WebDAVAuthCache.invalidate_account(user_id, account_name)
|
||||||
|
|
||||||
|
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
# WebDAV 操作路由
|
|
||||||
35
routers/dav/__init__.py
Normal file
35
routers/dav/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
WebDAV 协议入口
|
||||||
|
|
||||||
|
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
|
||||||
|
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
|
||||||
|
"""
|
||||||
|
from a2wsgi import WSGIMiddleware
|
||||||
|
from wsgidav.wsgidav_app import WsgiDAVApp
|
||||||
|
|
||||||
|
from .domain_controller import DiskNextDomainController
|
||||||
|
from .provider import DiskNextDAVProvider
|
||||||
|
|
||||||
|
_wsgidav_config: dict[str, object] = {
|
||||||
|
"provider_mapping": {
|
||||||
|
"/": DiskNextDAVProvider(),
|
||||||
|
},
|
||||||
|
"http_authenticator": {
|
||||||
|
"domain_controller": DiskNextDomainController,
|
||||||
|
"accept_basic": True,
|
||||||
|
"accept_digest": False,
|
||||||
|
"default_to_digest": False,
|
||||||
|
},
|
||||||
|
"verbose": 1,
|
||||||
|
# 使用 WsgiDAV 内置的内存锁管理器
|
||||||
|
"lock_storage": True,
|
||||||
|
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
|
||||||
|
"dir_browser": {
|
||||||
|
"enable": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
|
||||||
|
|
||||||
|
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
|
||||||
|
"""ASGI 应用,挂载到 /dav 路径"""
|
||||||
148
routers/dav/domain_controller.py
Normal file
148
routers/dav/domain_controller.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
WebDAV 认证控制器
|
||||||
|
|
||||||
|
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
|
||||||
|
通过 DiskNext 的 WebDAV 账户模型进行认证。
|
||||||
|
|
||||||
|
用户名格式: {email}/{webdav_account_name}
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from loguru import logger as l
|
||||||
|
from wsgidav.dc.base_dc import BaseDomainController
|
||||||
|
|
||||||
|
from routers.dav.provider import EventLoopRef, _get_session
|
||||||
|
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||||
|
from sqlmodels.user import User, UserStatus
|
||||||
|
from sqlmodels.webdav import WebDAV
|
||||||
|
from utils.password.pwd import Password, PasswordStatus
|
||||||
|
|
||||||
|
|
||||||
|
async def _authenticate(
|
||||||
|
email: str,
|
||||||
|
account_name: str,
|
||||||
|
password: str,
|
||||||
|
) -> tuple[UUID, int] | None:
|
||||||
|
"""
|
||||||
|
异步认证 WebDAV 用户。
|
||||||
|
|
||||||
|
:param email: 用户邮箱
|
||||||
|
:param account_name: WebDAV 账户名
|
||||||
|
:param password: 明文密码
|
||||||
|
:return: (user_id, webdav_id) 或 None
|
||||||
|
"""
|
||||||
|
# 1. 查缓存
|
||||||
|
cached = await WebDAVAuthCache.get(email, account_name, password)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# 2. 缓存未命中,查库验证
|
||||||
|
async with _get_session() as session:
|
||||||
|
user = await User.get(session, User.email == email, load=User.group)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
if user.status != UserStatus.ACTIVE:
|
||||||
|
return None
|
||||||
|
if not user.group.web_dav_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
account = await WebDAV.get(
|
||||||
|
session,
|
||||||
|
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
|
||||||
|
)
|
||||||
|
if not account:
|
||||||
|
return None
|
||||||
|
|
||||||
|
status = Password.verify(account.password, password)
|
||||||
|
if status == PasswordStatus.INVALID:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id: UUID = user.id
|
||||||
|
webdav_id: int = account.id
|
||||||
|
|
||||||
|
# 3. 写入缓存
|
||||||
|
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
|
||||||
|
|
||||||
|
return user_id, webdav_id
|
||||||
|
|
||||||
|
|
||||||
|
class DiskNextDomainController(BaseDomainController):
|
||||||
|
"""
|
||||||
|
DiskNext WebDAV 认证控制器
|
||||||
|
|
||||||
|
用户名格式: {email}/{webdav_account_name}
|
||||||
|
密码: WebDAV 账户密码(创建账户时设置)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
|
||||||
|
super().__init__(wsgidav_app, config)
|
||||||
|
|
||||||
|
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
|
||||||
|
"""返回 realm 名称"""
|
||||||
|
return "DiskNext WebDAV"
|
||||||
|
|
||||||
|
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
|
||||||
|
"""所有请求都需要认证"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_share_anonymous(self, path_info: str) -> bool:
|
||||||
|
"""不支持匿名访问"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_http_digest_auth(self) -> bool:
|
||||||
|
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def basic_auth_user(
|
||||||
|
self,
|
||||||
|
realm: str,
|
||||||
|
user_name: str,
|
||||||
|
password: str,
|
||||||
|
environ: dict[str, object],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
HTTP Basic Auth 认证。
|
||||||
|
|
||||||
|
用户名格式: {email}/{webdav_account_name}
|
||||||
|
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
|
||||||
|
"""
|
||||||
|
# 解析用户名
|
||||||
|
if "/" not in user_name:
|
||||||
|
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||||
|
return False
|
||||||
|
|
||||||
|
email, account_name = user_name.split("/", 1)
|
||||||
|
if not email or not account_name:
|
||||||
|
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 在 WSGI 线程中调用异步认证
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
_authenticate(email, account_name, password),
|
||||||
|
EventLoopRef.get(),
|
||||||
|
)
|
||||||
|
result = future.result()
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
user_id, webdav_id = result
|
||||||
|
|
||||||
|
# 将认证信息存入 environ,供 Provider 使用
|
||||||
|
environ["disknext.user_id"] = user_id
|
||||||
|
environ["disknext.webdav_id"] = webdav_id
|
||||||
|
environ["disknext.email"] = email
|
||||||
|
environ["disknext.account_name"] = account_name
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def digest_auth_user(
|
||||||
|
self,
|
||||||
|
realm: str,
|
||||||
|
user_name: str,
|
||||||
|
environ: dict[str, object],
|
||||||
|
) -> bool:
|
||||||
|
"""不支持 Digest 认证"""
|
||||||
|
return False
|
||||||
645
routers/dav/provider.py
Normal file
645
routers/dav/provider.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
"""
|
||||||
|
DiskNext WebDAV 存储 Provider
|
||||||
|
|
||||||
|
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
|
||||||
|
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import ClassVar
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from loguru import logger as l
|
||||||
|
from wsgidav.dav_error import (
|
||||||
|
DAVError,
|
||||||
|
HTTP_FORBIDDEN,
|
||||||
|
HTTP_INSUFFICIENT_STORAGE,
|
||||||
|
HTTP_NOT_FOUND,
|
||||||
|
)
|
||||||
|
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
|
||||||
|
|
||||||
|
from service.storage import LocalStorageService, adjust_user_storage
|
||||||
|
from sqlmodels.database_connection import DatabaseManager
|
||||||
|
from sqlmodels.object import Object, ObjectType
|
||||||
|
from sqlmodels.physical_file import PhysicalFile
|
||||||
|
from sqlmodels.policy import Policy
|
||||||
|
from sqlmodels.user import User
|
||||||
|
from sqlmodels.webdav import WebDAV
|
||||||
|
|
||||||
|
|
||||||
|
class EventLoopRef:
|
||||||
|
"""持有主线程事件循环引用,供 WSGI 线程使用"""
|
||||||
|
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def capture(cls) -> None:
|
||||||
|
"""在 async 上下文中调用,捕获当前事件循环"""
|
||||||
|
cls._loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls) -> asyncio.AbstractEventLoop:
|
||||||
|
if cls._loop is None:
|
||||||
|
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
|
||||||
|
return cls._loop
|
||||||
|
|
||||||
|
|
||||||
|
def _run_async(coro): # type: ignore[no-untyped-def]
|
||||||
|
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
|
||||||
|
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session(): # type: ignore[no-untyped-def]
|
||||||
|
"""获取数据库会话上下文管理器"""
|
||||||
|
return DatabaseManager._async_session_factory()
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 异步辅助函数 ====================
|
||||||
|
|
||||||
|
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
|
||||||
|
"""获取 WebDAV 账户"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
return await WebDAV.get(session, WebDAV.id == webdav_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
|
||||||
|
"""根据路径获取对象"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
return await Object.get_by_path(session, user_id, path)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
|
||||||
|
"""获取目录子对象"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
return await Object.get_children(session, user_id, parent_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_object_by_id(object_id: UUID) -> Object | None:
|
||||||
|
"""根据ID获取对象"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_user(user_id: UUID) -> User | None:
|
||||||
|
"""获取用户(含 group 关系)"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
return await User.get(session, User.id == user_id, load=User.group)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_policy(policy_id: UUID) -> Policy | None:
|
||||||
|
"""获取存储策略"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
return await Policy.get(session, Policy.id == policy_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_folder(
|
||||||
|
name: str,
|
||||||
|
parent_id: UUID,
|
||||||
|
owner_id: UUID,
|
||||||
|
policy_id: UUID,
|
||||||
|
) -> Object:
|
||||||
|
"""创建目录对象"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
obj = Object(
|
||||||
|
name=name,
|
||||||
|
type=ObjectType.FOLDER,
|
||||||
|
size=0,
|
||||||
|
parent_id=parent_id,
|
||||||
|
owner_id=owner_id,
|
||||||
|
policy_id=policy_id,
|
||||||
|
)
|
||||||
|
obj = await obj.save(session)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_file(
|
||||||
|
name: str,
|
||||||
|
parent_id: UUID,
|
||||||
|
owner_id: UUID,
|
||||||
|
policy_id: UUID,
|
||||||
|
) -> Object:
|
||||||
|
"""创建空文件对象"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
obj = Object(
|
||||||
|
name=name,
|
||||||
|
type=ObjectType.FILE,
|
||||||
|
size=0,
|
||||||
|
parent_id=parent_id,
|
||||||
|
owner_id=owner_id,
|
||||||
|
policy_id=policy_id,
|
||||||
|
)
|
||||||
|
obj = await obj.save(session)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
async def _soft_delete_object(object_id: UUID) -> None:
|
||||||
|
"""软删除对象(移入回收站)"""
|
||||||
|
from service.storage import soft_delete_objects
|
||||||
|
|
||||||
|
async with _get_session() as session:
|
||||||
|
obj = await Object.get(session, Object.id == object_id)
|
||||||
|
if obj:
|
||||||
|
await soft_delete_objects(session, [obj])
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_upload(
|
||||||
|
object_id: UUID,
|
||||||
|
physical_path: str,
|
||||||
|
size: int,
|
||||||
|
owner_id: UUID,
|
||||||
|
policy_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""上传完成后更新对象元数据和物理文件记录"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
# 获取存储路径(相对路径)
|
||||||
|
policy = await Policy.get(session, Policy.id == policy_id)
|
||||||
|
if not policy or not policy.server:
|
||||||
|
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||||
|
|
||||||
|
base_path = Path(policy.server).resolve()
|
||||||
|
full_path = Path(physical_path).resolve()
|
||||||
|
storage_path = str(full_path.relative_to(base_path))
|
||||||
|
|
||||||
|
# 创建 PhysicalFile 记录
|
||||||
|
pf = PhysicalFile(
|
||||||
|
storage_path=storage_path,
|
||||||
|
size=size,
|
||||||
|
policy_id=policy_id,
|
||||||
|
reference_count=1,
|
||||||
|
)
|
||||||
|
pf = await pf.save(session)
|
||||||
|
|
||||||
|
# 更新 Object
|
||||||
|
obj = await Object.get(session, Object.id == object_id)
|
||||||
|
if obj:
|
||||||
|
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
|
||||||
|
obj = await obj.save(session)
|
||||||
|
|
||||||
|
# 更新用户存储用量
|
||||||
|
if size > 0:
|
||||||
|
await adjust_user_storage(session, owner_id, size)
|
||||||
|
|
||||||
|
|
||||||
|
async def _move_object(
|
||||||
|
object_id: UUID,
|
||||||
|
new_parent_id: UUID,
|
||||||
|
new_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""移动/重命名对象"""
|
||||||
|
async with _get_session() as session:
|
||||||
|
obj = await Object.get(session, Object.id == object_id)
|
||||||
|
if obj:
|
||||||
|
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
|
||||||
|
obj = await obj.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def _copy_object_recursive(
|
||||||
|
src_id: UUID,
|
||||||
|
dst_parent_id: UUID,
|
||||||
|
dst_name: str,
|
||||||
|
owner_id: UUID,
|
||||||
|
) -> None:
|
||||||
|
"""递归复制对象"""
|
||||||
|
from service.storage import copy_object_recursive
|
||||||
|
|
||||||
|
async with _get_session() as session:
|
||||||
|
src = await Object.get(session, Object.id == src_id)
|
||||||
|
if not src:
|
||||||
|
return
|
||||||
|
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 辅助工具 ====================
|
||||||
|
|
||||||
|
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
|
||||||
|
"""从 environ 中提取认证信息"""
|
||||||
|
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
|
||||||
|
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
|
||||||
|
return user_id, webdav_id
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
|
||||||
|
"""
|
||||||
|
将 DAV 相对路径映射到 DiskNext 绝对路径。
|
||||||
|
|
||||||
|
:param account_root: 账户挂载根路径,如 "/" 或 "/docs"
|
||||||
|
:param dav_path: DAV 请求路径,如 "/" 或 "/photos/cat.jpg"
|
||||||
|
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
|
||||||
|
"""
|
||||||
|
# 规范化根路径
|
||||||
|
root = account_root.rstrip("/")
|
||||||
|
if not root:
|
||||||
|
root = ""
|
||||||
|
|
||||||
|
# 规范化 DAV 路径
|
||||||
|
if not dav_path or dav_path == "/":
|
||||||
|
return root + "/" if root else "/"
|
||||||
|
|
||||||
|
if not dav_path.startswith("/"):
|
||||||
|
dav_path = "/" + dav_path
|
||||||
|
|
||||||
|
full = root + dav_path
|
||||||
|
return full if full else "/"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_readonly(environ: dict[str, object]) -> None:
|
||||||
|
"""检查账户是否只读,只读则抛出 403"""
|
||||||
|
account = environ.get("disknext.webdav_account")
|
||||||
|
if account and getattr(account, 'readonly', False):
|
||||||
|
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_storage_quota(user: User, additional_bytes: int) -> None:
|
||||||
|
"""检查存储配额"""
|
||||||
|
max_storage = user.group.max_storage
|
||||||
|
if max_storage > 0 and user.storage + additional_bytes > max_storage:
|
||||||
|
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaLimitedWriter(io.RawIOBase):
|
||||||
|
"""带配额限制的写入流包装器"""
|
||||||
|
|
||||||
|
def __init__(self, stream: io.BufferedWriter, max_bytes: int) -> None:
|
||||||
|
self._stream = stream
|
||||||
|
self._max_bytes = max_bytes
|
||||||
|
self._bytes_written = 0
|
||||||
|
|
||||||
|
def writable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def write(self, b: bytes | bytearray) -> int:
|
||||||
|
if self._bytes_written + len(b) > self._max_bytes:
|
||||||
|
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||||
|
written = self._stream.write(b)
|
||||||
|
self._bytes_written += written
|
||||||
|
return written
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._stream.close()
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bytes_written(self) -> int:
|
||||||
|
return self._bytes_written
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Provider ====================
|
||||||
|
|
||||||
|
class DiskNextDAVProvider(DAVProvider):
|
||||||
|
"""DiskNext WebDAV 存储 Provider"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_resource_inst(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
environ: dict[str, object],
|
||||||
|
) -> 'DiskNextCollection | DiskNextFile | None':
|
||||||
|
"""
|
||||||
|
将 WebDAV 路径映射到资源对象。
|
||||||
|
|
||||||
|
首次调用时加载 WebDAV 账户信息并缓存到 environ。
|
||||||
|
"""
|
||||||
|
user_id, webdav_id = _get_environ_info(environ)
|
||||||
|
|
||||||
|
# 首次请求时加载账户信息
|
||||||
|
if "disknext.webdav_account" not in environ:
|
||||||
|
account = _run_async(_get_webdav_account(webdav_id))
|
||||||
|
if not account:
|
||||||
|
return None
|
||||||
|
environ["disknext.webdav_account"] = account
|
||||||
|
|
||||||
|
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
|
||||||
|
disknext_path = _resolve_dav_path(account.root, path)
|
||||||
|
|
||||||
|
obj = _run_async(_get_object_by_path(user_id, disknext_path))
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if obj.is_folder:
|
||||||
|
return DiskNextCollection(path, environ, obj, user_id, account)
|
||||||
|
else:
|
||||||
|
return DiskNextFile(path, environ, obj, user_id, account)
|
||||||
|
|
||||||
|
def is_readonly(self) -> bool:
|
||||||
|
"""只读由账户级别控制,不在 provider 级别限制"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Collection(目录) ====================
|
||||||
|
|
||||||
|
class DiskNextCollection(DAVCollection):
|
||||||
|
"""DiskNext 目录资源"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
environ: dict[str, object],
|
||||||
|
obj: Object,
|
||||||
|
user_id: UUID,
|
||||||
|
account: WebDAV,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(path, environ)
|
||||||
|
self._obj = obj
|
||||||
|
self._user_id = user_id
|
||||||
|
self._account = account
|
||||||
|
|
||||||
|
def get_display_info(self) -> dict[str, str]:
|
||||||
|
return {"type": "Directory"}
|
||||||
|
|
||||||
|
def get_member_names(self) -> list[str]:
|
||||||
|
"""获取子对象名称列表"""
|
||||||
|
children = _run_async(_get_children(self._user_id, self._obj.id))
|
||||||
|
return [c.name for c in children]
|
||||||
|
|
||||||
|
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
|
||||||
|
"""获取指定名称的子资源"""
|
||||||
|
member_path = self.path.rstrip("/") + "/" + name
|
||||||
|
account_root = self._account.root
|
||||||
|
disknext_path = _resolve_dav_path(account_root, member_path)
|
||||||
|
|
||||||
|
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if obj.is_folder:
|
||||||
|
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||||
|
else:
|
||||||
|
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||||
|
|
||||||
|
def get_creation_date(self) -> float | None:
|
||||||
|
if self._obj.created_at:
|
||||||
|
return self._obj.created_at.timestamp()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_last_modified(self) -> float | None:
|
||||||
|
if self._obj.updated_at:
|
||||||
|
return self._obj.updated_at.timestamp()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_empty_resource(self, name: str) -> 'DiskNextFile':
|
||||||
|
"""创建空文件(PUT 操作的第一步)"""
|
||||||
|
_check_readonly(self.environ)
|
||||||
|
|
||||||
|
obj = _run_async(_create_file(
|
||||||
|
name=name,
|
||||||
|
parent_id=self._obj.id,
|
||||||
|
owner_id=self._user_id,
|
||||||
|
policy_id=self._obj.policy_id,
|
||||||
|
))
|
||||||
|
|
||||||
|
member_path = self.path.rstrip("/") + "/" + name
|
||||||
|
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||||
|
|
||||||
|
def create_collection(self, name: str) -> 'DiskNextCollection':
|
||||||
|
"""创建子目录(MKCOL)"""
|
||||||
|
_check_readonly(self.environ)
|
||||||
|
|
||||||
|
obj = _run_async(_create_folder(
|
||||||
|
name=name,
|
||||||
|
parent_id=self._obj.id,
|
||||||
|
owner_id=self._user_id,
|
||||||
|
policy_id=self._obj.policy_id,
|
||||||
|
))
|
||||||
|
|
||||||
|
member_path = self.path.rstrip("/") + "/" + name
|
||||||
|
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""软删除目录"""
|
||||||
|
_check_readonly(self.environ)
|
||||||
|
_run_async(_soft_delete_object(self._obj.id))
|
||||||
|
|
||||||
|
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||||
|
"""复制或移动目录"""
|
||||||
|
_check_readonly(self.environ)
|
||||||
|
|
||||||
|
account_root = self._account.root
|
||||||
|
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||||
|
|
||||||
|
# 解析目标父路径和新名称
|
||||||
|
if "/" in dest_disknext.rstrip("/"):
|
||||||
|
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||||
|
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||||
|
else:
|
||||||
|
parent_path = "/"
|
||||||
|
new_name = dest_disknext.lstrip("/")
|
||||||
|
|
||||||
|
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||||
|
if not dest_parent:
|
||||||
|
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||||
|
|
||||||
|
if is_move:
|
||||||
|
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||||
|
else:
|
||||||
|
_run_async(_copy_object_recursive(
|
||||||
|
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||||
|
))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def support_recursive_delete(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def support_recursive_move(self, dest_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== NonCollection(文件) ====================
|
||||||
|
|
||||||
|
class DiskNextFile(DAVNonCollection):
|
||||||
|
"""DiskNext 文件资源"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
environ: dict[str, object],
|
||||||
|
obj: Object,
|
||||||
|
user_id: UUID,
|
||||||
|
account: WebDAV,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(path, environ)
|
||||||
|
self._obj = obj
|
||||||
|
self._user_id = user_id
|
||||||
|
self._account = account
|
||||||
|
self._write_path: str | None = None
|
||||||
|
self._write_stream: io.BufferedWriter | QuotaLimitedWriter | None = None
|
||||||
|
|
||||||
|
def get_content_length(self) -> int | None:
|
||||||
|
return self._obj.size if self._obj.size else 0
|
||||||
|
|
||||||
|
def get_content_type(self) -> str | None:
|
||||||
|
# 尝试从文件名推断 MIME 类型
|
||||||
|
mime, _ = mimetypes.guess_type(self._obj.name)
|
||||||
|
return mime or "application/octet-stream"
|
||||||
|
|
||||||
|
def get_creation_date(self) -> float | None:
|
||||||
|
if self._obj.created_at:
|
||||||
|
return self._obj.created_at.timestamp()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_last_modified(self) -> float | None:
|
||||||
|
if self._obj.updated_at:
|
||||||
|
return self._obj.updated_at.timestamp()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_display_info(self) -> dict[str, str]:
|
||||||
|
return {"type": "File"}
|
||||||
|
|
||||||
|
def get_content(self) -> io.BufferedReader | None:
|
||||||
|
"""
|
||||||
|
返回文件内容的可读流。
|
||||||
|
|
||||||
|
WsgiDAV 在线程中运行,可安全使用同步 open()。
|
||||||
|
"""
|
||||||
|
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
|
||||||
|
if not obj_with_file or not obj_with_file.physical_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
pf = obj_with_file.physical_file
|
||||||
|
policy = _run_async(_get_policy(obj_with_file.policy_id))
|
||||||
|
if not policy or not policy.server:
|
||||||
|
return None
|
||||||
|
|
||||||
|
full_path = Path(policy.server).resolve() / pf.storage_path
|
||||||
|
if not full_path.is_file():
|
||||||
|
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return open(full_path, "rb") # noqa: SIM115
|
||||||
|
|
||||||
|
def begin_write(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
content_type: str | None = None,
|
||||||
|
) -> io.BufferedWriter | QuotaLimitedWriter:
|
||||||
|
"""
|
||||||
|
开始写入文件(PUT 操作)。
|
||||||
|
|
||||||
|
返回一个可写的文件流,WsgiDAV 将向其中写入请求体数据。
|
||||||
|
当用户有配额限制时,返回 QuotaLimitedWriter 在写入过程中实时检查配额。
|
||||||
|
"""
|
||||||
|
_check_readonly(self.environ)
|
||||||
|
|
||||||
|
# 检查配额
|
||||||
|
remaining_quota: int = 0
|
||||||
|
user = _run_async(_get_user(self._user_id))
|
||||||
|
if user:
|
||||||
|
max_storage = user.group.max_storage
|
||||||
|
if max_storage > 0:
|
||||||
|
remaining_quota = max_storage - user.storage
|
||||||
|
if remaining_quota <= 0:
|
||||||
|
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||||
|
# Content-Length 预检(如果有的话)
|
||||||
|
content_length = self.environ.get("CONTENT_LENGTH")
|
||||||
|
if content_length and int(content_length) > remaining_quota:
|
||||||
|
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||||
|
|
||||||
|
# 获取策略以确定存储路径
|
||||||
|
policy = _run_async(_get_policy(self._obj.policy_id))
|
||||||
|
if not policy or not policy.server:
|
||||||
|
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||||
|
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
dir_path, storage_name, full_path = _run_async(
|
||||||
|
storage_service.generate_file_path(
|
||||||
|
user_id=self._user_id,
|
||||||
|
original_filename=self._obj.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._write_path = full_path
|
||||||
|
raw_stream = open(full_path, "wb") # noqa: SIM115
|
||||||
|
|
||||||
|
# 有配额限制时使用包装流,实时检查写入量
|
||||||
|
if remaining_quota > 0:
|
||||||
|
self._write_stream = QuotaLimitedWriter(raw_stream, remaining_quota)
|
||||||
|
else:
|
||||||
|
self._write_stream = raw_stream
|
||||||
|
|
||||||
|
return self._write_stream
|
||||||
|
|
||||||
|
def end_write(self, *, with_errors: bool) -> None:
|
||||||
|
"""写入完成后的收尾工作"""
|
||||||
|
if self._write_stream:
|
||||||
|
self._write_stream.close()
|
||||||
|
self._write_stream = None
|
||||||
|
|
||||||
|
if with_errors:
|
||||||
|
if self._write_path:
|
||||||
|
file_path = Path(self._write_path)
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink()
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._write_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取文件大小
|
||||||
|
file_path = Path(self._write_path)
|
||||||
|
if not file_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
size = file_path.stat().st_size
|
||||||
|
|
||||||
|
# 更新数据库记录
|
||||||
|
_run_async(_finalize_upload(
|
||||||
|
object_id=self._obj.id,
|
||||||
|
physical_path=self._write_path,
|
||||||
|
size=size,
|
||||||
|
owner_id=self._user_id,
|
||||||
|
policy_id=self._obj.policy_id,
|
||||||
|
))
|
||||||
|
|
||||||
|
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""软删除文件"""
|
||||||
|
_check_readonly(self.environ)
|
||||||
|
_run_async(_soft_delete_object(self._obj.id))
|
||||||
|
|
||||||
|
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||||
|
"""复制或移动文件"""
|
||||||
|
_check_readonly(self.environ)
|
||||||
|
|
||||||
|
account_root = self._account.root
|
||||||
|
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||||
|
|
||||||
|
# 解析目标父路径和新名称
|
||||||
|
if "/" in dest_disknext.rstrip("/"):
|
||||||
|
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||||
|
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||||
|
else:
|
||||||
|
parent_path = "/"
|
||||||
|
new_name = dest_disknext.lstrip("/")
|
||||||
|
|
||||||
|
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||||
|
if not dest_parent:
|
||||||
|
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||||
|
|
||||||
|
if is_move:
|
||||||
|
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||||
|
else:
|
||||||
|
_run_async(_copy_object_recursive(
|
||||||
|
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||||
|
))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def support_content_length(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_etag(self) -> str | None:
|
||||||
|
"""返回 ETag(基于ID和更新时间),WsgiDAV 会自动加双引号"""
|
||||||
|
if self._obj.updated_at:
|
||||||
|
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
def support_etag(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def support_ranges(self) -> bool:
|
||||||
|
return True
|
||||||
11
routers/wopi/__init__.py
Normal file
11
routers/wopi/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
WOPI(Web Application Open Platform Interface)路由
|
||||||
|
|
||||||
|
挂载在根级别 /wopi(非 /api/v1 下),因为 WOPI 客户端要求标准路径。
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .files import wopi_files_router
|
||||||
|
|
||||||
|
wopi_router = APIRouter(prefix="/wopi", tags=["wopi"])
|
||||||
|
wopi_router.include_router(wopi_files_router)
|
||||||
203
routers/wopi/files/__init__.py
Normal file
203
routers/wopi/files/__init__.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""
|
||||||
|
WOPI 文件操作端点
|
||||||
|
|
||||||
|
实现 WOPI 协议的核心文件操作接口:
|
||||||
|
- CheckFileInfo: 获取文件元数据
|
||||||
|
- GetFile: 下载文件内容
|
||||||
|
- PutFile: 上传/更新文件内容
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, Request, Response
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
from sqlmodels import Object, PhysicalFile, Policy, PolicyType, User, WopiFileInfo
|
||||||
|
from service.storage import LocalStorageService
|
||||||
|
from utils import http_exceptions
|
||||||
|
from utils.JWT.wopi_token import verify_wopi_token
|
||||||
|
|
||||||
|
wopi_files_router = APIRouter(prefix="/files", tags=["wopi"])
|
||||||
|
|
||||||
|
|
||||||
|
@wopi_files_router.get(
|
||||||
|
path='/{file_id}',
|
||||||
|
summary='WOPI CheckFileInfo',
|
||||||
|
description='返回文件的元数据信息。',
|
||||||
|
)
|
||||||
|
async def check_file_info(
|
||||||
|
session: SessionDep,
|
||||||
|
file_id: UUID,
|
||||||
|
access_token: str = Query(...),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
WOPI CheckFileInfo 端点
|
||||||
|
|
||||||
|
认证:WOPI access_token(query 参数)
|
||||||
|
|
||||||
|
返回 WOPI 规范的 PascalCase JSON。
|
||||||
|
"""
|
||||||
|
# 验证令牌
|
||||||
|
payload = verify_wopi_token(access_token)
|
||||||
|
if not payload or payload.file_id != file_id:
|
||||||
|
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||||
|
|
||||||
|
# 获取文件
|
||||||
|
file_obj: Object | None = await Object.get(
|
||||||
|
session,
|
||||||
|
Object.id == file_id,
|
||||||
|
)
|
||||||
|
if not file_obj or not file_obj.is_file:
|
||||||
|
http_exceptions.raise_not_found("文件不存在")
|
||||||
|
|
||||||
|
# 获取用户信息
|
||||||
|
user: User | None = await User.get(session, User.id == payload.user_id)
|
||||||
|
user_name = user.nickname or user.email or str(payload.user_id) if user else str(payload.user_id)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
info = WopiFileInfo(
|
||||||
|
base_file_name=file_obj.name,
|
||||||
|
size=file_obj.size or 0,
|
||||||
|
owner_id=str(file_obj.owner_id),
|
||||||
|
user_id=str(payload.user_id),
|
||||||
|
user_friendly_name=user_name,
|
||||||
|
version=file_obj.updated_at.isoformat() if file_obj.updated_at else "",
|
||||||
|
user_can_write=payload.can_write,
|
||||||
|
read_only=not payload.can_write,
|
||||||
|
supports_update=payload.can_write,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(content=info.to_wopi_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@wopi_files_router.get(
|
||||||
|
path='/{file_id}/contents',
|
||||||
|
summary='WOPI GetFile',
|
||||||
|
description='返回文件的二进制内容。',
|
||||||
|
)
|
||||||
|
async def get_file(
|
||||||
|
session: SessionDep,
|
||||||
|
file_id: UUID,
|
||||||
|
access_token: str = Query(...),
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
WOPI GetFile 端点
|
||||||
|
|
||||||
|
认证:WOPI access_token(query 参数)
|
||||||
|
|
||||||
|
返回文件的原始二进制内容。
|
||||||
|
"""
|
||||||
|
# 验证令牌
|
||||||
|
payload = verify_wopi_token(access_token)
|
||||||
|
if not payload or payload.file_id != file_id:
|
||||||
|
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||||
|
|
||||||
|
# 获取文件
|
||||||
|
file_obj: Object | None = await Object.get(session, Object.id == file_id)
|
||||||
|
if not file_obj or not file_obj.is_file:
|
||||||
|
http_exceptions.raise_not_found("文件不存在")
|
||||||
|
|
||||||
|
# 获取物理文件
|
||||||
|
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
|
||||||
|
if not physical_file or not physical_file.storage_path:
|
||||||
|
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||||
|
|
||||||
|
# 获取策略
|
||||||
|
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||||
|
if not policy:
|
||||||
|
http_exceptions.raise_internal_error("存储策略不存在")
|
||||||
|
|
||||||
|
if policy.type == PolicyType.LOCAL:
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
if not await storage_service.file_exists(physical_file.storage_path):
|
||||||
|
http_exceptions.raise_not_found("物理文件不存在")
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
async with aiofiles.open(physical_file.storage_path, 'rb') as f:
|
||||||
|
content = await f.read()
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"X-WOPI-ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
|
|
||||||
|
|
||||||
|
@wopi_files_router.post(
|
||||||
|
path='/{file_id}/contents',
|
||||||
|
summary='WOPI PutFile',
|
||||||
|
description='更新文件内容。',
|
||||||
|
)
|
||||||
|
async def put_file(
|
||||||
|
session: SessionDep,
|
||||||
|
request: Request,
|
||||||
|
file_id: UUID,
|
||||||
|
access_token: str = Query(...),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
WOPI PutFile 端点
|
||||||
|
|
||||||
|
认证:WOPI access_token(query 参数,需要写权限)
|
||||||
|
|
||||||
|
接收请求体中的文件二进制内容并覆盖存储。
|
||||||
|
"""
|
||||||
|
# 验证令牌
|
||||||
|
payload = verify_wopi_token(access_token)
|
||||||
|
if not payload or payload.file_id != file_id:
|
||||||
|
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||||
|
|
||||||
|
if not payload.can_write:
|
||||||
|
http_exceptions.raise_forbidden("没有写入权限")
|
||||||
|
|
||||||
|
# 获取文件
|
||||||
|
file_obj: Object | None = await Object.get(session, Object.id == file_id)
|
||||||
|
if not file_obj or not file_obj.is_file:
|
||||||
|
http_exceptions.raise_not_found("文件不存在")
|
||||||
|
|
||||||
|
# 获取物理文件
|
||||||
|
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
|
||||||
|
if not physical_file or not physical_file.storage_path:
|
||||||
|
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||||
|
|
||||||
|
# 获取策略
|
||||||
|
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||||
|
if not policy:
|
||||||
|
http_exceptions.raise_internal_error("存储策略不存在")
|
||||||
|
|
||||||
|
# 读取请求体
|
||||||
|
content = await request.body()
|
||||||
|
|
||||||
|
if policy.type == PolicyType.LOCAL:
|
||||||
|
import aiofiles
|
||||||
|
async with aiofiles.open(physical_file.storage_path, 'wb') as f:
|
||||||
|
await f.write(content)
|
||||||
|
|
||||||
|
# 更新文件大小
|
||||||
|
new_size = len(content)
|
||||||
|
old_size = file_obj.size or 0
|
||||||
|
file_obj.size = new_size
|
||||||
|
file_obj = await file_obj.save(session, commit=False)
|
||||||
|
|
||||||
|
# 更新物理文件大小
|
||||||
|
physical_file.size = new_size
|
||||||
|
await physical_file.save(session, commit=False)
|
||||||
|
|
||||||
|
# 更新用户存储配额
|
||||||
|
size_diff = new_size - old_size
|
||||||
|
if size_diff != 0:
|
||||||
|
from service.storage import adjust_user_storage
|
||||||
|
await adjust_user_storage(session, file_obj.owner_id, size_diff, commit=False)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
l.info(f"WOPI PutFile: file_id={file_id}, new_size={new_size}")
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content={"ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||||
@@ -57,7 +57,7 @@ class CaptchaScene(StrEnum):
|
|||||||
async def verify_captcha_if_needed(
|
async def verify_captcha_if_needed(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
scene: CaptchaScene,
|
scene: CaptchaScene,
|
||||||
captcha_code: str | None,
|
captcha_code: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
通用验证码校验:查询设置判断是否需要,需要则校验。
|
通用验证码校验:查询设置判断是否需要,需要则校验。
|
||||||
@@ -81,23 +81,19 @@ async def verify_captcha_if_needed(
|
|||||||
if not scene_setting or scene_setting.value != "1":
|
if not scene_setting or scene_setting.value != "1":
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2. 需要但未提供
|
# 2. 查询验证码类型和密钥
|
||||||
if not captcha_code:
|
|
||||||
http_exceptions.raise_bad_request(detail="请完成验证码验证")
|
|
||||||
|
|
||||||
# 3. 查询验证码类型和密钥
|
|
||||||
captcha_settings: list[Setting] = await Setting.get(
|
captcha_settings: list[Setting] = await Setting.get(
|
||||||
session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all",
|
session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all",
|
||||||
)
|
)
|
||||||
s: dict[str, str | None] = {item.name: item.value for item in captcha_settings}
|
s: dict[str, str | None] = {item.name: item.value for item in captcha_settings}
|
||||||
captcha_type = CaptchaType(s.get("captcha_type") or "default")
|
captcha_type = CaptchaType(s.get("captcha_type") or "default")
|
||||||
|
|
||||||
# 4. DEFAULT 图片验证码尚未实现,跳过
|
# 3. DEFAULT 图片验证码尚未实现,跳过
|
||||||
if captcha_type == CaptchaType.DEFAULT:
|
if captcha_type == CaptchaType.DEFAULT:
|
||||||
l.warning("DEFAULT 图片验证码尚未实现,跳过验证")
|
l.warning("DEFAULT 图片验证码尚未实现,跳过验证")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 5. 选择验证器和密钥
|
# 4. 选择验证器和密钥
|
||||||
if captcha_type == CaptchaType.GCAPTCHA:
|
if captcha_type == CaptchaType.GCAPTCHA:
|
||||||
secret = s.get("captcha_ReCaptchaSecret")
|
secret = s.get("captcha_ReCaptchaSecret")
|
||||||
verifier: CaptchaBase = GCaptcha()
|
verifier: CaptchaBase = GCaptcha()
|
||||||
@@ -112,7 +108,7 @@ async def verify_captcha_if_needed(
|
|||||||
l.error(f"验证码密钥未配置: captcha_type={captcha_type}")
|
l.error(f"验证码密钥未配置: captcha_type={captcha_type}")
|
||||||
http_exceptions.raise_internal_error()
|
http_exceptions.raise_internal_error()
|
||||||
|
|
||||||
# 6. 调用第三方 API 校验
|
# 5. 调用第三方 API 校验
|
||||||
is_valid = await verifier.verify_captcha(
|
is_valid = await verifier.verify_captcha(
|
||||||
CaptchaRequestBase(response=captcha_code, secret=secret)
|
CaptchaRequestBase(response=captcha_code, secret=secret)
|
||||||
)
|
)
|
||||||
|
|||||||
5
service/captcha/default.py
Normal file
5
service/captcha/default.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from captcha.image import ImageCaptcha
|
||||||
|
|
||||||
|
captcha = ImageCaptcha()
|
||||||
|
|
||||||
|
print(captcha.generate())
|
||||||
68
service/redis/challenge_store.py
Normal file
68
service/redis/challenge_store.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
WebAuthn Challenge 一次性存储
|
||||||
|
|
||||||
|
支持 Redis(首选,使用 GETDEL 原子操作)和内存 TTLCache(降级)。
|
||||||
|
Challenge 存储后 5 分钟过期,取出即删除(防重放)。
|
||||||
|
"""
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from . import RedisManager
|
||||||
|
|
||||||
|
# Challenge 过期时间(秒)
|
||||||
|
_CHALLENGE_TTL: int = 300
|
||||||
|
|
||||||
|
|
||||||
|
class ChallengeStore:
|
||||||
|
"""
|
||||||
|
WebAuthn Challenge 一次性存储管理器
|
||||||
|
|
||||||
|
根据 Redis 可用性自动选择存储后端:
|
||||||
|
- Redis 可用:使用 Redis GETDEL 原子操作
|
||||||
|
- Redis 不可用:使用内存 TTLCache(仅单实例)
|
||||||
|
|
||||||
|
Key 约定:
|
||||||
|
- 注册: ``reg:{user_id}``
|
||||||
|
- 登录: ``auth:{challenge_token}``
|
||||||
|
"""
|
||||||
|
|
||||||
|
_memory_cache: ClassVar[TTLCache[str, bytes]] = TTLCache(
|
||||||
|
maxsize=10000,
|
||||||
|
ttl=_CHALLENGE_TTL,
|
||||||
|
)
|
||||||
|
"""内存缓存降级方案"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def store(cls, key: str, challenge: bytes) -> None:
|
||||||
|
"""
|
||||||
|
存储 challenge,TTL 5 分钟。
|
||||||
|
|
||||||
|
:param key: 存储键(如 ``reg:{user_id}`` 或 ``auth:{token}``)
|
||||||
|
:param challenge: challenge 字节数据
|
||||||
|
"""
|
||||||
|
client = RedisManager.get_client()
|
||||||
|
|
||||||
|
if client is not None:
|
||||||
|
redis_key = f"webauthn_challenge:{key}"
|
||||||
|
await client.set(redis_key, challenge, ex=_CHALLENGE_TTL)
|
||||||
|
else:
|
||||||
|
cls._memory_cache[key] = challenge
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def retrieve_and_delete(cls, key: str) -> bytes | None:
|
||||||
|
"""
|
||||||
|
一次性取出并删除 challenge(防重放)。
|
||||||
|
|
||||||
|
:param key: 存储键
|
||||||
|
:return: challenge 字节数据,过期或不存在时返回 None
|
||||||
|
"""
|
||||||
|
client = RedisManager.get_client()
|
||||||
|
|
||||||
|
if client is not None:
|
||||||
|
redis_key = f"webauthn_challenge:{key}"
|
||||||
|
result: bytes | None = await client.getdel(redis_key)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return cls._memory_cache.pop(key, None)
|
||||||
128
service/redis/webdav_auth_cache.py
Normal file
128
service/redis/webdav_auth_cache.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
WebDAV 认证缓存
|
||||||
|
|
||||||
|
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
|
||||||
|
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
from typing import ClassVar
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from . import RedisManager
|
||||||
|
|
||||||
|
_AUTH_TTL: int = 300
|
||||||
|
"""认证缓存 TTL(秒),5 分钟"""
|
||||||
|
|
||||||
|
|
||||||
|
class WebDAVAuthCache:
|
||||||
|
"""
|
||||||
|
WebDAV 认证结果缓存
|
||||||
|
|
||||||
|
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
|
||||||
|
缓存值格式: {user_id}:{webdav_id}
|
||||||
|
|
||||||
|
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
|
||||||
|
"""内存缓存降级方案"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_key(cls, email: str, account_name: str, password: str) -> str:
|
||||||
|
"""构建缓存键"""
|
||||||
|
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
|
||||||
|
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(
|
||||||
|
cls,
|
||||||
|
email: str,
|
||||||
|
account_name: str,
|
||||||
|
password: str,
|
||||||
|
) -> tuple[UUID, int] | None:
|
||||||
|
"""
|
||||||
|
查询缓存中的认证结果。
|
||||||
|
|
||||||
|
:param email: 用户邮箱
|
||||||
|
:param account_name: WebDAV 账户名
|
||||||
|
:param password: 用户提供的明文密码
|
||||||
|
:return: (user_id, webdav_id) 或 None(缓存未命中)
|
||||||
|
"""
|
||||||
|
key = cls._build_key(email, account_name, password)
|
||||||
|
|
||||||
|
client = RedisManager.get_client()
|
||||||
|
if client is not None:
|
||||||
|
value = await client.get(key)
|
||||||
|
if value is not None:
|
||||||
|
raw = value.decode() if isinstance(value, bytes) else value
|
||||||
|
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||||
|
return UUID(user_id_str), int(webdav_id_str)
|
||||||
|
else:
|
||||||
|
raw = cls._memory_cache.get(key)
|
||||||
|
if raw is not None:
|
||||||
|
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||||
|
return UUID(user_id_str), int(webdav_id_str)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def set(
|
||||||
|
cls,
|
||||||
|
email: str,
|
||||||
|
account_name: str,
|
||||||
|
password: str,
|
||||||
|
user_id: UUID,
|
||||||
|
webdav_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
写入认证结果到缓存。
|
||||||
|
|
||||||
|
:param email: 用户邮箱
|
||||||
|
:param account_name: WebDAV 账户名
|
||||||
|
:param password: 用户提供的明文密码
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param webdav_id: WebDAV 账户ID
|
||||||
|
"""
|
||||||
|
key = cls._build_key(email, account_name, password)
|
||||||
|
value = f"{user_id}:{webdav_id}"
|
||||||
|
|
||||||
|
client = RedisManager.get_client()
|
||||||
|
if client is not None:
|
||||||
|
await client.set(key, value, ex=_AUTH_TTL)
|
||||||
|
else:
|
||||||
|
cls._memory_cache[key] = value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
|
||||||
|
"""
|
||||||
|
失效指定账户的所有缓存。
|
||||||
|
|
||||||
|
由于缓存键包含 password hash,无法精确删除,
|
||||||
|
Redis 端使用 pattern scan 删除,内存端清空全部。
|
||||||
|
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param account_name: WebDAV 账户名
|
||||||
|
"""
|
||||||
|
client = RedisManager.get_client()
|
||||||
|
if client is not None:
|
||||||
|
pattern = f"webdav_auth:*/{account_name}:*"
|
||||||
|
cursor: int = 0
|
||||||
|
while True:
|
||||||
|
cursor, keys = await client.scan(cursor, match=pattern, count=100)
|
||||||
|
if keys:
|
||||||
|
await client.delete(*keys)
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
|
||||||
|
keys_to_delete = [
|
||||||
|
k for k in cls._memory_cache
|
||||||
|
if f"/{account_name}:" in k
|
||||||
|
]
|
||||||
|
for k in keys_to_delete:
|
||||||
|
cls._memory_cache.pop(k, None)
|
||||||
|
|
||||||
|
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
提供文件存储相关的服务,包括:
|
提供文件存储相关的服务,包括:
|
||||||
- 本地存储服务
|
- 本地存储服务
|
||||||
|
- S3 存储服务
|
||||||
- 命名规则解析器
|
- 命名规则解析器
|
||||||
- 存储异常定义
|
- 存储异常定义
|
||||||
"""
|
"""
|
||||||
@@ -11,6 +12,8 @@ from .exceptions import (
|
|||||||
FileReadError,
|
FileReadError,
|
||||||
FileWriteError,
|
FileWriteError,
|
||||||
InvalidPathError,
|
InvalidPathError,
|
||||||
|
S3APIError,
|
||||||
|
S3MultipartUploadError,
|
||||||
StorageException,
|
StorageException,
|
||||||
StorageFileNotFoundError,
|
StorageFileNotFoundError,
|
||||||
UploadSessionExpiredError,
|
UploadSessionExpiredError,
|
||||||
@@ -18,3 +21,13 @@ from .exceptions import (
|
|||||||
)
|
)
|
||||||
from .local_storage import LocalStorageService
|
from .local_storage import LocalStorageService
|
||||||
from .naming_rule import NamingContext, NamingRuleParser
|
from .naming_rule import NamingContext, NamingRuleParser
|
||||||
|
from .object import (
|
||||||
|
adjust_user_storage,
|
||||||
|
copy_object_recursive,
|
||||||
|
delete_object_recursive,
|
||||||
|
permanently_delete_objects,
|
||||||
|
restore_objects,
|
||||||
|
soft_delete_objects,
|
||||||
|
)
|
||||||
|
from .migrate import migrate_file_with_task, migrate_directory_files
|
||||||
|
from .s3_storage import S3StorageService
|
||||||
@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
|
|||||||
class InvalidPathError(StorageException):
|
class InvalidPathError(StorageException):
|
||||||
"""无效的路径"""
|
"""无效的路径"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class S3APIError(StorageException):
|
||||||
|
"""S3 API 请求错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class S3MultipartUploadError(S3APIError):
|
||||||
|
"""S3 分片上传错误"""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -263,15 +263,49 @@ class LocalStorageService:
|
|||||||
"""
|
"""
|
||||||
删除文件(物理删除)
|
删除文件(物理删除)
|
||||||
|
|
||||||
|
删除文件后会尝试清理因此变空的父目录。
|
||||||
|
|
||||||
:param path: 完整文件路径
|
:param path: 完整文件路径
|
||||||
"""
|
"""
|
||||||
if await self.file_exists(path):
|
if await self.file_exists(path):
|
||||||
try:
|
try:
|
||||||
await aiofiles.os.remove(path)
|
await aiofiles.os.remove(path)
|
||||||
l.debug(f"已删除文件: {path}")
|
l.debug(f"已删除文件: {path}")
|
||||||
|
await self._cleanup_empty_parents(path)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
l.warning(f"删除文件失败 {path}: {e}")
|
l.warning(f"删除文件失败 {path}: {e}")
|
||||||
|
|
||||||
|
async def _cleanup_empty_parents(self, file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
从被删文件的父目录开始,向上逐级删除空目录
|
||||||
|
|
||||||
|
在以下情况停止:
|
||||||
|
|
||||||
|
- 到达存储根目录(_base_path)
|
||||||
|
- 遇到非空目录
|
||||||
|
- 遇到 .trash 目录
|
||||||
|
- 删除失败(权限、并发等)
|
||||||
|
|
||||||
|
:param file_path: 被删文件的完整路径
|
||||||
|
"""
|
||||||
|
current = Path(file_path).parent
|
||||||
|
|
||||||
|
while current != self._base_path and str(current).startswith(str(self._base_path)):
|
||||||
|
if current.name == '.trash':
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
entries = await aiofiles.os.listdir(str(current))
|
||||||
|
if entries:
|
||||||
|
break
|
||||||
|
|
||||||
|
await aiofiles.os.rmdir(str(current))
|
||||||
|
l.debug(f"已清理空目录: {current}")
|
||||||
|
current = current.parent
|
||||||
|
except OSError as e:
|
||||||
|
l.debug(f"清理空目录失败(忽略): {current}: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
async def move_to_trash(
|
async def move_to_trash(
|
||||||
self,
|
self,
|
||||||
source_path: str,
|
source_path: str,
|
||||||
@@ -304,6 +338,7 @@ class LocalStorageService:
|
|||||||
try:
|
try:
|
||||||
await aiofiles.os.rename(source_path, str(trash_path))
|
await aiofiles.os.rename(source_path, str(trash_path))
|
||||||
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
|
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
|
||||||
|
await self._cleanup_empty_parents(source_path)
|
||||||
return str(trash_path)
|
return str(trash_path)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise StorageException(f"移动文件到回收站失败: {e}")
|
raise StorageException(f"移动文件到回收站失败: {e}")
|
||||||
|
|||||||
291
service/storage/migrate.py
Normal file
291
service/storage/migrate.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
"""
|
||||||
|
存储策略迁移服务
|
||||||
|
|
||||||
|
提供跨存储策略的文件迁移功能:
|
||||||
|
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
|
||||||
|
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from loguru import logger as l
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from sqlmodels.object import Object, ObjectType
|
||||||
|
from sqlmodels.physical_file import PhysicalFile
|
||||||
|
from sqlmodels.policy import Policy, PolicyType
|
||||||
|
from sqlmodels.task import Task, TaskStatus
|
||||||
|
|
||||||
|
from .local_storage import LocalStorageService
|
||||||
|
from .s3_storage import S3StorageService
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_storage_service(
|
||||||
|
policy: Policy,
|
||||||
|
) -> LocalStorageService | S3StorageService:
|
||||||
|
"""
|
||||||
|
根据策略类型创建对应的存储服务实例
|
||||||
|
|
||||||
|
:param policy: 存储策略
|
||||||
|
:return: 存储服务实例
|
||||||
|
"""
|
||||||
|
if policy.type == PolicyType.LOCAL:
|
||||||
|
return LocalStorageService(policy)
|
||||||
|
elif policy.type == PolicyType.S3:
|
||||||
|
return await S3StorageService.from_policy(policy)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的存储策略类型: {policy.type}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_file_from_storage(
|
||||||
|
service: LocalStorageService | S3StorageService,
|
||||||
|
storage_path: str,
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
从存储服务读取文件内容
|
||||||
|
|
||||||
|
:param service: 存储服务实例
|
||||||
|
:param storage_path: 文件存储路径
|
||||||
|
:return: 文件二进制内容
|
||||||
|
"""
|
||||||
|
if isinstance(service, LocalStorageService):
|
||||||
|
return await service.read_file(storage_path)
|
||||||
|
else:
|
||||||
|
return await service.download_file(storage_path)
|
||||||
|
|
||||||
|
|
||||||
|
async def _write_file_to_storage(
|
||||||
|
service: LocalStorageService | S3StorageService,
|
||||||
|
storage_path: str,
|
||||||
|
data: bytes,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
将文件内容写入存储服务
|
||||||
|
|
||||||
|
:param service: 存储服务实例
|
||||||
|
:param storage_path: 文件存储路径
|
||||||
|
:param data: 文件二进制内容
|
||||||
|
"""
|
||||||
|
if isinstance(service, LocalStorageService):
|
||||||
|
await service.write_file(storage_path, data)
|
||||||
|
else:
|
||||||
|
await service.upload_file(storage_path, data)
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_file_from_storage(
|
||||||
|
service: LocalStorageService | S3StorageService,
|
||||||
|
storage_path: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
从存储服务删除文件
|
||||||
|
|
||||||
|
:param service: 存储服务实例
|
||||||
|
:param storage_path: 文件存储路径
|
||||||
|
"""
|
||||||
|
if isinstance(service, LocalStorageService):
|
||||||
|
await service.delete_file(storage_path)
|
||||||
|
else:
|
||||||
|
await service.delete_file(storage_path)
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_single_file(
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: Object,
|
||||||
|
dest_policy: Policy,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
将单个文件对象从当前存储策略迁移到目标策略
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 获取源物理文件和存储服务
|
||||||
|
2. 读取源文件内容
|
||||||
|
3. 在目标存储中生成新路径并写入
|
||||||
|
4. 创建新的 PhysicalFile 记录
|
||||||
|
5. 更新 Object 的 policy_id 和 physical_file_id
|
||||||
|
6. 旧 PhysicalFile 引用计数 -1,如为 0 则删除源物理文件
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param obj: 待迁移的文件对象(必须为文件类型)
|
||||||
|
:param dest_policy: 目标存储策略
|
||||||
|
"""
|
||||||
|
if obj.type != ObjectType.FILE:
|
||||||
|
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
|
||||||
|
|
||||||
|
# 获取源策略和物理文件
|
||||||
|
src_policy: Policy = await obj.awaitable_attrs.policy
|
||||||
|
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
|
||||||
|
|
||||||
|
if not old_physical:
|
||||||
|
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
|
||||||
|
return
|
||||||
|
|
||||||
|
if src_policy.id == dest_policy.id:
|
||||||
|
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. 从源存储读取文件
|
||||||
|
src_service = await _get_storage_service(src_policy)
|
||||||
|
data = await _read_file_from_storage(src_service, old_physical.storage_path)
|
||||||
|
|
||||||
|
# 2. 在目标存储生成新路径并写入
|
||||||
|
dest_service = await _get_storage_service(dest_policy)
|
||||||
|
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
|
||||||
|
user_id=obj.owner_id,
|
||||||
|
original_filename=obj.name,
|
||||||
|
)
|
||||||
|
await _write_file_to_storage(dest_service, new_storage_path, data)
|
||||||
|
|
||||||
|
# 3. 创建新的 PhysicalFile
|
||||||
|
new_physical = PhysicalFile(
|
||||||
|
storage_path=new_storage_path,
|
||||||
|
size=old_physical.size,
|
||||||
|
checksum_md5=old_physical.checksum_md5,
|
||||||
|
policy_id=dest_policy.id,
|
||||||
|
reference_count=1,
|
||||||
|
)
|
||||||
|
new_physical = await new_physical.save(session)
|
||||||
|
|
||||||
|
# 4. 更新 Object
|
||||||
|
obj.policy_id = dest_policy.id
|
||||||
|
obj.physical_file_id = new_physical.id
|
||||||
|
obj = await obj.save(session)
|
||||||
|
|
||||||
|
# 5. 旧 PhysicalFile 引用计数 -1
|
||||||
|
old_physical.decrement_reference()
|
||||||
|
if old_physical.can_be_deleted:
|
||||||
|
# 删除源存储中的物理文件
|
||||||
|
try:
|
||||||
|
await _delete_file_from_storage(src_service, old_physical.storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
|
||||||
|
await PhysicalFile.delete(session, old_physical)
|
||||||
|
else:
|
||||||
|
old_physical = await old_physical.save(session)
|
||||||
|
|
||||||
|
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name} → {dest_policy.name}")
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_file_with_task(
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: Object,
|
||||||
|
dest_policy: Policy,
|
||||||
|
task: Task,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
迁移单个文件并更新任务状态
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param obj: 待迁移的文件对象
|
||||||
|
:param dest_policy: 目标存储策略
|
||||||
|
:param task: 关联的任务记录
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
task.progress = 0
|
||||||
|
task = await task.save(session)
|
||||||
|
|
||||||
|
await migrate_single_file(session, obj, dest_policy)
|
||||||
|
|
||||||
|
task.status = TaskStatus.COMPLETED
|
||||||
|
task.progress = 100
|
||||||
|
task = await task.save(session)
|
||||||
|
except Exception as e:
|
||||||
|
l.error(f"文件迁移任务失败: {obj.id}: {e}")
|
||||||
|
task.status = TaskStatus.ERROR
|
||||||
|
task.error = str(e)[:500]
|
||||||
|
task = await task.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_directory_files(
|
||||||
|
session: AsyncSession,
|
||||||
|
folder: Object,
|
||||||
|
dest_policy: Policy,
|
||||||
|
task: Task,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
迁移目录下所有文件到目标存储策略
|
||||||
|
|
||||||
|
递归遍历目录树,将所有文件迁移到目标策略。
|
||||||
|
子目录的 policy_id 同步更新。
|
||||||
|
任务进度按文件数比例更新。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param folder: 目录对象
|
||||||
|
:param dest_policy: 目标存储策略
|
||||||
|
:param task: 关联的任务记录
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
task.progress = 0
|
||||||
|
task = await task.save(session)
|
||||||
|
|
||||||
|
# 收集所有需要迁移的文件
|
||||||
|
files_to_migrate: list[Object] = []
|
||||||
|
folders_to_update: list[Object] = []
|
||||||
|
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
|
||||||
|
|
||||||
|
total = len(files_to_migrate)
|
||||||
|
migrated = 0
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for file_obj in files_to_migrate:
|
||||||
|
try:
|
||||||
|
await migrate_single_file(session, file_obj, dest_policy)
|
||||||
|
migrated += 1
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"{file_obj.name}: {e}"
|
||||||
|
l.error(f"迁移文件失败: {error_msg}")
|
||||||
|
errors.append(error_msg)
|
||||||
|
|
||||||
|
# 更新进度
|
||||||
|
if total > 0:
|
||||||
|
task.progress = min(99, int(migrated / total * 100))
|
||||||
|
task = await task.save(session)
|
||||||
|
|
||||||
|
# 更新所有子目录的 policy_id
|
||||||
|
for sub_folder in folders_to_update:
|
||||||
|
sub_folder.policy_id = dest_policy.id
|
||||||
|
sub_folder = await sub_folder.save(session)
|
||||||
|
|
||||||
|
# 完成任务
|
||||||
|
if errors:
|
||||||
|
task.status = TaskStatus.ERROR
|
||||||
|
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
|
||||||
|
else:
|
||||||
|
task.status = TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
task.progress = 100
|
||||||
|
task = await task.save(session)
|
||||||
|
|
||||||
|
l.info(
|
||||||
|
f"目录迁移完成: {folder.name} ({folder.id}), "
|
||||||
|
f"成功 {migrated}/{total}, 错误 {len(errors)}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
l.error(f"目录迁移任务失败: {folder.id}: {e}")
|
||||||
|
task.status = TaskStatus.ERROR
|
||||||
|
task.error = str(e)[:500]
|
||||||
|
task = await task.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def _collect_objects_recursive(
|
||||||
|
session: AsyncSession,
|
||||||
|
folder: Object,
|
||||||
|
files: list[Object],
|
||||||
|
folders: list[Object],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
递归收集目录下所有文件和子目录
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param folder: 当前目录
|
||||||
|
:param files: 文件列表(输出)
|
||||||
|
:param folders: 子目录列表(输出)
|
||||||
|
"""
|
||||||
|
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
|
||||||
|
|
||||||
|
for child in children:
|
||||||
|
if child.type == ObjectType.FILE:
|
||||||
|
files.append(child)
|
||||||
|
elif child.type == ObjectType.FOLDER:
|
||||||
|
folders.append(child)
|
||||||
|
await _collect_objects_recursive(session, child, files, folders)
|
||||||
@@ -23,7 +23,7 @@ import string
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from sqlmodels.base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
|
||||||
class NamingContext(SQLModelBase):
|
class NamingContext(SQLModelBase):
|
||||||
|
|||||||
505
service/storage/object.py
Normal file
505
service/storage/object.py
Normal file
@@ -0,0 +1,505 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from loguru import logger as l
|
||||||
|
from sqlalchemy import update as sql_update
|
||||||
|
from sqlalchemy.sql.functions import func
|
||||||
|
from middleware.dependencies import SessionDep
|
||||||
|
|
||||||
|
from .local_storage import LocalStorageService
|
||||||
|
from .s3_storage import S3StorageService
|
||||||
|
from sqlmodels import (
|
||||||
|
Object,
|
||||||
|
PhysicalFile,
|
||||||
|
Policy,
|
||||||
|
PolicyType,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def adjust_user_storage(
|
||||||
|
session: SessionDep,
|
||||||
|
user_id: UUID,
|
||||||
|
delta: int,
|
||||||
|
commit: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
原子更新用户已用存储空间
|
||||||
|
|
||||||
|
使用 SQL UPDATE SET storage = GREATEST(0, storage + delta) 避免竞态条件。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param delta: 变化量(正数增加,负数减少)
|
||||||
|
:param commit: 是否立即提交
|
||||||
|
"""
|
||||||
|
if delta == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
sql_update(User)
|
||||||
|
.where(User.id == user_id)
|
||||||
|
.values(storage=func.greatest(0, User.storage + delta))
|
||||||
|
)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
l.debug(f"用户 {user_id} 存储配额变更: {'+' if delta > 0 else ''}{delta} bytes")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 软删除 ====================
|
||||||
|
|
||||||
|
async def soft_delete_objects(
|
||||||
|
session: SessionDep,
|
||||||
|
objects: list[Object],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
软删除对象列表
|
||||||
|
|
||||||
|
只标记顶层对象:设置 deleted_at、保存原 parent_id 到 deleted_original_parent_id、
|
||||||
|
将 parent_id 置 NULL 脱离文件树。子对象保持不变,物理文件不移动。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param objects: 待软删除的对象列表
|
||||||
|
:return: 软删除的对象数量
|
||||||
|
"""
|
||||||
|
deleted_count = 0
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
for obj in objects:
|
||||||
|
obj.deleted_at = now
|
||||||
|
obj.deleted_original_parent_id = obj.parent_id
|
||||||
|
obj.parent_id = None
|
||||||
|
await obj.save(session, commit=False, refresh=False)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 恢复 ====================
|
||||||
|
|
||||||
|
async def _resolve_name_conflict(
|
||||||
|
session: SessionDep,
|
||||||
|
user_id: UUID,
|
||||||
|
parent_id: UUID,
|
||||||
|
name: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
解决同名冲突,返回不冲突的名称
|
||||||
|
|
||||||
|
命名规则:原名称 → 原名称 (1) → 原名称 (2) → ...
|
||||||
|
对于有扩展名的文件:name.ext → name (1).ext → name (2).ext → ...
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param parent_id: 父目录UUID
|
||||||
|
:param name: 原始名称
|
||||||
|
:return: 不冲突的名称
|
||||||
|
"""
|
||||||
|
existing = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.owner_id == user_id) &
|
||||||
|
(Object.parent_id == parent_id) &
|
||||||
|
(Object.name == name) &
|
||||||
|
(Object.deleted_at == None)
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
return name
|
||||||
|
|
||||||
|
# 分离文件名和扩展名
|
||||||
|
if '.' in name:
|
||||||
|
base, ext = name.rsplit('.', 1)
|
||||||
|
ext = f".{ext}"
|
||||||
|
else:
|
||||||
|
base = name
|
||||||
|
ext = ""
|
||||||
|
|
||||||
|
counter = 1
|
||||||
|
while True:
|
||||||
|
new_name = f"{base} ({counter}){ext}"
|
||||||
|
existing = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.owner_id == user_id) &
|
||||||
|
(Object.parent_id == parent_id) &
|
||||||
|
(Object.name == new_name) &
|
||||||
|
(Object.deleted_at == None)
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
return new_name
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
|
||||||
|
async def restore_objects(
|
||||||
|
session: SessionDep,
|
||||||
|
objects: list[Object],
|
||||||
|
user_id: UUID,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
从回收站恢复对象
|
||||||
|
|
||||||
|
检查原父目录是否存在且未删除:
|
||||||
|
- 存在 → 恢复到原位置
|
||||||
|
- 不存在 → 恢复到用户根目录
|
||||||
|
处理同名冲突(自动重命名)。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param objects: 待恢复的对象列表(必须是回收站中的顶层对象)
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:return: 恢复的对象数量
|
||||||
|
"""
|
||||||
|
root = await Object.get_root(session, user_id)
|
||||||
|
if not root:
|
||||||
|
raise ValueError("用户根目录不存在")
|
||||||
|
|
||||||
|
restored_count = 0
|
||||||
|
|
||||||
|
for obj in objects:
|
||||||
|
if not obj.deleted_at:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确定恢复目标目录
|
||||||
|
target_parent_id = root.id
|
||||||
|
if obj.deleted_original_parent_id:
|
||||||
|
original_parent = await Object.get(
|
||||||
|
session,
|
||||||
|
(Object.id == obj.deleted_original_parent_id) & (Object.deleted_at == None)
|
||||||
|
)
|
||||||
|
if original_parent:
|
||||||
|
target_parent_id = original_parent.id
|
||||||
|
|
||||||
|
# 解决同名冲突
|
||||||
|
resolved_name = await _resolve_name_conflict(
|
||||||
|
session, user_id, target_parent_id, obj.name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 恢复对象
|
||||||
|
obj.parent_id = target_parent_id
|
||||||
|
obj.deleted_at = None
|
||||||
|
obj.deleted_original_parent_id = None
|
||||||
|
if resolved_name != obj.name:
|
||||||
|
obj.name = resolved_name
|
||||||
|
await obj.save(session, commit=False, refresh=False)
|
||||||
|
restored_count += 1
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return restored_count
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 永久删除 ====================
|
||||||
|
|
||||||
|
async def _collect_file_entries_all(
|
||||||
|
session: SessionDep,
|
||||||
|
user_id: UUID,
|
||||||
|
root: Object,
|
||||||
|
) -> tuple[list[tuple[UUID, str, UUID]], int, int]:
|
||||||
|
"""
|
||||||
|
BFS 收集子树中所有文件的物理文件信息(包含已删除和未删除的子对象)
|
||||||
|
|
||||||
|
只执行 SELECT 查询,不触发 commit,ORM 对象始终有效。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param root: 根对象
|
||||||
|
:return: (文件条目列表[(obj_id, name, physical_file_id)], 总对象数, 总文件大小)
|
||||||
|
"""
|
||||||
|
file_entries: list[tuple[UUID, str, UUID]] = []
|
||||||
|
total_count = 1
|
||||||
|
total_file_size = 0
|
||||||
|
|
||||||
|
# 根对象本身是文件
|
||||||
|
if root.is_file and root.physical_file_id:
|
||||||
|
file_entries.append((root.id, root.name, root.physical_file_id))
|
||||||
|
total_file_size += root.size
|
||||||
|
|
||||||
|
# BFS 遍历子目录(使用 get_all_children 包含所有子对象)
|
||||||
|
if root.is_folder:
|
||||||
|
queue: list[UUID] = [root.id]
|
||||||
|
while queue:
|
||||||
|
parent_id = queue.pop(0)
|
||||||
|
children = await Object.get_all_children(session, user_id, parent_id)
|
||||||
|
for child in children:
|
||||||
|
total_count += 1
|
||||||
|
if child.is_file and child.physical_file_id:
|
||||||
|
file_entries.append((child.id, child.name, child.physical_file_id))
|
||||||
|
total_file_size += child.size
|
||||||
|
elif child.is_folder:
|
||||||
|
queue.append(child.id)
|
||||||
|
|
||||||
|
return file_entries, total_count, total_file_size
|
||||||
|
|
||||||
|
|
||||||
|
async def permanently_delete_objects(
|
||||||
|
session: SessionDep,
|
||||||
|
objects: list[Object],
|
||||||
|
user_id: UUID,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
永久删除回收站中的对象
|
||||||
|
|
||||||
|
验证对象在回收站中(deleted_at IS NOT NULL),
|
||||||
|
BFS 收集所有子文件的 PhysicalFile 信息,
|
||||||
|
处理引用计数,引用为 0 时物理删除文件,
|
||||||
|
最后硬删除根 Object(CASCADE 自动清理子对象)。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param objects: 待永久删除的对象列表
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:return: 永久删除的对象数量
|
||||||
|
"""
|
||||||
|
total_deleted = 0
|
||||||
|
|
||||||
|
for obj in objects:
|
||||||
|
if not obj.deleted_at:
|
||||||
|
l.warning(f"对象 {obj.id} 不在回收站中,跳过永久删除")
|
||||||
|
continue
|
||||||
|
|
||||||
|
root_id = obj.id
|
||||||
|
file_entries, obj_count, total_file_size = await _collect_file_entries_all(
|
||||||
|
session, user_id, obj
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理 PhysicalFile 引用计数
|
||||||
|
for obj_id, obj_name, physical_file_id in file_entries:
|
||||||
|
physical_file = await PhysicalFile.get(session, PhysicalFile.id == physical_file_id)
|
||||||
|
if not physical_file:
|
||||||
|
continue
|
||||||
|
|
||||||
|
physical_file.decrement_reference()
|
||||||
|
|
||||||
|
if physical_file.can_be_deleted:
|
||||||
|
# 物理删除文件
|
||||||
|
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||||
|
if policy:
|
||||||
|
try:
|
||||||
|
if policy.type == PolicyType.LOCAL:
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
await storage_service.delete_file(physical_file.storage_path)
|
||||||
|
elif policy.type == PolicyType.S3:
|
||||||
|
s3_service = await S3StorageService.from_policy(policy)
|
||||||
|
await s3_service.delete_file(physical_file.storage_path)
|
||||||
|
l.debug(f"物理文件已删除: {obj_name}")
|
||||||
|
except Exception as e:
|
||||||
|
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||||
|
|
||||||
|
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||||
|
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||||
|
else:
|
||||||
|
physical_file = await physical_file.save(session, commit=False)
|
||||||
|
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||||
|
|
||||||
|
# 更新用户存储配额
|
||||||
|
if total_file_size > 0:
|
||||||
|
await adjust_user_storage(session, user_id, -total_file_size, commit=False)
|
||||||
|
|
||||||
|
# 硬删除根对象,CASCADE 自动删除所有子对象(不立即提交,避免其余对象过期)
|
||||||
|
await Object.delete(session, condition=Object.id == root_id, commit=False)
|
||||||
|
|
||||||
|
total_deleted += obj_count
|
||||||
|
|
||||||
|
# 统一提交所有变更
|
||||||
|
await session.commit()
|
||||||
|
return total_deleted
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 旧接口(保持向后兼容) ====================
|
||||||
|
|
||||||
|
async def _collect_file_entries(
|
||||||
|
session: SessionDep,
|
||||||
|
user_id: UUID,
|
||||||
|
root: Object,
|
||||||
|
) -> tuple[list[tuple[UUID, str, UUID]], int, int]:
|
||||||
|
"""
|
||||||
|
BFS 收集子树中所有文件的物理文件信息
|
||||||
|
|
||||||
|
只执行 SELECT 查询,不触发 commit,ORM 对象始终有效。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param root: 根对象
|
||||||
|
:return: (文件条目列表[(obj_id, name, physical_file_id)], 总对象数, 总文件大小)
|
||||||
|
"""
|
||||||
|
file_entries: list[tuple[UUID, str, UUID]] = []
|
||||||
|
total_count = 1
|
||||||
|
total_file_size = 0
|
||||||
|
|
||||||
|
# 根对象本身是文件
|
||||||
|
if root.is_file and root.physical_file_id:
|
||||||
|
file_entries.append((root.id, root.name, root.physical_file_id))
|
||||||
|
total_file_size += root.size
|
||||||
|
|
||||||
|
# BFS 遍历子目录
|
||||||
|
if root.is_folder:
|
||||||
|
queue: list[UUID] = [root.id]
|
||||||
|
while queue:
|
||||||
|
parent_id = queue.pop(0)
|
||||||
|
children = await Object.get_children(session, user_id, parent_id)
|
||||||
|
for child in children:
|
||||||
|
total_count += 1
|
||||||
|
if child.is_file and child.physical_file_id:
|
||||||
|
file_entries.append((child.id, child.name, child.physical_file_id))
|
||||||
|
total_file_size += child.size
|
||||||
|
elif child.is_folder:
|
||||||
|
queue.append(child.id)
|
||||||
|
|
||||||
|
return file_entries, total_count, total_file_size
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_object_recursive(
|
||||||
|
session: SessionDep,
|
||||||
|
obj: Object,
|
||||||
|
user_id: UUID,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
删除对象及其所有子对象(硬删除)
|
||||||
|
|
||||||
|
两阶段策略:
|
||||||
|
1. BFS 只读收集所有文件的 PhysicalFile 信息
|
||||||
|
2. 批量处理引用计数(commit=False),最后删除根对象触发 CASCADE
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param obj: 要删除的对象
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:return: 删除的对象数量
|
||||||
|
"""
|
||||||
|
# 阶段一:只读收集(不触发任何 commit)
|
||||||
|
root_id = obj.id
|
||||||
|
file_entries, total_count, total_file_size = await _collect_file_entries(session, user_id, obj)
|
||||||
|
|
||||||
|
# 阶段二:批量处理 PhysicalFile 引用(全部 commit=False)
|
||||||
|
for obj_id, obj_name, physical_file_id in file_entries:
|
||||||
|
physical_file = await PhysicalFile.get(session, PhysicalFile.id == physical_file_id)
|
||||||
|
if not physical_file:
|
||||||
|
continue
|
||||||
|
|
||||||
|
physical_file.decrement_reference()
|
||||||
|
|
||||||
|
if physical_file.can_be_deleted:
|
||||||
|
# 物理删除文件
|
||||||
|
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||||
|
if policy:
|
||||||
|
try:
|
||||||
|
if policy.type == PolicyType.LOCAL:
|
||||||
|
storage_service = LocalStorageService(policy)
|
||||||
|
await storage_service.delete_file(physical_file.storage_path)
|
||||||
|
elif policy.type == PolicyType.S3:
|
||||||
|
options = await policy.awaitable_attrs.options
|
||||||
|
s3_service = S3StorageService(
|
||||||
|
policy,
|
||||||
|
region=options.s3_region if options else 'us-east-1',
|
||||||
|
is_path_style=options.s3_path_style if options else False,
|
||||||
|
)
|
||||||
|
await s3_service.delete_file(physical_file.storage_path)
|
||||||
|
l.debug(f"物理文件已删除: {obj_name}")
|
||||||
|
except Exception as e:
|
||||||
|
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||||
|
|
||||||
|
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||||
|
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||||
|
else:
|
||||||
|
physical_file = await physical_file.save(session, commit=False)
|
||||||
|
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||||
|
|
||||||
|
# 阶段三:更新用户存储配额(与删除在同一事务中)
|
||||||
|
if total_file_size > 0:
|
||||||
|
await adjust_user_storage(session, user_id, -total_file_size, commit=False)
|
||||||
|
|
||||||
|
# 阶段四:删除根对象,数据库 CASCADE 自动删除所有子对象
|
||||||
|
# commit=True(默认),一次性提交所有 PhysicalFile 变更 + Object 删除 + 配额更新
|
||||||
|
await Object.delete(session, condition=Object.id == root_id)
|
||||||
|
|
||||||
|
return total_count
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 复制 ====================
|
||||||
|
|
||||||
|
async def _copy_object_recursive(
|
||||||
|
session: SessionDep,
|
||||||
|
src: Object,
|
||||||
|
dst_parent_id: UUID,
|
||||||
|
user_id: UUID,
|
||||||
|
) -> tuple[int, list[UUID], int]:
|
||||||
|
"""
|
||||||
|
递归复制对象(内部实现)
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param src: 源对象
|
||||||
|
:param dst_parent_id: 目标父目录UUID
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:return: (复制数量, 新对象UUID列表, 复制的总文件大小)
|
||||||
|
"""
|
||||||
|
copied_count = 0
|
||||||
|
new_ids: list[UUID] = []
|
||||||
|
total_copied_size = 0
|
||||||
|
|
||||||
|
# 在 save() 之前保存需要的属性值,避免 commit 后对象过期导致懒加载失败
|
||||||
|
src_is_folder = src.is_folder
|
||||||
|
src_is_file = src.is_file
|
||||||
|
src_id = src.id
|
||||||
|
src_size = src.size
|
||||||
|
src_physical_file_id = src.physical_file_id
|
||||||
|
|
||||||
|
# 创建新的 Object 记录
|
||||||
|
new_obj = Object(
|
||||||
|
name=src.name,
|
||||||
|
type=src.type,
|
||||||
|
size=src.size,
|
||||||
|
password=src.password,
|
||||||
|
parent_id=dst_parent_id,
|
||||||
|
owner_id=user_id,
|
||||||
|
policy_id=src.policy_id,
|
||||||
|
physical_file_id=src.physical_file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果是文件,增加物理文件引用计数
|
||||||
|
if src_is_file and src_physical_file_id:
|
||||||
|
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
|
||||||
|
if physical_file:
|
||||||
|
physical_file.increment_reference()
|
||||||
|
physical_file = await physical_file.save(session)
|
||||||
|
total_copied_size += src_size
|
||||||
|
|
||||||
|
new_obj = await new_obj.save(session)
|
||||||
|
copied_count += 1
|
||||||
|
new_ids.append(new_obj.id)
|
||||||
|
|
||||||
|
# 如果是目录,递归复制子对象
|
||||||
|
if src_is_folder:
|
||||||
|
children = await Object.get_children(session, user_id, src_id)
|
||||||
|
for child in children:
|
||||||
|
child_count, child_ids, child_size = await _copy_object_recursive(
|
||||||
|
session, child, new_obj.id, user_id
|
||||||
|
)
|
||||||
|
copied_count += child_count
|
||||||
|
new_ids.extend(child_ids)
|
||||||
|
total_copied_size += child_size
|
||||||
|
|
||||||
|
return copied_count, new_ids, total_copied_size
|
||||||
|
|
||||||
|
|
||||||
|
async def copy_object_recursive(
|
||||||
|
session: SessionDep,
|
||||||
|
src: Object,
|
||||||
|
dst_parent_id: UUID,
|
||||||
|
user_id: UUID,
|
||||||
|
) -> tuple[int, list[UUID], int]:
|
||||||
|
"""
|
||||||
|
递归复制对象
|
||||||
|
|
||||||
|
对于文件:
|
||||||
|
- 增加 PhysicalFile 引用计数
|
||||||
|
- 创建新的 Object 记录指向同一 PhysicalFile
|
||||||
|
|
||||||
|
对于目录:
|
||||||
|
- 创建新目录
|
||||||
|
- 递归复制所有子对象
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param src: 源对象
|
||||||
|
:param dst_parent_id: 目标父目录UUID
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:return: (复制数量, 新对象UUID列表, 复制的总文件大小)
|
||||||
|
"""
|
||||||
|
return await _copy_object_recursive(session, src, dst_parent_id, user_id)
|
||||||
709
service/storage/s3_storage.py
Normal file
709
service/storage/s3_storage.py
Normal file
@@ -0,0 +1,709 @@
|
|||||||
|
"""
|
||||||
|
S3 存储服务
|
||||||
|
|
||||||
|
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||||
|
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
|
||||||
|
|
||||||
|
移植自 foxline-pro-backend-server 项目的 S3APIClient,
|
||||||
|
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import ClassVar, Literal
|
||||||
|
from urllib.parse import quote, urlencode
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from yarl import URL
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
from sqlmodels.policy import Policy
|
||||||
|
from .exceptions import S3APIError, S3MultipartUploadError
|
||||||
|
from .naming_rule import NamingContext, NamingRuleParser
|
||||||
|
|
||||||
|
|
||||||
|
def _sign(key: bytes, msg: str) -> bytes:
|
||||||
|
"""HMAC-SHA256 签名"""
|
||||||
|
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||||
|
|
||||||
|
|
||||||
|
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
||||||
|
|
||||||
|
|
||||||
|
class S3StorageService:
|
||||||
|
"""
|
||||||
|
S3 存储服务
|
||||||
|
|
||||||
|
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||||
|
从 Policy 配置中读取 S3 连接信息。
|
||||||
|
|
||||||
|
使用示例::
|
||||||
|
|
||||||
|
service = S3StorageService(policy, region='us-east-1')
|
||||||
|
await service.upload_file('path/to/file.txt', b'content')
|
||||||
|
data = await service.download_file('path/to/file.txt')
|
||||||
|
"""
|
||||||
|
|
||||||
|
_http_session: ClassVar[aiohttp.ClientSession | None] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: Policy,
|
||||||
|
region: str = 'us-east-1',
|
||||||
|
is_path_style: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param policy: 存储策略(server=endpoint_url, bucket_name, access_key, secret_key)
|
||||||
|
:param region: S3 区域
|
||||||
|
:param is_path_style: 是否使用路径风格 URL
|
||||||
|
"""
|
||||||
|
if not policy.server:
|
||||||
|
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
|
||||||
|
if not policy.bucket_name:
|
||||||
|
raise S3APIError("S3 策略必须指定 bucket_name")
|
||||||
|
if not policy.access_key:
|
||||||
|
raise S3APIError("S3 策略必须指定 access_key")
|
||||||
|
if not policy.secret_key:
|
||||||
|
raise S3APIError("S3 策略必须指定 secret_key")
|
||||||
|
|
||||||
|
self._policy = policy
|
||||||
|
self._endpoint_url = policy.server.rstrip("/")
|
||||||
|
self._bucket_name = policy.bucket_name
|
||||||
|
self._access_key = policy.access_key
|
||||||
|
self._secret_key = policy.secret_key
|
||||||
|
self._region = region
|
||||||
|
self._is_path_style = is_path_style
|
||||||
|
self._base_url = policy.base_url
|
||||||
|
|
||||||
|
# 从 endpoint_url 提取 host
|
||||||
|
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||||
|
|
||||||
|
# ==================== 工厂方法 ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
|
||||||
|
"""
|
||||||
|
根据 Policy 异步创建 S3StorageService(自动加载 options)
|
||||||
|
|
||||||
|
:param policy: 存储策略
|
||||||
|
:return: S3StorageService 实例
|
||||||
|
"""
|
||||||
|
options = await policy.awaitable_attrs.options
|
||||||
|
region = options.s3_region if options else 'us-east-1'
|
||||||
|
is_path_style = options.s3_path_style if options else False
|
||||||
|
return cls(policy, region=region, is_path_style=is_path_style)
|
||||||
|
|
||||||
|
# ==================== HTTP Session 管理 ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def initialize_session(cls) -> None:
|
||||||
|
"""初始化全局 aiohttp ClientSession"""
|
||||||
|
if cls._http_session is None or cls._http_session.closed:
|
||||||
|
cls._http_session = aiohttp.ClientSession()
|
||||||
|
l.info("S3StorageService HTTP session 已初始化")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def close_session(cls) -> None:
|
||||||
|
"""关闭全局 aiohttp ClientSession"""
|
||||||
|
if cls._http_session and not cls._http_session.closed:
|
||||||
|
await cls._http_session.close()
|
||||||
|
cls._http_session = None
|
||||||
|
l.info("S3StorageService HTTP session 已关闭")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_session(cls) -> aiohttp.ClientSession:
|
||||||
|
"""获取 HTTP session"""
|
||||||
|
if cls._http_session is None or cls._http_session.closed:
|
||||||
|
# 懒初始化,以防 initialize_session 未被调用
|
||||||
|
cls._http_session = aiohttp.ClientSession()
|
||||||
|
return cls._http_session
|
||||||
|
|
||||||
|
# ==================== AWS Signature V4 签名 ====================
|
||||||
|
|
||||||
|
def _get_signature_key(self, date_stamp: str) -> bytes:
|
||||||
|
"""生成 AWS Signature V4 签名密钥"""
|
||||||
|
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
|
||||||
|
k_region = _sign(k_date, self._region)
|
||||||
|
k_service = _sign(k_region, "s3")
|
||||||
|
return _sign(k_service, "aws4_request")
|
||||||
|
|
||||||
|
def _create_authorization_header(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
uri: str,
|
||||||
|
query_string: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
payload_hash: str,
|
||||||
|
amz_date: str,
|
||||||
|
date_stamp: str,
|
||||||
|
) -> str:
|
||||||
|
"""创建 AWS Signature V4 授权头"""
|
||||||
|
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
|
||||||
|
canonical_headers = "".join(
|
||||||
|
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
|
||||||
|
)
|
||||||
|
canonical_request = (
|
||||||
|
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
|
||||||
|
f"{signed_headers}\n{payload_hash}"
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm = "AWS4-HMAC-SHA256"
|
||||||
|
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||||
|
string_to_sign = (
|
||||||
|
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
|
||||||
|
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
signing_key = self._get_signature_key(date_stamp)
|
||||||
|
signature = hmac.new(
|
||||||
|
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
|
||||||
|
f"SignedHeaders={signed_headers}, Signature={signature}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_headers(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
uri: str,
|
||||||
|
query_string: str = "",
|
||||||
|
payload: bytes = b"",
|
||||||
|
content_type: str | None = None,
|
||||||
|
extra_headers: dict[str, str] | None = None,
|
||||||
|
payload_hash: str | None = None,
|
||||||
|
host: str | None = None,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
构建包含 AWS V4 签名的完整请求头
|
||||||
|
|
||||||
|
:param method: HTTP 方法
|
||||||
|
:param uri: 请求 URI
|
||||||
|
:param query_string: 查询字符串
|
||||||
|
:param payload: 请求体字节(用于计算哈希)
|
||||||
|
:param content_type: Content-Type
|
||||||
|
:param extra_headers: 额外请求头
|
||||||
|
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
|
||||||
|
:param host: Host 头(默认使用 self._host)
|
||||||
|
"""
|
||||||
|
now_utc = datetime.now(timezone.utc)
|
||||||
|
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
date_stamp = now_utc.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
if payload_hash is None:
|
||||||
|
payload_hash = hashlib.sha256(payload).hexdigest()
|
||||||
|
|
||||||
|
effective_host = host or self._host
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
"Host": effective_host,
|
||||||
|
"X-Amz-Date": amz_date,
|
||||||
|
"X-Amz-Content-Sha256": payload_hash,
|
||||||
|
}
|
||||||
|
if content_type:
|
||||||
|
headers["Content-Type"] = content_type
|
||||||
|
if extra_headers:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
|
authorization = self._create_authorization_header(
|
||||||
|
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
|
||||||
|
)
|
||||||
|
headers["Authorization"] = authorization
|
||||||
|
return headers
|
||||||
|
|
||||||
|
# ==================== 内部请求方法 ====================
|
||||||
|
|
||||||
|
def _build_uri(self, key: str | None = None) -> str:
|
||||||
|
"""
|
||||||
|
构建请求 URI
|
||||||
|
|
||||||
|
按 AWS S3 Signature V4 规范对路径进行 URI 编码(S3 仅需一次)。
|
||||||
|
斜杠作为路径分隔符保留不编码。
|
||||||
|
"""
|
||||||
|
if self._is_path_style:
|
||||||
|
if key:
|
||||||
|
return f"/{self._bucket_name}/{quote(key, safe='/')}"
|
||||||
|
return f"/{self._bucket_name}"
|
||||||
|
else:
|
||||||
|
if key:
|
||||||
|
return f"/{quote(key, safe='/')}"
|
||||||
|
return "/"
|
||||||
|
|
||||||
|
def _build_url(self, uri: str, query_string: str = "") -> str:
|
||||||
|
"""构建完整请求 URL"""
|
||||||
|
if self._is_path_style:
|
||||||
|
base = self._endpoint_url
|
||||||
|
else:
|
||||||
|
# 虚拟主机风格:bucket.endpoint
|
||||||
|
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
|
||||||
|
base = f"{protocol}{self._bucket_name}.{self._host}"
|
||||||
|
|
||||||
|
url = f"{base}{uri}"
|
||||||
|
if query_string:
|
||||||
|
url = f"{url}?{query_string}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
def _get_effective_host(self) -> str:
|
||||||
|
"""获取实际请求的 Host 头"""
|
||||||
|
if self._is_path_style:
|
||||||
|
return self._host
|
||||||
|
return f"{self._bucket_name}.{self._host}"
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
key: str | None = None,
|
||||||
|
query_params: dict[str, str] | None = None,
|
||||||
|
payload: bytes = b"",
|
||||||
|
content_type: str | None = None,
|
||||||
|
extra_headers: dict[str, str] | None = None,
|
||||||
|
) -> aiohttp.ClientResponse:
|
||||||
|
"""发送签名请求"""
|
||||||
|
uri = self._build_uri(key)
|
||||||
|
query_string = urlencode(sorted(query_params.items())) if query_params else ""
|
||||||
|
effective_host = self._get_effective_host()
|
||||||
|
|
||||||
|
headers = self._build_headers(
|
||||||
|
method, uri, query_string, payload, content_type,
|
||||||
|
extra_headers, host=effective_host,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = self._build_url(uri, query_string)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._get_session().request(
|
||||||
|
method, URL(url, encoded=True),
|
||||||
|
headers=headers, data=payload if payload else None,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
|
||||||
|
|
||||||
|
async def _request_streaming(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
key: str,
|
||||||
|
data_stream: AsyncIterator[bytes],
|
||||||
|
content_length: int,
|
||||||
|
content_type: str | None = None,
|
||||||
|
) -> aiohttp.ClientResponse:
|
||||||
|
"""
|
||||||
|
发送流式签名请求(大文件上传)
|
||||||
|
|
||||||
|
使用 UNSIGNED-PAYLOAD 作为 payload hash。
|
||||||
|
"""
|
||||||
|
uri = self._build_uri(key)
|
||||||
|
effective_host = self._get_effective_host()
|
||||||
|
|
||||||
|
headers = self._build_headers(
|
||||||
|
method,
|
||||||
|
uri,
|
||||||
|
query_string="",
|
||||||
|
content_type=content_type,
|
||||||
|
extra_headers={"Content-Length": str(content_length)},
|
||||||
|
payload_hash="UNSIGNED-PAYLOAD",
|
||||||
|
host=effective_host,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = self._build_url(uri)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._get_session().request(
|
||||||
|
method, URL(url, encoded=True),
|
||||||
|
headers=headers, data=data_stream,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
|
||||||
|
|
||||||
|
# ==================== 文件操作 ====================
|
||||||
|
|
||||||
|
async def upload_file(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
data: bytes,
|
||||||
|
content_type: str = 'application/octet-stream',
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
上传文件
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:param data: 文件内容
|
||||||
|
:param content_type: MIME 类型
|
||||||
|
"""
|
||||||
|
async with await self._request(
|
||||||
|
"PUT", key=key, payload=data, content_type=content_type,
|
||||||
|
) as response:
|
||||||
|
if response.status not in (200, 201):
|
||||||
|
body = await response.text()
|
||||||
|
raise S3APIError(
|
||||||
|
f"上传失败: {self._bucket_name}/{key}, "
|
||||||
|
f"状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
|
||||||
|
|
||||||
|
async def upload_file_streaming(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
data_stream: AsyncIterator[bytes],
|
||||||
|
content_length: int,
|
||||||
|
content_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
流式上传文件(大文件,避免全部加载到内存)
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:param data_stream: 异步字节流迭代器
|
||||||
|
:param content_length: 数据总长度(必须准确)
|
||||||
|
:param content_type: MIME 类型
|
||||||
|
"""
|
||||||
|
async with await self._request_streaming(
|
||||||
|
"PUT", key=key, data_stream=data_stream,
|
||||||
|
content_length=content_length, content_type=content_type,
|
||||||
|
) as response:
|
||||||
|
if response.status not in (200, 201):
|
||||||
|
body = await response.text()
|
||||||
|
raise S3APIError(
|
||||||
|
f"流式上传失败: {self._bucket_name}/{key}, "
|
||||||
|
f"状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
|
||||||
|
|
||||||
|
async def download_file(self, key: str) -> bytes:
|
||||||
|
"""
|
||||||
|
下载文件
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:return: 文件内容
|
||||||
|
"""
|
||||||
|
async with await self._request("GET", key=key) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
body = await response.text()
|
||||||
|
raise S3APIError(
|
||||||
|
f"下载失败: {self._bucket_name}/{key}, "
|
||||||
|
f"状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
data = await response.read()
|
||||||
|
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def delete_file(self, key: str) -> None:
|
||||||
|
"""
|
||||||
|
删除文件
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
"""
|
||||||
|
async with await self._request("DELETE", key=key) as response:
|
||||||
|
if response.status in (200, 204):
|
||||||
|
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
|
||||||
|
else:
|
||||||
|
body = await response.text()
|
||||||
|
raise S3APIError(
|
||||||
|
f"删除失败: {self._bucket_name}/{key}, "
|
||||||
|
f"状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def file_exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查文件是否存在
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:return: 是否存在
|
||||||
|
"""
|
||||||
|
async with await self._request("HEAD", key=key) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return True
|
||||||
|
elif response.status == 404:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise S3APIError(
|
||||||
|
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_file_size(self, key: str) -> int:
|
||||||
|
"""
|
||||||
|
获取文件大小
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:return: 文件大小(字节)
|
||||||
|
"""
|
||||||
|
async with await self._request("HEAD", key=key) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise S3APIError(
|
||||||
|
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||||
|
)
|
||||||
|
return int(response.headers.get("Content-Length", 0))
|
||||||
|
|
||||||
|
# ==================== Multipart Upload ====================
|
||||||
|
|
||||||
|
async def create_multipart_upload(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
content_type: str = 'application/octet-stream',
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
创建分片上传任务
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:param content_type: MIME 类型
|
||||||
|
:return: Upload ID
|
||||||
|
"""
|
||||||
|
async with await self._request(
|
||||||
|
"POST",
|
||||||
|
key=key,
|
||||||
|
query_params={"uploads": ""},
|
||||||
|
content_type=content_type,
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
body = await response.text()
|
||||||
|
raise S3MultipartUploadError(
|
||||||
|
f"创建分片上传失败: {self._bucket_name}/{key}, "
|
||||||
|
f"状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
body = await response.text()
|
||||||
|
root = ET.fromstring(body)
|
||||||
|
|
||||||
|
# 查找 UploadId 元素(支持命名空间)
|
||||||
|
upload_id_elem = root.find("UploadId")
|
||||||
|
if upload_id_elem is None:
|
||||||
|
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
|
||||||
|
if upload_id_elem is None or not upload_id_elem.text:
|
||||||
|
raise S3MultipartUploadError(
|
||||||
|
f"创建分片上传响应中未找到 UploadId: {body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_id = upload_id_elem.text
|
||||||
|
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
|
||||||
|
return upload_id
|
||||||
|
|
||||||
|
async def upload_part(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
upload_id: str,
|
||||||
|
part_number: int,
|
||||||
|
data: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
上传单个分片
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:param upload_id: 分片上传 ID
|
||||||
|
:param part_number: 分片编号(从 1 开始)
|
||||||
|
:param data: 分片数据
|
||||||
|
:return: ETag
|
||||||
|
"""
|
||||||
|
async with await self._request(
|
||||||
|
"PUT",
|
||||||
|
key=key,
|
||||||
|
query_params={
|
||||||
|
"partNumber": str(part_number),
|
||||||
|
"uploadId": upload_id,
|
||||||
|
},
|
||||||
|
payload=data,
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
body = await response.text()
|
||||||
|
raise S3MultipartUploadError(
|
||||||
|
f"上传分片失败: {self._bucket_name}/{key}, "
|
||||||
|
f"part={part_number}, 状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
etag = response.headers.get("ETag", "").strip('"')
|
||||||
|
l.debug(
|
||||||
|
f"S3 分片上传成功: {self._bucket_name}/{key}, "
|
||||||
|
f"part={part_number}, etag={etag}"
|
||||||
|
)
|
||||||
|
return etag
|
||||||
|
|
||||||
|
async def complete_multipart_upload(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
upload_id: str,
|
||||||
|
parts: list[tuple[int, str]],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
完成分片上传
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:param upload_id: 分片上传 ID
|
||||||
|
:param parts: 分片列表 [(part_number, etag)]
|
||||||
|
"""
|
||||||
|
# 按 part_number 排序
|
||||||
|
parts_sorted = sorted(parts, key=lambda p: p[0])
|
||||||
|
|
||||||
|
# 构建 CompleteMultipartUpload XML
|
||||||
|
xml_parts = ''.join(
|
||||||
|
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
|
||||||
|
for pn, etag in parts_sorted
|
||||||
|
)
|
||||||
|
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
|
||||||
|
payload_bytes = payload.encode('utf-8')
|
||||||
|
|
||||||
|
async with await self._request(
|
||||||
|
"POST",
|
||||||
|
key=key,
|
||||||
|
query_params={"uploadId": upload_id},
|
||||||
|
payload=payload_bytes,
|
||||||
|
content_type="application/xml",
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
body = await response.text()
|
||||||
|
raise S3MultipartUploadError(
|
||||||
|
f"完成分片上传失败: {self._bucket_name}/{key}, "
|
||||||
|
f"状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
l.info(
|
||||||
|
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
|
||||||
|
f"共 {len(parts)} 个分片"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
|
||||||
|
"""
|
||||||
|
取消分片上传
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:param upload_id: 分片上传 ID
|
||||||
|
"""
|
||||||
|
async with await self._request(
|
||||||
|
"DELETE",
|
||||||
|
key=key,
|
||||||
|
query_params={"uploadId": upload_id},
|
||||||
|
) as response:
|
||||||
|
if response.status in (200, 204):
|
||||||
|
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
|
||||||
|
else:
|
||||||
|
body = await response.text()
|
||||||
|
l.warning(
|
||||||
|
f"取消分片上传失败: {self._bucket_name}/{key}, "
|
||||||
|
f"状态: {response.status}, {body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 预签名 URL ====================
|
||||||
|
|
||||||
|
def generate_presigned_url(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
method: Literal['GET', 'PUT'] = 'GET',
|
||||||
|
expires_in: int = 3600,
|
||||||
|
filename: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
生成 S3 预签名 URL(AWS Signature V4 Query String)
|
||||||
|
|
||||||
|
:param key: S3 对象键
|
||||||
|
:param method: HTTP 方法(GET 下载,PUT 上传)
|
||||||
|
:param expires_in: URL 有效期(秒)
|
||||||
|
:param filename: 文件名(GET 请求时设置 Content-Disposition)
|
||||||
|
:return: 预签名 URL
|
||||||
|
"""
|
||||||
|
current_time = datetime.now(timezone.utc)
|
||||||
|
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
date_stamp = current_time.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||||
|
credential = f"{self._access_key}/{credential_scope}"
|
||||||
|
|
||||||
|
uri = self._build_uri(key)
|
||||||
|
effective_host = self._get_effective_host()
|
||||||
|
|
||||||
|
query_params: dict[str, str] = {
|
||||||
|
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
|
||||||
|
'X-Amz-Credential': credential,
|
||||||
|
'X-Amz-Date': amz_date,
|
||||||
|
'X-Amz-Expires': str(expires_in),
|
||||||
|
'X-Amz-SignedHeaders': 'host',
|
||||||
|
}
|
||||||
|
|
||||||
|
# GET 请求时添加 Content-Disposition
|
||||||
|
if method == "GET" and filename:
|
||||||
|
encoded_filename = quote(filename, safe='')
|
||||||
|
query_params['response-content-disposition'] = (
|
||||||
|
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
|
canonical_query_string = "&".join(
|
||||||
|
f"{quote(k, safe='')}={quote(v, safe='')}"
|
||||||
|
for k, v in sorted(query_params.items())
|
||||||
|
)
|
||||||
|
|
||||||
|
canonical_headers = f"host:{effective_host}\n"
|
||||||
|
signed_headers = "host"
|
||||||
|
payload_hash = "UNSIGNED-PAYLOAD"
|
||||||
|
|
||||||
|
canonical_request = (
|
||||||
|
f"{method}\n"
|
||||||
|
f"{uri}\n"
|
||||||
|
f"{canonical_query_string}\n"
|
||||||
|
f"{canonical_headers}\n"
|
||||||
|
f"{signed_headers}\n"
|
||||||
|
f"{payload_hash}"
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm = "AWS4-HMAC-SHA256"
|
||||||
|
string_to_sign = (
|
||||||
|
f"{algorithm}\n"
|
||||||
|
f"{amz_date}\n"
|
||||||
|
f"{credential_scope}\n"
|
||||||
|
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
signing_key = self._get_signature_key(date_stamp)
|
||||||
|
signature = hmac.new(
|
||||||
|
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
base_url = self._build_url(uri)
|
||||||
|
return (
|
||||||
|
f"{base_url}?"
|
||||||
|
f"{canonical_query_string}&"
|
||||||
|
f"X-Amz-Signature={signature}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 路径生成 ====================
|
||||||
|
|
||||||
|
async def generate_file_path(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
original_filename: str,
|
||||||
|
) -> tuple[str, str, str]:
|
||||||
|
"""
|
||||||
|
根据命名规则生成 S3 文件存储路径
|
||||||
|
|
||||||
|
与 LocalStorageService.generate_file_path 接口一致。
|
||||||
|
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param original_filename: 原始文件名
|
||||||
|
:return: (相对目录路径, 存储文件名, 完整存储路径)
|
||||||
|
"""
|
||||||
|
context = NamingContext(
|
||||||
|
user_id=user_id,
|
||||||
|
original_filename=original_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析目录规则
|
||||||
|
dir_path = ""
|
||||||
|
if self._policy.dir_name_rule:
|
||||||
|
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
|
||||||
|
|
||||||
|
# 解析文件名规则
|
||||||
|
if self._policy.auto_rename and self._policy.file_name_rule:
|
||||||
|
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
|
||||||
|
# 确保有扩展名
|
||||||
|
if '.' in original_filename and '.' not in storage_name:
|
||||||
|
ext = original_filename.rsplit('.', 1)[1]
|
||||||
|
storage_name = f"{storage_name}.{ext}"
|
||||||
|
else:
|
||||||
|
storage_name = original_filename
|
||||||
|
|
||||||
|
# S3 不需要创建目录,直接拼接路径
|
||||||
|
if dir_path:
|
||||||
|
storage_path = f"{dir_path}/{storage_name}"
|
||||||
|
else:
|
||||||
|
storage_path = storage_name
|
||||||
|
|
||||||
|
return dir_path, storage_name, storage_path
|
||||||
@@ -1 +1 @@
|
|||||||
from .login import login
|
from .login import unified_login
|
||||||
|
|||||||
@@ -1,83 +1,428 @@
|
|||||||
from uuid import uuid4
|
"""
|
||||||
|
统一登录服务
|
||||||
|
|
||||||
from loguru import logger
|
支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from middleware.dependencies import SessionDep
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
from sqlmodels import LoginRequest, TokenResponse, User
|
from loguru import logger as l
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from service.redis.token_store import TokenStore
|
||||||
|
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||||
from sqlmodels.group import GroupClaims, GroupOptions
|
from sqlmodels.group import GroupClaims, GroupOptions
|
||||||
from sqlmodels.user import UserStatus
|
from sqlmodels.object import Object, ObjectType
|
||||||
from utils import http_exceptions
|
from sqlmodels.policy import Policy
|
||||||
from utils.JWT import create_access_token, create_refresh_token
|
from sqlmodels.setting import Setting, SettingsType
|
||||||
|
from sqlmodels.user import TokenResponse, UnifiedLoginRequest, User, UserStatus
|
||||||
|
from utils import JWT, http_exceptions
|
||||||
from utils.password.pwd import Password, PasswordStatus
|
from utils.password.pwd import Password, PasswordStatus
|
||||||
|
|
||||||
|
|
||||||
async def login(
|
async def unified_login(
|
||||||
session: SessionDep,
|
session: AsyncSession,
|
||||||
login_request: LoginRequest,
|
request: UnifiedLoginRequest,
|
||||||
) -> TokenResponse:
|
) -> TokenResponse:
|
||||||
"""
|
"""
|
||||||
根据账号密码进行登录。
|
统一登录入口,根据 provider 分发到不同的登录逻辑。
|
||||||
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
|
|
||||||
|
|
||||||
:param session: 数据库会话
|
:param session: 数据库会话
|
||||||
:param login_request: 登录请求
|
:param request: 统一登录请求
|
||||||
|
:return: TokenResponse
|
||||||
:return: TokenResponse 对象或状态码或 None
|
|
||||||
"""
|
"""
|
||||||
# 获取用户信息(预加载 group 关系)
|
await _check_provider_enabled(session, request.provider)
|
||||||
current_user: User = await User.get(
|
|
||||||
|
match request.provider:
|
||||||
|
case AuthProviderType.EMAIL_PASSWORD:
|
||||||
|
user = await _login_email_password(session, request)
|
||||||
|
case AuthProviderType.GITHUB:
|
||||||
|
user = await _login_oauth(session, request, AuthProviderType.GITHUB)
|
||||||
|
case AuthProviderType.QQ:
|
||||||
|
user = await _login_oauth(session, request, AuthProviderType.QQ)
|
||||||
|
case AuthProviderType.PASSKEY:
|
||||||
|
user = await _login_passkey(session, request)
|
||||||
|
case AuthProviderType.MAGIC_LINK:
|
||||||
|
user = await _login_magic_link(session, request)
|
||||||
|
case AuthProviderType.PHONE_SMS:
|
||||||
|
http_exceptions.raise_not_implemented("短信登录暂未开放")
|
||||||
|
case _:
|
||||||
|
http_exceptions.raise_bad_request(f"不支持的登录方式: {request.provider}")
|
||||||
|
|
||||||
|
return await _issue_tokens(session, user)
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_provider_enabled(session: AsyncSession, provider: AuthProviderType) -> None:
|
||||||
|
"""检查认证方式是否已被站长启用"""
|
||||||
|
# OAuth 类型从 OAUTH 设置中查询
|
||||||
|
if provider in (AuthProviderType.GITHUB, AuthProviderType.QQ):
|
||||||
|
setting_name = f"{provider.value}_enabled"
|
||||||
|
setting = await Setting.get(
|
||||||
session,
|
session,
|
||||||
User.email == login_request.email,
|
(Setting.type == SettingsType.OAUTH) & (Setting.name == setting_name),
|
||||||
fetch_mode="first",
|
)
|
||||||
load=User.group,
|
if not setting or setting.value != "1":
|
||||||
) #type: ignore
|
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||||
|
return
|
||||||
|
|
||||||
# 验证用户是否存在
|
# 其他类型从 AUTH 设置中查询
|
||||||
if not current_user:
|
setting_name = f"auth_{provider.value}_enabled"
|
||||||
logger.debug(f"Cannot find user with email: {login_request.email}")
|
setting = await Setting.get(
|
||||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
session,
|
||||||
|
(Setting.type == SettingsType.AUTH) & (Setting.name == setting_name),
|
||||||
|
)
|
||||||
|
if not setting or setting.value != "1":
|
||||||
|
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||||
|
|
||||||
# 验证密码是否正确
|
|
||||||
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
|
|
||||||
logger.debug(f"Password verification failed for user: {login_request.email}")
|
|
||||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
|
||||||
|
|
||||||
# 验证用户是否可登录(修复:显式枚举比较,StrEnum 永远 truthy)
|
async def _login_email_password(
|
||||||
if current_user.status != UserStatus.ACTIVE:
|
session: AsyncSession,
|
||||||
http_exceptions.raise_forbidden("Your account is disabled")
|
request: UnifiedLoginRequest,
|
||||||
|
) -> User:
|
||||||
|
"""邮箱+密码登录"""
|
||||||
|
if not request.credential:
|
||||||
|
http_exceptions.raise_bad_request("密码不能为空")
|
||||||
|
|
||||||
# 检查两步验证
|
# 查找 AuthIdentity
|
||||||
if current_user.two_factor:
|
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
# 用户已启用两步验证
|
session,
|
||||||
if not login_request.two_fa_code:
|
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||||
logger.debug(f"2FA required for user: {login_request.email}")
|
& (AuthIdentity.identifier == request.identifier),
|
||||||
http_exceptions.raise_precondition_required("2FA required")
|
)
|
||||||
|
if not identity:
|
||||||
|
l.debug(f"未找到邮箱密码身份: {request.identifier}")
|
||||||
|
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||||
|
|
||||||
# 验证 OTP 码
|
# 验证密码
|
||||||
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
|
if not identity.credential:
|
||||||
logger.debug(f"Invalid 2FA code for user: {login_request.email}")
|
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
|
||||||
|
|
||||||
|
if Password.verify(identity.credential, request.credential) != PasswordStatus.VALID:
|
||||||
|
l.debug(f"密码验证失败: {request.identifier}")
|
||||||
|
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||||
|
|
||||||
|
# 加载用户
|
||||||
|
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||||
|
if not user:
|
||||||
|
http_exceptions.raise_unauthorized("用户不存在")
|
||||||
|
|
||||||
|
# 验证用户状态
|
||||||
|
if user.status != UserStatus.ACTIVE:
|
||||||
|
http_exceptions.raise_forbidden("账户已被禁用")
|
||||||
|
|
||||||
|
# 检查两步验证(从 AuthIdentity.extra_data 中读取 2FA secret)
|
||||||
|
if identity.extra_data:
|
||||||
|
import orjson
|
||||||
|
extra: dict = orjson.loads(identity.extra_data)
|
||||||
|
two_factor_secret: str | None = extra.get("two_factor")
|
||||||
|
if two_factor_secret:
|
||||||
|
if not request.two_fa_code:
|
||||||
|
l.debug(f"需要两步验证: {request.identifier}")
|
||||||
|
http_exceptions.raise_precondition_required("需要两步验证")
|
||||||
|
if Password.verify_totp(two_factor_secret, request.two_fa_code) != PasswordStatus.VALID:
|
||||||
|
l.debug(f"两步验证失败: {request.identifier}")
|
||||||
|
http_exceptions.raise_unauthorized("两步验证码错误")
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def _login_oauth(
|
||||||
|
session: AsyncSession,
|
||||||
|
request: UnifiedLoginRequest,
|
||||||
|
provider: AuthProviderType,
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
OAuth 登录(GitHub / QQ)
|
||||||
|
|
||||||
|
identifier 为 OAuth authorization code,后端换取 access_token 再获取用户信息。
|
||||||
|
"""
|
||||||
|
# 读取 OAuth 配置
|
||||||
|
client_id_setting = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_id"),
|
||||||
|
)
|
||||||
|
client_secret_setting = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_secret"),
|
||||||
|
)
|
||||||
|
if not client_id_setting or not client_secret_setting:
|
||||||
|
http_exceptions.raise_bad_request(f"{provider.value} OAuth 未配置")
|
||||||
|
|
||||||
|
client_id = client_id_setting.value or ""
|
||||||
|
client_secret = client_secret_setting.value or ""
|
||||||
|
|
||||||
|
# 根据 provider 创建对应的 OAuth 客户端
|
||||||
|
if provider == AuthProviderType.GITHUB:
|
||||||
|
from service.oauth import GithubOAuth
|
||||||
|
oauth_client = GithubOAuth(client_id, client_secret)
|
||||||
|
token_resp = await oauth_client.get_access_token(code=request.identifier)
|
||||||
|
user_info_resp = await oauth_client.get_user_info(token_resp)
|
||||||
|
openid = str(user_info_resp.user_data.id)
|
||||||
|
nickname = user_info_resp.user_data.name or user_info_resp.user_data.login
|
||||||
|
avatar_url = user_info_resp.user_data.avatar_url
|
||||||
|
email = user_info_resp.user_data.email
|
||||||
|
elif provider == AuthProviderType.QQ:
|
||||||
|
from service.oauth import QQOAuth
|
||||||
|
oauth_client = QQOAuth(client_id, client_secret)
|
||||||
|
token_resp = await oauth_client.get_access_token(
|
||||||
|
code=request.identifier,
|
||||||
|
redirect_uri=request.redirect_uri or "",
|
||||||
|
)
|
||||||
|
openid_resp = await oauth_client.get_openid(token_resp.access_token)
|
||||||
|
user_info_resp = await oauth_client.get_user_info(
|
||||||
|
token_resp,
|
||||||
|
app_id=client_id,
|
||||||
|
openid=openid_resp.openid,
|
||||||
|
)
|
||||||
|
openid = openid_resp.openid
|
||||||
|
nickname = user_info_resp.user_data.nickname
|
||||||
|
avatar_url = user_info_resp.user_data.figureurl_qq_2 or user_info_resp.user_data.figureurl_2
|
||||||
|
email = None
|
||||||
|
else:
|
||||||
|
http_exceptions.raise_bad_request(f"不支持的 OAuth 提供者: {provider.value}")
|
||||||
|
|
||||||
|
# 查找已有 AuthIdentity
|
||||||
|
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.provider == provider) & (AuthIdentity.identifier == openid),
|
||||||
|
)
|
||||||
|
|
||||||
|
if identity:
|
||||||
|
# 已绑定 → 更新 OAuth 信息并返回关联用户
|
||||||
|
identity.display_name = nickname
|
||||||
|
identity.avatar_url = avatar_url
|
||||||
|
identity = await identity.save(session)
|
||||||
|
|
||||||
|
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||||
|
if not user:
|
||||||
|
http_exceptions.raise_unauthorized("用户不存在")
|
||||||
|
if user.status != UserStatus.ACTIVE:
|
||||||
|
http_exceptions.raise_forbidden("账户已被禁用")
|
||||||
|
return user
|
||||||
|
|
||||||
|
# 未绑定 → 自动注册
|
||||||
|
user = await _auto_register_oauth_user(
|
||||||
|
session,
|
||||||
|
provider=provider,
|
||||||
|
openid=openid,
|
||||||
|
nickname=nickname,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
email=email,
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def _auto_register_oauth_user(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
provider: AuthProviderType,
|
||||||
|
openid: str,
|
||||||
|
nickname: str | None,
|
||||||
|
avatar_url: str | None,
|
||||||
|
email: str | None,
|
||||||
|
) -> User:
|
||||||
|
"""OAuth 自动注册用户"""
|
||||||
|
# 获取默认用户组
|
||||||
|
default_group_setting = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.REGISTER) & (Setting.name == "default_group"),
|
||||||
|
)
|
||||||
|
if not default_group_setting or not default_group_setting.value:
|
||||||
|
l.error("默认用户组未配置")
|
||||||
|
http_exceptions.raise_internal_error()
|
||||||
|
|
||||||
|
default_group_id = UUID(default_group_setting.value)
|
||||||
|
|
||||||
|
# 创建用户
|
||||||
|
new_user = User(
|
||||||
|
email=email,
|
||||||
|
nickname=nickname,
|
||||||
|
avatar=avatar_url or "default",
|
||||||
|
group_id=default_group_id,
|
||||||
|
)
|
||||||
|
new_user_id = new_user.id
|
||||||
|
new_user = await new_user.save(session)
|
||||||
|
|
||||||
|
# 创建 AuthIdentity
|
||||||
|
identity = AuthIdentity(
|
||||||
|
provider=provider,
|
||||||
|
identifier=openid,
|
||||||
|
display_name=nickname,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
is_primary=True,
|
||||||
|
is_verified=True,
|
||||||
|
user_id=new_user_id,
|
||||||
|
)
|
||||||
|
identity = await identity.save(session)
|
||||||
|
|
||||||
|
# 创建用户根目录
|
||||||
|
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||||
|
if default_policy:
|
||||||
|
await Object(
|
||||||
|
name="/",
|
||||||
|
type=ObjectType.FOLDER,
|
||||||
|
owner_id=new_user_id,
|
||||||
|
parent_id=None,
|
||||||
|
policy_id=default_policy.id,
|
||||||
|
).save(session)
|
||||||
|
|
||||||
|
# 重新加载用户(含 group 关系)
|
||||||
|
user: User = await User.get(session, User.id == new_user_id, load=User.group)
|
||||||
|
l.info(f"OAuth 自动注册用户: provider={provider.value}, openid={openid}")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def _login_passkey(
|
||||||
|
session: AsyncSession,
|
||||||
|
request: UnifiedLoginRequest,
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Passkey/WebAuthn 登录(Discoverable Credentials 模式)
|
||||||
|
|
||||||
|
identifier 为 challenge_token(前端从 ``POST /authn/options`` 获取),
|
||||||
|
credential 为 JSON 格式的 authenticator assertion response。
|
||||||
|
"""
|
||||||
|
from webauthn import verify_authentication_response
|
||||||
|
from webauthn.helpers import base64url_to_bytes
|
||||||
|
|
||||||
|
from service.redis.challenge_store import ChallengeStore
|
||||||
|
from service.webauthn import get_rp_config
|
||||||
|
from sqlmodels.user_authn import UserAuthn
|
||||||
|
|
||||||
|
if not request.credential:
|
||||||
|
http_exceptions.raise_bad_request("WebAuthn assertion response 不能为空")
|
||||||
|
|
||||||
|
if not request.identifier:
|
||||||
|
http_exceptions.raise_bad_request("challenge_token 不能为空")
|
||||||
|
|
||||||
|
# 从 ChallengeStore 取出 challenge(一次性,防重放)
|
||||||
|
challenge: bytes | None = await ChallengeStore.retrieve_and_delete(f"auth:{request.identifier}")
|
||||||
|
if challenge is None:
|
||||||
|
http_exceptions.raise_unauthorized("登录会话已过期,请重新获取 options")
|
||||||
|
|
||||||
|
# 从 assertion JSON 中解析 credential_id(Discoverable Credentials 模式)
|
||||||
|
import orjson
|
||||||
|
credential_dict: dict = orjson.loads(request.credential)
|
||||||
|
credential_id_b64: str | None = credential_dict.get("id")
|
||||||
|
if not credential_id_b64:
|
||||||
|
http_exceptions.raise_bad_request("缺少凭证 ID")
|
||||||
|
|
||||||
|
# 查找 UserAuthn 记录
|
||||||
|
authn: UserAuthn | None = await UserAuthn.get(
|
||||||
|
session,
|
||||||
|
UserAuthn.credential_id == credential_id_b64,
|
||||||
|
)
|
||||||
|
if not authn:
|
||||||
|
http_exceptions.raise_unauthorized("Passkey 凭证未注册")
|
||||||
|
|
||||||
|
# 获取 RP 配置
|
||||||
|
rp_id, _rp_name, origin = await get_rp_config(session)
|
||||||
|
|
||||||
|
# 验证 WebAuthn assertion
|
||||||
|
try:
|
||||||
|
verification = verify_authentication_response(
|
||||||
|
credential=request.credential,
|
||||||
|
expected_rp_id=rp_id,
|
||||||
|
expected_origin=origin,
|
||||||
|
expected_challenge=challenge,
|
||||||
|
credential_public_key=base64url_to_bytes(authn.credential_public_key),
|
||||||
|
credential_current_sign_count=authn.sign_count,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
l.warning(f"WebAuthn 验证失败: {e}")
|
||||||
|
http_exceptions.raise_unauthorized("Passkey 验证失败")
|
||||||
|
|
||||||
|
# 更新签名计数
|
||||||
|
authn.sign_count = verification.new_sign_count
|
||||||
|
authn = await authn.save(session)
|
||||||
|
|
||||||
|
# 加载用户
|
||||||
|
user: User = await User.get(session, User.id == authn.user_id, load=User.group)
|
||||||
|
if not user:
|
||||||
|
http_exceptions.raise_unauthorized("用户不存在")
|
||||||
|
if user.status != UserStatus.ACTIVE:
|
||||||
|
http_exceptions.raise_forbidden("账户已被禁用")
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def _login_magic_link(
|
||||||
|
session: AsyncSession,
|
||||||
|
request: UnifiedLoginRequest,
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Magic Link 登录
|
||||||
|
|
||||||
|
identifier 为签名 token,由 itsdangerous 生成。
|
||||||
|
"""
|
||||||
|
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||||
|
|
||||||
|
try:
|
||||||
|
email = serializer.loads(request.identifier, salt="magic-link-salt", max_age=600)
|
||||||
|
except SignatureExpired:
|
||||||
|
http_exceptions.raise_unauthorized("Magic Link 已过期")
|
||||||
|
except BadSignature:
|
||||||
|
http_exceptions.raise_unauthorized("Magic Link 无效")
|
||||||
|
|
||||||
|
# 防重放:使用 token 哈希作为标识符
|
||||||
|
token_hash = hashlib.sha256(request.identifier.encode()).hexdigest()
|
||||||
|
is_first_use = await TokenStore.mark_used(f"magic_link:{token_hash}", ttl=600)
|
||||||
|
if not is_first_use:
|
||||||
|
http_exceptions.raise_unauthorized("Magic Link 已被使用")
|
||||||
|
|
||||||
|
# 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link)
|
||||||
|
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||||
|
session,
|
||||||
|
(AuthIdentity.identifier == email)
|
||||||
|
& (
|
||||||
|
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||||
|
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not identity:
|
||||||
|
http_exceptions.raise_unauthorized("该邮箱未注册")
|
||||||
|
|
||||||
|
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||||
|
if not user:
|
||||||
|
http_exceptions.raise_unauthorized("用户不存在")
|
||||||
|
if user.status != UserStatus.ACTIVE:
|
||||||
|
http_exceptions.raise_forbidden("账户已被禁用")
|
||||||
|
|
||||||
|
# 标记邮箱已验证
|
||||||
|
if not identity.is_verified:
|
||||||
|
identity.is_verified = True
|
||||||
|
identity = await identity.save(session)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def _issue_tokens(session: AsyncSession, user: User) -> TokenResponse:
|
||||||
|
"""
|
||||||
|
签发 JWT 双令牌(access + refresh)
|
||||||
|
|
||||||
|
提取自原 login.py 的签发逻辑,供所有 provider 共用。
|
||||||
|
"""
|
||||||
# 加载 GroupOptions
|
# 加载 GroupOptions
|
||||||
group_options: GroupOptions | None = await GroupOptions.get(
|
group_options: GroupOptions | None = await GroupOptions.get(
|
||||||
session,
|
session,
|
||||||
GroupOptions.group_id == current_user.group_id,
|
GroupOptions.group_id == user.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建权限快照
|
# 构建权限快照
|
||||||
current_user.group.options = group_options
|
user.group.options = group_options
|
||||||
group_claims = GroupClaims.from_group(current_user.group)
|
group_claims = GroupClaims.from_group(user.group)
|
||||||
|
|
||||||
# 创建令牌
|
# 创建令牌
|
||||||
access_token = create_access_token(
|
access_token = JWT.create_access_token(
|
||||||
sub=current_user.id,
|
sub=user.id,
|
||||||
jti=uuid4(),
|
jti=uuid4(),
|
||||||
status=current_user.status.value,
|
status=user.status.value,
|
||||||
group=group_claims,
|
group=group_claims,
|
||||||
)
|
)
|
||||||
refresh_token = create_refresh_token(
|
refresh_token = JWT.create_refresh_token(
|
||||||
sub=current_user.id,
|
sub=user.id,
|
||||||
jti=uuid4()
|
jti=uuid4(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
|
|||||||
41
service/webauthn.py
Normal file
41
service/webauthn.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
WebAuthn RP(Relying Party)配置辅助模块
|
||||||
|
|
||||||
|
从数据库 Setting 中读取 siteURL / siteTitle,
|
||||||
|
解析出 rp_id、rp_name、origin,供注册/登录流程复用。
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from sqlmodels.setting import Setting, SettingsType
|
||||||
|
|
||||||
|
|
||||||
|
async def get_rp_config(session: AsyncSession) -> tuple[str, str, str]:
|
||||||
|
"""
|
||||||
|
获取 WebAuthn RP 配置。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:return: ``(rp_id, rp_name, origin)`` 元组
|
||||||
|
|
||||||
|
- ``rp_id``: 站点域名(从 siteURL 解析,如 ``example.com``)
|
||||||
|
- ``rp_name``: 站点标题
|
||||||
|
- ``origin``: 完整 origin(如 ``https://example.com``)
|
||||||
|
"""
|
||||||
|
site_url_setting: Setting | None = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||||
|
)
|
||||||
|
site_title_setting: Setting | None = await Setting.get(
|
||||||
|
session,
|
||||||
|
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteTitle"),
|
||||||
|
)
|
||||||
|
|
||||||
|
site_url: str = site_url_setting.value if site_url_setting and site_url_setting.value else "https://localhost"
|
||||||
|
rp_name: str = site_title_setting.value if site_title_setting and site_title_setting.value else "DiskNext"
|
||||||
|
|
||||||
|
parsed = urlparse(site_url)
|
||||||
|
rp_id: str = parsed.hostname or "localhost"
|
||||||
|
origin: str = f"{parsed.scheme}://{parsed.netloc}" if parsed.netloc else site_url
|
||||||
|
|
||||||
|
return rp_id, rp_name, origin
|
||||||
185
service/wopi/__init__.py
Normal file
185
service/wopi/__init__.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""
|
||||||
|
WOPI Discovery 服务模块
|
||||||
|
|
||||||
|
解析 WOPI 服务端(Collabora / OnlyOffice 等)的 Discovery XML,
|
||||||
|
提取支持的文件扩展名及对应的编辑器 URL 模板。
|
||||||
|
|
||||||
|
参考:Cloudreve pkg/wopi/discovery.go 和 pkg/wopi/wopi.go
|
||||||
|
"""
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
|
from loguru import logger as l
|
||||||
|
|
||||||
|
# WOPI URL 模板中已知的查询参数占位符及其替换值
|
||||||
|
# 值为 None 表示删除该参数,非 None 表示替换为该值
|
||||||
|
# 参考 Cloudreve pkg/wopi/wopi.go queryPlaceholders
|
||||||
|
_WOPI_QUERY_PLACEHOLDERS: dict[str, str | None] = {
|
||||||
|
'BUSINESS_USER': None,
|
||||||
|
'DC_LLCC': 'lng',
|
||||||
|
'DISABLE_ASYNC': None,
|
||||||
|
'DISABLE_CHAT': None,
|
||||||
|
'EMBEDDED': 'true',
|
||||||
|
'FULLSCREEN': 'true',
|
||||||
|
'HOST_SESSION_ID': None,
|
||||||
|
'SESSION_CONTEXT': None,
|
||||||
|
'RECORDING': None,
|
||||||
|
'THEME_ID': 'darkmode',
|
||||||
|
'UI_LLCC': 'lng',
|
||||||
|
'VALIDATOR_TEST_CATEGORY': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
_WOPI_SRC_PLACEHOLDER = 'WOPI_SOURCE'
|
||||||
|
|
||||||
|
|
||||||
|
def process_wopi_action_url(raw_urlsrc: str) -> str:
|
||||||
|
"""
|
||||||
|
将 WOPI Discovery 中的原始 urlsrc 转换为 DiskNext 可用的 URL 模板。
|
||||||
|
|
||||||
|
处理流程(参考 Cloudreve generateActionUrl):
|
||||||
|
1. 去除 ``<>`` 占位符标记
|
||||||
|
2. 解析查询参数,替换/删除已知占位符
|
||||||
|
3. ``WOPI_SOURCE`` → ``{wopi_src}``
|
||||||
|
|
||||||
|
注意:access_token 和 access_token_ttl 不放在 URL 中,
|
||||||
|
根据 WOPI 规范它们通过 POST 表单字段传递给编辑器。
|
||||||
|
|
||||||
|
:param raw_urlsrc: WOPI Discovery XML 中的 urlsrc 原始值
|
||||||
|
:return: 处理后的 URL 模板字符串,包含 {wopi_src} 占位符
|
||||||
|
"""
|
||||||
|
# 去除 <> 标记
|
||||||
|
cleaned = raw_urlsrc.replace('<', '').replace('>', '')
|
||||||
|
parsed = urlparse(cleaned)
|
||||||
|
raw_params = parse_qs(parsed.query, keep_blank_values=True)
|
||||||
|
|
||||||
|
new_params: list[tuple[str, str]] = []
|
||||||
|
is_src_replaced = False
|
||||||
|
|
||||||
|
for key, values in raw_params.items():
|
||||||
|
value = values[0] if values else ''
|
||||||
|
|
||||||
|
# WOPI_SOURCE 占位符 → {wopi_src}
|
||||||
|
if value == _WOPI_SRC_PLACEHOLDER:
|
||||||
|
new_params.append((key, '{wopi_src}'))
|
||||||
|
is_src_replaced = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 已知占位符
|
||||||
|
if value in _WOPI_QUERY_PLACEHOLDERS:
|
||||||
|
replacement = _WOPI_QUERY_PLACEHOLDERS[value]
|
||||||
|
if replacement is not None:
|
||||||
|
new_params.append((key, replacement))
|
||||||
|
# replacement 为 None 时删除该参数
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 其他参数保留原值
|
||||||
|
new_params.append((key, value))
|
||||||
|
|
||||||
|
# 如果没有找到 WOPI_SOURCE 占位符,手动添加 WOPISrc
|
||||||
|
if not is_src_replaced:
|
||||||
|
new_params.append(('WOPISrc', '{wopi_src}'))
|
||||||
|
|
||||||
|
# LibreOffice/Collabora 需要 lang 参数(避免重复添加)
|
||||||
|
existing_keys = {k for k, _ in new_params}
|
||||||
|
if 'lang' not in existing_keys:
|
||||||
|
new_params.append(('lang', 'lng'))
|
||||||
|
|
||||||
|
# 注意:access_token 和 access_token_ttl 不放在 URL 中
|
||||||
|
# 根据 WOPI 规范,它们通过 POST 表单字段传递给编辑器
|
||||||
|
|
||||||
|
# 重建 URL
|
||||||
|
new_query = urlencode(new_params, safe='{}')
|
||||||
|
result = urlunparse((
|
||||||
|
parsed.scheme,
|
||||||
|
parsed.netloc,
|
||||||
|
parsed.path,
|
||||||
|
parsed.params,
|
||||||
|
new_query,
|
||||||
|
'',
|
||||||
|
))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def parse_wopi_discovery_xml(xml_content: str) -> tuple[dict[str, str], list[str]]:
|
||||||
|
"""
|
||||||
|
解析 WOPI Discovery XML,提取扩展名到 URL 模板的映射。
|
||||||
|
|
||||||
|
XML 结构::
|
||||||
|
|
||||||
|
<wopi-discovery>
|
||||||
|
<net-zone name="external-https">
|
||||||
|
<app name="Writer" favIconUrl="...">
|
||||||
|
<action name="edit" ext="docx" urlsrc="https://..."/>
|
||||||
|
<action name="view" ext="docx" urlsrc="https://..."/>
|
||||||
|
</app>
|
||||||
|
</net-zone>
|
||||||
|
</wopi-discovery>
|
||||||
|
|
||||||
|
动作优先级:edit > embedview > view(参考 Cloudreve discovery.go)
|
||||||
|
|
||||||
|
:param xml_content: WOPI Discovery 端点返回的 XML 字符串
|
||||||
|
:return: (action_urls, app_names) 元组
|
||||||
|
action_urls: {extension: processed_url_template}
|
||||||
|
app_names: 发现的应用名称列表
|
||||||
|
:raises ValueError: XML 解析失败或格式无效
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
root = ET.fromstring(xml_content)
|
||||||
|
except ET.ParseError as e:
|
||||||
|
raise ValueError(f"WOPI Discovery XML 解析失败: {e}")
|
||||||
|
|
||||||
|
# 查找 net-zone(可能有多个,取第一个非空的)
|
||||||
|
net_zones = root.findall('net-zone')
|
||||||
|
if not net_zones:
|
||||||
|
raise ValueError("WOPI Discovery XML 缺少 net-zone 节点")
|
||||||
|
|
||||||
|
# ext_actions: {extension: {action_name: urlsrc}}
|
||||||
|
ext_actions: dict[str, dict[str, str]] = {}
|
||||||
|
app_names: list[str] = []
|
||||||
|
|
||||||
|
for net_zone in net_zones:
|
||||||
|
for app_elem in net_zone.findall('app'):
|
||||||
|
app_name = app_elem.get('name', '')
|
||||||
|
if app_name:
|
||||||
|
app_names.append(app_name)
|
||||||
|
|
||||||
|
for action_elem in app_elem.findall('action'):
|
||||||
|
action_name = action_elem.get('name', '')
|
||||||
|
ext = action_elem.get('ext', '')
|
||||||
|
urlsrc = action_elem.get('urlsrc', '')
|
||||||
|
|
||||||
|
if not ext or not urlsrc:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 只关注 edit / embedview / view 三种动作
|
||||||
|
if action_name not in ('edit', 'embedview', 'view'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ext not in ext_actions:
|
||||||
|
ext_actions[ext] = {}
|
||||||
|
ext_actions[ext][action_name] = urlsrc
|
||||||
|
|
||||||
|
# 为每个扩展名选择最佳 URL: edit > embedview > view
|
||||||
|
action_urls: dict[str, str] = {}
|
||||||
|
for ext, actions_map in ext_actions.items():
|
||||||
|
selected_urlsrc: str | None = None
|
||||||
|
for preferred in ('edit', 'embedview', 'view'):
|
||||||
|
if preferred in actions_map:
|
||||||
|
selected_urlsrc = actions_map[preferred]
|
||||||
|
break
|
||||||
|
|
||||||
|
if selected_urlsrc:
|
||||||
|
action_urls[ext] = process_wopi_action_url(selected_urlsrc)
|
||||||
|
|
||||||
|
# 去重 app_names
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique_names: list[str] = []
|
||||||
|
for name in app_names:
|
||||||
|
if name not in seen:
|
||||||
|
seen.add(name)
|
||||||
|
unique_names.append(name)
|
||||||
|
|
||||||
|
l.info(f"WOPI Discovery 解析完成: {len(action_urls)} 个扩展名, 应用: {unique_names}")
|
||||||
|
|
||||||
|
return action_urls, unique_names
|
||||||
92
setup_cython.py
Normal file
92
setup_cython.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
Cython 编译脚本 — 将 ee/ 下的纯逻辑文件编译为 .so
|
||||||
|
|
||||||
|
用法:
|
||||||
|
uv run --extra build python setup_cython.py build_ext --inplace
|
||||||
|
|
||||||
|
编译规则:
|
||||||
|
- 跳过 __init__.py(Python 包发现需要)
|
||||||
|
- 只编译 .py 文件(纯函数 / 服务逻辑)
|
||||||
|
|
||||||
|
编译后清理(Pro Docker 构建用):
|
||||||
|
uv run --extra build python setup_cython.py clean_artifacts
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
EE_DIR = Path("ee")
|
||||||
|
|
||||||
|
# 跳过 __init__.py —— 包发现需要原始 .py
|
||||||
|
SKIP_NAMES = {"__init__.py"}
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_modules() -> list[str]:
|
||||||
|
"""收集 ee/ 下需要编译的 .py 文件路径(点分模块名)。"""
|
||||||
|
modules: list[str] = []
|
||||||
|
for py_file in EE_DIR.rglob("*.py"):
|
||||||
|
if py_file.name in SKIP_NAMES:
|
||||||
|
continue
|
||||||
|
# ee/license.py → ee.license
|
||||||
|
module = str(py_file.with_suffix("")).replace("\\", "/").replace("/", ".")
|
||||||
|
modules.append(module)
|
||||||
|
return modules
|
||||||
|
|
||||||
|
|
||||||
|
def clean_artifacts() -> None:
|
||||||
|
"""删除已编译的 .py 源码、.c 中间文件和 build/ 目录。"""
|
||||||
|
for py_file in EE_DIR.rglob("*.py"):
|
||||||
|
if py_file.name in SKIP_NAMES:
|
||||||
|
continue
|
||||||
|
# 只删除有对应 .so / .pyd 的源文件
|
||||||
|
parent = py_file.parent
|
||||||
|
stem = py_file.stem
|
||||||
|
has_compiled = (
|
||||||
|
any(parent.glob(f"{stem}*.so")) or
|
||||||
|
any(parent.glob(f"{stem}*.pyd"))
|
||||||
|
)
|
||||||
|
if has_compiled:
|
||||||
|
py_file.unlink()
|
||||||
|
print(f"已删除源码: {py_file}")
|
||||||
|
|
||||||
|
# 删除 .c 中间文件
|
||||||
|
for c_file in EE_DIR.rglob("*.c"):
|
||||||
|
c_file.unlink()
|
||||||
|
print(f"已删除中间文件: {c_file}")
|
||||||
|
|
||||||
|
# 删除 build/ 目录
|
||||||
|
build_dir = Path("build")
|
||||||
|
if build_dir.exists():
|
||||||
|
shutil.rmtree(build_dir)
|
||||||
|
print(f"已删除: {build_dir}/")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] == "clean_artifacts":
|
||||||
|
clean_artifacts()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# 动态导入(仅在编译时需要)
|
||||||
|
from Cython.Build import cythonize
|
||||||
|
from setuptools import Extension, setup
|
||||||
|
|
||||||
|
modules = _collect_modules()
|
||||||
|
if not modules:
|
||||||
|
print("未找到需要编译的模块")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
print(f"即将编译以下模块: {modules}")
|
||||||
|
|
||||||
|
extensions = [
|
||||||
|
Extension(mod, [mod.replace(".", "/") + ".py"])
|
||||||
|
for mod in modules
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="disknext-ee",
|
||||||
|
packages=[],
|
||||||
|
ext_modules=cythonize(
|
||||||
|
extensions,
|
||||||
|
compiler_directives={'language_level': "3"},
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -954,18 +954,11 @@ class PolicyType(StrEnum):
|
|||||||
S3 = "s3" # S3 兼容存储
|
S3 = "s3" # S3 兼容存储
|
||||||
```
|
```
|
||||||
|
|
||||||
### StorageType
|
### PolicyType
|
||||||
```python
|
```python
|
||||||
class StorageType(StrEnum):
|
class PolicyType(StrEnum):
|
||||||
LOCAL = "local" # 本地存储
|
LOCAL = "local" # 本地存储
|
||||||
QINIU = "qiniu" # 七牛云
|
S3 = "s3" # S3 兼容存储
|
||||||
TENCENT = "tencent" # 腾讯云
|
|
||||||
ALIYUN = "aliyun" # 阿里云
|
|
||||||
ONEDRIVE = "onedrive" # OneDrive
|
|
||||||
GOOGLE_DRIVE = "google_drive" # Google Drive
|
|
||||||
DROPBOX = "dropbox" # Dropbox
|
|
||||||
WEBDAV = "webdav" # WebDAV
|
|
||||||
REMOTE = "remote" # 远程存储
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### UserStatus
|
### UserStatus
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
|
from .auth_identity import (
|
||||||
|
AuthIdentity,
|
||||||
|
AuthIdentityResponse,
|
||||||
|
AuthProviderType,
|
||||||
|
BindIdentityRequest,
|
||||||
|
ChangePasswordRequest,
|
||||||
|
)
|
||||||
from .user import (
|
from .user import (
|
||||||
BatchDeleteRequest,
|
BatchDeleteRequest,
|
||||||
JWTPayload,
|
JWTPayload,
|
||||||
LoginRequest,
|
MagicLinkRequest,
|
||||||
|
UnifiedLoginRequest,
|
||||||
|
UnifiedRegisterRequest,
|
||||||
RefreshTokenRequest,
|
RefreshTokenRequest,
|
||||||
RegisterRequest,
|
|
||||||
AccessTokenBase,
|
AccessTokenBase,
|
||||||
RefreshTokenBase,
|
RefreshTokenBase,
|
||||||
TokenResponse,
|
TokenResponse,
|
||||||
@@ -13,14 +21,28 @@ from .user import (
|
|||||||
UserPublic,
|
UserPublic,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
UserSettingResponse,
|
UserSettingResponse,
|
||||||
|
UserThemeUpdateRequest,
|
||||||
|
SettingOption,
|
||||||
|
UserSettingUpdateRequest,
|
||||||
WebAuthnInfo,
|
WebAuthnInfo,
|
||||||
|
UserTwoFactorResponse,
|
||||||
# 管理员DTO
|
# 管理员DTO
|
||||||
UserAdminUpdateRequest,
|
UserAdminUpdateRequest,
|
||||||
UserCalibrateResponse,
|
UserCalibrateResponse,
|
||||||
UserAdminDetailResponse,
|
UserAdminDetailResponse,
|
||||||
)
|
)
|
||||||
from .user_authn import AuthnResponse, UserAuthn
|
from .user_authn import (
|
||||||
from .color import ThemeResponse
|
AuthnDetailResponse,
|
||||||
|
AuthnFinishRequest,
|
||||||
|
AuthnRenameRequest,
|
||||||
|
UserAuthn,
|
||||||
|
)
|
||||||
|
from .color import ChromaticColor, NeutralColor, ThemeColorsBase, BUILTIN_DEFAULT_COLORS
|
||||||
|
from .theme_preset import (
|
||||||
|
ThemePreset, ThemePresetBase,
|
||||||
|
ThemePresetCreateRequest, ThemePresetUpdateRequest,
|
||||||
|
ThemePresetResponse, ThemePresetListResponse,
|
||||||
|
)
|
||||||
|
|
||||||
from .download import (
|
from .download import (
|
||||||
Download,
|
Download,
|
||||||
@@ -47,18 +69,20 @@ from .object import (
|
|||||||
CreateUploadSessionRequest,
|
CreateUploadSessionRequest,
|
||||||
DirectoryCreateRequest,
|
DirectoryCreateRequest,
|
||||||
DirectoryResponse,
|
DirectoryResponse,
|
||||||
FileMetadata,
|
|
||||||
FileMetadataBase,
|
|
||||||
Object,
|
Object,
|
||||||
ObjectBase,
|
ObjectBase,
|
||||||
ObjectCopyRequest,
|
ObjectCopyRequest,
|
||||||
ObjectDeleteRequest,
|
ObjectDeleteRequest,
|
||||||
|
ObjectFileFinalize,
|
||||||
ObjectMoveRequest,
|
ObjectMoveRequest,
|
||||||
|
ObjectMoveUpdate,
|
||||||
ObjectPropertyDetailResponse,
|
ObjectPropertyDetailResponse,
|
||||||
ObjectPropertyResponse,
|
ObjectPropertyResponse,
|
||||||
ObjectRenameRequest,
|
ObjectRenameRequest,
|
||||||
ObjectResponse,
|
ObjectResponse,
|
||||||
|
ObjectSwitchPolicyRequest,
|
||||||
ObjectType,
|
ObjectType,
|
||||||
|
FileCategory,
|
||||||
PolicyResponse,
|
PolicyResponse,
|
||||||
UploadChunkResponse,
|
UploadChunkResponse,
|
||||||
UploadSession,
|
UploadSession,
|
||||||
@@ -68,24 +92,75 @@ from .object import (
|
|||||||
AdminFileResponse,
|
AdminFileResponse,
|
||||||
AdminFileListResponse,
|
AdminFileListResponse,
|
||||||
FileBanRequest,
|
FileBanRequest,
|
||||||
|
# 回收站DTO
|
||||||
|
TrashItemResponse,
|
||||||
|
TrashRestoreRequest,
|
||||||
|
TrashDeleteRequest,
|
||||||
|
)
|
||||||
|
from .object_metadata import (
|
||||||
|
ObjectMetadata,
|
||||||
|
ObjectMetadataBase,
|
||||||
|
MetadataNamespace,
|
||||||
|
MetadataResponse,
|
||||||
|
MetadataPatchItem,
|
||||||
|
MetadataPatchRequest,
|
||||||
|
INTERNAL_NAMESPACES,
|
||||||
|
USER_WRITABLE_NAMESPACES,
|
||||||
|
)
|
||||||
|
from .custom_property import (
|
||||||
|
CustomPropertyDefinition,
|
||||||
|
CustomPropertyDefinitionBase,
|
||||||
|
CustomPropertyType,
|
||||||
|
CustomPropertyCreateRequest,
|
||||||
|
CustomPropertyUpdateRequest,
|
||||||
|
CustomPropertyResponse,
|
||||||
)
|
)
|
||||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||||
from .uri import DiskNextURI, FileSystemNamespace
|
from .uri import DiskNextURI, FileSystemNamespace
|
||||||
from .order import Order, OrderStatus, OrderType
|
from .order import (
|
||||||
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
|
Order, OrderStatus, OrderType,
|
||||||
from .redeem import Redeem, RedeemType
|
CreateOrderRequest, OrderResponse,
|
||||||
|
)
|
||||||
|
from .policy import (
|
||||||
|
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
|
||||||
|
PolicyType, PolicySummary, PolicyUpdateRequest,
|
||||||
|
)
|
||||||
|
from .product import (
|
||||||
|
Product, ProductBase, ProductType, PaymentMethod,
|
||||||
|
ProductCreateRequest, ProductUpdateRequest, ProductResponse,
|
||||||
|
)
|
||||||
|
from .redeem import (
|
||||||
|
Redeem, RedeemType,
|
||||||
|
RedeemCreateRequest, RedeemUseRequest, RedeemInfoResponse, RedeemAdminResponse,
|
||||||
|
)
|
||||||
from .report import Report, ReportReason
|
from .report import Report, ReportReason
|
||||||
from .setting import (
|
from .setting import (
|
||||||
Setting, SettingsType, SiteConfigResponse,
|
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
|
||||||
# 管理员DTO
|
# 管理员DTO
|
||||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||||
)
|
)
|
||||||
from .share import Share, ShareBase, ShareCreateRequest, ShareResponse, AdminShareListItem
|
from .share import (
|
||||||
|
Share, ShareBase, ShareCreateRequest, CreateShareResponse, ShareResponse,
|
||||||
|
ShareOwnerInfo, ShareObjectItem, ShareDetailResponse,
|
||||||
|
AdminShareListItem,
|
||||||
|
)
|
||||||
from .source_link import SourceLink
|
from .source_link import SourceLink
|
||||||
from .storage_pack import StoragePack
|
from .storage_pack import StoragePack, StoragePackResponse
|
||||||
from .tag import Tag, TagType
|
from .tag import Tag, TagType
|
||||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
|
||||||
from .webdav import WebDAV
|
from .webdav import (
|
||||||
|
WebDAV, WebDAVBase,
|
||||||
|
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
|
||||||
|
)
|
||||||
|
from .file_app import (
|
||||||
|
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
|
||||||
|
# DTO
|
||||||
|
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
|
||||||
|
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
|
||||||
|
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
|
||||||
|
WopiDiscoveredExtension, WopiDiscoveryResponse,
|
||||||
|
)
|
||||||
|
from .wopi import WopiFileInfo, WopiAccessTokenPayload
|
||||||
|
|
||||||
from .database_connection import DatabaseManager
|
from .database_connection import DatabaseManager
|
||||||
|
|
||||||
@@ -102,5 +177,5 @@ from .model_base import (
|
|||||||
AdminSummaryResponse,
|
AdminSummaryResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# mixin 中的通用分页模型
|
# 通用分页模型
|
||||||
from .mixin import ListResponse
|
from sqlmodel_ext import ListResponse
|
||||||
|
|||||||
148
sqlmodels/auth_identity.py
Normal file
148
sqlmodels/auth_identity.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
认证身份模块
|
||||||
|
|
||||||
|
一个用户可拥有多种登录方式(邮箱密码、OAuth、Passkey、Magic Link 等)。
|
||||||
|
AuthIdentity 表存储每种认证方式的凭证信息。
|
||||||
|
"""
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||||
|
|
||||||
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100, Str128, Str255, Text1024
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProviderType(StrEnum):
|
||||||
|
"""认证提供者类型"""
|
||||||
|
|
||||||
|
EMAIL_PASSWORD = "email_password"
|
||||||
|
"""邮箱+密码"""
|
||||||
|
|
||||||
|
PHONE_SMS = "phone_sms"
|
||||||
|
"""手机号+短信验证码(预留)"""
|
||||||
|
|
||||||
|
GITHUB = "github"
|
||||||
|
"""GitHub OAuth"""
|
||||||
|
|
||||||
|
QQ = "qq"
|
||||||
|
"""QQ OAuth"""
|
||||||
|
|
||||||
|
PASSKEY = "passkey"
|
||||||
|
"""Passkey/WebAuthn"""
|
||||||
|
|
||||||
|
MAGIC_LINK = "magic_link"
|
||||||
|
"""邮箱 Magic Link"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class AuthIdentityResponse(SQLModelBase):
|
||||||
|
"""认证身份响应 DTO(列表展示用)"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""身份UUID"""
|
||||||
|
|
||||||
|
provider: AuthProviderType
|
||||||
|
"""提供者类型"""
|
||||||
|
|
||||||
|
identifier: str
|
||||||
|
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||||
|
|
||||||
|
display_name: str | None = None
|
||||||
|
"""显示名称(OAuth 昵称等)"""
|
||||||
|
|
||||||
|
avatar_url: str | None = None
|
||||||
|
"""头像 URL"""
|
||||||
|
|
||||||
|
is_primary: bool = False
|
||||||
|
"""是否主要身份"""
|
||||||
|
|
||||||
|
is_verified: bool = False
|
||||||
|
"""是否已验证"""
|
||||||
|
|
||||||
|
|
||||||
|
class BindIdentityRequest(SQLModelBase):
|
||||||
|
"""绑定认证身份请求 DTO"""
|
||||||
|
|
||||||
|
provider: AuthProviderType
|
||||||
|
"""提供者类型"""
|
||||||
|
|
||||||
|
identifier: str
|
||||||
|
"""标识符(邮箱/手机号/OAuth code)"""
|
||||||
|
|
||||||
|
credential: str | None = None
|
||||||
|
"""凭证(密码、验证码等)"""
|
||||||
|
|
||||||
|
redirect_uri: str | None = None
|
||||||
|
"""OAuth 回调地址"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChangePasswordRequest(SQLModelBase):
|
||||||
|
"""修改密码请求 DTO"""
|
||||||
|
|
||||||
|
old_password: str = Field(min_length=1)
|
||||||
|
"""当前密码"""
|
||||||
|
|
||||||
|
new_password: Str128 = Field(min_length=8)
|
||||||
|
"""新密码(至少 8 位)"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
|
||||||
|
"""用户认证身份 — 一个用户可以有多种登录方式"""
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("provider", "identifier", name="uq_auth_identity_provider_identifier"),
|
||||||
|
)
|
||||||
|
|
||||||
|
provider: AuthProviderType = Field(index=True)
|
||||||
|
"""提供者类型"""
|
||||||
|
|
||||||
|
identifier: Str255 = Field(index=True)
|
||||||
|
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||||
|
|
||||||
|
credential: Text1024 | None = None
|
||||||
|
"""凭证(Argon2 哈希密码 / null)"""
|
||||||
|
|
||||||
|
display_name: Str100 | None = None
|
||||||
|
"""OAuth 昵称"""
|
||||||
|
|
||||||
|
avatar_url: str | None = Field(default=None, max_length=512)
|
||||||
|
"""OAuth 头像 URL"""
|
||||||
|
|
||||||
|
extra_data: str | None = None
|
||||||
|
"""JSON 附加数据(2FA secret、OAuth refresh_token 等)"""
|
||||||
|
|
||||||
|
is_primary: bool = False
|
||||||
|
"""是否主要身份"""
|
||||||
|
|
||||||
|
is_verified: bool = False
|
||||||
|
"""是否已验证"""
|
||||||
|
|
||||||
|
# 外键
|
||||||
|
user_id: UUID = Field(
|
||||||
|
foreign_key="user.id",
|
||||||
|
index=True,
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
"""所属用户UUID"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
user: "User" = Relationship(back_populates="auth_identities")
|
||||||
|
|
||||||
|
def to_response(self) -> AuthIdentityResponse:
|
||||||
|
"""转换为响应 DTO"""
|
||||||
|
return AuthIdentityResponse(
|
||||||
|
id=self.id,
|
||||||
|
provider=self.provider,
|
||||||
|
identifier=self.identifier,
|
||||||
|
display_name=self.display_name,
|
||||||
|
avatar_url=self.avatar_url,
|
||||||
|
is_primary=self.is_primary,
|
||||||
|
is_verified=self.is_verified,
|
||||||
|
)
|
||||||
@@ -1,657 +0,0 @@
|
|||||||
# SQLModels Base Module
|
|
||||||
|
|
||||||
This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
|
|
||||||
|
|
||||||
**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
|
|
||||||
|
|
||||||
## Table of Contents
|
|
||||||
|
|
||||||
- [Overview](#overview)
|
|
||||||
- [Migration Notice](#migration-notice)
|
|
||||||
- [Python 3.14 Compatibility](#python-314-compatibility)
|
|
||||||
- [Core Component](#core-component)
|
|
||||||
- [SQLModelBase](#sqlmodelbase)
|
|
||||||
- [Metaclass Features](#metaclass-features)
|
|
||||||
- [Automatic sa_type Injection](#automatic-sa_type-injection)
|
|
||||||
- [Table Configuration](#table-configuration)
|
|
||||||
- [Polymorphic Support](#polymorphic-support)
|
|
||||||
- [Custom Types Integration](#custom-types-integration)
|
|
||||||
- [Best Practices](#best-practices)
|
|
||||||
- [Troubleshooting](#troubleshooting)
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
|
|
||||||
|
|
||||||
- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
|
|
||||||
- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
|
|
||||||
- **Flexible configuration** through class parameters and automatic docstring support
|
|
||||||
- **Type-safe annotations** with automatic validation
|
|
||||||
|
|
||||||
All models in this project should directly or indirectly inherit from `SQLModelBase`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Migration Notice
|
|
||||||
|
|
||||||
As of the recent refactoring, the following components have been moved:
|
|
||||||
|
|
||||||
| Component | Old Location | New Location |
|
|
||||||
|-----------|-------------|--------------|
|
|
||||||
| `TableBase` → `TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
|
||||||
| `UUIDTableBase` → `UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
|
||||||
| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
|
||||||
| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
|
||||||
| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
|
||||||
| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
|
|
||||||
| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
|
||||||
|
|
||||||
**Update your imports**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# ❌ Old (deprecated)
|
|
||||||
from sqlmodels.base import TableBase, UUIDTableBase
|
|
||||||
|
|
||||||
# ✅ New (correct)
|
|
||||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
|
||||||
```
|
|
||||||
|
|
||||||
For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Python 3.14 Compatibility
|
|
||||||
|
|
||||||
### Overview
|
|
||||||
|
|
||||||
This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
|
|
||||||
|
|
||||||
**Key Changes in Python 3.14**:
|
|
||||||
- Annotations are no longer evaluated at class definition time
|
|
||||||
- Type hints are stored as deferred code objects
|
|
||||||
- `__annotate__` function generates annotations on demand
|
|
||||||
- Forward references become `ForwardRef` objects
|
|
||||||
|
|
||||||
### Implementation Strategy
|
|
||||||
|
|
||||||
We use **`typing.get_type_hints()`** as the universal annotations resolver:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
|
|
||||||
# Create temporary proxy class
|
|
||||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
|
||||||
|
|
||||||
# Use get_type_hints with include_extras=True
|
|
||||||
evaluated = get_type_hints(
|
|
||||||
temp_cls,
|
|
||||||
globalns=module_globals,
|
|
||||||
localns=localns,
|
|
||||||
include_extras=True # Preserve Annotated metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
return dict(evaluated), {}, module_globals, localns
|
|
||||||
```
|
|
||||||
|
|
||||||
**Why `get_type_hints()`?**
|
|
||||||
- ✅ Works across Python 3.10-3.14+
|
|
||||||
- ✅ Handles PEP 649 automatically
|
|
||||||
- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
|
|
||||||
- ✅ Resolves forward references
|
|
||||||
- ✅ Recommended by Python documentation
|
|
||||||
|
|
||||||
### SQLModel Compatibility Patch
|
|
||||||
|
|
||||||
**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
|
|
||||||
|
|
||||||
**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
|
|
||||||
|
|
||||||
```python
|
|
||||||
if sys.version_info >= (3, 14):
|
|
||||||
def _patched_get_sqlalchemy_type(field):
|
|
||||||
annotation = getattr(field, 'annotation', None)
|
|
||||||
if annotation is not None:
|
|
||||||
# Priority 1: Check __sqlmodel_sa_type__ attribute
|
|
||||||
# Handles NumpyVector[dims, dtype] and similar custom types
|
|
||||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
|
||||||
return annotation.__sqlmodel_sa_type__
|
|
||||||
|
|
||||||
# Priority 2: Check Annotated metadata
|
|
||||||
if get_origin(annotation) is Annotated:
|
|
||||||
for metadata in get_args(annotation)[1:]:
|
|
||||||
if hasattr(metadata, '__sqlmodel_sa_type__'):
|
|
||||||
return metadata.__sqlmodel_sa_type__
|
|
||||||
|
|
||||||
# ... handle ForwardRef, ClassVar, etc.
|
|
||||||
|
|
||||||
return _original_get_sqlalchemy_type(field)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Supported Patterns
|
|
||||||
|
|
||||||
#### Pattern 1: Direct Custom Type Usage
|
|
||||||
```python
|
|
||||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: NumpyVector[256, np.float32]
|
|
||||||
"""Voice embedding - sa_type automatically extracted"""
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Pattern 2: Annotated Wrapper
|
|
||||||
```python
|
|
||||||
from typing import Annotated
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
|
||||||
|
|
||||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: EmbeddingVector
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Pattern 3: Array Type
|
|
||||||
```python
|
|
||||||
from sqlmodels.sqlmodel_types.dialects.postgresql import Array
|
|
||||||
from sqlmodels.mixin import TableBaseMixin
|
|
||||||
|
|
||||||
class ServerConfig(TableBaseMixin, table=True):
|
|
||||||
protocols: Array[ProtocolEnum]
|
|
||||||
"""Allowed protocols - sa_type from Array handler"""
|
|
||||||
```
|
|
||||||
|
|
||||||
### Migration from Python 3.13
|
|
||||||
|
|
||||||
**No code changes required!** The implementation is transparent:
|
|
||||||
|
|
||||||
- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
|
|
||||||
- Custom types already use `__sqlmodel_sa_type__` attribute
|
|
||||||
- Monkey-patch only activates for Python 3.14+
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Core Component
|
|
||||||
|
|
||||||
### SQLModelBase
|
|
||||||
|
|
||||||
`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
|
|
||||||
|
|
||||||
**Key Features**:
|
|
||||||
- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
|
|
||||||
- Automatic `validate_by_name` configuration
|
|
||||||
- Custom metaclass for sa_type injection and polymorphic setup
|
|
||||||
- Integration with Pydantic v2
|
|
||||||
- Python 3.14 PEP 649 compatibility
|
|
||||||
|
|
||||||
**Usage**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.base import SQLModelBase
|
|
||||||
|
|
||||||
class UserBase(SQLModelBase):
|
|
||||||
name: str
|
|
||||||
"""User's display name"""
|
|
||||||
|
|
||||||
email: str
|
|
||||||
"""User's email address"""
|
|
||||||
```
|
|
||||||
|
|
||||||
**Important Notes**:
|
|
||||||
- Use **docstrings** for field descriptions, not `Field(description=...)`
|
|
||||||
- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
|
|
||||||
- This class should be used for non-table models (DTOs, request/response models)
|
|
||||||
|
|
||||||
**For table models**, use mixins from `sqlmodels.mixin`:
|
|
||||||
- `TableBaseMixin` - Integer primary key with timestamps
|
|
||||||
- `UUIDTableBaseMixin` - UUID primary key with timestamps
|
|
||||||
|
|
||||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Metaclass Features
|
|
||||||
|
|
||||||
### Automatic sa_type Injection
|
|
||||||
|
|
||||||
The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
|
|
||||||
|
|
||||||
**Before** (verbose):
|
|
||||||
```python
|
|
||||||
from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: np.ndarray = Field(
|
|
||||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**After** (clean):
|
|
||||||
```python
|
|
||||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: NumpyVector[256, np.float32]
|
|
||||||
"""Speaker voice embedding"""
|
|
||||||
```
|
|
||||||
|
|
||||||
**How It Works**:
|
|
||||||
|
|
||||||
The metaclass uses a three-tier detection strategy:
|
|
||||||
|
|
||||||
1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
|
|
||||||
```python
|
|
||||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
|
||||||
return annotation.__sqlmodel_sa_type__
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Annotated metadata** (Priority 2)
|
|
||||||
```python
|
|
||||||
# For Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
|
||||||
if get_origin(annotation) is typing.Annotated:
|
|
||||||
for item in metadata_items:
|
|
||||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
|
||||||
return item.__sqlmodel_sa_type__
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Pydantic Core Schema metadata** (Priority 3)
|
|
||||||
```python
|
|
||||||
schema = annotation.__get_pydantic_core_schema__(...)
|
|
||||||
if schema['metadata'].get('sa_type'):
|
|
||||||
return schema['metadata']['sa_type']
|
|
||||||
```
|
|
||||||
|
|
||||||
After extracting `sa_type`, the metaclass:
|
|
||||||
- Creates `Field(sa_type=sa_type)` if no Field is defined
|
|
||||||
- Injects `sa_type` into existing Field if not already set
|
|
||||||
- Respects explicit `Field(sa_type=...)` (no override)
|
|
||||||
|
|
||||||
**Supported Patterns**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
# Pattern 1: Direct usage (recommended)
|
|
||||||
class Model(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: NumpyVector[256, np.float32]
|
|
||||||
|
|
||||||
# Pattern 2: With Field constraints
|
|
||||||
class Model(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: NumpyVector[256, np.float32] = Field(nullable=False)
|
|
||||||
|
|
||||||
# Pattern 3: Annotated wrapper
|
|
||||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
|
||||||
|
|
||||||
class Model(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: EmbeddingVector
|
|
||||||
|
|
||||||
# Pattern 4: Explicit sa_type (override)
|
|
||||||
class Model(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: NumpyVector[256, np.float32] = Field(
|
|
||||||
sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Table Configuration
|
|
||||||
|
|
||||||
The metaclass provides smart defaults and flexible configuration:
|
|
||||||
|
|
||||||
**Automatic `table=True`**:
|
|
||||||
```python
|
|
||||||
# Classes inheriting from TableBaseMixin automatically get table=True
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class MyModel(UUIDTableBaseMixin): # table=True is automatic
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**Convenient mapper arguments**:
|
|
||||||
```python
|
|
||||||
# Instead of verbose __mapper_args__
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class MyModel(
|
|
||||||
UUIDTableBaseMixin,
|
|
||||||
polymorphic_on='_polymorphic_name',
|
|
||||||
polymorphic_abstract=True
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Equivalent to:
|
|
||||||
class MyModel(UUIDTableBaseMixin):
|
|
||||||
__mapper_args__ = {
|
|
||||||
'polymorphic_on': '_polymorphic_name',
|
|
||||||
'polymorphic_abstract': True
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Smart merging**:
|
|
||||||
```python
|
|
||||||
# Dictionary and keyword arguments are merged
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class MyModel(
|
|
||||||
UUIDTableBaseMixin,
|
|
||||||
mapper_args={'version_id_col': 'version'},
|
|
||||||
polymorphic_on='type' # Merged into __mapper_args__
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
### Polymorphic Support
|
|
||||||
|
|
||||||
The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
|
|
||||||
|
|
||||||
**Supported parameters**:
|
|
||||||
- `polymorphic_on`: Discriminator column name
|
|
||||||
- `polymorphic_identity`: Identity value for this class
|
|
||||||
- `polymorphic_abstract`: Whether this is an abstract base
|
|
||||||
- `table_args`: SQLAlchemy table arguments
|
|
||||||
- `table_name`: Override table name (becomes `__tablename__`)
|
|
||||||
|
|
||||||
**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Custom Types Integration
|
|
||||||
|
|
||||||
### Using NumpyVector
|
|
||||||
|
|
||||||
The `NumpyVector` type demonstrates automatic sa_type injection:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: NumpyVector[256, np.float32]
|
|
||||||
"""Speaker voice embedding - sa_type automatically injected"""
|
|
||||||
```
|
|
||||||
|
|
||||||
**How NumpyVector works**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# NumpyVector[dims, dtype] returns a class with:
|
|
||||||
class _NumpyVectorType:
|
|
||||||
__sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
|
||||||
return handler.generate_schema(np.ndarray)
|
|
||||||
```
|
|
||||||
|
|
||||||
This dual approach ensures:
|
|
||||||
1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
|
|
||||||
2. Pydantic can validate as `np.ndarray`
|
|
||||||
|
|
||||||
### Creating Custom SQLAlchemy Types
|
|
||||||
|
|
||||||
To create types that work with automatic injection, provide one of:
|
|
||||||
|
|
||||||
**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlalchemy import TypeDecorator, String
|
|
||||||
|
|
||||||
class UpperCaseString(TypeDecorator):
|
|
||||||
impl = String
|
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
|
||||||
return value.upper() if value else value
|
|
||||||
|
|
||||||
class UpperCaseType:
|
|
||||||
__sqlmodel_sa_type__ = UpperCaseString()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
|
||||||
return core_schema.str_schema()
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class MyModel(UUIDTableBaseMixin, table=True):
|
|
||||||
code: UpperCaseType # Automatically uses UpperCaseString()
|
|
||||||
```
|
|
||||||
|
|
||||||
**Option 2: Pydantic metadata with sa_type**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def __get_pydantic_core_schema__(self, source_type, handler):
|
|
||||||
return core_schema.json_or_python_schema(
|
|
||||||
json_schema=core_schema.str_schema(),
|
|
||||||
python_schema=core_schema.str_schema(),
|
|
||||||
metadata={'sa_type': UpperCaseString()}
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Option 3: Using Annotated**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from typing import Annotated
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
UpperCase = Annotated[str, UpperCaseType()]
|
|
||||||
|
|
||||||
class MyModel(UUIDTableBaseMixin, table=True):
|
|
||||||
code: UpperCase
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Best Practices
|
|
||||||
|
|
||||||
### 1. Inherit from correct base classes
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.base import SQLModelBase
|
|
||||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
|
||||||
|
|
||||||
# ✅ For non-table models (DTOs, requests, responses)
|
|
||||||
class UserBase(SQLModelBase):
|
|
||||||
name: str
|
|
||||||
|
|
||||||
# ✅ For table models with UUID primary key
|
|
||||||
class User(UserBase, UUIDTableBaseMixin, table=True):
|
|
||||||
email: str
|
|
||||||
|
|
||||||
# ✅ For table models with custom primary key
|
|
||||||
class LegacyUser(TableBaseMixin, table=True):
|
|
||||||
id: int = Field(primary_key=True)
|
|
||||||
username: str
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Use docstrings for field descriptions
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
# ✅ Recommended
|
|
||||||
class User(UUIDTableBaseMixin, table=True):
|
|
||||||
name: str
|
|
||||||
"""User's display name"""
|
|
||||||
|
|
||||||
# ❌ Avoid
|
|
||||||
class User(UUIDTableBaseMixin, table=True):
|
|
||||||
name: str = Field(description="User's display name")
|
|
||||||
```
|
|
||||||
|
|
||||||
**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
|
|
||||||
|
|
||||||
### 3. Leverage automatic sa_type injection
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
# ✅ Clean and recommended
|
|
||||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: NumpyVector[256, np.float32]
|
|
||||||
"""Voice embedding"""
|
|
||||||
|
|
||||||
# ❌ Verbose and unnecessary
|
|
||||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
|
||||||
embedding: np.ndarray = Field(
|
|
||||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Follow polymorphic naming conventions
|
|
||||||
|
|
||||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
|
|
||||||
|
|
||||||
### 5. Separate Base, Parent, and Implementation classes
|
|
||||||
|
|
||||||
```python
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from sqlmodels.base import SQLModelBase
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
|
|
||||||
|
|
||||||
# ✅ Recommended structure
|
|
||||||
class ASRBase(SQLModelBase):
|
|
||||||
"""Pure data fields, no table"""
|
|
||||||
name: str
|
|
||||||
base_url: str
|
|
||||||
|
|
||||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
|
||||||
"""Abstract parent with table"""
|
|
||||||
@abstractmethod
|
|
||||||
async def transcribe(self, audio: bytes) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class WhisperASR(ASR, table=True):
|
|
||||||
"""Concrete implementation"""
|
|
||||||
model_size: str
|
|
||||||
|
|
||||||
async def transcribe(self, audio: bytes) -> str:
|
|
||||||
# Implementation
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**Why?**
|
|
||||||
- Base class can be reused for DTOs
|
|
||||||
- Parent class defines the polymorphic hierarchy
|
|
||||||
- Implementation classes are clean and focused
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Issue: ValueError: X has no matching SQLAlchemy type
|
|
||||||
|
|
||||||
**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# ✅ Provide __sqlmodel_sa_type__
|
|
||||||
class MyType:
|
|
||||||
__sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue: Can't generate DDL for NullType()
|
|
||||||
|
|
||||||
**Symptoms**: Error during table creation saying a column has `NullType`.
|
|
||||||
|
|
||||||
**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
|
|
||||||
|
|
||||||
**Solution**:
|
|
||||||
1. Ensure your type has `__sqlmodel_sa_type__` class attribute
|
|
||||||
2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
|
|
||||||
3. Verify type annotation is correct (not a string forward reference)
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
# ✅ Correct
|
|
||||||
class Model(UUIDTableBaseMixin, table=True):
|
|
||||||
data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
|
|
||||||
|
|
||||||
# ❌ Wrong (string annotation)
|
|
||||||
class Model(UUIDTableBaseMixin, table=True):
|
|
||||||
data: 'NumpyVector[256, np.float32]' # sa_type lost
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue: Polymorphic identity conflicts
|
|
||||||
|
|
||||||
**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
|
|
||||||
|
|
||||||
**Solution**:
|
|
||||||
1. Check that each concrete class has a unique identity
|
|
||||||
2. Use `AutoPolymorphicIdentityMixin` for automatic naming
|
|
||||||
3. Manually specify identity if needed:
|
|
||||||
```python
|
|
||||||
class MyClass(Parent, polymorphic_identity='unique.name', table=True):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue: Python 3.14 annotation errors
|
|
||||||
|
|
||||||
**Symptoms**: Errors related to `__annotations__` or type resolution.
|
|
||||||
|
|
||||||
**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
|
|
||||||
1. Check for manual `__annotations__` manipulation (avoid it)
|
|
||||||
2. Ensure all types are properly imported
|
|
||||||
3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
|
|
||||||
|
|
||||||
### Issue: Polymorphic and CRUD-related errors
|
|
||||||
|
|
||||||
For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
For developers modifying this module:
|
|
||||||
|
|
||||||
**Core files**:
|
|
||||||
- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
|
|
||||||
- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
|
|
||||||
- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
|
|
||||||
|
|
||||||
**Key functions in this module**:
|
|
||||||
|
|
||||||
1. **`_resolve_annotations(attrs: dict[str, Any])`**
|
|
||||||
- Uses `typing.get_type_hints()` for Python 3.14 compatibility
|
|
||||||
- Returns tuple: `(annotations, annotation_strings, globalns, localns)`
|
|
||||||
- Preserves `Annotated` metadata with `include_extras=True`
|
|
||||||
|
|
||||||
2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
|
|
||||||
- Extracts SQLAlchemy type from type annotations
|
|
||||||
- Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
|
|
||||||
- Called by metaclass during class creation
|
|
||||||
|
|
||||||
3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
|
|
||||||
- Global monkey-patch for SQLModel
|
|
||||||
- Checks `__sqlmodel_sa_type__` before falling back to original logic
|
|
||||||
- Handles custom types like `NumpyVector` and `Array`
|
|
||||||
|
|
||||||
4. **`__DeclarativeMeta.__new__()`**
|
|
||||||
- Processes class definition parameters
|
|
||||||
- Injects `sa_type` into field definitions
|
|
||||||
- Sets up `__mapper_args__`, `__table_args__`, etc.
|
|
||||||
- Handles Python 3.14 annotations via `get_type_hints()`
|
|
||||||
|
|
||||||
**Metaclass processing order**:
|
|
||||||
1. Check if class should be a table (`_has_table_mixin`)
|
|
||||||
2. Collect `__mapper_args__` from kwargs and explicit dict
|
|
||||||
3. Process `table_args`, `table_name`, `abstract` parameters
|
|
||||||
4. Resolve annotations using `get_type_hints()`
|
|
||||||
5. For each field, try to extract `sa_type` and inject into Field
|
|
||||||
6. Call parent metaclass with cleaned kwargs
|
|
||||||
|
|
||||||
For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## See Also
|
|
||||||
|
|
||||||
**Project Documentation**:
|
|
||||||
- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
|
|
||||||
- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
|
|
||||||
- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
|
|
||||||
|
|
||||||
**External References**:
|
|
||||||
- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
|
||||||
- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
|
|
||||||
- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
|
|
||||||
- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
|
|
||||||
- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
|
|
||||||
- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
SQLModel 基础模块
|
|
||||||
|
|
||||||
包含:
|
|
||||||
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
|
|
||||||
|
|
||||||
注意:
|
|
||||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
|
|
||||||
为了避免循环导入,此处不再重新导出它们
|
|
||||||
请直接从 sqlmodels.mixin 导入这些类
|
|
||||||
"""
|
|
||||||
from .sqlmodel_base import SQLModelBase
|
|
||||||
@@ -1,846 +0,0 @@
|
|||||||
import sys
|
|
||||||
import typing
|
|
||||||
from typing import Any, Mapping, get_args, get_origin, get_type_hints
|
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from pydantic_core import PydanticUndefined as Undefined
|
|
||||||
from sqlalchemy.orm import Mapped
|
|
||||||
from sqlmodel import Field, SQLModel
|
|
||||||
from sqlmodel.main import SQLModelMetaclass
|
|
||||||
|
|
||||||
# Python 3.14+ PEP 649支持
|
|
||||||
if sys.version_info >= (3, 14):
|
|
||||||
import annotationlib
|
|
||||||
|
|
||||||
# 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
|
|
||||||
import sqlmodel.main
|
|
||||||
_original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
|
|
||||||
|
|
||||||
def _patched_get_sqlalchemy_type(field):
|
|
||||||
"""
|
|
||||||
修复SQLModel的get_sqlalchemy_type函数,处理Python 3.14的类型问题。
|
|
||||||
|
|
||||||
问题:
|
|
||||||
1. ForwardRef对象(来自Relationship字段)会导致issubclass错误
|
|
||||||
2. typing._GenericAlias对象(如ClassVar[T])也会导致同样问题
|
|
||||||
3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
|
|
||||||
4. Mapped类型在Python 3.14下可能出现在annotation中
|
|
||||||
5. Annotated类型可能包含sa_type metadata(如Array[T])
|
|
||||||
6. 自定义类型(如NumpyVector)有__sqlmodel_sa_type__属性
|
|
||||||
7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
|
|
||||||
|
|
||||||
解决:
|
|
||||||
- 优先检查field.metadata中的__get_pydantic_core_schema__(Pydantic已处理的情况)
|
|
||||||
- 检测__sqlmodel_sa_type__属性(NumpyVector等)
|
|
||||||
- 检测Relationship/ClassVar等返回None
|
|
||||||
- 对于Annotated类型,尝试提取sa_type metadata
|
|
||||||
- 其他情况调用原始函数
|
|
||||||
"""
|
|
||||||
# 优先检查 field.metadata(Pydantic已处理Annotated类型的情况)
|
|
||||||
# 当使用 Array[T] 或 Annotated[T, metadata] 时,Pydantic会将metadata存储在这里
|
|
||||||
metadata = getattr(field, 'metadata', None)
|
|
||||||
if metadata:
|
|
||||||
# metadata是一个列表,包含所有Annotated的元数据项
|
|
||||||
for metadata_item in metadata:
|
|
||||||
# 检查metadata_item是否有__get_pydantic_core_schema__方法
|
|
||||||
if hasattr(metadata_item, '__get_pydantic_core_schema__'):
|
|
||||||
try:
|
|
||||||
# 调用获取schema
|
|
||||||
schema = metadata_item.__get_pydantic_core_schema__(None, None)
|
|
||||||
# 检查schema的metadata中是否有sa_type
|
|
||||||
if isinstance(schema, dict) and 'metadata' in schema:
|
|
||||||
sa_type = schema['metadata'].get('sa_type')
|
|
||||||
if sa_type is not None:
|
|
||||||
return sa_type
|
|
||||||
except (TypeError, AttributeError, KeyError):
|
|
||||||
# Pydantic schema获取可能失败(类型不匹配、缺少属性等)
|
|
||||||
# 这是正常情况,继续检查下一个metadata项
|
|
||||||
pass
|
|
||||||
|
|
||||||
annotation = getattr(field, 'annotation', None)
|
|
||||||
if annotation is not None:
|
|
||||||
# 优先检查 __sqlmodel_sa_type__ 属性
|
|
||||||
# 这处理 NumpyVector[dims, dtype] 等自定义类型
|
|
||||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
|
||||||
return annotation.__sqlmodel_sa_type__
|
|
||||||
|
|
||||||
# 检查自定义类型(如JSON100K)的 __get_pydantic_core_schema__ 方法
|
|
||||||
# 这些类型在schema的metadata中定义sa_type
|
|
||||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
|
||||||
try:
|
|
||||||
# 调用获取schema(传None作为handler,因为我们只需要metadata)
|
|
||||||
schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
|
|
||||||
# 检查schema的metadata中是否有sa_type
|
|
||||||
if isinstance(schema, dict) and 'metadata' in schema:
|
|
||||||
sa_type = schema['metadata'].get('sa_type')
|
|
||||||
if sa_type is not None:
|
|
||||||
return sa_type
|
|
||||||
except (TypeError, AttributeError, KeyError):
|
|
||||||
# Schema获取失败,继续其他检查
|
|
||||||
pass
|
|
||||||
|
|
||||||
anno_type_name = type(annotation).__name__
|
|
||||||
|
|
||||||
# ForwardRef: Relationship字段的annotation
|
|
||||||
if anno_type_name == 'ForwardRef':
|
|
||||||
return None
|
|
||||||
|
|
||||||
# AnnotatedAlias: 检查是否有sa_type metadata(如Array[T])
|
|
||||||
if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
|
|
||||||
from typing import get_origin, get_args
|
|
||||||
import typing
|
|
||||||
|
|
||||||
# 尝试提取Annotated的metadata
|
|
||||||
if hasattr(typing, 'get_args'):
|
|
||||||
args = get_args(annotation)
|
|
||||||
# args[0]是实际类型,args[1:]是metadata
|
|
||||||
for metadata in args[1:]:
|
|
||||||
# 检查metadata是否有__get_pydantic_core_schema__方法
|
|
||||||
if hasattr(metadata, '__get_pydantic_core_schema__'):
|
|
||||||
try:
|
|
||||||
# 调用获取schema
|
|
||||||
schema = metadata.__get_pydantic_core_schema__(None, None)
|
|
||||||
# 检查schema中是否有sa_type
|
|
||||||
if isinstance(schema, dict) and 'metadata' in schema:
|
|
||||||
sa_type = schema['metadata'].get('sa_type')
|
|
||||||
if sa_type is not None:
|
|
||||||
return sa_type
|
|
||||||
except (TypeError, AttributeError, KeyError):
|
|
||||||
# Annotated metadata的schema获取可能失败
|
|
||||||
# 这是正常的类型检查过程,继续检查下一个metadata
|
|
||||||
pass
|
|
||||||
|
|
||||||
# _GenericAlias或GenericAlias: typing泛型类型
|
|
||||||
if anno_type_name in ('_GenericAlias', 'GenericAlias'):
|
|
||||||
from typing import get_origin
|
|
||||||
import typing
|
|
||||||
origin = get_origin(annotation)
|
|
||||||
|
|
||||||
# ClassVar必须跳过
|
|
||||||
if origin is typing.ClassVar:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# list/dict/tuple/set等内置泛型,如果字段没有明确的Field或Relationship,也跳过
|
|
||||||
# 这通常意味着它是Relationship字段或类变量
|
|
||||||
if origin in (list, dict, tuple, set):
|
|
||||||
# 检查field_info是否存在且有意义
|
|
||||||
# Relationship字段会有特殊的field_info
|
|
||||||
field_info = getattr(field, 'field_info', None)
|
|
||||||
if field_info is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Mapped: SQLAlchemy 2.0的Mapped类型,SQLModel不应该处理
|
|
||||||
# 这可能是从父类继承的字段或Python 3.14注解处理的副作用
|
|
||||||
# 检查类型名称和annotation的字符串表示
|
|
||||||
if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 检查annotation是否是Mapped类或其实例
|
|
||||||
try:
|
|
||||||
from sqlalchemy.orm import Mapped as SAMapped
|
|
||||||
# 检查origin(对于Mapped[T]这种泛型)
|
|
||||||
from typing import get_origin
|
|
||||||
if get_origin(annotation) is SAMapped:
|
|
||||||
return None
|
|
||||||
# 检查类型本身
|
|
||||||
if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
|
|
||||||
return None
|
|
||||||
except (ImportError, TypeError):
|
|
||||||
# 如果SQLAlchemy没有Mapped或检查失败,继续
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 其他情况正常处理
|
|
||||||
return _original_get_sqlalchemy_type(field)
|
|
||||||
|
|
||||||
sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
|
|
||||||
|
|
||||||
# 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
|
|
||||||
# 在Python 3.14 + SQLModel组合下,当子类(如SMSBaoProvider)继承父类(如VerificationCodeProvider)时,
|
|
||||||
# 父类的关系字段(如server_config)会在子类的model_fields中出现,
|
|
||||||
# 但其default值错误地设置为InstrumentedAttribute对象,而不是None
|
|
||||||
# 这导致实例化时尝试设置InstrumentedAttribute为字段值,触发SQLAlchemy内部错误
|
|
||||||
import sqlmodel._compat as _compat
|
|
||||||
from sqlalchemy.orm import attributes as _sa_attributes
|
|
||||||
|
|
||||||
_original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
|
|
||||||
|
|
||||||
def _patched_sqlmodel_table_construct(self_instance, values):
|
|
||||||
"""
|
|
||||||
修复sqlmodel_table_construct,跳过InstrumentedAttribute默认值
|
|
||||||
|
|
||||||
问题:
|
|
||||||
- 继承自polymorphic基类的表类(如FishAudioTTS, SMSBaoProvider)
|
|
||||||
- 其model_fields中的继承字段default值为InstrumentedAttribute
|
|
||||||
- 原函数尝试将InstrumentedAttribute设置为字段值
|
|
||||||
- SQLAlchemy无法处理,抛出 '_sa_instance_state' 错误
|
|
||||||
|
|
||||||
解决:
|
|
||||||
- 只设置用户提供的值和非InstrumentedAttribute默认值
|
|
||||||
- InstrumentedAttribute默认值跳过(让SQLAlchemy自己处理)
|
|
||||||
"""
|
|
||||||
cls = type(self_instance)
|
|
||||||
|
|
||||||
# 收集要设置的字段值
|
|
||||||
fields_to_set = {}
|
|
||||||
|
|
||||||
for name, field in cls.model_fields.items():
|
|
||||||
# 如果用户提供了值,直接使用
|
|
||||||
if name in values:
|
|
||||||
fields_to_set[name] = values[name]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 否则检查默认值
|
|
||||||
# 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
|
|
||||||
if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用正常的默认值
|
|
||||||
if field.default is not Undefined:
|
|
||||||
fields_to_set[name] = field.default
|
|
||||||
elif field.default_factory is not None:
|
|
||||||
fields_to_set[name] = field.get_default(call_default_factory=True)
|
|
||||||
|
|
||||||
# 设置属性 - 只设置非InstrumentedAttribute值
|
|
||||||
for key, value in fields_to_set.items():
|
|
||||||
if not isinstance(value, _sa_attributes.InstrumentedAttribute):
|
|
||||||
setattr(self_instance, key, value)
|
|
||||||
|
|
||||||
# 设置Pydantic内部属性
|
|
||||||
object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
|
|
||||||
if not cls.__pydantic_root_model__:
|
|
||||||
_extra = None
|
|
||||||
if cls.model_config.get('extra') == 'allow':
|
|
||||||
_extra = {}
|
|
||||||
for k, v in values.items():
|
|
||||||
if k not in cls.model_fields:
|
|
||||||
_extra[k] = v
|
|
||||||
object.__setattr__(self_instance, '__pydantic_extra__', _extra)
|
|
||||||
|
|
||||||
if cls.__pydantic_post_init__:
|
|
||||||
self_instance.model_post_init(None)
|
|
||||||
elif not cls.__pydantic_root_model__:
|
|
||||||
object.__setattr__(self_instance, '__pydantic_private__', None)
|
|
||||||
|
|
||||||
# 设置关系
|
|
||||||
for key in self_instance.__sqlmodel_relationships__:
|
|
||||||
value = values.get(key, Undefined)
|
|
||||||
if value is not Undefined:
|
|
||||||
setattr(self_instance, key, value)
|
|
||||||
|
|
||||||
return self_instance
|
|
||||||
|
|
||||||
_compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
|
|
||||||
else:
|
|
||||||
annotationlib = None
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
|
|
||||||
"""
|
|
||||||
从类型注解中提取SQLAlchemy类型。
|
|
||||||
|
|
||||||
支持以下形式:
|
|
||||||
1. NumpyVector[256, np.float32] - 直接使用类型(有__sqlmodel_sa_type__属性)
|
|
||||||
2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
|
|
||||||
3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
annotation: 字段的类型注解
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
提取到的SQLAlchemy类型,如果没有则返回None
|
|
||||||
"""
|
|
||||||
# 方法1:直接检查类型本身是否有__sqlmodel_sa_type__属性
|
|
||||||
# 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
|
|
||||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
|
||||||
return annotation.__sqlmodel_sa_type__
|
|
||||||
|
|
||||||
# 方法2:检查是否为Annotated类型
|
|
||||||
if get_origin(annotation) is typing.Annotated:
|
|
||||||
# 获取元数据项(跳过第一个实际类型参数)
|
|
||||||
args = get_args(annotation)
|
|
||||||
if len(args) >= 2:
|
|
||||||
metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
|
|
||||||
|
|
||||||
# 遍历元数据,查找包含sa_type的项
|
|
||||||
for item in metadata_items:
|
|
||||||
# 检查元数据项是否有__sqlmodel_sa_type__属性
|
|
||||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
|
||||||
return item.__sqlmodel_sa_type__
|
|
||||||
|
|
||||||
# 检查是否有__get_pydantic_core_schema__方法
|
|
||||||
if hasattr(item, '__get_pydantic_core_schema__'):
|
|
||||||
try:
|
|
||||||
# 调用该方法获取core schema
|
|
||||||
schema = item.__get_pydantic_core_schema__(
|
|
||||||
annotation,
|
|
||||||
lambda x: None # 虚拟handler
|
|
||||||
)
|
|
||||||
# 检查schema的metadata中是否有sa_type
|
|
||||||
if isinstance(schema, dict) and 'metadata' in schema:
|
|
||||||
sa_type = schema['metadata'].get('sa_type')
|
|
||||||
if sa_type is not None:
|
|
||||||
return sa_type
|
|
||||||
except (TypeError, AttributeError, KeyError, ValueError):
|
|
||||||
# Pydantic core schema获取可能失败:
|
|
||||||
# - TypeError: 参数不匹配
|
|
||||||
# - AttributeError: metadata不存在
|
|
||||||
# - KeyError: schema结构不符合预期
|
|
||||||
# - ValueError: 无效的类型定义
|
|
||||||
# 这是正常的类型探测过程,继续检查下一个metadata项
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 方法3:检查类型本身是否有__get_pydantic_core_schema__
|
|
||||||
# (虽然NumpyVector已经在方法1处理,但这是通用的fallback)
|
|
||||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
|
||||||
try:
|
|
||||||
schema = annotation.__get_pydantic_core_schema__(
|
|
||||||
annotation,
|
|
||||||
lambda x: None # 虚拟handler
|
|
||||||
)
|
|
||||||
if isinstance(schema, dict) and 'metadata' in schema:
|
|
||||||
sa_type = schema['metadata'].get('sa_type')
|
|
||||||
if sa_type is not None:
|
|
||||||
return sa_type
|
|
||||||
except (TypeError, AttributeError, KeyError, ValueError):
|
|
||||||
# 类型本身的schema获取失败
|
|
||||||
# 这是正常的fallback机制,annotation可能不支持此协议
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
|
|
||||||
dict[str, Any],
|
|
||||||
dict[str, str],
|
|
||||||
Mapping[str, Any],
|
|
||||||
Mapping[str, Any],
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
|
|
||||||
|
|
||||||
This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
|
|
||||||
metadata and custom types remain accessible. Forward references that cannot be
|
|
||||||
evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
|
|
||||||
whole resolution process.
|
|
||||||
"""
|
|
||||||
raw_annotations = attrs.get('__annotations__') or {}
|
|
||||||
try:
|
|
||||||
base_annotations = dict(raw_annotations)
|
|
||||||
except TypeError:
|
|
||||||
base_annotations = {}
|
|
||||||
|
|
||||||
module_name = attrs.get('__module__')
|
|
||||||
module_globals: dict[str, Any]
|
|
||||||
if module_name and module_name in sys.modules:
|
|
||||||
module_globals = dict(sys.modules[module_name].__dict__)
|
|
||||||
else:
|
|
||||||
module_globals = {}
|
|
||||||
|
|
||||||
module_globals.setdefault('__builtins__', __builtins__)
|
|
||||||
localns: dict[str, Any] = dict(attrs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
|
||||||
temp_cls.__module__ = module_name
|
|
||||||
extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
|
|
||||||
evaluated = get_type_hints(
|
|
||||||
temp_cls,
|
|
||||||
globalns=module_globals,
|
|
||||||
localns=localns,
|
|
||||||
**extras_kw,
|
|
||||||
)
|
|
||||||
except (NameError, AttributeError, TypeError, RecursionError):
|
|
||||||
# get_type_hints可能失败的原因:
|
|
||||||
# - NameError: 前向引用无法解析(类型尚未定义)
|
|
||||||
# - AttributeError: 模块或类型不存在
|
|
||||||
# - TypeError: 无效的类型注解
|
|
||||||
# - RecursionError: 循环依赖的类型定义
|
|
||||||
# 这是正常情况,回退到原始注解字符串
|
|
||||||
evaluated = base_annotations
|
|
||||||
|
|
||||||
return dict(evaluated), {}, module_globals, localns
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_annotation_from_string(
|
|
||||||
field_name: str,
|
|
||||||
annotation_strings: dict[str, str],
|
|
||||||
current_type: Any,
|
|
||||||
globalns: Mapping[str, Any],
|
|
||||||
localns: Mapping[str, Any],
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Attempt to re-evaluate the original annotation string for a field.
|
|
||||||
|
|
||||||
This is used as a fallback when the resolved annotation lost its metadata
|
|
||||||
(e.g., Annotated wrappers) and we need to recover custom sa_type data.
|
|
||||||
"""
|
|
||||||
if not annotation_strings:
|
|
||||||
return current_type
|
|
||||||
|
|
||||||
expr = annotation_strings.get(field_name)
|
|
||||||
if not expr or not isinstance(expr, str):
|
|
||||||
return current_type
|
|
||||||
|
|
||||||
try:
|
|
||||||
return eval(expr, globalns, localns)
|
|
||||||
except (NameError, SyntaxError, AttributeError, TypeError):
|
|
||||||
# eval可能失败的原因:
|
|
||||||
# - NameError: 类型名称在namespace中不存在
|
|
||||||
# - SyntaxError: 注解字符串有语法错误
|
|
||||||
# - AttributeError: 访问不存在的模块属性
|
|
||||||
# - TypeError: 无效的类型表达式
|
|
||||||
# 这是正常的fallback机制,返回当前已解析的类型
|
|
||||||
return current_type
|
|
||||||
|
|
||||||
|
|
||||||
class __DeclarativeMeta(SQLModelMetaclass):
|
|
||||||
"""
|
|
||||||
一个智能的混合模式元类,它提供了灵活性和清晰度:
|
|
||||||
|
|
||||||
1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
|
|
||||||
2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
|
|
||||||
3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
|
|
||||||
4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
|
|
||||||
"""
|
|
||||||
|
|
||||||
_KNOWN_MAPPER_KEYS = {
|
|
||||||
"polymorphic_on",
|
|
||||||
"polymorphic_identity",
|
|
||||||
"polymorphic_abstract",
|
|
||||||
"version_id_col",
|
|
||||||
"concrete",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs, **kwargs):
|
|
||||||
# 1. 约定优于配置:自动设置 table=True
|
|
||||||
is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
|
|
||||||
if is_intended_as_table and 'table' not in kwargs:
|
|
||||||
kwargs['table'] = True
|
|
||||||
|
|
||||||
# 2. 智能合并 __mapper_args__
|
|
||||||
collected_mapper_args = {}
|
|
||||||
|
|
||||||
# 首先,处理明确的 mapper_args 字典 (优先级较低)
|
|
||||||
if 'mapper_args' in kwargs:
|
|
||||||
collected_mapper_args.update(kwargs.pop('mapper_args'))
|
|
||||||
|
|
||||||
# 其次,处理便捷的关键字参数 (优先级更高)
|
|
||||||
for key in cls._KNOWN_MAPPER_KEYS:
|
|
||||||
if key in kwargs:
|
|
||||||
# .pop() 获取值并移除,避免传递给父类
|
|
||||||
collected_mapper_args[key] = kwargs.pop(key)
|
|
||||||
|
|
||||||
# 如果收集到了任何 mapper 参数,则更新到类的属性中
|
|
||||||
if collected_mapper_args:
|
|
||||||
existing = attrs.get('__mapper_args__', {}).copy()
|
|
||||||
existing.update(collected_mapper_args)
|
|
||||||
attrs['__mapper_args__'] = existing
|
|
||||||
|
|
||||||
# 3. 处理其他明确的参数
|
|
||||||
if 'table_args' in kwargs:
|
|
||||||
attrs['__table_args__'] = kwargs.pop('table_args')
|
|
||||||
if 'table_name' in kwargs:
|
|
||||||
attrs['__tablename__'] = kwargs.pop('table_name')
|
|
||||||
if 'abstract' in kwargs:
|
|
||||||
attrs['__abstract__'] = kwargs.pop('abstract')
|
|
||||||
|
|
||||||
# 4. 从Annotated元数据中提取sa_type并注入到Field
|
|
||||||
# 重要:必须在调用父类__new__之前处理,因为SQLModel会消费annotations
|
|
||||||
#
|
|
||||||
# Python 3.14兼容性问题:
|
|
||||||
# - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃(issubclass错误)
|
|
||||||
# - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
|
|
||||||
# - 虽然PEP 749建议不修改__annotations__,但这是修复SQLModel bug的必要措施
|
|
||||||
#
|
|
||||||
# 获取annotations的策略:
|
|
||||||
# - Python 3.14+: 优先从__annotate__获取(如果存在)
|
|
||||||
# - fallback: 从__annotations__读取(如果存在)
|
|
||||||
# - 最终fallback: 空字典
|
|
||||||
annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
|
|
||||||
|
|
||||||
if annotations:
|
|
||||||
attrs['__annotations__'] = annotations
|
|
||||||
if annotationlib is not None:
|
|
||||||
# 在Python 3.14中禁用descriptor,转为普通dict
|
|
||||||
attrs['__annotate__'] = None
|
|
||||||
|
|
||||||
for field_name, field_type in annotations.items():
|
|
||||||
field_type = _evaluate_annotation_from_string(
|
|
||||||
field_name,
|
|
||||||
annotation_strings,
|
|
||||||
field_type,
|
|
||||||
eval_globals,
|
|
||||||
eval_locals,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 跳过字符串或ForwardRef类型注解,让SQLModel自己处理
|
|
||||||
if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 跳过特殊类型的字段
|
|
||||||
origin = get_origin(field_type)
|
|
||||||
|
|
||||||
# 跳过 ClassVar 字段 - 它们不是数据库字段
|
|
||||||
if origin is typing.ClassVar:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
|
|
||||||
if origin is Mapped:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 尝试从注解中提取sa_type
|
|
||||||
sa_type = _extract_sa_type_from_annotation(field_type)
|
|
||||||
|
|
||||||
if sa_type is not None:
|
|
||||||
# 检查字段是否已有Field定义
|
|
||||||
field_value = attrs.get(field_name, Undefined)
|
|
||||||
|
|
||||||
if field_value is Undefined:
|
|
||||||
# 没有Field定义,创建一个新的Field并注入sa_type
|
|
||||||
attrs[field_name] = Field(sa_type=sa_type)
|
|
||||||
elif isinstance(field_value, FieldInfo):
|
|
||||||
# 已有Field定义,检查是否已设置sa_type
|
|
||||||
# 注意:只有在未设置时才注入,尊重显式配置
|
|
||||||
# SQLModel使用Undefined作为"未设置"的标记
|
|
||||||
if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
|
|
||||||
field_value.sa_type = sa_type
|
|
||||||
# 如果field_value是其他类型(如默认值),不处理
|
|
||||||
# SQLModel会在后续处理中将其转换为Field
|
|
||||||
|
|
||||||
# 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
|
|
||||||
result = super().__new__(cls, name, bases, attrs, **kwargs)
|
|
||||||
|
|
||||||
# 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
|
|
||||||
# SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
|
|
||||||
# 这导致子类丢失父类的关系定义,触发错误的 Column 创建
|
|
||||||
# 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
|
|
||||||
if kwargs.get('table', False):
|
|
||||||
for base in bases:
|
|
||||||
if hasattr(base, '__sqlmodel_relationships__'):
|
|
||||||
for rel_name, rel_info in base.__sqlmodel_relationships__.items():
|
|
||||||
# 只继承子类没有重新定义的关系
|
|
||||||
if rel_name not in result.__sqlmodel_relationships__:
|
|
||||||
result.__sqlmodel_relationships__[rel_name] = rel_info
|
|
||||||
# 同时修复被错误创建的 Column - 恢复为父类的 relationship
|
|
||||||
if hasattr(base, rel_name):
|
|
||||||
base_attr = getattr(base, rel_name)
|
|
||||||
setattr(result, rel_name, base_attr)
|
|
||||||
|
|
||||||
# 7. 检测:禁止子类重定义父类的 Relationship 字段
|
|
||||||
# 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
|
|
||||||
# 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
|
|
||||||
for base in bases:
|
|
||||||
parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
|
|
||||||
for rel_name in parent_relationships:
|
|
||||||
# 检查当前类是否在 attrs 中重新定义了这个关系字段
|
|
||||||
if rel_name in attrs:
|
|
||||||
raise TypeError(
|
|
||||||
f"类 {name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'。"
|
|
||||||
f"如需修改关系配置,请在父类中修改。"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
|
|
||||||
# SQLModel 0.0.27 bug:子类会错误地继承父类的 Relationship 字段到 model_fields
|
|
||||||
# 这导致 Pydantic 尝试为 Relationship 字段生成 schema,因为类型是
|
|
||||||
# Mapped[list['Character']] 这种前向引用,Pydantic 无法解析,
|
|
||||||
# 导致 __pydantic_complete__ = False
|
|
||||||
#
|
|
||||||
# 修复策略:
|
|
||||||
# - 检查类的 __sqlmodel_relationships__ 属性
|
|
||||||
# - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
|
|
||||||
# - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
|
|
||||||
relationships = getattr(result, '__sqlmodel_relationships__', {})
|
|
||||||
if relationships:
|
|
||||||
model_fields = getattr(result, 'model_fields', {})
|
|
||||||
pydantic_fields = getattr(result, '__pydantic_fields__', {})
|
|
||||||
|
|
||||||
fields_removed = False
|
|
||||||
for rel_name in relationships:
|
|
||||||
if rel_name in model_fields:
|
|
||||||
del model_fields[rel_name]
|
|
||||||
fields_removed = True
|
|
||||||
if rel_name in pydantic_fields:
|
|
||||||
del pydantic_fields[rel_name]
|
|
||||||
fields_removed = True
|
|
||||||
|
|
||||||
# 如果移除了字段,重新构建 Pydantic 模式
|
|
||||||
# 注意:只在有字段被移除时才 rebuild,避免不必要的开销
|
|
||||||
if fields_removed and hasattr(result, 'model_rebuild'):
|
|
||||||
result.model_rebuild(force=True)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
cls,
|
|
||||||
classname: str,
|
|
||||||
bases: tuple[type, ...],
|
|
||||||
dict_: dict[str, typing.Any],
|
|
||||||
**kw: typing.Any,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
重写 SQLModel 的 __init__ 以支持联表继承(Joined Table Inheritance)
|
|
||||||
|
|
||||||
SQLModel 原始行为:
|
|
||||||
- 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
|
|
||||||
- 这阻止了子类创建自己的表
|
|
||||||
|
|
||||||
修复逻辑:
|
|
||||||
- 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
|
|
||||||
- 强制调用 DeclarativeMeta.__init__ 来创建子表
|
|
||||||
"""
|
|
||||||
from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
|
|
||||||
|
|
||||||
# 检查是否是表模型
|
|
||||||
if not is_table_model_class(cls):
|
|
||||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 检查是否有基类是表模型
|
|
||||||
base_is_table = any(is_table_model_class(base) for base in bases)
|
|
||||||
|
|
||||||
if not base_is_table:
|
|
||||||
# 没有基类是表模型,走正常的 SQLModel 流程
|
|
||||||
# 处理关系字段
|
|
||||||
cls._setup_relationships()
|
|
||||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 关键:检测联表继承场景
|
|
||||||
# 条件:
|
|
||||||
# 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
|
|
||||||
# 2. 当前类有字段带有 foreign_key 指向父表
|
|
||||||
current_tablename = getattr(cls, '__tablename__', None)
|
|
||||||
|
|
||||||
# 查找父表信息
|
|
||||||
parent_table = None
|
|
||||||
parent_tablename = None
|
|
||||||
for base in bases:
|
|
||||||
if is_table_model_class(base) and hasattr(base, '__tablename__'):
|
|
||||||
parent_tablename = base.__tablename__
|
|
||||||
break
|
|
||||||
|
|
||||||
# 检查是否有不同的 tablename
|
|
||||||
has_different_tablename = (
|
|
||||||
current_tablename is not None
|
|
||||||
and parent_tablename is not None
|
|
||||||
and current_tablename != parent_tablename
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查是否有外键字段指向父表的主键
|
|
||||||
# 注意:由于字段合并,我们需要检查直接基类的 model_fields
|
|
||||||
# 而不是当前类的合并后的 model_fields
|
|
||||||
has_fk_to_parent = False
|
|
||||||
|
|
||||||
def _normalize_tablename(name: str) -> str:
|
|
||||||
"""标准化表名以进行比较(移除下划线,转小写)"""
|
|
||||||
return name.replace('_', '').lower()
|
|
||||||
|
|
||||||
def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
|
|
||||||
"""检查 FK 字符串是否指向父表"""
|
|
||||||
if not fk_str or not parent_table:
|
|
||||||
return False
|
|
||||||
# FK 格式: "tablename.column" 或 "schema.tablename.column"
|
|
||||||
parts = fk_str.split('.')
|
|
||||||
if len(parts) >= 2:
|
|
||||||
fk_table = parts[-2] # 取倒数第二个作为表名
|
|
||||||
# 标准化比较(处理下划线差异)
|
|
||||||
return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if has_different_tablename and parent_tablename:
|
|
||||||
# 首先检查当前类的 model_fields
|
|
||||||
for field_name, field_info in cls.model_fields.items():
|
|
||||||
fk = getattr(field_info, 'foreign_key', None)
|
|
||||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
|
||||||
has_fk_to_parent = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# 如果没找到,检查直接基类的 model_fields(解决 mixin 字段被覆盖的问题)
|
|
||||||
if not has_fk_to_parent:
|
|
||||||
for base in bases:
|
|
||||||
if hasattr(base, 'model_fields'):
|
|
||||||
for field_name, field_info in base.model_fields.items():
|
|
||||||
fk = getattr(field_info, 'foreign_key', None)
|
|
||||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
|
||||||
has_fk_to_parent = True
|
|
||||||
break
|
|
||||||
if has_fk_to_parent:
|
|
||||||
break
|
|
||||||
|
|
||||||
is_joined_inheritance = has_different_tablename and has_fk_to_parent
|
|
||||||
|
|
||||||
if is_joined_inheritance:
|
|
||||||
# 联表继承:需要创建子表
|
|
||||||
|
|
||||||
# 修复外键字段:由于字段合并,外键信息可能丢失
|
|
||||||
# 需要从基类的 mixin 中找回外键信息,并重建列
|
|
||||||
from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID as SA_UUID
|
|
||||||
from sqlalchemy.exc import NoInspectionAvailable
|
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
||||||
|
|
||||||
# 联表继承:子表只应该有 id(FK 到父表)+ 子类特有的字段
|
|
||||||
# 所有继承自祖先表的列都不应该在子表中重复创建
|
|
||||||
|
|
||||||
# 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
|
|
||||||
# 需要遍历整个 MRO,因为可能是多级继承(如 Tool -> Function -> GetWeatherFunction)
|
|
||||||
ancestor_column_names: set[str] = set()
|
|
||||||
for ancestor in cls.__mro__:
|
|
||||||
if ancestor is cls:
|
|
||||||
continue # 跳过当前类
|
|
||||||
if is_table_model_class(ancestor):
|
|
||||||
try:
|
|
||||||
# 使用 inspect() 获取 mapper 的公开属性
|
|
||||||
# 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
|
|
||||||
mapper = sa_inspect(ancestor)
|
|
||||||
for col in mapper.local_table.columns:
|
|
||||||
# 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
|
|
||||||
if col.name.startswith('_polymorphic'):
|
|
||||||
continue
|
|
||||||
ancestor_column_names.add(col.name)
|
|
||||||
except NoInspectionAvailable:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 找到子类自己定义的字段(不在父类中的)
|
|
||||||
child_own_fields: set[str] = set()
|
|
||||||
for field_name in cls.model_fields:
|
|
||||||
# 检查这个字段是否是在当前类直接定义的(不是继承的)
|
|
||||||
# 通过检查父类是否有这个字段来判断
|
|
||||||
is_inherited = False
|
|
||||||
for base in bases:
|
|
||||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
|
||||||
is_inherited = True
|
|
||||||
break
|
|
||||||
if not is_inherited:
|
|
||||||
child_own_fields.add(field_name)
|
|
||||||
|
|
||||||
# 从子类类属性中移除父表已有的列定义
|
|
||||||
# 这样 SQLAlchemy 就不会在子表中创建这些列
|
|
||||||
fk_field_name = None
|
|
||||||
for base in bases:
|
|
||||||
if hasattr(base, 'model_fields'):
|
|
||||||
for field_name, field_info in base.model_fields.items():
|
|
||||||
fk = getattr(field_info, 'foreign_key', None)
|
|
||||||
pk = getattr(field_info, 'primary_key', False)
|
|
||||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
|
||||||
fk_field_name = field_name
|
|
||||||
# 找到了外键字段,重建它
|
|
||||||
# 创建一个新的 Column 对象包含外键约束
|
|
||||||
new_col = Column(
|
|
||||||
field_name,
|
|
||||||
SA_UUID(as_uuid=True),
|
|
||||||
ForeignKey(fk),
|
|
||||||
primary_key=pk if pk else False
|
|
||||||
)
|
|
||||||
setattr(cls, field_name, new_col)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
|
|
||||||
# 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
|
|
||||||
# 这防止 SQLAlchemy 在子表中创建重复列
|
|
||||||
# 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
|
|
||||||
for col_name in ancestor_column_names:
|
|
||||||
if col_name == fk_field_name:
|
|
||||||
continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
|
|
||||||
if col_name == 'id':
|
|
||||||
continue # id 会被 FK 字段覆盖
|
|
||||||
if col_name in child_own_fields:
|
|
||||||
continue # 保留子类自己定义的字段
|
|
||||||
|
|
||||||
# 检查类属性是否是 Column 或 InstrumentedAttribute
|
|
||||||
if col_name in cls.__dict__:
|
|
||||||
attr = cls.__dict__[col_name]
|
|
||||||
# Column 对象或 InstrumentedAttribute 都需要删除
|
|
||||||
if isinstance(attr, (Column, InstrumentedAttribute)):
|
|
||||||
try:
|
|
||||||
delattr(cls, col_name)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 找到子类自己定义的关系(不在父类中的)
|
|
||||||
# 继承的关系会从父类自动获取,只需要设置子类新增的关系
|
|
||||||
child_own_relationships: set[str] = set()
|
|
||||||
for rel_name in cls.__sqlmodel_relationships__:
|
|
||||||
is_inherited = False
|
|
||||||
for base in bases:
|
|
||||||
if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
|
|
||||||
is_inherited = True
|
|
||||||
break
|
|
||||||
if not is_inherited:
|
|
||||||
child_own_relationships.add(rel_name)
|
|
||||||
|
|
||||||
# 只为子类自己定义的新关系调用关系设置
|
|
||||||
if child_own_relationships:
|
|
||||||
cls._setup_relationships(only_these=child_own_relationships)
|
|
||||||
|
|
||||||
# 强制调用 DeclarativeMeta.__init__
|
|
||||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
|
||||||
else:
|
|
||||||
# 非联表继承:单表继承或正常 Pydantic 模型
|
|
||||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
|
||||||
|
|
||||||
def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
|
|
||||||
"""
|
|
||||||
设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
|
|
||||||
如果为 None,设置所有关系(默认行为)
|
|
||||||
"""
|
|
||||||
from sqlalchemy.orm import relationship, Mapped
|
|
||||||
from sqlalchemy import inspect
|
|
||||||
from sqlmodel.main import get_relationship_to
|
|
||||||
from typing import get_origin
|
|
||||||
|
|
||||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
|
||||||
# 如果指定了 only_these,只设置这些关系
|
|
||||||
if only_these is not None and rel_name not in only_these:
|
|
||||||
continue
|
|
||||||
if rel_info.sa_relationship:
|
|
||||||
setattr(cls, rel_name, rel_info.sa_relationship)
|
|
||||||
continue
|
|
||||||
|
|
||||||
raw_ann = cls.__annotations__[rel_name]
|
|
||||||
origin: typing.Any = get_origin(raw_ann)
|
|
||||||
if origin is Mapped:
|
|
||||||
ann = raw_ann.__args__[0]
|
|
||||||
else:
|
|
||||||
ann = raw_ann
|
|
||||||
cls.__annotations__[rel_name] = Mapped[ann]
|
|
||||||
|
|
||||||
relationship_to = get_relationship_to(
|
|
||||||
name=rel_name, rel_info=rel_info, annotation=ann
|
|
||||||
)
|
|
||||||
rel_kwargs: dict[str, typing.Any] = {}
|
|
||||||
if rel_info.back_populates:
|
|
||||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
|
||||||
if rel_info.cascade_delete:
|
|
||||||
rel_kwargs["cascade"] = "all, delete-orphan"
|
|
||||||
if rel_info.passive_deletes:
|
|
||||||
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
|
|
||||||
if rel_info.link_model:
|
|
||||||
ins = inspect(rel_info.link_model)
|
|
||||||
local_table = getattr(ins, "local_table")
|
|
||||||
if local_table is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Couldn't find secondary table for {rel_info.link_model}"
|
|
||||||
)
|
|
||||||
rel_kwargs["secondary"] = local_table
|
|
||||||
|
|
||||||
rel_args: list[typing.Any] = []
|
|
||||||
if rel_info.sa_relationship_args:
|
|
||||||
rel_args.extend(rel_info.sa_relationship_args)
|
|
||||||
if rel_info.sa_relationship_kwargs:
|
|
||||||
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
|
||||||
|
|
||||||
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
|
|
||||||
setattr(cls, rel_name, rel_value)
|
|
||||||
|
|
||||||
|
|
||||||
class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
|
|
||||||
"""此类必须和TableBase系列类搭配使用"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)
|
|
||||||
@@ -1,7 +1,71 @@
|
|||||||
from .base import SQLModelBase
|
from enum import StrEnum
|
||||||
|
|
||||||
class ThemeResponse(SQLModelBase):
|
from sqlmodel_ext import SQLModelBase
|
||||||
"""主题响应 DTO"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
class ChromaticColor(StrEnum):
|
||||||
|
"""有彩色枚举(17种 Tailwind 调色板颜色)"""
|
||||||
|
|
||||||
|
RED = "red"
|
||||||
|
ORANGE = "orange"
|
||||||
|
AMBER = "amber"
|
||||||
|
YELLOW = "yellow"
|
||||||
|
LIME = "lime"
|
||||||
|
GREEN = "green"
|
||||||
|
EMERALD = "emerald"
|
||||||
|
TEAL = "teal"
|
||||||
|
CYAN = "cyan"
|
||||||
|
SKY = "sky"
|
||||||
|
BLUE = "blue"
|
||||||
|
INDIGO = "indigo"
|
||||||
|
VIOLET = "violet"
|
||||||
|
PURPLE = "purple"
|
||||||
|
FUCHSIA = "fuchsia"
|
||||||
|
PINK = "pink"
|
||||||
|
ROSE = "rose"
|
||||||
|
|
||||||
|
|
||||||
|
class NeutralColor(StrEnum):
|
||||||
|
"""无彩色枚举(5种灰色调)"""
|
||||||
|
|
||||||
|
SLATE = "slate"
|
||||||
|
GRAY = "gray"
|
||||||
|
ZINC = "zinc"
|
||||||
|
NEUTRAL = "neutral"
|
||||||
|
STONE = "stone"
|
||||||
|
|
||||||
|
|
||||||
|
class ThemeColorsBase(SQLModelBase):
|
||||||
|
"""嵌套颜色 DTO,API 请求/响应层使用"""
|
||||||
|
|
||||||
|
primary: ChromaticColor
|
||||||
|
"""主色调"""
|
||||||
|
|
||||||
|
secondary: ChromaticColor
|
||||||
|
"""辅助色"""
|
||||||
|
|
||||||
|
success: ChromaticColor
|
||||||
|
"""成功色"""
|
||||||
|
|
||||||
|
info: ChromaticColor
|
||||||
|
"""信息色"""
|
||||||
|
|
||||||
|
warning: ChromaticColor
|
||||||
|
"""警告色"""
|
||||||
|
|
||||||
|
error: ChromaticColor
|
||||||
|
"""错误色"""
|
||||||
|
|
||||||
|
neutral: NeutralColor
|
||||||
|
"""中性色"""
|
||||||
|
|
||||||
|
|
||||||
|
BUILTIN_DEFAULT_COLORS = ThemeColorsBase(
|
||||||
|
primary=ChromaticColor.GREEN,
|
||||||
|
secondary=ChromaticColor.BLUE,
|
||||||
|
success=ChromaticColor.GREEN,
|
||||||
|
info=ChromaticColor.BLUE,
|
||||||
|
warning=ChromaticColor.YELLOW,
|
||||||
|
error=ChromaticColor.RED,
|
||||||
|
neutral=NeutralColor.ZINC,
|
||||||
|
)
|
||||||
|
|||||||
135
sqlmodels/custom_property.py
Normal file
135
sqlmodels/custom_property.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
用户自定义属性定义模型
|
||||||
|
|
||||||
|
允许用户定义类型化的自定义属性模板(如标签、评分、分类等),
|
||||||
|
实际值通过 ObjectMetadata KV 表存储,键名格式:custom:{property_definition_id}。
|
||||||
|
|
||||||
|
支持的属性类型:text, number, boolean, select, multi_select, rating, link
|
||||||
|
"""
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import JSON
|
||||||
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 枚举 ====================
|
||||||
|
|
||||||
|
class CustomPropertyType(StrEnum):
|
||||||
|
"""自定义属性值类型枚举"""
|
||||||
|
TEXT = "text"
|
||||||
|
"""文本"""
|
||||||
|
NUMBER = "number"
|
||||||
|
"""数字"""
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
"""布尔值"""
|
||||||
|
SELECT = "select"
|
||||||
|
"""单选"""
|
||||||
|
MULTI_SELECT = "multi_select"
|
||||||
|
"""多选"""
|
||||||
|
RATING = "rating"
|
||||||
|
"""评分(1-5)"""
|
||||||
|
LINK = "link"
|
||||||
|
"""链接"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Base 模型 ====================
|
||||||
|
|
||||||
|
class CustomPropertyDefinitionBase(SQLModelBase):
|
||||||
|
"""自定义属性定义基础模型"""
|
||||||
|
|
||||||
|
name: Str100
|
||||||
|
"""属性显示名称"""
|
||||||
|
|
||||||
|
type: CustomPropertyType
|
||||||
|
"""属性值类型"""
|
||||||
|
|
||||||
|
icon: Str100 | None = None
|
||||||
|
"""图标标识(iconify 名称)"""
|
||||||
|
|
||||||
|
options: list[str] | None = Field(default=None, sa_type=JSON)
|
||||||
|
"""可选值列表(仅 select/multi_select 类型)"""
|
||||||
|
|
||||||
|
default_value: str | None = Field(default=None, max_length=500)
|
||||||
|
"""默认值"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
class CustomPropertyDefinition(CustomPropertyDefinitionBase, UUIDTableBaseMixin):
|
||||||
|
"""
|
||||||
|
用户自定义属性定义
|
||||||
|
|
||||||
|
每个用户独立管理自己的属性模板。
|
||||||
|
实际属性值存储在 ObjectMetadata 表中,键名格式:custom:{id}。
|
||||||
|
"""
|
||||||
|
|
||||||
|
owner_id: UUID = Field(
|
||||||
|
foreign_key="user.id",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
"""所有者用户UUID"""
|
||||||
|
|
||||||
|
sort_order: int = 0
|
||||||
|
"""排序顺序"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
owner: "User" = Relationship()
|
||||||
|
"""所有者"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class CustomPropertyCreateRequest(SQLModelBase):
|
||||||
|
"""创建自定义属性请求 DTO"""
|
||||||
|
|
||||||
|
name: Str100
|
||||||
|
"""属性显示名称"""
|
||||||
|
|
||||||
|
type: CustomPropertyType
|
||||||
|
"""属性值类型"""
|
||||||
|
|
||||||
|
icon: str | None = None
|
||||||
|
"""图标标识"""
|
||||||
|
|
||||||
|
options: list[str] | None = None
|
||||||
|
"""可选值列表(仅 select/multi_select 类型)"""
|
||||||
|
|
||||||
|
default_value: str | None = None
|
||||||
|
"""默认值"""
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPropertyUpdateRequest(SQLModelBase):
|
||||||
|
"""更新自定义属性请求 DTO"""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
"""属性显示名称"""
|
||||||
|
|
||||||
|
icon: str | None = None
|
||||||
|
"""图标标识"""
|
||||||
|
|
||||||
|
options: list[str] | None = None
|
||||||
|
"""可选值列表"""
|
||||||
|
|
||||||
|
default_value: str | None = None
|
||||||
|
"""默认值"""
|
||||||
|
|
||||||
|
sort_order: int | None = None
|
||||||
|
"""排序顺序"""
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPropertyResponse(CustomPropertyDefinitionBase):
|
||||||
|
"""自定义属性响应 DTO"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""属性定义UUID"""
|
||||||
|
|
||||||
|
sort_order: int
|
||||||
|
"""排序顺序"""
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
from sqlmodel import SQLModel
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
from utils.conf import appmeta
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
ASYNC_DATABASE_URL = appmeta.database_url
|
|
||||||
|
|
||||||
engine: AsyncEngine = create_async_engine(
|
|
||||||
ASYNC_DATABASE_URL,
|
|
||||||
echo=appmeta.debug,
|
|
||||||
connect_args={
|
|
||||||
"check_same_thread": False
|
|
||||||
} if ASYNC_DATABASE_URL.startswith("sqlite") else {},
|
|
||||||
future=True,
|
|
||||||
# pool_size=POOL_SIZE,
|
|
||||||
# max_overflow=64,
|
|
||||||
)
|
|
||||||
|
|
||||||
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
|
|
||||||
|
|
||||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
async with _async_session_factory() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
async def init_db(
|
|
||||||
url: str = ASYNC_DATABASE_URL
|
|
||||||
):
|
|
||||||
"""创建数据库结构"""
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(SQLModel.metadata.create_all)
|
|
||||||
|
|
||||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin, Str255
|
||||||
from .mixin import UUIDTableBaseMixin, TableBaseMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -142,7 +141,7 @@ class Download(DownloadBase, UUIDTableBaseMixin):
|
|||||||
speed: int = Field(default=0)
|
speed: int = Field(default=0)
|
||||||
"""下载速度(bytes/s)"""
|
"""下载速度(bytes/s)"""
|
||||||
|
|
||||||
parent: str | None = Field(default=None, max_length=255)
|
parent: Str255 | None = None
|
||||||
"""父任务标识"""
|
"""父任务标识"""
|
||||||
|
|
||||||
error: str | None = Field(default=None)
|
error: str | None = Field(default=None)
|
||||||
|
|||||||
435
sqlmodels/file_app.py
Normal file
435
sqlmodels/file_app.py
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
"""
|
||||||
|
文件查看器应用模块
|
||||||
|
|
||||||
|
提供文件预览应用选择器系统的数据模型和 DTO。
|
||||||
|
类似 Android 的"使用什么应用打开"机制:
|
||||||
|
- 管理员注册应用(内置/iframe/WOPI)
|
||||||
|
- 用户按扩展名查询可用查看器
|
||||||
|
- 用户可设置"始终使用"偏好
|
||||||
|
- 支持用户组级别的访问控制
|
||||||
|
|
||||||
|
架构:
|
||||||
|
FileApp (应用注册表)
|
||||||
|
├── FileAppExtension (扩展名关联)
|
||||||
|
├── FileAppGroupLink (用户组访问控制)
|
||||||
|
└── UserFileAppDefault (用户默认偏好)
|
||||||
|
"""
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||||
|
|
||||||
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str100, Str255, Text1024
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .group import Group
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 枚举 ====================
|
||||||
|
|
||||||
|
class FileAppType(StrEnum):
|
||||||
|
"""文件应用类型"""
|
||||||
|
|
||||||
|
BUILTIN = "builtin"
|
||||||
|
"""前端内置查看器(如 pdf.js, Monaco)"""
|
||||||
|
|
||||||
|
IFRAME = "iframe"
|
||||||
|
"""iframe 内嵌第三方服务"""
|
||||||
|
|
||||||
|
WOPI = "wopi"
|
||||||
|
"""WOPI 协议(OnlyOffice / Collabora)"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Link 表 ====================
|
||||||
|
|
||||||
|
class FileAppGroupLink(SQLModelBase, table=True):
|
||||||
|
"""应用-用户组访问控制关联表"""
|
||||||
|
|
||||||
|
app_id: UUID = Field(foreign_key="fileapp.id", primary_key=True, ondelete="CASCADE")
|
||||||
|
"""关联的应用UUID"""
|
||||||
|
|
||||||
|
group_id: UUID = Field(foreign_key="group.id", primary_key=True, ondelete="CASCADE")
|
||||||
|
"""关联的用户组UUID"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class FileAppSummary(SQLModelBase):
|
||||||
|
"""查看器列表项 DTO,用于选择器弹窗"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""应用UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""应用名称"""
|
||||||
|
|
||||||
|
app_key: str
|
||||||
|
"""应用唯一标识"""
|
||||||
|
|
||||||
|
type: FileAppType
|
||||||
|
"""应用类型"""
|
||||||
|
|
||||||
|
icon: str | None = None
|
||||||
|
"""图标名称/URL"""
|
||||||
|
|
||||||
|
description: str | None = None
|
||||||
|
"""应用描述"""
|
||||||
|
|
||||||
|
iframe_url_template: str | None = None
|
||||||
|
"""iframe URL 模板"""
|
||||||
|
|
||||||
|
wopi_editor_url_template: str | None = None
|
||||||
|
"""WOPI 编辑器 URL 模板"""
|
||||||
|
|
||||||
|
|
||||||
|
class FileViewersResponse(SQLModelBase):
|
||||||
|
"""查看器查询响应 DTO"""
|
||||||
|
|
||||||
|
viewers: list[FileAppSummary] = []
|
||||||
|
"""可用查看器列表(已按 priority 排序)"""
|
||||||
|
|
||||||
|
default_viewer_id: UUID | None = None
|
||||||
|
"""用户默认查看器UUID(如果已设置"始终使用")"""
|
||||||
|
|
||||||
|
|
||||||
|
class SetDefaultViewerRequest(SQLModelBase):
|
||||||
|
"""设置默认查看器请求 DTO"""
|
||||||
|
|
||||||
|
extension: str = Field(max_length=20)
|
||||||
|
"""文件扩展名(小写,无点号)"""
|
||||||
|
|
||||||
|
app_id: UUID
|
||||||
|
"""应用UUID"""
|
||||||
|
|
||||||
|
|
||||||
|
class UserFileAppDefaultResponse(SQLModelBase):
|
||||||
|
"""用户默认查看器响应 DTO"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""记录UUID"""
|
||||||
|
|
||||||
|
extension: str
|
||||||
|
"""扩展名"""
|
||||||
|
|
||||||
|
app: FileAppSummary
|
||||||
|
"""关联的应用摘要"""
|
||||||
|
|
||||||
|
|
||||||
|
class FileAppCreateRequest(SQLModelBase):
|
||||||
|
"""管理员创建应用请求 DTO"""
|
||||||
|
|
||||||
|
name: Str100
|
||||||
|
"""应用名称"""
|
||||||
|
|
||||||
|
app_key: str = Field(max_length=50)
|
||||||
|
"""应用唯一标识"""
|
||||||
|
|
||||||
|
type: FileAppType
|
||||||
|
"""应用类型"""
|
||||||
|
|
||||||
|
icon: Str255 | None = None
|
||||||
|
"""图标名称/URL"""
|
||||||
|
|
||||||
|
description: str | None = Field(default=None, max_length=500)
|
||||||
|
"""应用描述"""
|
||||||
|
|
||||||
|
is_enabled: bool = True
|
||||||
|
"""是否启用"""
|
||||||
|
|
||||||
|
is_restricted: bool = False
|
||||||
|
"""是否限制用户组访问"""
|
||||||
|
|
||||||
|
iframe_url_template: Text1024 | None = None
|
||||||
|
"""iframe URL 模板"""
|
||||||
|
|
||||||
|
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||||
|
"""WOPI 发现端点 URL"""
|
||||||
|
|
||||||
|
wopi_editor_url_template: Text1024 | None = None
|
||||||
|
"""WOPI 编辑器 URL 模板"""
|
||||||
|
|
||||||
|
extensions: list[str] = []
|
||||||
|
"""关联的扩展名列表"""
|
||||||
|
|
||||||
|
allowed_group_ids: list[UUID] = []
|
||||||
|
"""允许访问的用户组UUID列表"""
|
||||||
|
|
||||||
|
|
||||||
|
class FileAppUpdateRequest(SQLModelBase):
|
||||||
|
"""管理员更新应用请求 DTO(所有字段可选)"""
|
||||||
|
|
||||||
|
name: Str100 | None = None
|
||||||
|
"""应用名称"""
|
||||||
|
|
||||||
|
app_key: str | None = Field(default=None, max_length=50)
|
||||||
|
"""应用唯一标识"""
|
||||||
|
|
||||||
|
type: FileAppType | None = None
|
||||||
|
"""应用类型"""
|
||||||
|
|
||||||
|
icon: Str255 | None = None
|
||||||
|
"""图标名称/URL"""
|
||||||
|
|
||||||
|
description: str | None = Field(default=None, max_length=500)
|
||||||
|
"""应用描述"""
|
||||||
|
|
||||||
|
is_enabled: bool | None = None
|
||||||
|
"""是否启用"""
|
||||||
|
|
||||||
|
is_restricted: bool | None = None
|
||||||
|
"""是否限制用户组访问"""
|
||||||
|
|
||||||
|
iframe_url_template: Text1024 | None = None
|
||||||
|
"""iframe URL 模板"""
|
||||||
|
|
||||||
|
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||||
|
"""WOPI 发现端点 URL"""
|
||||||
|
|
||||||
|
wopi_editor_url_template: Text1024 | None = None
|
||||||
|
"""WOPI 编辑器 URL 模板"""
|
||||||
|
|
||||||
|
|
||||||
|
class FileAppResponse(SQLModelBase):
|
||||||
|
"""管理员应用详情响应 DTO"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""应用UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""应用名称"""
|
||||||
|
|
||||||
|
app_key: str
|
||||||
|
"""应用唯一标识"""
|
||||||
|
|
||||||
|
type: FileAppType
|
||||||
|
"""应用类型"""
|
||||||
|
|
||||||
|
icon: str | None = None
|
||||||
|
"""图标名称/URL"""
|
||||||
|
|
||||||
|
description: str | None = None
|
||||||
|
"""应用描述"""
|
||||||
|
|
||||||
|
is_enabled: bool = True
|
||||||
|
"""是否启用"""
|
||||||
|
|
||||||
|
is_restricted: bool = False
|
||||||
|
"""是否限制用户组访问"""
|
||||||
|
|
||||||
|
iframe_url_template: str | None = None
|
||||||
|
"""iframe URL 模板"""
|
||||||
|
|
||||||
|
wopi_discovery_url: str | None = None
|
||||||
|
"""WOPI 发现端点 URL"""
|
||||||
|
|
||||||
|
wopi_editor_url_template: str | None = None
|
||||||
|
"""WOPI 编辑器 URL 模板"""
|
||||||
|
|
||||||
|
extensions: list[str] = []
|
||||||
|
"""关联的扩展名列表"""
|
||||||
|
|
||||||
|
allowed_group_ids: list[UUID] = []
|
||||||
|
"""允许访问的用户组UUID列表"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_app(
|
||||||
|
cls,
|
||||||
|
app: "FileApp",
|
||||||
|
extensions: list["FileAppExtension"],
|
||||||
|
group_links: list[FileAppGroupLink],
|
||||||
|
) -> "FileAppResponse":
|
||||||
|
"""从 ORM 对象构建 DTO"""
|
||||||
|
return cls(
|
||||||
|
id=app.id,
|
||||||
|
name=app.name,
|
||||||
|
app_key=app.app_key,
|
||||||
|
type=app.type,
|
||||||
|
icon=app.icon,
|
||||||
|
description=app.description,
|
||||||
|
is_enabled=app.is_enabled,
|
||||||
|
is_restricted=app.is_restricted,
|
||||||
|
iframe_url_template=app.iframe_url_template,
|
||||||
|
wopi_discovery_url=app.wopi_discovery_url,
|
||||||
|
wopi_editor_url_template=app.wopi_editor_url_template,
|
||||||
|
extensions=[ext.extension for ext in extensions],
|
||||||
|
allowed_group_ids=[link.group_id for link in group_links],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileAppListResponse(SQLModelBase):
|
||||||
|
"""管理员应用列表响应 DTO"""
|
||||||
|
|
||||||
|
apps: list[FileAppResponse] = []
|
||||||
|
"""应用列表"""
|
||||||
|
|
||||||
|
total: int = 0
|
||||||
|
"""总数"""
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionUpdateRequest(SQLModelBase):
|
||||||
|
"""扩展名全量替换请求 DTO"""
|
||||||
|
|
||||||
|
extensions: list[str]
|
||||||
|
"""扩展名列表(小写,无点号)"""
|
||||||
|
|
||||||
|
|
||||||
|
class GroupAccessUpdateRequest(SQLModelBase):
|
||||||
|
"""用户组权限全量替换请求 DTO"""
|
||||||
|
|
||||||
|
group_ids: list[UUID]
|
||||||
|
"""允许访问的用户组UUID列表"""
|
||||||
|
|
||||||
|
|
||||||
|
class WopiSessionResponse(SQLModelBase):
|
||||||
|
"""WOPI 会话响应 DTO"""
|
||||||
|
|
||||||
|
wopi_src: str
|
||||||
|
"""WOPI 源 URL"""
|
||||||
|
|
||||||
|
access_token: str
|
||||||
|
"""WOPI 访问令牌"""
|
||||||
|
|
||||||
|
access_token_ttl: int
|
||||||
|
"""令牌过期时间戳(毫秒,WOPI 规范要求)"""
|
||||||
|
|
||||||
|
editor_url: str
|
||||||
|
"""完整的编辑器 URL"""
|
||||||
|
|
||||||
|
|
||||||
|
class WopiDiscoveredExtension(SQLModelBase):
|
||||||
|
"""单个 WOPI Discovery 发现的扩展名"""
|
||||||
|
|
||||||
|
extension: str
|
||||||
|
"""文件扩展名"""
|
||||||
|
|
||||||
|
action_url: str
|
||||||
|
"""处理后的动作 URL 模板"""
|
||||||
|
|
||||||
|
|
||||||
|
class WopiDiscoveryResponse(SQLModelBase):
|
||||||
|
"""WOPI Discovery 结果响应 DTO"""
|
||||||
|
|
||||||
|
discovered_extensions: list[WopiDiscoveredExtension] = []
|
||||||
|
"""发现的扩展名及其 URL 模板"""
|
||||||
|
|
||||||
|
app_names: list[str] = []
|
||||||
|
"""WOPI 服务端报告的应用名称(如 Writer、Calc、Impress)"""
|
||||||
|
|
||||||
|
applied_count: int = 0
|
||||||
|
"""已应用到 FileAppExtension 的数量"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
||||||
|
"""文件查看器应用注册表"""
|
||||||
|
|
||||||
|
name: Str100
|
||||||
|
"""应用名称"""
|
||||||
|
|
||||||
|
app_key: str = Field(max_length=50, unique=True, index=True)
|
||||||
|
"""应用唯一标识,前端路由用"""
|
||||||
|
|
||||||
|
type: FileAppType
|
||||||
|
"""应用类型"""
|
||||||
|
|
||||||
|
icon: Str255 | None = None
|
||||||
|
"""图标名称/URL"""
|
||||||
|
|
||||||
|
description: str | None = Field(default=None, max_length=500)
|
||||||
|
"""应用描述"""
|
||||||
|
|
||||||
|
is_enabled: bool = True
|
||||||
|
"""是否启用"""
|
||||||
|
|
||||||
|
is_restricted: bool = False
|
||||||
|
"""是否限制用户组访问"""
|
||||||
|
|
||||||
|
iframe_url_template: Text1024 | None = None
|
||||||
|
"""iframe URL 模板,支持 {file_url} 占位符"""
|
||||||
|
|
||||||
|
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||||
|
"""WOPI 客户端发现端点 URL"""
|
||||||
|
|
||||||
|
wopi_editor_url_template: Text1024 | None = None
|
||||||
|
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
extensions: list["FileAppExtension"] = Relationship(
|
||||||
|
back_populates="app",
|
||||||
|
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||||
|
)
|
||||||
|
|
||||||
|
user_defaults: list["UserFileAppDefault"] = Relationship(
|
||||||
|
back_populates="app",
|
||||||
|
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_groups: list["Group"] = Relationship(
|
||||||
|
link_model=FileAppGroupLink,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_summary(self) -> FileAppSummary:
|
||||||
|
"""转换为摘要 DTO"""
|
||||||
|
return FileAppSummary(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
app_key=self.app_key,
|
||||||
|
type=self.type,
|
||||||
|
icon=self.icon,
|
||||||
|
description=self.description,
|
||||||
|
iframe_url_template=self.iframe_url_template,
|
||||||
|
wopi_editor_url_template=self.wopi_editor_url_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileAppExtension(SQLModelBase, TableBaseMixin):
|
||||||
|
"""扩展名关联表"""
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("app_id", "extension", name="uq_fileappextension_app_extension"),
|
||||||
|
)
|
||||||
|
|
||||||
|
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
|
||||||
|
"""关联的应用UUID"""
|
||||||
|
|
||||||
|
extension: str = Field(max_length=20, index=True)
|
||||||
|
"""扩展名(小写,无点号)"""
|
||||||
|
|
||||||
|
priority: int = Field(default=0, ge=0)
|
||||||
|
"""排序优先级(越小越优先)"""
|
||||||
|
|
||||||
|
wopi_action_url: str | None = Field(default=None, max_length=2048)
|
||||||
|
"""WOPI 动作 URL 模板(Discovery 自动填充),支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
app: FileApp = Relationship(back_populates="extensions")
|
||||||
|
|
||||||
|
|
||||||
|
class UserFileAppDefault(SQLModelBase, UUIDTableBaseMixin):
|
||||||
|
"""用户"始终使用"偏好"""
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("user_id", "extension", name="uq_userfileappdefault_user_extension"),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
|
||||||
|
"""用户UUID"""
|
||||||
|
|
||||||
|
extension: str = Field(max_length=20)
|
||||||
|
"""扩展名(小写,无点号)"""
|
||||||
|
|
||||||
|
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
|
||||||
|
"""关联的应用UUID"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
app: FileApp = Relationship(back_populates="user_defaults")
|
||||||
|
|
||||||
|
def to_response(self) -> UserFileAppDefaultResponse:
|
||||||
|
"""转换为响应 DTO(需预加载 app 关系)"""
|
||||||
|
return UserFileAppDefaultResponse(
|
||||||
|
id=self.id,
|
||||||
|
extension=self.extension,
|
||||||
|
app=self.app.to_summary(),
|
||||||
|
)
|
||||||
@@ -2,10 +2,10 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger
|
||||||
from sqlmodel import Field, Relationship, text
|
from sqlmodel import Field, Relationship, text
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str255
|
||||||
from .mixin import TableBaseMixin, UUIDTableBaseMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -67,7 +67,7 @@ class GroupAllOptionsBase(GroupOptionsBase):
|
|||||||
class GroupCreateRequest(GroupAllOptionsBase):
|
class GroupCreateRequest(GroupAllOptionsBase):
|
||||||
"""创建用户组请求 DTO"""
|
"""创建用户组请求 DTO"""
|
||||||
|
|
||||||
name: str = Field(max_length=255)
|
name: Str255
|
||||||
"""用户组名称"""
|
"""用户组名称"""
|
||||||
|
|
||||||
max_storage: int = Field(default=0, ge=0)
|
max_storage: int = Field(default=0, ge=0)
|
||||||
@@ -92,7 +92,7 @@ class GroupCreateRequest(GroupAllOptionsBase):
|
|||||||
class GroupUpdateRequest(SQLModelBase):
|
class GroupUpdateRequest(SQLModelBase):
|
||||||
"""更新用户组请求 DTO(所有字段可选)"""
|
"""更新用户组请求 DTO(所有字段可选)"""
|
||||||
|
|
||||||
name: str | None = Field(default=None, max_length=255)
|
name: Str255 | None = None
|
||||||
"""用户组名称"""
|
"""用户组名称"""
|
||||||
|
|
||||||
max_storage: int | None = Field(default=None, ge=0)
|
max_storage: int | None = Field(default=None, ge=0)
|
||||||
@@ -258,10 +258,10 @@ class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
|
|||||||
class Group(GroupBase, UUIDTableBaseMixin):
|
class Group(GroupBase, UUIDTableBaseMixin):
|
||||||
"""用户组模型"""
|
"""用户组模型"""
|
||||||
|
|
||||||
name: str = Field(max_length=255, unique=True)
|
name: Str255 = Field(unique=True)
|
||||||
"""用户组名"""
|
"""用户组名"""
|
||||||
|
|
||||||
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||||
"""最大存储空间(字节)"""
|
"""最大存储空间(字节)"""
|
||||||
|
|
||||||
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,543 +0,0 @@
|
|||||||
# SQLModel Mixin Module
|
|
||||||
|
|
||||||
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
|
|
||||||
|
|
||||||
## Module Overview
|
|
||||||
|
|
||||||
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
|
|
||||||
|
|
||||||
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
|
|
||||||
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
|
|
||||||
- **JWT Authentication**: Token generation and validation
|
|
||||||
- **Pagination & Sorting**: Standardized table view parameters
|
|
||||||
- **Response DTOs**: Consistent id/timestamp fields for API responses
|
|
||||||
|
|
||||||
## Module Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
sqlmodels/mixin/
|
|
||||||
├── __init__.py # Module exports
|
|
||||||
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
|
|
||||||
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
|
|
||||||
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
|
|
||||||
└── jwt/ # JWT authentication
|
|
||||||
├── __init__.py
|
|
||||||
├── key.py # JWTKey database model
|
|
||||||
├── payload.py # JWTPayloadBase
|
|
||||||
├── manager.py # JWTManager singleton
|
|
||||||
├── auth.py # JWTAuthMixin
|
|
||||||
├── exceptions.py # JWT-related exceptions
|
|
||||||
└── responses.py # TokenResponse DTO
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dependency Hierarchy
|
|
||||||
|
|
||||||
The module has a strict import order to avoid circular dependencies:
|
|
||||||
|
|
||||||
1. **polymorphic.py** - Only depends on `SQLModelBase`
|
|
||||||
2. **table.py** - Depends on `polymorphic.py`
|
|
||||||
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
|
|
||||||
4. **info_response.py** - Only depends on `SQLModelBase`
|
|
||||||
|
|
||||||
## Core Components
|
|
||||||
|
|
||||||
### 1. TableBaseMixin
|
|
||||||
|
|
||||||
Base mixin for database table models with integer primary keys.
|
|
||||||
|
|
||||||
**Features:**
|
|
||||||
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
|
|
||||||
- Automatic timestamp management (`created_at`, `updated_at`)
|
|
||||||
- Async relationship loading support (via `AsyncAttrs`)
|
|
||||||
- Pagination and sorting via `TableViewRequest`
|
|
||||||
- Polymorphic subclass loading support
|
|
||||||
|
|
||||||
**Fields:**
|
|
||||||
- `id: int | None` - Integer primary key (auto-increment)
|
|
||||||
- `created_at: datetime` - Record creation timestamp
|
|
||||||
- `updated_at: datetime` - Record update timestamp (auto-updated)
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import TableBaseMixin
|
|
||||||
from sqlmodels.base import SQLModelBase
|
|
||||||
|
|
||||||
class User(SQLModelBase, TableBaseMixin, table=True):
|
|
||||||
name: str
|
|
||||||
email: str
|
|
||||||
"""User email"""
|
|
||||||
|
|
||||||
# CRUD operations
|
|
||||||
async def example(session: AsyncSession):
|
|
||||||
# Add
|
|
||||||
user = User(name="Alice", email="alice@example.com")
|
|
||||||
user = await user.save(session)
|
|
||||||
|
|
||||||
# Get
|
|
||||||
user = await User.get(session, User.id == 1)
|
|
||||||
|
|
||||||
# Update
|
|
||||||
update_data = UserUpdateRequest(name="Alice Smith")
|
|
||||||
user = await user.update(session, update_data)
|
|
||||||
|
|
||||||
# Delete
|
|
||||||
await User.delete(session, user)
|
|
||||||
|
|
||||||
# Count
|
|
||||||
count = await User.count(session, User.is_active == True)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Important Notes:**
|
|
||||||
- `save()` and `update()` return refreshed instances - **always use the return value**:
|
|
||||||
```python
|
|
||||||
# ✅ Correct
|
|
||||||
device = await device.save(session)
|
|
||||||
return device
|
|
||||||
|
|
||||||
# ❌ Wrong - device is expired after commit
|
|
||||||
await device.save(session)
|
|
||||||
return device
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. UUIDTableBaseMixin
|
|
||||||
|
|
||||||
Extends `TableBaseMixin` with UUID primary keys instead of integers.
|
|
||||||
|
|
||||||
**Differences from TableBaseMixin:**
|
|
||||||
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
|
|
||||||
- `get_exist_one()` accepts `UUID` instead of `int`
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
|
|
||||||
name: str
|
|
||||||
description: str | None = None
|
|
||||||
"""Character description"""
|
|
||||||
```
|
|
||||||
|
|
||||||
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
|
|
||||||
|
|
||||||
### 3. TableViewRequest
|
|
||||||
|
|
||||||
Standardized pagination and sorting parameters for LIST endpoints.
|
|
||||||
|
|
||||||
**Fields:**
|
|
||||||
- `offset: int | None` - Skip first N records (default: 0)
|
|
||||||
- `limit: int | None` - Return max N records (default: 50, max: 100)
|
|
||||||
- `desc: bool | None` - Sort descending (default: True)
|
|
||||||
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
|
|
||||||
|
|
||||||
**Usage with TableBaseMixin.get():**
|
|
||||||
```python
|
|
||||||
from dependencies import TableViewRequestDep
|
|
||||||
|
|
||||||
@router.get("/list")
|
|
||||||
async def list_characters(
|
|
||||||
session: SessionDep,
|
|
||||||
table_view: TableViewRequestDep
|
|
||||||
) -> list[Character]:
|
|
||||||
"""List characters with pagination and sorting"""
|
|
||||||
return await Character.get(
|
|
||||||
session,
|
|
||||||
fetch_mode="all",
|
|
||||||
table_view=table_view # Automatically handles pagination and sorting
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Manual usage:**
|
|
||||||
```python
|
|
||||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
|
||||||
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Backward Compatibility:**
|
|
||||||
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
|
|
||||||
|
|
||||||
### 4. PolymorphicBaseMixin
|
|
||||||
|
|
||||||
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
|
|
||||||
|
|
||||||
**Automatic Configuration:**
|
|
||||||
- Defines `_polymorphic_name: str` field (indexed)
|
|
||||||
- Sets `polymorphic_on='_polymorphic_name'`
|
|
||||||
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
|
|
||||||
|
|
||||||
**Methods:**
|
|
||||||
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
|
|
||||||
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
|
|
||||||
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
```python
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
|
|
||||||
"""Abstract base class for all tools"""
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
"""Tool description"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def execute(self, params: dict) -> dict:
|
|
||||||
"""Execute the tool"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**Why Single Underscore Prefix?**
|
|
||||||
- SQLAlchemy maps single-underscore fields to database columns
|
|
||||||
- Pydantic treats them as private (excluded from serialization)
|
|
||||||
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
|
|
||||||
|
|
||||||
### 5. create_subclass_id_mixin()
|
|
||||||
|
|
||||||
Factory function to create ID mixins for subclasses in joined table inheritance.
|
|
||||||
|
|
||||||
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
|
|
||||||
|
|
||||||
**Signature:**
|
|
||||||
```python
|
|
||||||
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A mixin class containing id field (foreign key + primary key)
|
|
||||||
"""
|
|
||||||
```
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import create_subclass_id_mixin
|
|
||||||
|
|
||||||
# Create mixin for ASR subclasses
|
|
||||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
|
||||||
|
|
||||||
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
|
|
||||||
"""FunASR implementation"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
|
|
||||||
|
|
||||||
### 6. AutoPolymorphicIdentityMixin
|
|
||||||
|
|
||||||
Automatically generates `polymorphic_identity` based on class name.
|
|
||||||
|
|
||||||
**Naming Convention:**
|
|
||||||
- Format: `{parent_identity}.{classname_lowercase}`
|
|
||||||
- If no parent identity exists, uses `{classname_lowercase}`
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
|
|
||||||
|
|
||||||
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
|
||||||
"""Base class for function-type tools"""
|
|
||||||
pass
|
|
||||||
# polymorphic_identity = 'function'
|
|
||||||
|
|
||||||
class GetWeatherFunction(Function, table=True):
|
|
||||||
"""Weather query function"""
|
|
||||||
pass
|
|
||||||
# polymorphic_identity = 'function.getweatherfunction'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Manual Override:**
|
|
||||||
```python
|
|
||||||
class CustomTool(
|
|
||||||
Tool,
|
|
||||||
AutoPolymorphicIdentityMixin,
|
|
||||||
polymorphic_identity='custom_name', # Override auto-generated name
|
|
||||||
table=True
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7. JWTAuthMixin
|
|
||||||
|
|
||||||
Provides JWT token generation and validation for entity classes (User, Client).
|
|
||||||
|
|
||||||
**Methods:**
|
|
||||||
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
|
|
||||||
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
|
|
||||||
|
|
||||||
**Requirements:**
|
|
||||||
Subclasses must define:
|
|
||||||
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
|
|
||||||
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
|
|
||||||
|
|
||||||
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
|
|
||||||
JWTPayload = UserJWTPayload # Define payload model
|
|
||||||
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
|
|
||||||
|
|
||||||
email: str
|
|
||||||
is_admin: bool = False
|
|
||||||
is_active: bool = True
|
|
||||||
"""User active status"""
|
|
||||||
|
|
||||||
# Generate token
|
|
||||||
async def login(session: AsyncSession, user: User) -> str:
|
|
||||||
token = await user.issue_jwt(session)
|
|
||||||
return token
|
|
||||||
|
|
||||||
# Validate token
|
|
||||||
async def verify(session: AsyncSession, token: str) -> User:
|
|
||||||
user = await User.from_jwt(session, token)
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8. Response DTO Mixins
|
|
||||||
|
|
||||||
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
|
|
||||||
|
|
||||||
**Available Mixins:**
|
|
||||||
- `IntIdInfoMixin` - Integer ID field
|
|
||||||
- `UUIDIdInfoMixin` - UUID ID field
|
|
||||||
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
|
|
||||||
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
|
|
||||||
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
|
|
||||||
|
|
||||||
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
|
|
||||||
|
|
||||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
|
||||||
"""Character response DTO with id and timestamps"""
|
|
||||||
pass # Inherits id, created_at, updated_at from mixin
|
|
||||||
```
|
|
||||||
|
|
||||||
## Complete Joined Table Inheritance Example
|
|
||||||
|
|
||||||
Here's a complete example demonstrating polymorphic inheritance:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from sqlmodels.base import SQLModelBase
|
|
||||||
from sqlmodels.mixin import (
|
|
||||||
UUIDTableBaseMixin,
|
|
||||||
PolymorphicBaseMixin,
|
|
||||||
create_subclass_id_mixin,
|
|
||||||
AutoPolymorphicIdentityMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Define Base class (fields only, no table)
|
|
||||||
class ASRBase(SQLModelBase):
|
|
||||||
name: str
|
|
||||||
"""Configuration name"""
|
|
||||||
|
|
||||||
base_url: str
|
|
||||||
"""Service URL"""
|
|
||||||
|
|
||||||
# 2. Define abstract parent class (with table)
|
|
||||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
|
||||||
"""Abstract base class for ASR configurations"""
|
|
||||||
# PolymorphicBaseMixin automatically provides:
|
|
||||||
# - _polymorphic_name field
|
|
||||||
# - polymorphic_on='_polymorphic_name'
|
|
||||||
# - polymorphic_abstract=True (when ABC with abstract methods)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def transcribe(self, pcm_data: bytes) -> str:
|
|
||||||
"""Transcribe audio to text"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 3. Create ID Mixin for second-level subclasses
|
|
||||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
|
||||||
|
|
||||||
# 4. Create second-level abstract class (if needed)
|
|
||||||
class FunASR(
|
|
||||||
ASRSubclassIdMixin,
|
|
||||||
ASR,
|
|
||||||
AutoPolymorphicIdentityMixin,
|
|
||||||
polymorphic_abstract=True
|
|
||||||
):
|
|
||||||
"""FunASR abstract base (may have multiple implementations)"""
|
|
||||||
pass
|
|
||||||
# polymorphic_identity = 'funasr'
|
|
||||||
|
|
||||||
# 5. Create concrete implementation classes
|
|
||||||
class FunASRLocal(FunASR, table=True):
|
|
||||||
"""FunASR local deployment"""
|
|
||||||
# polymorphic_identity = 'funasr.funasrlocal'
|
|
||||||
|
|
||||||
async def transcribe(self, pcm_data: bytes) -> str:
|
|
||||||
# Implementation...
|
|
||||||
return "transcribed text"
|
|
||||||
|
|
||||||
# 6. Get all concrete subclasses (for selectin_polymorphic)
|
|
||||||
concrete_asrs = ASR.get_concrete_subclasses()
|
|
||||||
# Returns: [FunASRLocal, ...]
|
|
||||||
```
|
|
||||||
|
|
||||||
## Import Guidelines
|
|
||||||
|
|
||||||
**Standard Import:**
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import (
|
|
||||||
TableBaseMixin,
|
|
||||||
UUIDTableBaseMixin,
|
|
||||||
PolymorphicBaseMixin,
|
|
||||||
TableViewRequest,
|
|
||||||
create_subclass_id_mixin,
|
|
||||||
AutoPolymorphicIdentityMixin,
|
|
||||||
JWTAuthMixin,
|
|
||||||
UUIDIdDatetimeInfoMixin,
|
|
||||||
now,
|
|
||||||
now_date,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Backward Compatibility:**
|
|
||||||
Some exports are also available from `sqlmodels.base` for backward compatibility:
|
|
||||||
```python
|
|
||||||
# Legacy import path (still works)
|
|
||||||
from sqlmodels.base import UUIDTableBase, TableViewRequest
|
|
||||||
|
|
||||||
# Recommended new import path
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
|
|
||||||
```
|
|
||||||
|
|
||||||
## Best Practices
|
|
||||||
|
|
||||||
### 1. Mixin Order Matters
|
|
||||||
|
|
||||||
**Correct Order:**
|
|
||||||
```python
|
|
||||||
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
|
|
||||||
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**Wrong Order:**
|
|
||||||
```python
|
|
||||||
# ❌ ID Mixin not first - won't override parent's id field
|
|
||||||
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Always Use Return Values from save() and update()
|
|
||||||
|
|
||||||
```python
|
|
||||||
# ✅ Correct - use returned instance
|
|
||||||
device = await device.save(session)
|
|
||||||
return device
|
|
||||||
|
|
||||||
# ❌ Wrong - device is expired after commit
|
|
||||||
await device.save(session)
|
|
||||||
return device # AttributeError when accessing fields
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Prefer table_view Over Manual Pagination
|
|
||||||
|
|
||||||
```python
|
|
||||||
# ✅ Recommended - consistent across all endpoints
|
|
||||||
characters = await Character.get(
|
|
||||||
session,
|
|
||||||
fetch_mode="all",
|
|
||||||
table_view=table_view
|
|
||||||
)
|
|
||||||
|
|
||||||
# ⚠️ Works but not recommended - manual parameter management
|
|
||||||
characters = await Character.get(
|
|
||||||
session,
|
|
||||||
fetch_mode="all",
|
|
||||||
offset=0,
|
|
||||||
limit=20,
|
|
||||||
order_by=[desc(Character.created_at)]
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Polymorphic Loading for Many Subclasses
|
|
||||||
|
|
||||||
```python
|
|
||||||
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
|
|
||||||
tool_set = await ToolSet.get(
|
|
||||||
session,
|
|
||||||
ToolSet.id == tool_set_id,
|
|
||||||
load=ToolSet.tools,
|
|
||||||
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
|
|
||||||
)
|
|
||||||
|
|
||||||
# For fewer subclasses, specify the list explicitly
|
|
||||||
tool_set = await ToolSet.get(
|
|
||||||
session,
|
|
||||||
ToolSet.id == tool_set_id,
|
|
||||||
load=ToolSet.tools,
|
|
||||||
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. Response DTOs Should Inherit Base Classes
|
|
||||||
|
|
||||||
```python
|
|
||||||
# ✅ Correct - inherits from CharacterBase
|
|
||||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ❌ Wrong - doesn't inherit from CharacterBase
|
|
||||||
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
|
|
||||||
name: str # Duplicated field definition
|
|
||||||
description: str | None = None
|
|
||||||
```
|
|
||||||
|
|
||||||
**Reason:** Inheriting from Base classes ensures:
|
|
||||||
- Type checking via `isinstance(obj, XxxBase)`
|
|
||||||
- Consistency across related DTOs
|
|
||||||
- Future field additions automatically propagate
|
|
||||||
|
|
||||||
### 6. Use Specific Types, Not Containers
|
|
||||||
|
|
||||||
```python
|
|
||||||
# ✅ Correct - specific DTO for config updates
|
|
||||||
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
|
|
||||||
weather_api_key: str | None = None
|
|
||||||
default_location: str | None = None
|
|
||||||
"""Default location"""
|
|
||||||
|
|
||||||
# ❌ Wrong - lose type safety
|
|
||||||
class ToolUpdateRequest(SQLModelBase):
|
|
||||||
config: dict[str, Any] # No field validation
|
|
||||||
```
|
|
||||||
|
|
||||||
## Type Variables
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import T, M
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
|
|
||||||
M = TypeVar("M", bound="SQLModel") # For update() method
|
|
||||||
```
|
|
||||||
|
|
||||||
## Utility Functions
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlmodels.mixin import now, now_date
|
|
||||||
|
|
||||||
# Lambda functions for default factories
|
|
||||||
now = lambda: datetime.now()
|
|
||||||
now_date = lambda: datetime.now().date()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Related Modules
|
|
||||||
|
|
||||||
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
|
|
||||||
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
|
|
||||||
- **sqlmodels.user** - User model with JWT authentication
|
|
||||||
- **sqlmodels.client** - Client model with JWT authentication
|
|
||||||
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
|
|
||||||
|
|
||||||
## Additional Resources
|
|
||||||
|
|
||||||
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
|
|
||||||
- `CLAUDE.md` - Project coding standards and design philosophy
|
|
||||||
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
"""
|
|
||||||
SQLModel Mixin模块
|
|
||||||
|
|
||||||
提供各种Mixin类供SQLModel实体使用。
|
|
||||||
|
|
||||||
包含:
|
|
||||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
|
||||||
- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError)
|
|
||||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
|
||||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
|
||||||
- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations)
|
|
||||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
|
||||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
|
||||||
|
|
||||||
导入顺序很重要,避免循环导入:
|
|
||||||
1. polymorphic(只依赖 SQLModelBase)
|
|
||||||
2. optimistic_lock(只依赖 SQLAlchemy)
|
|
||||||
3. table(依赖 polymorphic 和 optimistic_lock)
|
|
||||||
4. relation_preload(只依赖 SQLModelBase)
|
|
||||||
|
|
||||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
|
||||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
|
||||||
"""
|
|
||||||
# polymorphic 必须先导入
|
|
||||||
from .polymorphic import (
|
|
||||||
AutoPolymorphicIdentityMixin,
|
|
||||||
PolymorphicBaseMixin,
|
|
||||||
create_subclass_id_mixin,
|
|
||||||
register_sti_column_properties_for_all_subclasses,
|
|
||||||
register_sti_columns_for_all_subclasses,
|
|
||||||
)
|
|
||||||
# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前
|
|
||||||
from .optimistic_lock import (
|
|
||||||
OptimisticLockError,
|
|
||||||
OptimisticLockMixin,
|
|
||||||
)
|
|
||||||
# table 依赖 polymorphic 和 optimistic_lock
|
|
||||||
from .table import (
|
|
||||||
ListResponse,
|
|
||||||
PaginationRequest,
|
|
||||||
T,
|
|
||||||
TableBaseMixin,
|
|
||||||
TableViewRequest,
|
|
||||||
TimeFilterRequest,
|
|
||||||
UUIDTableBaseMixin,
|
|
||||||
now,
|
|
||||||
now_date,
|
|
||||||
)
|
|
||||||
# relation_preload 只依赖 SQLModelBase
|
|
||||||
from .relation_preload import (
|
|
||||||
RelationPreloadMixin,
|
|
||||||
requires_relations,
|
|
||||||
)
|
|
||||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
|
||||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
|
||||||
from .info_response import (
|
|
||||||
DatetimeInfoMixin,
|
|
||||||
IntIdDatetimeInfoMixin,
|
|
||||||
IntIdInfoMixin,
|
|
||||||
UUIDIdDatetimeInfoMixin,
|
|
||||||
UUIDIdInfoMixin,
|
|
||||||
)
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""
|
|
||||||
InfoResponse DTO Mixin模块
|
|
||||||
|
|
||||||
提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。
|
|
||||||
|
|
||||||
设计说明:
|
|
||||||
- 这些Mixin用于**响应DTO**,不是数据库表
|
|
||||||
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
|
|
||||||
- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成)
|
|
||||||
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
|
|
||||||
"""
|
|
||||||
from datetime import datetime
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlmodels.base import SQLModelBase
|
|
||||||
|
|
||||||
|
|
||||||
class IntIdInfoMixin(SQLModelBase):
|
|
||||||
"""整数ID响应mixin - 用于InfoResponse DTO"""
|
|
||||||
id: int
|
|
||||||
"""记录ID"""
|
|
||||||
|
|
||||||
|
|
||||||
class UUIDIdInfoMixin(SQLModelBase):
|
|
||||||
"""UUID ID响应mixin - 用于InfoResponse DTO"""
|
|
||||||
id: UUID
|
|
||||||
"""记录ID"""
|
|
||||||
|
|
||||||
|
|
||||||
class DatetimeInfoMixin(SQLModelBase):
|
|
||||||
"""时间戳响应mixin - 用于InfoResponse DTO"""
|
|
||||||
created_at: datetime
|
|
||||||
"""创建时间"""
|
|
||||||
|
|
||||||
updated_at: datetime
|
|
||||||
"""更新时间"""
|
|
||||||
|
|
||||||
|
|
||||||
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
|
|
||||||
"""整数ID + 时间戳响应mixin"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
|
|
||||||
"""UUID ID + 时间戳响应mixin"""
|
|
||||||
pass
|
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
"""
|
|
||||||
乐观锁 Mixin
|
|
||||||
|
|
||||||
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
|
|
||||||
|
|
||||||
乐观锁适用场景:
|
|
||||||
- 涉及"状态转换"的表(如:待支付 -> 已支付)
|
|
||||||
- 涉及"数值变动"的表(如:余额、库存)
|
|
||||||
|
|
||||||
不适用场景:
|
|
||||||
- 日志表、纯插入表、低价值统计表
|
|
||||||
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
|
|
||||||
|
|
||||||
使用示例:
|
|
||||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
|
||||||
status: OrderStatusEnum
|
|
||||||
amount: Decimal
|
|
||||||
|
|
||||||
# save/update 时自动检查版本号
|
|
||||||
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
|
|
||||||
try:
|
|
||||||
order = await order.save(session)
|
|
||||||
except OptimisticLockError as e:
|
|
||||||
# 处理冲突:重新查询并重试,或报错给用户
|
|
||||||
l.warning(f"乐观锁冲突: {e}")
|
|
||||||
"""
|
|
||||||
from typing import ClassVar
|
|
||||||
|
|
||||||
from sqlalchemy.orm.exc import StaleDataError
|
|
||||||
|
|
||||||
|
|
||||||
class OptimisticLockError(Exception):
|
|
||||||
"""
|
|
||||||
乐观锁冲突异常
|
|
||||||
|
|
||||||
当 save/update 操作检测到版本号不匹配时抛出。
|
|
||||||
这意味着在读取和写入之间,其他事务已经修改了该记录。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model_class: 发生冲突的模型类名
|
|
||||||
record_id: 记录 ID(如果可用)
|
|
||||||
expected_version: 期望的版本号(如果可用)
|
|
||||||
original_error: 原始的 StaleDataError
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message: str,
|
|
||||||
model_class: str | None = None,
|
|
||||||
record_id: str | None = None,
|
|
||||||
expected_version: int | None = None,
|
|
||||||
original_error: StaleDataError | None = None,
|
|
||||||
):
|
|
||||||
super().__init__(message)
|
|
||||||
self.model_class = model_class
|
|
||||||
self.record_id = record_id
|
|
||||||
self.expected_version = expected_version
|
|
||||||
self.original_error = original_error
|
|
||||||
|
|
||||||
|
|
||||||
class OptimisticLockMixin:
|
|
||||||
"""
|
|
||||||
乐观锁 Mixin
|
|
||||||
|
|
||||||
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
|
|
||||||
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
|
|
||||||
session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。
|
|
||||||
|
|
||||||
原理:
|
|
||||||
1. 每条记录有一个 version 字段,初始值为 0
|
|
||||||
2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似:
|
|
||||||
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
|
|
||||||
3. 如果 WHERE 条件不匹配(version 已被其他事务修改),
|
|
||||||
UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError
|
|
||||||
|
|
||||||
继承顺序:
|
|
||||||
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
|
|
||||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
|
||||||
...
|
|
||||||
|
|
||||||
配套重试:
|
|
||||||
如果加了乐观锁,业务层需要处理 OptimisticLockError:
|
|
||||||
- 报错给用户:"数据已被修改,请刷新后重试"
|
|
||||||
- 自动重试:重新查询最新数据并再次尝试
|
|
||||||
"""
|
|
||||||
_has_optimistic_lock: ClassVar[bool] = True
|
|
||||||
"""标记此类启用了乐观锁"""
|
|
||||||
|
|
||||||
version: int = 0
|
|
||||||
"""乐观锁版本号,每次更新自动递增"""
|
|
||||||
@@ -1,710 +0,0 @@
|
|||||||
"""
|
|
||||||
联表继承(Joined Table Inheritance)的通用工具
|
|
||||||
|
|
||||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
|
||||||
|
|
||||||
Usage Example:
|
|
||||||
|
|
||||||
from sqlmodels.base import SQLModelBase
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
from sqlmodels.mixin.polymorphic import (
|
|
||||||
PolymorphicBaseMixin,
|
|
||||||
create_subclass_id_mixin,
|
|
||||||
AutoPolymorphicIdentityMixin
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 定义Base类(只有字段,无表)
|
|
||||||
class ASRBase(SQLModelBase):
|
|
||||||
name: str
|
|
||||||
\"\"\"配置名称\"\"\"
|
|
||||||
|
|
||||||
base_url: str
|
|
||||||
\"\"\"服务地址\"\"\"
|
|
||||||
|
|
||||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
|
||||||
class ASR(
|
|
||||||
ASRBase,
|
|
||||||
UUIDTableBaseMixin,
|
|
||||||
PolymorphicBaseMixin,
|
|
||||||
ABC
|
|
||||||
):
|
|
||||||
\"\"\"ASR配置的抽象基类\"\"\"
|
|
||||||
# PolymorphicBaseMixin 自动提供:
|
|
||||||
# - _polymorphic_name 字段
|
|
||||||
# - polymorphic_on='_polymorphic_name'
|
|
||||||
# - polymorphic_abstract=True(当有抽象方法时)
|
|
||||||
|
|
||||||
# 3. 为第二层子类创建ID Mixin
|
|
||||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
|
||||||
|
|
||||||
# 4. 创建第二层抽象类(如果需要)
|
|
||||||
class FunASR(
|
|
||||||
ASRSubclassIdMixin,
|
|
||||||
ASR,
|
|
||||||
AutoPolymorphicIdentityMixin,
|
|
||||||
polymorphic_abstract=True
|
|
||||||
):
|
|
||||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 5. 创建具体实现类
|
|
||||||
class FunASRLocal(FunASR, table=True):
|
|
||||||
\"\"\"FunASR本地部署版本\"\"\"
|
|
||||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
|
||||||
concrete_asrs = ASR.get_concrete_subclasses()
|
|
||||||
# 返回 [FunASRLocal, ...]
|
|
||||||
"""
|
|
||||||
import uuid
|
|
||||||
from abc import ABC
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from loguru import logger as l
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from pydantic_core import PydanticUndefined
|
|
||||||
from sqlalchemy import Column, String, inspect
|
|
||||||
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
|
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
||||||
from sqlmodel import Field
|
|
||||||
from sqlmodel.main import get_column_from_field
|
|
||||||
|
|
||||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
|
||||||
|
|
||||||
# 用于延迟注册 STI 子类列的队列
|
|
||||||
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
|
|
||||||
_sti_subclasses_to_register: list[type] = []
|
|
||||||
|
|
||||||
|
|
||||||
def register_sti_columns_for_all_subclasses() -> None:
|
|
||||||
"""
|
|
||||||
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
|
|
||||||
|
|
||||||
此函数应在 configure_mappers() 之前调用。
|
|
||||||
将 STI 子类的字段添加到父表的 metadata 中。
|
|
||||||
同时修复被 Column 对象污染的 model_fields。
|
|
||||||
"""
|
|
||||||
for cls in _sti_subclasses_to_register:
|
|
||||||
try:
|
|
||||||
cls._register_sti_columns()
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
|
|
||||||
|
|
||||||
# 修复被 Column 对象污染的 model_fields
|
|
||||||
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
|
|
||||||
try:
|
|
||||||
_fix_polluted_model_fields(cls)
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def register_sti_column_properties_for_all_subclasses() -> None:
|
|
||||||
"""
|
|
||||||
为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
|
|
||||||
|
|
||||||
此函数应在 configure_mappers() 之后调用。
|
|
||||||
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
|
|
||||||
"""
|
|
||||||
for cls in _sti_subclasses_to_register:
|
|
||||||
try:
|
|
||||||
cls._register_sti_column_properties()
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
|
|
||||||
|
|
||||||
# 清空队列
|
|
||||||
_sti_subclasses_to_register.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _fix_polluted_model_fields(cls: type) -> None:
|
|
||||||
"""
|
|
||||||
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
|
|
||||||
|
|
||||||
当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
|
|
||||||
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
|
|
||||||
时错误地使用这些 SQLAlchemy 对象作为默认值。
|
|
||||||
|
|
||||||
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
|
|
||||||
|
|
||||||
:param cls: 要修复的类
|
|
||||||
"""
|
|
||||||
if not hasattr(cls, 'model_fields'):
|
|
||||||
return
|
|
||||||
|
|
||||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
|
||||||
"""从 MRO 中查找字段的原始定义(未被污染的)"""
|
|
||||||
for base in cls.__mro__[1:]: # 跳过自己
|
|
||||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
|
||||||
field_info = base.model_fields[field_name]
|
|
||||||
# 跳过被 InstrumentedAttribute 或 Column 污染的
|
|
||||||
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
|
|
||||||
return field_info
|
|
||||||
return None
|
|
||||||
|
|
||||||
for field_name, current_field in cls.model_fields.items():
|
|
||||||
# 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
|
|
||||||
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
|
|
||||||
# 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
|
|
||||||
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
|
|
||||||
continue # 未被污染,跳过
|
|
||||||
|
|
||||||
# 从父类查找原始定义
|
|
||||||
original = find_original_field_info(field_name)
|
|
||||||
if original is None:
|
|
||||||
continue # 找不到原始定义,跳过
|
|
||||||
|
|
||||||
# 根据原始定义的 default/default_factory 来修复
|
|
||||||
if original.default_factory:
|
|
||||||
# 有 default_factory(如 uuid.uuid4, now)
|
|
||||||
new_field = FieldInfo(
|
|
||||||
default_factory=original.default_factory,
|
|
||||||
annotation=current_field.annotation,
|
|
||||||
json_schema_extra=current_field.json_schema_extra,
|
|
||||||
)
|
|
||||||
elif original.default is not PydanticUndefined:
|
|
||||||
# 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
|
|
||||||
# PydanticUndefined 表示字段没有默认值(必填)
|
|
||||||
new_field = FieldInfo(
|
|
||||||
default=original.default,
|
|
||||||
annotation=current_field.annotation,
|
|
||||||
json_schema_extra=current_field.json_schema_extra,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
|
||||||
|
|
||||||
# 复制 SQLModel 特有的属性
|
|
||||||
if hasattr(current_field, 'foreign_key'):
|
|
||||||
new_field.foreign_key = current_field.foreign_key
|
|
||||||
if hasattr(current_field, 'primary_key'):
|
|
||||||
new_field.primary_key = current_field.primary_key
|
|
||||||
|
|
||||||
cls.model_fields[field_name] = new_field
|
|
||||||
|
|
||||||
|
|
||||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
|
||||||
"""
|
|
||||||
动态创建SubclassIdMixin类
|
|
||||||
|
|
||||||
在联表继承中,子类需要一个外键指向父表的主键。
|
|
||||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
|
||||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
|
||||||
... pass
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
|
||||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
|
||||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
|
||||||
"""
|
|
||||||
if not parent_table_name:
|
|
||||||
raise ValueError("parent_table_name 不能为空")
|
|
||||||
|
|
||||||
# 转换为PascalCase作为类名
|
|
||||||
class_name_parts = parent_table_name.split('_')
|
|
||||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
|
||||||
|
|
||||||
# 使用闭包捕获parent_table_name
|
|
||||||
_parent_table_name = parent_table_name
|
|
||||||
|
|
||||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
|
||||||
class SubclassIdMixin(SQLModelBase):
|
|
||||||
# 定义id字段
|
|
||||||
id: UUID = Field(
|
|
||||||
default_factory=uuid.uuid4,
|
|
||||||
foreign_key=f'{_parent_table_name}.id',
|
|
||||||
primary_key=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __pydantic_init_subclass__(cls, **kwargs):
|
|
||||||
"""
|
|
||||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
|
||||||
|
|
||||||
修复联表继承中子类字段的 default_factory 丢失问题。
|
|
||||||
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
|
|
||||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
|
||||||
"""
|
|
||||||
super().__pydantic_init_subclass__(**kwargs)
|
|
||||||
_fix_polluted_model_fields(cls)
|
|
||||||
|
|
||||||
# 设置类名和文档
|
|
||||||
SubclassIdMixin.__name__ = class_name
|
|
||||||
SubclassIdMixin.__qualname__ = class_name
|
|
||||||
SubclassIdMixin.__doc__ = f"""
|
|
||||||
{parent_table_name}子类的ID Mixin
|
|
||||||
|
|
||||||
用于{parent_table_name}的子类,提供外键指向父表。
|
|
||||||
通过MRO确保此id字段覆盖继承的id字段。
|
|
||||||
"""
|
|
||||||
|
|
||||||
return SubclassIdMixin
|
|
||||||
|
|
||||||
|
|
||||||
class AutoPolymorphicIdentityMixin:
|
|
||||||
"""
|
|
||||||
自动生成polymorphic_identity的Mixin,并支持STI子类列注册
|
|
||||||
|
|
||||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
|
||||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
|
||||||
|
|
||||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
|
||||||
|
|
||||||
**重要:数据库迁移注意事项**
|
|
||||||
|
|
||||||
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
|
|
||||||
|
|
||||||
例如,对于以下继承链::
|
|
||||||
|
|
||||||
LLM (polymorphic_on='_polymorphic_name')
|
|
||||||
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
|
|
||||||
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
|
|
||||||
|
|
||||||
迁移脚本中设置 _polymorphic_name 时::
|
|
||||||
|
|
||||||
# ❌ 错误:缺少父类前缀
|
|
||||||
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
|
|
||||||
|
|
||||||
# ✅ 正确:包含完整的继承链前缀
|
|
||||||
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
|
|
||||||
|
|
||||||
**STI(单表继承)支持**:
|
|
||||||
当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
|
|
||||||
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
|
|
||||||
注册到父表的问题。
|
|
||||||
|
|
||||||
Example (JTI):
|
|
||||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
|
||||||
... __polymorphic_name: str
|
|
||||||
...
|
|
||||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
|
||||||
... pass
|
|
||||||
... # polymorphic_identity 会自动设置为 'function'
|
|
||||||
...
|
|
||||||
>>> class CodeInterpreterFunction(Function, table=True):
|
|
||||||
... pass
|
|
||||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
|
||||||
|
|
||||||
Example (STI):
|
|
||||||
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
|
|
||||||
... user_id: UUID
|
|
||||||
...
|
|
||||||
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
|
|
||||||
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
|
|
||||||
... # polymorphic_identity 会自动设置为 'pendingfile'
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
|
||||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
|
||||||
- STI模式下,新字段会在类定义时自动添加到父表的metadata中
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
|
||||||
"""
|
|
||||||
子类化钩子,自动生成polymorphic_identity并处理STI列注册
|
|
||||||
|
|
||||||
Args:
|
|
||||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
|
||||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
|
||||||
"""
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
|
|
||||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
|
||||||
if polymorphic_identity is not None:
|
|
||||||
identity = polymorphic_identity
|
|
||||||
else:
|
|
||||||
# 自动生成polymorphic_identity
|
|
||||||
class_name = cls.__name__.lower()
|
|
||||||
|
|
||||||
# 尝试从父类获取polymorphic_identity作为前缀
|
|
||||||
parent_identity = None
|
|
||||||
for base in cls.__mro__[1:]: # 跳过自己
|
|
||||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
|
||||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
|
||||||
if parent_identity:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 构建identity
|
|
||||||
if parent_identity:
|
|
||||||
identity = f'{parent_identity}.{class_name}'
|
|
||||||
else:
|
|
||||||
identity = class_name
|
|
||||||
|
|
||||||
# 设置到__mapper_args__
|
|
||||||
if '__mapper_args__' not in cls.__dict__:
|
|
||||||
cls.__mapper_args__ = {}
|
|
||||||
|
|
||||||
# 只在尚未设置polymorphic_identity时设置
|
|
||||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
|
||||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
|
||||||
|
|
||||||
# 注册 STI 子类列的延迟执行
|
|
||||||
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
|
|
||||||
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
|
|
||||||
_sti_subclasses_to_register.append(cls)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __pydantic_init_subclass__(cls, **kwargs):
|
|
||||||
"""
|
|
||||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
|
||||||
|
|
||||||
修复 STI 继承中子类字段被 Column 对象污染的问题。
|
|
||||||
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
|
|
||||||
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
|
|
||||||
"""
|
|
||||||
super().__pydantic_init_subclass__(**kwargs)
|
|
||||||
_fix_polluted_model_fields(cls)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _register_sti_columns(cls) -> None:
|
|
||||||
"""
|
|
||||||
将STI子类的新字段注册到父表的列定义中
|
|
||||||
|
|
||||||
检测当前类是否是STI子类(与父类共用同一张表),
|
|
||||||
如果是,则将子类定义的新字段添加到父表的metadata中。
|
|
||||||
|
|
||||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
|
||||||
"""
|
|
||||||
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
|
|
||||||
parent_table = None
|
|
||||||
parent_fields: set[str] = set()
|
|
||||||
|
|
||||||
for base in cls.__mro__[1:]:
|
|
||||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
|
||||||
parent_table = base.__table__
|
|
||||||
# 收集父类的所有字段名
|
|
||||||
if hasattr(base, 'model_fields'):
|
|
||||||
parent_fields.update(base.model_fields.keys())
|
|
||||||
break
|
|
||||||
|
|
||||||
if parent_table is None:
|
|
||||||
return # 没有找到父表,可能是根类
|
|
||||||
|
|
||||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
|
||||||
# JTI 类有自己独立的表,不需要将列注册到父表
|
|
||||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
|
||||||
if cls.__table__.name != parent_table.name:
|
|
||||||
return # JTI,跳过 STI 列注册
|
|
||||||
|
|
||||||
# 获取当前类的新字段(不在父类中的字段)
|
|
||||||
if not hasattr(cls, 'model_fields'):
|
|
||||||
return
|
|
||||||
|
|
||||||
existing_columns = {col.name for col in parent_table.columns}
|
|
||||||
|
|
||||||
for field_name, field_info in cls.model_fields.items():
|
|
||||||
# 跳过从父类继承的字段
|
|
||||||
if field_name in parent_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 跳过私有字段和ClassVar
|
|
||||||
if field_name.startswith('_'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 跳过已存在的列
|
|
||||||
if field_name in existing_columns:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用 SQLModel 的内置 API 创建列
|
|
||||||
try:
|
|
||||||
column = get_column_from_field(field_info)
|
|
||||||
column.name = field_name
|
|
||||||
column.key = field_name
|
|
||||||
# STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
|
|
||||||
# Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
|
|
||||||
column.nullable = True
|
|
||||||
|
|
||||||
# 将列添加到父表
|
|
||||||
parent_table.append_column(column)
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _register_sti_column_properties(cls) -> None:
|
|
||||||
"""
|
|
||||||
将 STI 子类的列作为 ColumnProperty 添加到 mapper
|
|
||||||
|
|
||||||
此方法在 configure_mappers() 之后调用,将已添加到表中的列
|
|
||||||
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
|
|
||||||
|
|
||||||
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
|
|
||||||
这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
|
|
||||||
避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
|
|
||||||
|
|
||||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
|
||||||
"""
|
|
||||||
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
|
|
||||||
parent_table = None
|
|
||||||
parent_class = None
|
|
||||||
for base in cls.__mro__[1:]:
|
|
||||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
|
||||||
parent_table = base.__table__
|
|
||||||
parent_class = base
|
|
||||||
break
|
|
||||||
|
|
||||||
if parent_table is None:
|
|
||||||
return # 没有找到父表,可能是根类
|
|
||||||
|
|
||||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
|
||||||
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
|
|
||||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
|
||||||
if cls.__table__.name != parent_table.name:
|
|
||||||
return # JTI,跳过 STI 列属性注册
|
|
||||||
|
|
||||||
# 获取子类和父类的 mapper
|
|
||||||
child_mapper = inspect(cls).mapper
|
|
||||||
parent_mapper = inspect(parent_class).mapper
|
|
||||||
local_table = child_mapper.local_table
|
|
||||||
|
|
||||||
# 查找父类的所有字段名
|
|
||||||
parent_fields: set[str] = set()
|
|
||||||
if hasattr(parent_class, 'model_fields'):
|
|
||||||
parent_fields.update(parent_class.model_fields.keys())
|
|
||||||
|
|
||||||
if not hasattr(cls, 'model_fields'):
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取两个 mapper 已有的列属性
|
|
||||||
child_existing_props = {p.key for p in child_mapper.column_attrs}
|
|
||||||
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
|
|
||||||
|
|
||||||
for field_name in cls.model_fields:
|
|
||||||
# 跳过从父类继承的字段
|
|
||||||
if field_name in parent_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 跳过私有字段
|
|
||||||
if field_name.startswith('_'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查表中是否有这个列
|
|
||||||
if field_name not in local_table.columns:
|
|
||||||
continue
|
|
||||||
|
|
||||||
column = local_table.columns[field_name]
|
|
||||||
|
|
||||||
# 添加到子类的 mapper(如果尚不存在)
|
|
||||||
if field_name not in child_existing_props:
|
|
||||||
try:
|
|
||||||
prop = ColumnProperty(column)
|
|
||||||
child_mapper.add_property(field_name, prop)
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
|
|
||||||
|
|
||||||
# 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
|
|
||||||
if field_name not in parent_existing_props:
|
|
||||||
try:
|
|
||||||
prop = ColumnProperty(column)
|
|
||||||
parent_mapper.add_property(field_name, prop)
|
|
||||||
except Exception as e:
|
|
||||||
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class PolymorphicBaseMixin:
|
|
||||||
"""
|
|
||||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
|
||||||
|
|
||||||
此 Mixin 自动设置以下内容:
|
|
||||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
|
||||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
|
||||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
|
||||||
|
|
||||||
使用场景:
|
|
||||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
|
||||||
|
|
||||||
用法示例:
|
|
||||||
```python
|
|
||||||
from abc import ABC
|
|
||||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
|
||||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
|
||||||
|
|
||||||
# 定义基类
|
|
||||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
|
||||||
__tablename__ = 'mytool'
|
|
||||||
|
|
||||||
# 不需要手动定义 _polymorphic_name
|
|
||||||
# 不需要手动设置 polymorphic_on
|
|
||||||
# 不需要手动设置 polymorphic_abstract
|
|
||||||
|
|
||||||
# 定义子类
|
|
||||||
class SpecificTool(MyTool):
|
|
||||||
__tablename__ = 'specifictool'
|
|
||||||
|
|
||||||
# 会自动继承 polymorphic 配置
|
|
||||||
```
|
|
||||||
|
|
||||||
自动行为:
|
|
||||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
|
||||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
|
||||||
3. 自动检测抽象类:
|
|
||||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
|
||||||
- 否则设置为 False
|
|
||||||
|
|
||||||
手动覆盖:
|
|
||||||
可以在类定义时手动指定参数来覆盖自动行为:
|
|
||||||
```python
|
|
||||||
class MyTool(
|
|
||||||
UUIDTableBaseMixin,
|
|
||||||
PolymorphicBaseMixin,
|
|
||||||
ABC,
|
|
||||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
|
||||||
polymorphic_abstract=False # 强制不设为抽象类
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
注意事项:
|
|
||||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
|
||||||
- 适用于联表继承(joined table inheritance)场景
|
|
||||||
- 子类会自动继承 _polymorphic_name 字段定义
|
|
||||||
- 使用单下划线前缀是因为:
|
|
||||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
|
||||||
* Pydantic 将其视为私有属性,不参与序列化
|
|
||||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
|
||||||
#
|
|
||||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
|
||||||
#
|
|
||||||
# 为什么这样做:
|
|
||||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
|
||||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
|
||||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
|
||||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
|
||||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
|
||||||
#
|
|
||||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
|
||||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
|
||||||
"""
|
|
||||||
多态鉴别器字段,用于标识具体的子类类型
|
|
||||||
|
|
||||||
注意:此字段使用单下划线前缀,表示内部使用。
|
|
||||||
- ✅ 存储到数据库
|
|
||||||
- ✅ 不出现在 API 序列化中
|
|
||||||
- ✅ 防止外部直接修改
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init_subclass__(
|
|
||||||
cls,
|
|
||||||
polymorphic_on: str | None = None,
|
|
||||||
polymorphic_abstract: bool | None = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
在子类定义时自动配置 polymorphic 设置
|
|
||||||
|
|
||||||
Args:
|
|
||||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
|
||||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
|
||||||
polymorphic_abstract: 是否为抽象类。
|
|
||||||
- None: 自动检测(默认)
|
|
||||||
- True: 强制设为抽象类
|
|
||||||
- False: 强制设为非抽象类
|
|
||||||
**kwargs: 传递给父类的其他参数
|
|
||||||
"""
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
|
|
||||||
# 初始化 __mapper_args__(如果还没有)
|
|
||||||
if '__mapper_args__' not in cls.__dict__:
|
|
||||||
cls.__mapper_args__ = {}
|
|
||||||
|
|
||||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
|
||||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
|
||||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
|
||||||
|
|
||||||
# 自动检测或设置 polymorphic_abstract
|
|
||||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
|
||||||
if polymorphic_abstract is None:
|
|
||||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
|
||||||
has_abc = ABC in cls.__mro__
|
|
||||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
|
||||||
polymorphic_abstract = has_abc and has_abstract_methods
|
|
||||||
|
|
||||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _is_joined_table_inheritance(cls) -> bool:
|
|
||||||
"""
|
|
||||||
检测当前类是否使用联表继承(Joined Table Inheritance)
|
|
||||||
|
|
||||||
通过检查子类是否有独立的表来判断:
|
|
||||||
- JTI: 子类有独立的 local_table(与父类不同)
|
|
||||||
- STI: 子类与父类共用同一个 local_table
|
|
||||||
|
|
||||||
:return: True 表示 JTI,False 表示 STI 或无子类
|
|
||||||
"""
|
|
||||||
mapper = inspect(cls)
|
|
||||||
base_table_name = mapper.local_table.name
|
|
||||||
|
|
||||||
# 检查所有直接子类
|
|
||||||
for subclass in cls.__subclasses__():
|
|
||||||
sub_mapper = inspect(subclass)
|
|
||||||
# 如果任何子类有不同的表名,说明是 JTI
|
|
||||||
if sub_mapper.local_table.name != base_table_name:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
|
||||||
"""
|
|
||||||
递归获取当前类的所有具体(非抽象)子类
|
|
||||||
|
|
||||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
|
||||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
|
||||||
|
|
||||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
|
||||||
"""
|
|
||||||
result: list[type[PolymorphicBaseMixin]] = []
|
|
||||||
for subclass in cls.__subclasses__():
|
|
||||||
# 使用 inspect() 获取 mapper 的公开属性
|
|
||||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
|
||||||
mapper = inspect(subclass)
|
|
||||||
if not mapper.polymorphic_abstract:
|
|
||||||
result.append(subclass)
|
|
||||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
|
||||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
|
||||||
result.extend(subclass.get_concrete_subclasses())
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_polymorphic_discriminator(cls) -> str:
|
|
||||||
"""
|
|
||||||
获取多态鉴别字段名
|
|
||||||
|
|
||||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
|
||||||
|
|
||||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
|
||||||
:raises ValueError: 如果类未配置 polymorphic_on
|
|
||||||
"""
|
|
||||||
polymorphic_on = inspect(cls).polymorphic_on
|
|
||||||
if polymorphic_on is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
|
||||||
f"请确保正确继承 PolymorphicBaseMixin"
|
|
||||||
)
|
|
||||||
return polymorphic_on.key
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
|
||||||
"""
|
|
||||||
获取 polymorphic_identity 到具体子类的映射
|
|
||||||
|
|
||||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
|
||||||
|
|
||||||
:return: identity 到子类的映射字典
|
|
||||||
"""
|
|
||||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
|
||||||
for subclass in cls.get_concrete_subclasses():
|
|
||||||
identity = inspect(subclass).polymorphic_identity
|
|
||||||
if identity:
|
|
||||||
result[identity] = subclass
|
|
||||||
return result
|
|
||||||
@@ -1,470 +0,0 @@
|
|||||||
"""
|
|
||||||
关系预加载 Mixin
|
|
||||||
|
|
||||||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
|
||||||
|
|
||||||
设计原则:
|
|
||||||
- 按需加载:只加载被调用方法需要的关系
|
|
||||||
- 增量加载:已加载的关系不重复加载
|
|
||||||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
|
||||||
- 零侵入:调用方无需任何改动
|
|
||||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
|
||||||
|
|
||||||
使用方式:
|
|
||||||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
|
||||||
|
|
||||||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
|
||||||
kling_video_generator: KlingO1Generator = Relationship(...)
|
|
||||||
|
|
||||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
|
||||||
async def cost(self, params, context, session) -> ToolCost:
|
|
||||||
# 自动加载,可以安全访问
|
|
||||||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
|
||||||
...
|
|
||||||
|
|
||||||
# 调用方 - 无需任何改动
|
|
||||||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
|
||||||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
|
||||||
|
|
||||||
支持 AsyncGenerator:
|
|
||||||
@requires_relations('twitter_api')
|
|
||||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
|
||||||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
|
||||||
"""
|
|
||||||
import inspect as python_inspect
|
|
||||||
from functools import wraps
|
|
||||||
from typing import Callable, TypeVar, ParamSpec, Any
|
|
||||||
|
|
||||||
from loguru import logger as l
|
|
||||||
from sqlalchemy import inspect as sa_inspect
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
from sqlmodel.main import RelationshipInfo
|
|
||||||
|
|
||||||
P = ParamSpec('P')
|
|
||||||
R = TypeVar('R')
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_session(
|
|
||||||
func: Callable,
|
|
||||||
args: tuple[Any, ...],
|
|
||||||
kwargs: dict[str, Any],
|
|
||||||
) -> AsyncSession | None:
|
|
||||||
"""
|
|
||||||
从方法参数中提取 AsyncSession
|
|
||||||
|
|
||||||
按以下顺序查找:
|
|
||||||
1. kwargs 中名为 'session' 的参数
|
|
||||||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
|
||||||
3. kwargs 中类型为 AsyncSession 的参数
|
|
||||||
"""
|
|
||||||
# 1. 优先从 kwargs 查找
|
|
||||||
if 'session' in kwargs:
|
|
||||||
return kwargs['session']
|
|
||||||
|
|
||||||
# 2. 从函数签名定位位置参数
|
|
||||||
try:
|
|
||||||
sig = python_inspect.signature(func)
|
|
||||||
param_names = list(sig.parameters.keys())
|
|
||||||
|
|
||||||
if 'session' in param_names:
|
|
||||||
# 计算位置(减去 self)
|
|
||||||
idx = param_names.index('session') - 1
|
|
||||||
if 0 <= idx < len(args):
|
|
||||||
return args[idx]
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
|
||||||
for value in kwargs.values():
|
|
||||||
if isinstance(value, AsyncSession):
|
|
||||||
return value
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
检查对象的关系是否已加载(独立函数版本)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: 要检查的对象
|
|
||||||
rel_name: 关系属性名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True 如果关系已加载,False 如果未加载或已过期
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
state = sa_inspect(obj)
|
|
||||||
return rel_name not in state.unloaded
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
|
||||||
"""
|
|
||||||
在类中查找指向目标类的关系属性名
|
|
||||||
|
|
||||||
Args:
|
|
||||||
from_class: 源类
|
|
||||||
to_class: 目标类
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
关系属性名,如果找不到则返回 None
|
|
||||||
|
|
||||||
Example:
|
|
||||||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
|
||||||
# 返回 'kling_video_generator'
|
|
||||||
"""
|
|
||||||
for attr_name in dir(from_class):
|
|
||||||
try:
|
|
||||||
attr = getattr(from_class, attr_name, None)
|
|
||||||
if attr is None:
|
|
||||||
continue
|
|
||||||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
|
||||||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
|
||||||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
|
||||||
target_class = attr.property.mapper.class_
|
|
||||||
if target_class == to_class:
|
|
||||||
return attr_name
|
|
||||||
except AttributeError:
|
|
||||||
continue
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
||||||
"""
|
|
||||||
装饰器:声明方法需要的关系,自动按需增量加载
|
|
||||||
|
|
||||||
参数格式:
|
|
||||||
- 字符串:本类属性名,如 'kling_video_generator'
|
|
||||||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
|
||||||
|
|
||||||
行为:
|
|
||||||
- 方法调用时自动检查关系是否已加载
|
|
||||||
- 未加载的关系会被增量加载(单次查询)
|
|
||||||
- 已加载的关系直接跳过
|
|
||||||
|
|
||||||
支持:
|
|
||||||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
|
||||||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
|
||||||
async def cost(self, params, context, session) -> ToolCost:
|
|
||||||
# self.kling_video_generator.kling_o1 已自动加载
|
|
||||||
...
|
|
||||||
|
|
||||||
@requires_relations('twitter_api')
|
|
||||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
|
||||||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
|
||||||
|
|
||||||
验证:
|
|
||||||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
|
||||||
- 拼写错误会在导入时抛出 AttributeError
|
|
||||||
"""
|
|
||||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
# 检测是否是 async generator 函数
|
|
||||||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
|
||||||
|
|
||||||
if is_async_gen:
|
|
||||||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
session = _extract_session(func, args, kwargs)
|
|
||||||
if session is not None:
|
|
||||||
await self._ensure_relations_loaded(session, relations)
|
|
||||||
# 委托给原始 async generator,逐个 yield 值
|
|
||||||
async for item in func(self, *args, **kwargs):
|
|
||||||
yield item # type: ignore
|
|
||||||
else:
|
|
||||||
# 普通 async 函数:await 并返回结果
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
session = _extract_session(func, args, kwargs)
|
|
||||||
if session is not None:
|
|
||||||
await self._ensure_relations_loaded(session, relations)
|
|
||||||
return await func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
# 保存关系声明供验证和内省使用
|
|
||||||
wrapper._required_relations = relations # type: ignore
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
class RelationPreloadMixin:
|
|
||||||
"""
|
|
||||||
关系预加载 Mixin
|
|
||||||
|
|
||||||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
|
||||||
|
|
||||||
特性:
|
|
||||||
- 按需加载:只加载被调用方法需要的关系
|
|
||||||
- 增量加载:已加载的关系不重复加载
|
|
||||||
- 原地更新:直接修改 self,无需替换实例
|
|
||||||
- 导入时验证:字符串关系名在类创建时验证
|
|
||||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs) -> None:
|
|
||||||
"""类创建时验证所有 @requires_relations 声明"""
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
|
|
||||||
# 收集类及其父类的所有注解(包含普通字段)
|
|
||||||
all_annotations: set[str] = set()
|
|
||||||
for klass in cls.__mro__:
|
|
||||||
if hasattr(klass, '__annotations__'):
|
|
||||||
all_annotations.update(klass.__annotations__.keys())
|
|
||||||
|
|
||||||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
|
||||||
sqlmodel_relationships: set[str] = set()
|
|
||||||
for klass in cls.__mro__:
|
|
||||||
if hasattr(klass, '__sqlmodel_relationships__'):
|
|
||||||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
|
||||||
|
|
||||||
# 合并所有可用的属性名
|
|
||||||
all_available_names = all_annotations | sqlmodel_relationships
|
|
||||||
|
|
||||||
for method_name in dir(cls):
|
|
||||||
if method_name.startswith('__'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
method = getattr(cls, method_name, None)
|
|
||||||
except AttributeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if method is None or not hasattr(method, '_required_relations'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 验证字符串格式的关系名
|
|
||||||
for spec in method._required_relations:
|
|
||||||
if isinstance(spec, str):
|
|
||||||
# 检查注解、Relationship 或已有属性
|
|
||||||
if spec not in all_available_names and not hasattr(cls, spec):
|
|
||||||
raise AttributeError(
|
|
||||||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
|
||||||
f"但 {cls.__name__} 没有此属性"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
|
||||||
|
|
||||||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
|
||||||
自动处理 commit 导致的 expire 问题。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rel_name: 关系属性名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True 如果关系已加载,False 如果未加载或已过期
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
state = sa_inspect(self)
|
|
||||||
# unloaded 包含未加载的关系属性名
|
|
||||||
return rel_name not in state.unloaded
|
|
||||||
except Exception:
|
|
||||||
# 对象可能未被 SQLAlchemy 管理
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _ensure_relations_loaded(
|
|
||||||
self,
|
|
||||||
session: AsyncSession,
|
|
||||||
relations: tuple[str | RelationshipInfo, ...],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
确保指定关系已加载,只加载未加载的部分
|
|
||||||
|
|
||||||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
|
||||||
- 首次访问的关系
|
|
||||||
- commit 后 expire 的关系
|
|
||||||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: 数据库会话
|
|
||||||
relations: 需要的关系规格
|
|
||||||
"""
|
|
||||||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
|
||||||
to_load: list[str | RelationshipInfo] = []
|
|
||||||
# 区分直接关系和嵌套关系的 key
|
|
||||||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
|
||||||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
|
||||||
|
|
||||||
for rel in relations:
|
|
||||||
if isinstance(rel, str):
|
|
||||||
# 直接关系:检查本类的关系是否已加载
|
|
||||||
if not self._is_relation_loaded(rel):
|
|
||||||
to_load.append(rel)
|
|
||||||
direct_keys.add(rel)
|
|
||||||
else:
|
|
||||||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
|
||||||
# 1. 查找指向父类的关系属性
|
|
||||||
parent_class = rel.parent.class_
|
|
||||||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
|
||||||
|
|
||||||
if parent_attr is None:
|
|
||||||
# 找不到路径,可能是配置错误,但仍尝试加载
|
|
||||||
l.warning(
|
|
||||||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
|
||||||
f"无法检查 {rel.key} 是否已加载"
|
|
||||||
)
|
|
||||||
to_load.append(rel)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 2. 检查父对象是否已加载
|
|
||||||
if not self._is_relation_loaded(parent_attr):
|
|
||||||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
|
||||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
|
||||||
to_load.append(parent_attr)
|
|
||||||
nested_parent_keys.add(parent_attr)
|
|
||||||
to_load.append(rel)
|
|
||||||
else:
|
|
||||||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
|
||||||
parent_obj = getattr(self, parent_attr)
|
|
||||||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
|
||||||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
|
||||||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
|
||||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
|
||||||
to_load.append(parent_attr)
|
|
||||||
nested_parent_keys.add(parent_attr)
|
|
||||||
to_load.append(rel)
|
|
||||||
|
|
||||||
if not to_load:
|
|
||||||
return # 全部已加载,跳过
|
|
||||||
|
|
||||||
# 构建 load 参数
|
|
||||||
load_options = self._specs_to_load_options(to_load)
|
|
||||||
if not load_options:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 安全地获取主键值(避免触发懒加载)
|
|
||||||
state = sa_inspect(self)
|
|
||||||
pk_tuple = state.key[1] if state.key else None
|
|
||||||
if pk_tuple is None:
|
|
||||||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
|
||||||
return
|
|
||||||
# 主键是元组,取第一个值(假设单列主键)
|
|
||||||
pk_value = pk_tuple[0]
|
|
||||||
|
|
||||||
# 单次查询加载缺失的关系
|
|
||||||
fresh = await self.__class__.get(
|
|
||||||
session,
|
|
||||||
self.__class__.id == pk_value,
|
|
||||||
load=load_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
if fresh is None:
|
|
||||||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
|
||||||
all_direct_keys = direct_keys | nested_parent_keys
|
|
||||||
for key in all_direct_keys:
|
|
||||||
value = getattr(fresh, key, None)
|
|
||||||
object.__setattr__(self, key, value)
|
|
||||||
|
|
||||||
def _specs_to_load_options(
|
|
||||||
self,
|
|
||||||
specs: list[str | RelationshipInfo],
|
|
||||||
) -> list[RelationshipInfo]:
|
|
||||||
"""
|
|
||||||
将关系规格转换为 load 参数
|
|
||||||
|
|
||||||
- 字符串 → cls.{name}
|
|
||||||
- RelationshipInfo → 直接使用
|
|
||||||
"""
|
|
||||||
result: list[RelationshipInfo] = []
|
|
||||||
|
|
||||||
for spec in specs:
|
|
||||||
if isinstance(spec, str):
|
|
||||||
rel = getattr(self.__class__, spec, None)
|
|
||||||
if rel is not None:
|
|
||||||
result.append(rel)
|
|
||||||
else:
|
|
||||||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
|
||||||
else:
|
|
||||||
result.append(spec)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
# ==================== 可选的手动预加载 API ====================
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
|
||||||
"""
|
|
||||||
获取指定方法声明的关系(用于外部预加载场景)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method_name: 方法名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RelationshipInfo 列表
|
|
||||||
"""
|
|
||||||
method = getattr(cls, method_name, None)
|
|
||||||
if method is None or not hasattr(method, '_required_relations'):
|
|
||||||
return []
|
|
||||||
|
|
||||||
result: list[RelationshipInfo] = []
|
|
||||||
for spec in method._required_relations:
|
|
||||||
if isinstance(spec, str):
|
|
||||||
rel = getattr(cls, spec, None)
|
|
||||||
if rel:
|
|
||||||
result.append(rel)
|
|
||||||
else:
|
|
||||||
result.append(spec)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
|
||||||
"""
|
|
||||||
获取多个方法的关系并去重(用于批量预加载场景)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method_names: 方法名列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
去重后的 RelationshipInfo 列表
|
|
||||||
"""
|
|
||||||
seen: set[str] = set()
|
|
||||||
result: list[RelationshipInfo] = []
|
|
||||||
|
|
||||||
for method_name in method_names:
|
|
||||||
for rel in cls.get_relations_for_method(method_name):
|
|
||||||
key = rel.key
|
|
||||||
if key not in seen:
|
|
||||||
seen.add(key)
|
|
||||||
result.append(rel)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
|
||||||
"""
|
|
||||||
手动预加载指定方法的关系(可选优化 API)
|
|
||||||
|
|
||||||
当需要确保在调用方法前完成所有加载时使用。
|
|
||||||
通常情况下不需要调用此方法,装饰器会自动处理。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: 数据库会话
|
|
||||||
method_names: 方法名列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
self(支持链式调用)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# 可选:显式预加载(通常不需要)
|
|
||||||
tool = await tool.preload_for(session, 'cost', '_call')
|
|
||||||
"""
|
|
||||||
all_relations: list[str | RelationshipInfo] = []
|
|
||||||
|
|
||||||
for method_name in method_names:
|
|
||||||
method = getattr(self.__class__, method_name, None)
|
|
||||||
if method and hasattr(method, '_required_relations'):
|
|
||||||
all_relations.extend(method._required_relations)
|
|
||||||
|
|
||||||
if all_relations:
|
|
||||||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
|
||||||
|
|
||||||
return self
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ from enum import StrEnum
|
|||||||
|
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase
|
||||||
|
|
||||||
|
|
||||||
class ResponseBase(SQLModelBase):
|
class ResponseBase(SQLModelBase):
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, text, Index
|
from sqlmodel import Field, Relationship, text, Index
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||||
from .mixin import TableBaseMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .download import Download
|
from .download import Download
|
||||||
@@ -29,13 +28,13 @@ class NodeType(StrEnum):
|
|||||||
class Aria2ConfigurationBase(SQLModelBase):
|
class Aria2ConfigurationBase(SQLModelBase):
|
||||||
"""Aria2配置基础模型"""
|
"""Aria2配置基础模型"""
|
||||||
|
|
||||||
rpc_url: str | None = Field(default=None, max_length=255)
|
rpc_url: Str255 | None = None
|
||||||
"""RPC地址"""
|
"""RPC地址"""
|
||||||
|
|
||||||
rpc_secret: str | None = None
|
rpc_secret: str | None = None
|
||||||
"""RPC密钥"""
|
"""RPC密钥"""
|
||||||
|
|
||||||
temp_path: str | None = Field(default=None, max_length=255)
|
temp_path: Str255 | None = None
|
||||||
"""临时下载路径"""
|
"""临时下载路径"""
|
||||||
|
|
||||||
max_concurrent: int = Field(default=5, ge=1, le=50)
|
max_concurrent: int = Field(default=5, ge=1, le=50)
|
||||||
@@ -71,19 +70,19 @@ class Node(SQLModelBase, TableBaseMixin):
|
|||||||
status: NodeStatus = Field(default=NodeStatus.ONLINE)
|
status: NodeStatus = Field(default=NodeStatus.ONLINE)
|
||||||
"""节点状态"""
|
"""节点状态"""
|
||||||
|
|
||||||
name: str = Field(max_length=255, unique=True)
|
name: Str255 = Field(unique=True)
|
||||||
"""节点名称"""
|
"""节点名称"""
|
||||||
|
|
||||||
type: NodeType
|
type: NodeType
|
||||||
"""节点类型"""
|
"""节点类型"""
|
||||||
|
|
||||||
server: str = Field(max_length=255)
|
server: Str255
|
||||||
"""节点地址(IP或域名)"""
|
"""节点地址(IP或域名)"""
|
||||||
|
|
||||||
slave_key: str | None = Field(default=None, max_length=255)
|
slave_key: Str255 | None = None
|
||||||
"""从机通讯密钥"""
|
"""从机通讯密钥"""
|
||||||
|
|
||||||
master_key: str | None = Field(default=None, max_length=255)
|
master_key: Str255 | None = None
|
||||||
"""主机通讯密钥"""
|
"""主机通讯密钥"""
|
||||||
|
|
||||||
aria2_enabled: bool = False
|
aria2_enabled: bool = False
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ from uuid import UUID
|
|||||||
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from sqlalchemy import BigInteger
|
from sqlalchemy import BigInteger
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text
|
from sqlmodel import Field, Relationship, CheckConstraint, Index, text
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255, Str256
|
||||||
from .mixin import UUIDTableBaseMixin
|
|
||||||
|
from .policy import PolicyType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -17,6 +18,7 @@ if TYPE_CHECKING:
|
|||||||
from .share import Share
|
from .share import Share
|
||||||
from .physical_file import PhysicalFile
|
from .physical_file import PhysicalFile
|
||||||
from .uri import DiskNextURI
|
from .uri import DiskNextURI
|
||||||
|
from .object_metadata import ObjectMetadata
|
||||||
|
|
||||||
|
|
||||||
class ObjectType(StrEnum):
|
class ObjectType(StrEnum):
|
||||||
@@ -24,42 +26,13 @@ class ObjectType(StrEnum):
|
|||||||
FILE = "file"
|
FILE = "file"
|
||||||
FOLDER = "folder"
|
FOLDER = "folder"
|
||||||
|
|
||||||
class StorageType(StrEnum):
|
|
||||||
"""存储类型枚举"""
|
|
||||||
LOCAL = "local"
|
|
||||||
QINIU = "qiniu"
|
|
||||||
TENCENT = "tencent"
|
|
||||||
ALIYUN = "aliyun"
|
|
||||||
ONEDRIVE = "onedrive"
|
|
||||||
GOOGLE_DRIVE = "google_drive"
|
|
||||||
DROPBOX = "dropbox"
|
|
||||||
WEBDAV = "webdav"
|
|
||||||
REMOTE = "remote"
|
|
||||||
|
|
||||||
|
class FileCategory(StrEnum):
|
||||||
class FileMetadataBase(SQLModelBase):
|
"""文件类型分类枚举,用于按类别筛选文件"""
|
||||||
"""文件元数据基础模型"""
|
IMAGE = "image"
|
||||||
|
VIDEO = "video"
|
||||||
width: int | None = Field(default=None)
|
AUDIO = "audio"
|
||||||
"""图片宽度(像素)"""
|
DOCUMENT = "document"
|
||||||
|
|
||||||
height: int | None = Field(default=None)
|
|
||||||
"""图片高度(像素)"""
|
|
||||||
|
|
||||||
duration: float | None = Field(default=None)
|
|
||||||
"""音视频时长(秒)"""
|
|
||||||
|
|
||||||
bitrate: int | None = Field(default=None)
|
|
||||||
"""比特率(kbps)"""
|
|
||||||
|
|
||||||
mime_type: str | None = Field(default=None, max_length=127)
|
|
||||||
"""MIME类型"""
|
|
||||||
|
|
||||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
|
||||||
"""MD5校验和"""
|
|
||||||
|
|
||||||
checksum_sha256: str | None = Field(default=None, max_length=64)
|
|
||||||
"""SHA256校验和"""
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Base 模型 ====================
|
# ==================== Base 模型 ====================
|
||||||
@@ -76,9 +49,32 @@ class ObjectBase(SQLModelBase):
|
|||||||
size: int | None = None
|
size: int | None = None
|
||||||
"""文件大小(字节),目录为 None"""
|
"""文件大小(字节),目录为 None"""
|
||||||
|
|
||||||
|
mime_type: str | None = Field(default=None, max_length=127)
|
||||||
|
"""MIME类型(仅文件有效)"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class ObjectFileFinalize(SQLModelBase):
|
||||||
|
"""文件上传完成后更新 Object 的 DTO"""
|
||||||
|
|
||||||
|
size: int
|
||||||
|
"""文件大小(字节)"""
|
||||||
|
|
||||||
|
physical_file_id: UUID
|
||||||
|
"""关联的物理文件UUID"""
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectMoveUpdate(SQLModelBase):
|
||||||
|
"""移动/重命名 Object 的 DTO"""
|
||||||
|
|
||||||
|
parent_id: UUID
|
||||||
|
"""新的父目录UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""新名称"""
|
||||||
|
|
||||||
|
|
||||||
class DirectoryCreateRequest(SQLModelBase):
|
class DirectoryCreateRequest(SQLModelBase):
|
||||||
"""创建目录请求 DTO"""
|
"""创建目录请求 DTO"""
|
||||||
|
|
||||||
@@ -137,7 +133,7 @@ class PolicyResponse(SQLModelBase):
|
|||||||
name: str
|
name: str
|
||||||
"""策略名称"""
|
"""策略名称"""
|
||||||
|
|
||||||
type: StorageType
|
type: PolicyType
|
||||||
"""存储类型"""
|
"""存储类型"""
|
||||||
|
|
||||||
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
||||||
@@ -165,22 +161,6 @@ class DirectoryResponse(SQLModelBase):
|
|||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
|
|
||||||
"""文件元数据模型(与Object一对一关联)"""
|
|
||||||
|
|
||||||
object_id: UUID = Field(
|
|
||||||
foreign_key="object.id",
|
|
||||||
unique=True,
|
|
||||||
index=True,
|
|
||||||
ondelete="CASCADE"
|
|
||||||
)
|
|
||||||
"""关联的对象UUID"""
|
|
||||||
|
|
||||||
# 反向关系
|
|
||||||
object: "Object" = Relationship(back_populates="file_metadata")
|
|
||||||
"""关联的对象"""
|
|
||||||
|
|
||||||
|
|
||||||
class Object(ObjectBase, UUIDTableBaseMixin):
|
class Object(ObjectBase, UUIDTableBaseMixin):
|
||||||
"""
|
"""
|
||||||
统一对象模型
|
统一对象模型
|
||||||
@@ -195,8 +175,13 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况)
|
# 同一父目录下名称唯一(仅对未删除记录生效)
|
||||||
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"),
|
Index(
|
||||||
|
"uq_object_parent_name_active",
|
||||||
|
"owner_id", "parent_id", "name",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("deleted_at IS NULL"),
|
||||||
|
),
|
||||||
# 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/")
|
# 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/")
|
||||||
CheckConstraint(
|
CheckConstraint(
|
||||||
"parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
|
"parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
|
||||||
@@ -207,17 +192,19 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
Index("ix_object_parent_updated", "parent_id", "updated_at"),
|
Index("ix_object_parent_updated", "parent_id", "updated_at"),
|
||||||
Index("ix_object_owner_type", "owner_id", "type"),
|
Index("ix_object_owner_type", "owner_id", "type"),
|
||||||
Index("ix_object_owner_size", "owner_id", "size"),
|
Index("ix_object_owner_size", "owner_id", "size"),
|
||||||
|
# 回收站查询索引
|
||||||
|
Index("ix_object_owner_deleted", "owner_id", "deleted_at"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== 基础字段 ====================
|
# ==================== 基础字段 ====================
|
||||||
|
|
||||||
name: str = Field(max_length=255)
|
name: Str255
|
||||||
"""对象名称(文件名或目录名)"""
|
"""对象名称(文件名或目录名)"""
|
||||||
|
|
||||||
type: ObjectType
|
type: ObjectType
|
||||||
"""对象类型:file 或 folder"""
|
"""对象类型:file 或 folder"""
|
||||||
|
|
||||||
password: str | None = Field(default=None, max_length=255)
|
password: Str255 | None = None
|
||||||
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
|
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
|
||||||
|
|
||||||
# ==================== 文件专属字段 ====================
|
# ==================== 文件专属字段 ====================
|
||||||
@@ -225,7 +212,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||||
"""文件大小(字节),目录为 0"""
|
"""文件大小(字节),目录为 0"""
|
||||||
|
|
||||||
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
upload_session_id: Str255 | None = Field(default=None, unique=True, index=True)
|
||||||
"""分块上传会话ID(仅文件有效)"""
|
"""分块上传会话ID(仅文件有效)"""
|
||||||
|
|
||||||
physical_file_id: UUID | None = Field(
|
physical_file_id: UUID | None = Field(
|
||||||
@@ -280,6 +267,18 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
ban_reason: str | None = Field(default=None, max_length=500)
|
ban_reason: str | None = Field(default=None, max_length=500)
|
||||||
"""封禁原因"""
|
"""封禁原因"""
|
||||||
|
|
||||||
|
# ==================== 软删除相关字段 ====================
|
||||||
|
|
||||||
|
deleted_at: datetime | None = Field(default=None, index=True)
|
||||||
|
"""软删除时间戳,NULL 表示未删除"""
|
||||||
|
|
||||||
|
deleted_original_parent_id: UUID | None = Field(
|
||||||
|
default=None,
|
||||||
|
foreign_key="object.id",
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
"""软删除前的原始父目录UUID(恢复时用于还原位置)"""
|
||||||
|
|
||||||
# ==================== 关系 ====================
|
# ==================== 关系 ====================
|
||||||
|
|
||||||
owner: "User" = Relationship(
|
owner: "User" = Relationship(
|
||||||
@@ -299,22 +298,28 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
# 自引用关系
|
# 自引用关系
|
||||||
parent: "Object" = Relationship(
|
parent: "Object" = Relationship(
|
||||||
back_populates="children",
|
back_populates="children",
|
||||||
sa_relationship_kwargs={"remote_side": "Object.id"},
|
sa_relationship_kwargs={
|
||||||
|
"remote_side": "Object.id",
|
||||||
|
"foreign_keys": "[Object.parent_id]",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
"""父目录"""
|
"""父目录"""
|
||||||
|
|
||||||
children: list["Object"] = Relationship(
|
children: list["Object"] = Relationship(
|
||||||
back_populates="parent",
|
back_populates="parent",
|
||||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
sa_relationship_kwargs={
|
||||||
|
"cascade": "all, delete-orphan",
|
||||||
|
"foreign_keys": "[Object.parent_id]",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
"""子对象(文件和子目录)"""
|
"""子对象(文件和子目录)"""
|
||||||
|
|
||||||
# 仅文件有效的关系
|
# 仅文件有效的关系
|
||||||
file_metadata: FileMetadata | None = Relationship(
|
metadata_entries: list["ObjectMetadata"] = Relationship(
|
||||||
back_populates="object",
|
back_populates="object",
|
||||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
|
||||||
)
|
)
|
||||||
"""文件元数据(仅文件有效)"""
|
"""元数据键值对列表"""
|
||||||
|
|
||||||
source_links: list["SourceLink"] = Relationship(
|
source_links: list["SourceLink"] = Relationship(
|
||||||
back_populates="object",
|
back_populates="object",
|
||||||
@@ -367,7 +372,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
"""
|
"""
|
||||||
return await cls.get(
|
return await cls.get(
|
||||||
session,
|
session,
|
||||||
(cls.owner_id == user_id) & (cls.parent_id == None)
|
(cls.owner_id == user_id) & (cls.parent_id == None) & (cls.deleted_at == None)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -416,7 +421,8 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
session,
|
session,
|
||||||
(cls.owner_id == user_id) &
|
(cls.owner_id == user_id) &
|
||||||
(cls.parent_id == current.id) &
|
(cls.parent_id == current.id) &
|
||||||
(cls.name == part)
|
(cls.name == part) &
|
||||||
|
(cls.deleted_at == None)
|
||||||
)
|
)
|
||||||
|
|
||||||
return current
|
return current
|
||||||
@@ -424,7 +430,23 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
|
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
|
||||||
"""
|
"""
|
||||||
获取目录下的所有子对象
|
获取目录下的所有子对象(不包含已软删除的)
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param parent_id: 父目录UUID
|
||||||
|
:return: 子对象列表
|
||||||
|
"""
|
||||||
|
return await cls.get(
|
||||||
|
session,
|
||||||
|
(cls.owner_id == user_id) & (cls.parent_id == parent_id) & (cls.deleted_at == None),
|
||||||
|
fetch_mode="all"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_all_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
|
||||||
|
"""
|
||||||
|
获取目录下的所有子对象(包含已软删除的,用于永久删除场景)
|
||||||
|
|
||||||
:param session: 数据库会话
|
:param session: 数据库会话
|
||||||
:param user_id: 用户UUID
|
:param user_id: 用户UUID
|
||||||
@@ -437,6 +459,55 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
fetch_mode="all"
|
fetch_mode="all"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_trash_items(cls, session, user_id: UUID) -> list["Object"]:
|
||||||
|
"""
|
||||||
|
获取用户回收站中的顶层对象
|
||||||
|
|
||||||
|
只返回被直接软删除的顶层对象(deleted_at 非 NULL),
|
||||||
|
不返回其子对象(子对象的 deleted_at 为 NULL,通过 parent 关系间接处于回收站中)。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:return: 回收站顶层对象列表
|
||||||
|
"""
|
||||||
|
return await cls.get(
|
||||||
|
session,
|
||||||
|
(cls.owner_id == user_id) & (cls.deleted_at != None),
|
||||||
|
fetch_mode="all"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_category(
|
||||||
|
cls,
|
||||||
|
session: 'AsyncSession',
|
||||||
|
user_id: UUID,
|
||||||
|
extensions: list[str],
|
||||||
|
table_view: 'TableViewRequest | None' = None,
|
||||||
|
) -> 'ListResponse[Object]':
|
||||||
|
"""
|
||||||
|
按扩展名列表查询用户的所有文件(跨目录)
|
||||||
|
|
||||||
|
只查询未删除、未封禁的文件对象,使用 ILIKE 匹配文件名后缀。
|
||||||
|
|
||||||
|
:param session: 数据库会话
|
||||||
|
:param user_id: 用户UUID
|
||||||
|
:param extensions: 扩展名列表(不含点号)
|
||||||
|
:param table_view: 分页排序参数
|
||||||
|
:return: 分页文件列表
|
||||||
|
"""
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
|
ext_conditions = [cls.name.ilike(f"%.{ext}") for ext in extensions]
|
||||||
|
condition = (
|
||||||
|
(cls.owner_id == user_id) &
|
||||||
|
(cls.type == ObjectType.FILE) &
|
||||||
|
(cls.deleted_at == None) &
|
||||||
|
(cls.is_banned == False) &
|
||||||
|
or_(*ext_conditions)
|
||||||
|
)
|
||||||
|
return await cls.get_with_count(session, condition, table_view=table_view)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def resolve_uri(
|
async def resolve_uri(
|
||||||
cls,
|
cls,
|
||||||
@@ -514,7 +585,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
|||||||
class UploadSessionBase(SQLModelBase):
|
class UploadSessionBase(SQLModelBase):
|
||||||
"""上传会话基础字段"""
|
"""上传会话基础字段"""
|
||||||
|
|
||||||
file_name: str = Field(max_length=255)
|
file_name: Str255
|
||||||
"""原始文件名"""
|
"""原始文件名"""
|
||||||
|
|
||||||
file_size: int = Field(ge=0, sa_type=BigInteger)
|
file_size: int = Field(ge=0, sa_type=BigInteger)
|
||||||
@@ -545,6 +616,12 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
|||||||
storage_path: str | None = Field(default=None, max_length=512)
|
storage_path: str | None = Field(default=None, max_length=512)
|
||||||
"""文件存储路径"""
|
"""文件存储路径"""
|
||||||
|
|
||||||
|
s3_upload_id: Str256 | None = None
|
||||||
|
"""S3 Multipart Upload ID(仅 S3 策略使用)"""
|
||||||
|
|
||||||
|
s3_part_etags: str | None = None
|
||||||
|
"""S3 已上传分片的 ETag 列表,JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
|
||||||
|
|
||||||
expires_at: datetime
|
expires_at: datetime
|
||||||
"""会话过期时间"""
|
"""会话过期时间"""
|
||||||
|
|
||||||
@@ -586,7 +663,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
|||||||
class CreateUploadSessionRequest(SQLModelBase):
|
class CreateUploadSessionRequest(SQLModelBase):
|
||||||
"""创建上传会话请求 DTO"""
|
"""创建上传会话请求 DTO"""
|
||||||
|
|
||||||
file_name: str = Field(max_length=255)
|
file_name: Str255
|
||||||
"""文件名"""
|
"""文件名"""
|
||||||
|
|
||||||
file_size: int = Field(ge=0)
|
file_size: int = Field(ge=0)
|
||||||
@@ -643,7 +720,7 @@ class UploadChunkResponse(SQLModelBase):
|
|||||||
class CreateFileRequest(SQLModelBase):
|
class CreateFileRequest(SQLModelBase):
|
||||||
"""创建空白文件请求 DTO"""
|
"""创建空白文件请求 DTO"""
|
||||||
|
|
||||||
name: str = Field(max_length=255)
|
name: Str255
|
||||||
"""文件名"""
|
"""文件名"""
|
||||||
|
|
||||||
parent_id: UUID
|
parent_id: UUID
|
||||||
@@ -653,6 +730,16 @@ class CreateFileRequest(SQLModelBase):
|
|||||||
"""存储策略UUID,不指定则使用父目录的策略"""
|
"""存储策略UUID,不指定则使用父目录的策略"""
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSwitchPolicyRequest(SQLModelBase):
|
||||||
|
"""切换对象存储策略请求"""
|
||||||
|
|
||||||
|
policy_id: UUID
|
||||||
|
"""目标存储策略UUID"""
|
||||||
|
|
||||||
|
is_migrate_existing: bool = False
|
||||||
|
"""(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== 对象操作相关 DTO ====================
|
# ==================== 对象操作相关 DTO ====================
|
||||||
|
|
||||||
class ObjectCopyRequest(SQLModelBase):
|
class ObjectCopyRequest(SQLModelBase):
|
||||||
@@ -671,7 +758,7 @@ class ObjectRenameRequest(SQLModelBase):
|
|||||||
id: UUID
|
id: UUID
|
||||||
"""对象UUID"""
|
"""对象UUID"""
|
||||||
|
|
||||||
new_name: str = Field(max_length=255)
|
new_name: Str255
|
||||||
"""新名称"""
|
"""新名称"""
|
||||||
|
|
||||||
|
|
||||||
@@ -690,6 +777,9 @@ class ObjectPropertyResponse(SQLModelBase):
|
|||||||
size: int
|
size: int
|
||||||
"""文件大小(字节)"""
|
"""文件大小(字节)"""
|
||||||
|
|
||||||
|
mime_type: str | None = None
|
||||||
|
"""MIME类型"""
|
||||||
|
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
"""创建时间"""
|
"""创建时间"""
|
||||||
|
|
||||||
@@ -703,22 +793,13 @@ class ObjectPropertyResponse(SQLModelBase):
|
|||||||
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
||||||
"""对象详细属性响应 DTO(继承基本属性)"""
|
"""对象详细属性响应 DTO(继承基本属性)"""
|
||||||
|
|
||||||
# 元数据信息
|
# 校验和(从 PhysicalFile 读取)
|
||||||
mime_type: str | None = None
|
|
||||||
"""MIME类型"""
|
|
||||||
|
|
||||||
width: int | None = None
|
|
||||||
"""图片宽度(像素)"""
|
|
||||||
|
|
||||||
height: int | None = None
|
|
||||||
"""图片高度(像素)"""
|
|
||||||
|
|
||||||
duration: float | None = None
|
|
||||||
"""音视频时长(秒)"""
|
|
||||||
|
|
||||||
checksum_md5: str | None = None
|
checksum_md5: str | None = None
|
||||||
"""MD5校验和"""
|
"""MD5校验和"""
|
||||||
|
|
||||||
|
checksum_sha256: str | None = None
|
||||||
|
"""SHA256校验和"""
|
||||||
|
|
||||||
# 分享统计
|
# 分享统计
|
||||||
share_count: int = 0
|
share_count: int = 0
|
||||||
"""分享次数"""
|
"""分享次数"""
|
||||||
@@ -736,6 +817,10 @@ class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
|||||||
reference_count: int = 1
|
reference_count: int = 1
|
||||||
"""物理文件引用计数(仅文件有效)"""
|
"""物理文件引用计数(仅文件有效)"""
|
||||||
|
|
||||||
|
# 元数据(KV 格式)
|
||||||
|
metadatas: dict[str, str] = {}
|
||||||
|
"""所有元数据条目(键名 → 值)"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== 管理员文件管理 DTO ====================
|
# ==================== 管理员文件管理 DTO ====================
|
||||||
|
|
||||||
@@ -805,3 +890,41 @@ class AdminFileListResponse(SQLModelBase):
|
|||||||
|
|
||||||
total: int = 0
|
total: int = 0
|
||||||
"""总数"""
|
"""总数"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 回收站相关 DTO ====================
|
||||||
|
|
||||||
|
class TrashItemResponse(SQLModelBase):
|
||||||
|
"""回收站对象响应 DTO"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""对象UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""对象名称"""
|
||||||
|
|
||||||
|
type: ObjectType
|
||||||
|
"""对象类型"""
|
||||||
|
|
||||||
|
size: int
|
||||||
|
"""文件大小(字节)"""
|
||||||
|
|
||||||
|
deleted_at: datetime
|
||||||
|
"""删除时间"""
|
||||||
|
|
||||||
|
original_parent_id: UUID | None
|
||||||
|
"""原始父目录UUID"""
|
||||||
|
|
||||||
|
|
||||||
|
class TrashRestoreRequest(SQLModelBase):
|
||||||
|
"""恢复对象请求 DTO"""
|
||||||
|
|
||||||
|
ids: list[UUID]
|
||||||
|
"""待恢复对象UUID列表"""
|
||||||
|
|
||||||
|
|
||||||
|
class TrashDeleteRequest(SQLModelBase):
|
||||||
|
"""永久删除对象请求 DTO"""
|
||||||
|
|
||||||
|
ids: list[UUID]
|
||||||
|
"""待永久删除对象UUID列表"""
|
||||||
|
|||||||
127
sqlmodels/object_metadata.py
Normal file
127
sqlmodels/object_metadata.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
对象元数据 KV 模型
|
||||||
|
|
||||||
|
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
|
||||||
|
如 exif:width, stream:duration, music:artist 等。
|
||||||
|
|
||||||
|
架构:
|
||||||
|
ObjectMetadata (KV 表,与 Object 一对多关系)
|
||||||
|
└── 每个 Object 可以有多条元数据记录
|
||||||
|
└── (object_id, name) 组合唯一索引
|
||||||
|
|
||||||
|
命名空间:
|
||||||
|
- exif: 图片 EXIF 信息(尺寸、相机参数、拍摄时间等)
|
||||||
|
- stream: 音视频流信息(时长、比特率、视频尺寸、编解码等)
|
||||||
|
- music: 音乐标签(标题、艺术家、专辑等)
|
||||||
|
- geo: 地理位置(经纬度、地址)
|
||||||
|
- apk: Android 安装包信息
|
||||||
|
- custom: 用户自定义属性
|
||||||
|
- sys: 系统内部元数据
|
||||||
|
- thumb: 缩略图信息
|
||||||
|
"""
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlmodel import Field, UniqueConstraint, Index, Relationship
|
||||||
|
|
||||||
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .object import Object
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 枚举 ====================
|
||||||
|
|
||||||
|
class MetadataNamespace(StrEnum):
|
||||||
|
"""元数据命名空间枚举"""
|
||||||
|
EXIF = "exif"
|
||||||
|
"""图片 EXIF 信息(含尺寸、相机参数、拍摄时间等)"""
|
||||||
|
MUSIC = "music"
|
||||||
|
"""音乐标签(title/artist/album/genre 等)"""
|
||||||
|
STREAM = "stream"
|
||||||
|
"""音视频流信息(codec/duration/bitrate/resolution 等)"""
|
||||||
|
GEO = "geo"
|
||||||
|
"""地理位置(latitude/longitude/address)"""
|
||||||
|
APK = "apk"
|
||||||
|
"""Android 安装包信息(package_name/version 等)"""
|
||||||
|
THUMB = "thumb"
|
||||||
|
"""缩略图信息(内部使用)"""
|
||||||
|
SYS = "sys"
|
||||||
|
"""系统元数据(内部使用)"""
|
||||||
|
CUSTOM = "custom"
|
||||||
|
"""用户自定义属性"""
|
||||||
|
|
||||||
|
|
||||||
|
# 对外不可见的命名空间(API 不返回给普通用户)
|
||||||
|
INTERNAL_NAMESPACES: set[str] = {MetadataNamespace.SYS, MetadataNamespace.THUMB}
|
||||||
|
|
||||||
|
# 用户可写的命名空间
|
||||||
|
USER_WRITABLE_NAMESPACES: set[str] = {MetadataNamespace.CUSTOM}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Base 模型 ====================
|
||||||
|
|
||||||
|
class ObjectMetadataBase(SQLModelBase):
|
||||||
|
"""对象元数据 KV 基础模型"""
|
||||||
|
|
||||||
|
name: Str255
|
||||||
|
"""元数据键名,格式:namespace:key(如 exif:width, stream:duration)"""
|
||||||
|
|
||||||
|
value: str
|
||||||
|
"""元数据值(统一为字符串存储)"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
class ObjectMetadata(ObjectMetadataBase, UUIDTableBaseMixin):
|
||||||
|
"""
|
||||||
|
对象元数据 KV 模型
|
||||||
|
|
||||||
|
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
|
||||||
|
每个对象的每个键名唯一(通过唯一索引保证)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("object_id", "name", name="uq_object_metadata_object_name"),
|
||||||
|
Index("ix_object_metadata_object_id", "object_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
object_id: UUID = Field(
|
||||||
|
foreign_key="object.id",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
"""关联的对象UUID"""
|
||||||
|
|
||||||
|
is_public: bool = False
|
||||||
|
"""是否对分享页面公开"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
object: "Object" = Relationship(back_populates="metadata_entries")
|
||||||
|
"""关联的对象"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class MetadataResponse(SQLModelBase):
|
||||||
|
"""元数据查询响应 DTO"""
|
||||||
|
|
||||||
|
metadatas: dict[str, str]
|
||||||
|
"""元数据字典(键名 → 值)"""
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataPatchItem(SQLModelBase):
|
||||||
|
"""单条元数据补丁 DTO"""
|
||||||
|
|
||||||
|
key: Str255
|
||||||
|
"""元数据键名"""
|
||||||
|
|
||||||
|
value: str | None = None
|
||||||
|
"""值,None 表示删除此条目"""
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataPatchRequest(SQLModelBase):
|
||||||
|
"""元数据批量更新请求 DTO"""
|
||||||
|
|
||||||
|
patches: list[MetadataPatchItem]
|
||||||
|
"""补丁列表"""
|
||||||
@@ -1,55 +1,118 @@
|
|||||||
|
from decimal import Decimal
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import Numeric
|
||||||
from sqlmodel import Field, Relationship
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||||
from .mixin import TableBaseMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .product import Product
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
class OrderStatus(StrEnum):
|
class OrderStatus(StrEnum):
|
||||||
"""订单状态枚举"""
|
"""订单状态枚举"""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
"""待支付"""
|
"""待支付"""
|
||||||
|
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
"""已完成"""
|
"""已完成"""
|
||||||
|
|
||||||
CANCELLED = "cancelled"
|
CANCELLED = "cancelled"
|
||||||
"""已取消"""
|
"""已取消"""
|
||||||
|
|
||||||
|
|
||||||
class OrderType(StrEnum):
|
class OrderType(StrEnum):
|
||||||
"""订单类型枚举"""
|
"""订单类型枚举"""
|
||||||
# [TODO] 补充具体订单类型
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
STORAGE_PACK = "storage_pack"
|
||||||
|
"""容量包"""
|
||||||
|
|
||||||
|
GROUP_TIME = "group_time"
|
||||||
|
"""用户组时长"""
|
||||||
|
|
||||||
|
SCORE = "score"
|
||||||
|
"""积分充值"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class CreateOrderRequest(SQLModelBase):
|
||||||
|
"""创建订单请求 DTO"""
|
||||||
|
|
||||||
|
product_id: UUID
|
||||||
|
"""商品UUID"""
|
||||||
|
|
||||||
|
num: int = Field(default=1, ge=1)
|
||||||
|
"""购买数量"""
|
||||||
|
|
||||||
|
method: str
|
||||||
|
"""支付方式"""
|
||||||
|
|
||||||
|
|
||||||
|
class OrderResponse(SQLModelBase):
|
||||||
|
"""订单响应 DTO"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
"""订单ID"""
|
||||||
|
|
||||||
|
order_no: str
|
||||||
|
"""订单号"""
|
||||||
|
|
||||||
|
type: OrderType
|
||||||
|
"""订单类型"""
|
||||||
|
|
||||||
|
method: str | None = None
|
||||||
|
"""支付方式"""
|
||||||
|
|
||||||
|
product_id: UUID | None = None
|
||||||
|
"""商品UUID"""
|
||||||
|
|
||||||
|
num: int
|
||||||
|
"""购买数量"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""商品名称"""
|
||||||
|
|
||||||
|
price: float
|
||||||
|
"""订单价格(元)"""
|
||||||
|
|
||||||
|
status: OrderStatus
|
||||||
|
"""订单状态"""
|
||||||
|
|
||||||
|
user_id: UUID
|
||||||
|
"""所属用户UUID"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
class Order(SQLModelBase, TableBaseMixin):
|
class Order(SQLModelBase, TableBaseMixin):
|
||||||
"""订单模型"""
|
"""订单模型"""
|
||||||
|
|
||||||
order_no: str = Field(max_length=255, unique=True, index=True)
|
order_no: Str255 = Field(unique=True, index=True)
|
||||||
"""订单号,唯一"""
|
"""订单号,唯一"""
|
||||||
|
|
||||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
type: OrderType
|
||||||
"""订单类型 [TODO] 待定义枚举"""
|
"""订单类型"""
|
||||||
|
|
||||||
method: str | None = Field(default=None, max_length=255)
|
method: Str255 | None = None
|
||||||
"""支付方式"""
|
"""支付方式"""
|
||||||
|
|
||||||
product_id: int | None = Field(default=None)
|
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
||||||
"""商品ID"""
|
"""关联商品UUID"""
|
||||||
|
|
||||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
||||||
"""购买数量"""
|
"""购买数量"""
|
||||||
|
|
||||||
name: str = Field(max_length=255)
|
name: Str255
|
||||||
"""商品名称"""
|
"""商品名称"""
|
||||||
|
|
||||||
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
|
||||||
"""订单价格(分)"""
|
"""订单价格(元)"""
|
||||||
|
|
||||||
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
||||||
"""订单状态"""
|
"""订单状态"""
|
||||||
@@ -64,3 +127,19 @@ class Order(SQLModelBase, TableBaseMixin):
|
|||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
user: "User" = Relationship(back_populates="orders")
|
user: "User" = Relationship(back_populates="orders")
|
||||||
|
product: "Product" = Relationship(back_populates="orders")
|
||||||
|
|
||||||
|
def to_response(self) -> OrderResponse:
|
||||||
|
"""转换为响应 DTO"""
|
||||||
|
return OrderResponse(
|
||||||
|
id=self.id,
|
||||||
|
order_no=self.order_no,
|
||||||
|
type=self.type,
|
||||||
|
method=self.method,
|
||||||
|
product_id=self.product_id,
|
||||||
|
num=self.num,
|
||||||
|
name=self.name,
|
||||||
|
price=float(self.price),
|
||||||
|
status=self.status,
|
||||||
|
user_id=self.user_id,
|
||||||
|
)
|
||||||
|
|||||||
@@ -15,8 +15,7 @@ from uuid import UUID
|
|||||||
from sqlalchemy import BigInteger
|
from sqlalchemy import BigInteger
|
||||||
from sqlmodel import Field, Relationship, Index
|
from sqlmodel import Field, Relationship, Index
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str32, Str64
|
||||||
from .mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .object import Object
|
from .object import Object
|
||||||
@@ -32,9 +31,12 @@ class PhysicalFileBase(SQLModelBase):
|
|||||||
size: int = Field(default=0, sa_type=BigInteger)
|
size: int = Field(default=0, sa_type=BigInteger)
|
||||||
"""文件大小(字节)"""
|
"""文件大小(字节)"""
|
||||||
|
|
||||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
checksum_md5: Str32 | None = None
|
||||||
"""MD5校验和(用于文件去重和完整性校验)"""
|
"""MD5校验和(用于文件去重和完整性校验)"""
|
||||||
|
|
||||||
|
checksum_sha256: Str64 | None = None
|
||||||
|
"""SHA256校验和"""
|
||||||
|
|
||||||
|
|
||||||
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
|
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from sqlmodel import Field, Relationship, text
|
from sqlmodel import Field, Relationship, text
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
|
||||||
from .mixin import UUIDTableBaseMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .object import Object
|
from .object import Object
|
||||||
@@ -38,22 +37,22 @@ class PolicyType(StrEnum):
|
|||||||
class PolicyBase(SQLModelBase):
|
class PolicyBase(SQLModelBase):
|
||||||
"""存储策略基础字段,供 DTO 和数据库模型共享"""
|
"""存储策略基础字段,供 DTO 和数据库模型共享"""
|
||||||
|
|
||||||
name: str = Field(max_length=255)
|
name: Str255
|
||||||
"""策略名称"""
|
"""策略名称"""
|
||||||
|
|
||||||
type: PolicyType
|
type: PolicyType
|
||||||
"""存储策略类型"""
|
"""存储策略类型"""
|
||||||
|
|
||||||
server: str | None = Field(default=None, max_length=255)
|
server: Str255 | None = None
|
||||||
"""服务器地址(本地策略为绝对路径)"""
|
"""服务器地址(本地策略为绝对路径)"""
|
||||||
|
|
||||||
bucket_name: str | None = Field(default=None, max_length=255)
|
bucket_name: Str255 | None = None
|
||||||
"""存储桶名称"""
|
"""存储桶名称"""
|
||||||
|
|
||||||
is_private: bool = True
|
is_private: bool = True
|
||||||
"""是否为私有空间"""
|
"""是否为私有空间"""
|
||||||
|
|
||||||
base_url: str | None = Field(default=None, max_length=255)
|
base_url: Str255 | None = None
|
||||||
"""访问文件的基础URL"""
|
"""访问文件的基础URL"""
|
||||||
|
|
||||||
access_key: str | None = None
|
access_key: str | None = None
|
||||||
@@ -68,10 +67,10 @@ class PolicyBase(SQLModelBase):
|
|||||||
auto_rename: bool = False
|
auto_rename: bool = False
|
||||||
"""是否自动重命名"""
|
"""是否自动重命名"""
|
||||||
|
|
||||||
dir_name_rule: str | None = Field(default=None, max_length=255)
|
dir_name_rule: Str255 | None = None
|
||||||
"""目录命名规则"""
|
"""目录命名规则"""
|
||||||
|
|
||||||
file_name_rule: str | None = Field(default=None, max_length=255)
|
file_name_rule: Str255 | None = None
|
||||||
"""文件命名规则"""
|
"""文件命名规则"""
|
||||||
|
|
||||||
is_origin_link_enable: bool = False
|
is_origin_link_enable: bool = False
|
||||||
@@ -103,6 +102,94 @@ class PolicySummary(SQLModelBase):
|
|||||||
"""是否私有"""
|
"""是否私有"""
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyCreateRequest(PolicyBase):
|
||||||
|
"""创建存储策略请求 DTO,包含 PolicyOptions 扁平字段"""
|
||||||
|
|
||||||
|
# PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
|
||||||
|
token: str | None = None
|
||||||
|
"""访问令牌"""
|
||||||
|
|
||||||
|
file_type: str | None = None
|
||||||
|
"""允许的文件类型"""
|
||||||
|
|
||||||
|
mimetype: str | None = Field(default=None, max_length=127)
|
||||||
|
"""MIME类型"""
|
||||||
|
|
||||||
|
od_redirect: Str255 | None = None
|
||||||
|
"""OneDrive重定向地址"""
|
||||||
|
|
||||||
|
chunk_size: int = Field(default=52428800, ge=1)
|
||||||
|
"""分片上传大小(字节),默认50MB"""
|
||||||
|
|
||||||
|
s3_path_style: bool = False
|
||||||
|
"""是否使用S3路径风格"""
|
||||||
|
|
||||||
|
s3_region: Str64 = 'us-east-1'
|
||||||
|
"""S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyUpdateRequest(SQLModelBase):
|
||||||
|
"""更新存储策略请求 DTO(所有字段可选)"""
|
||||||
|
|
||||||
|
name: Str255 | None = None
|
||||||
|
"""策略名称"""
|
||||||
|
|
||||||
|
server: Str255 | None = None
|
||||||
|
"""服务器地址"""
|
||||||
|
|
||||||
|
bucket_name: Str255 | None = None
|
||||||
|
"""存储桶名称"""
|
||||||
|
|
||||||
|
is_private: bool | None = None
|
||||||
|
"""是否为私有空间"""
|
||||||
|
|
||||||
|
base_url: Str255 | None = None
|
||||||
|
"""访问文件的基础URL"""
|
||||||
|
|
||||||
|
access_key: str | None = None
|
||||||
|
"""Access Key"""
|
||||||
|
|
||||||
|
secret_key: str | None = None
|
||||||
|
"""Secret Key"""
|
||||||
|
|
||||||
|
max_size: int | None = Field(default=None, ge=0)
|
||||||
|
"""允许上传的最大文件尺寸(字节)"""
|
||||||
|
|
||||||
|
auto_rename: bool | None = None
|
||||||
|
"""是否自动重命名"""
|
||||||
|
|
||||||
|
dir_name_rule: Str255 | None = None
|
||||||
|
"""目录命名规则"""
|
||||||
|
|
||||||
|
file_name_rule: Str255 | None = None
|
||||||
|
"""文件命名规则"""
|
||||||
|
|
||||||
|
is_origin_link_enable: bool | None = None
|
||||||
|
"""是否开启源链接访问"""
|
||||||
|
|
||||||
|
# PolicyOptions 字段
|
||||||
|
token: str | None = None
|
||||||
|
"""访问令牌"""
|
||||||
|
|
||||||
|
file_type: str | None = None
|
||||||
|
"""允许的文件类型"""
|
||||||
|
|
||||||
|
mimetype: str | None = Field(default=None, max_length=127)
|
||||||
|
"""MIME类型"""
|
||||||
|
|
||||||
|
od_redirect: Str255 | None = None
|
||||||
|
"""OneDrive重定向地址"""
|
||||||
|
|
||||||
|
chunk_size: int | None = Field(default=None, ge=1)
|
||||||
|
"""分片上传大小(字节)"""
|
||||||
|
|
||||||
|
s3_path_style: bool | None = None
|
||||||
|
"""是否使用S3路径风格"""
|
||||||
|
|
||||||
|
s3_region: Str64 | None = None
|
||||||
|
"""S3 区域"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
|
||||||
@@ -118,7 +205,7 @@ class PolicyOptionsBase(SQLModelBase):
|
|||||||
mimetype: str | None = Field(default=None, max_length=127)
|
mimetype: str | None = Field(default=None, max_length=127)
|
||||||
"""MIME类型"""
|
"""MIME类型"""
|
||||||
|
|
||||||
od_redirect: str | None = Field(default=None, max_length=255)
|
od_redirect: Str255 | None = None
|
||||||
"""OneDrive重定向地址"""
|
"""OneDrive重定向地址"""
|
||||||
|
|
||||||
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
|
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
|
||||||
@@ -127,6 +214,9 @@ class PolicyOptionsBase(SQLModelBase):
|
|||||||
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||||
"""是否使用S3路径风格"""
|
"""是否使用S3路径风格"""
|
||||||
|
|
||||||
|
s3_region: Str64 = Field(default='us-east-1', sa_column_kwargs={"server_default": "'us-east-1'"})
|
||||||
|
"""S3 区域(如 us-east-1、ap-southeast-1),仅 S3 策略使用"""
|
||||||
|
|
||||||
|
|
||||||
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
|
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
|
||||||
"""存储策略选项模型(与Policy一对一关联)"""
|
"""存储策略选项模型(与Policy一对一关联)"""
|
||||||
@@ -147,7 +237,7 @@ class Policy(PolicyBase, UUIDTableBaseMixin):
|
|||||||
"""存储策略模型"""
|
"""存储策略模型"""
|
||||||
|
|
||||||
# 覆盖基类字段以添加数据库专有配置
|
# 覆盖基类字段以添加数据库专有配置
|
||||||
name: str = Field(max_length=255, unique=True)
|
name: Str255 = Field(unique=True)
|
||||||
"""策略名称"""
|
"""策略名称"""
|
||||||
|
|
||||||
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||||
|
|||||||
206
sqlmodels/product.py
Normal file
206
sqlmodels/product.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import Numeric, BigInteger
|
||||||
|
from sqlmodel import Field, Relationship, text
|
||||||
|
|
||||||
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .order import Order
|
||||||
|
from .redeem import Redeem
|
||||||
|
|
||||||
|
|
||||||
|
class ProductType(StrEnum):
|
||||||
|
"""商品类型枚举"""
|
||||||
|
|
||||||
|
STORAGE_PACK = "storage_pack"
|
||||||
|
"""容量包"""
|
||||||
|
|
||||||
|
GROUP_TIME = "group_time"
|
||||||
|
"""用户组时长"""
|
||||||
|
|
||||||
|
SCORE = "score"
|
||||||
|
"""积分充值"""
|
||||||
|
|
||||||
|
|
||||||
|
class PaymentMethod(StrEnum):
|
||||||
|
"""支付方式枚举"""
|
||||||
|
|
||||||
|
ALIPAY = "alipay"
|
||||||
|
"""支付宝"""
|
||||||
|
|
||||||
|
WECHAT = "wechat"
|
||||||
|
"""微信支付"""
|
||||||
|
|
||||||
|
STRIPE = "stripe"
|
||||||
|
"""Stripe"""
|
||||||
|
|
||||||
|
EASYPAY = "easypay"
|
||||||
|
"""易支付"""
|
||||||
|
|
||||||
|
CUSTOM = "custom"
|
||||||
|
"""自定义支付"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class ProductBase(SQLModelBase):
|
||||||
|
"""商品基础字段"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""商品名称"""
|
||||||
|
|
||||||
|
type: ProductType
|
||||||
|
"""商品类型"""
|
||||||
|
|
||||||
|
description: str | None = None
|
||||||
|
"""商品描述"""
|
||||||
|
|
||||||
|
|
||||||
|
class ProductCreateRequest(ProductBase):
|
||||||
|
"""创建商品请求 DTO"""
|
||||||
|
|
||||||
|
name: Str255
|
||||||
|
"""商品名称"""
|
||||||
|
|
||||||
|
price: Decimal = Field(ge=0, decimal_places=2)
|
||||||
|
"""商品价格(元)"""
|
||||||
|
|
||||||
|
is_active: bool = True
|
||||||
|
"""是否上架"""
|
||||||
|
|
||||||
|
sort_order: int = Field(default=0, ge=0)
|
||||||
|
"""排序权重(越大越靠前)"""
|
||||||
|
|
||||||
|
# storage_pack 专用
|
||||||
|
size: int | None = Field(default=None, ge=0)
|
||||||
|
"""容量大小(字节),type=storage_pack 时必填"""
|
||||||
|
|
||||||
|
duration_days: int | None = Field(default=None, ge=1)
|
||||||
|
"""有效天数,type=storage_pack/group_time 时必填"""
|
||||||
|
|
||||||
|
# group_time 专用
|
||||||
|
group_id: UUID | None = None
|
||||||
|
"""目标用户组UUID,type=group_time 时必填"""
|
||||||
|
|
||||||
|
# score 专用
|
||||||
|
score_amount: int | None = Field(default=None, ge=1)
|
||||||
|
"""积分数量,type=score 时必填"""
|
||||||
|
|
||||||
|
|
||||||
|
class ProductUpdateRequest(SQLModelBase):
|
||||||
|
"""更新商品请求 DTO(所有字段可选)"""
|
||||||
|
|
||||||
|
name: Str255 | None = None
|
||||||
|
"""商品名称"""
|
||||||
|
|
||||||
|
description: str | None = None
|
||||||
|
"""商品描述"""
|
||||||
|
|
||||||
|
price: Decimal | None = Field(default=None, ge=0, decimal_places=2)
|
||||||
|
"""商品价格(元)"""
|
||||||
|
|
||||||
|
is_active: bool | None = None
|
||||||
|
"""是否上架"""
|
||||||
|
|
||||||
|
sort_order: int | None = Field(default=None, ge=0)
|
||||||
|
"""排序权重"""
|
||||||
|
|
||||||
|
size: int | None = Field(default=None, ge=0)
|
||||||
|
"""容量大小(字节)"""
|
||||||
|
|
||||||
|
duration_days: int | None = Field(default=None, ge=1)
|
||||||
|
"""有效天数"""
|
||||||
|
|
||||||
|
group_id: UUID | None = None
|
||||||
|
"""目标用户组UUID"""
|
||||||
|
|
||||||
|
score_amount: int | None = Field(default=None, ge=1)
|
||||||
|
"""积分数量"""
|
||||||
|
|
||||||
|
|
||||||
|
class ProductResponse(ProductBase):
|
||||||
|
"""商品响应 DTO"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""商品UUID"""
|
||||||
|
|
||||||
|
price: float
|
||||||
|
"""商品价格(元)"""
|
||||||
|
|
||||||
|
is_active: bool
|
||||||
|
"""是否上架"""
|
||||||
|
|
||||||
|
sort_order: int
|
||||||
|
"""排序权重"""
|
||||||
|
|
||||||
|
size: int | None = None
|
||||||
|
"""容量大小(字节)"""
|
||||||
|
|
||||||
|
duration_days: int | None = None
|
||||||
|
"""有效天数"""
|
||||||
|
|
||||||
|
group_id: UUID | None = None
|
||||||
|
"""目标用户组UUID"""
|
||||||
|
|
||||||
|
score_amount: int | None = None
|
||||||
|
"""积分数量"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
|
class Product(ProductBase, UUIDTableBaseMixin):
|
||||||
|
"""商品模型"""
|
||||||
|
|
||||||
|
name: Str255
|
||||||
|
"""商品名称"""
|
||||||
|
|
||||||
|
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
|
||||||
|
"""商品价格(元)"""
|
||||||
|
|
||||||
|
is_active: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||||
|
"""是否上架"""
|
||||||
|
|
||||||
|
sort_order: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
|
"""排序权重(越大越靠前)"""
|
||||||
|
|
||||||
|
# storage_pack 专用
|
||||||
|
size: int | None = Field(default=None, sa_type=BigInteger)
|
||||||
|
"""容量大小(字节),type=storage_pack 时必填"""
|
||||||
|
|
||||||
|
duration_days: int | None = None
|
||||||
|
"""有效天数,type=storage_pack/group_time 时必填"""
|
||||||
|
|
||||||
|
# group_time 专用
|
||||||
|
group_id: UUID | None = Field(default=None, foreign_key="group.id", ondelete="SET NULL")
|
||||||
|
"""目标用户组UUID,type=group_time 时必填"""
|
||||||
|
|
||||||
|
# score 专用
|
||||||
|
score_amount: int | None = None
|
||||||
|
"""积分数量,type=score 时必填"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
orders: list["Order"] = Relationship(back_populates="product")
|
||||||
|
"""关联的订单列表"""
|
||||||
|
|
||||||
|
redeems: list["Redeem"] = Relationship(back_populates="product")
|
||||||
|
"""关联的兑换码列表"""
|
||||||
|
|
||||||
|
def to_response(self) -> ProductResponse:
|
||||||
|
"""转换为响应 DTO"""
|
||||||
|
return ProductResponse(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
type=self.type,
|
||||||
|
description=self.description,
|
||||||
|
price=float(self.price),
|
||||||
|
is_active=self.is_active,
|
||||||
|
sort_order=self.sort_order,
|
||||||
|
size=self.size,
|
||||||
|
duration_days=self.duration_days,
|
||||||
|
group_id=self.group_id,
|
||||||
|
score_amount=self.score_amount,
|
||||||
|
)
|
||||||
@@ -1,23 +1,141 @@
|
|||||||
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlmodel import Field, text
|
from sqlmodel import Field, Relationship, text
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
from .mixin import TableBaseMixin
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .product import Product
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
class RedeemType(StrEnum):
|
class RedeemType(StrEnum):
|
||||||
"""兑换码类型枚举"""
|
"""兑换码类型枚举"""
|
||||||
# [TODO] 补充具体兑换码类型
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
STORAGE_PACK = "storage_pack"
|
||||||
|
"""容量包"""
|
||||||
|
|
||||||
|
GROUP_TIME = "group_time"
|
||||||
|
"""用户组时长"""
|
||||||
|
|
||||||
|
SCORE = "score"
|
||||||
|
"""积分充值"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
|
class RedeemCreateRequest(SQLModelBase):
|
||||||
|
"""批量生成兑换码请求 DTO"""
|
||||||
|
|
||||||
|
product_id: UUID
|
||||||
|
"""关联商品UUID"""
|
||||||
|
|
||||||
|
count: int = Field(default=1, ge=1, le=100)
|
||||||
|
"""生成数量"""
|
||||||
|
|
||||||
|
|
||||||
|
class RedeemUseRequest(SQLModelBase):
|
||||||
|
"""使用兑换码请求 DTO"""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
"""兑换码"""
|
||||||
|
|
||||||
|
|
||||||
|
class RedeemInfoResponse(SQLModelBase):
|
||||||
|
"""兑换码信息响应 DTO(用户侧)"""
|
||||||
|
|
||||||
|
type: RedeemType
|
||||||
|
"""兑换码类型"""
|
||||||
|
|
||||||
|
product_name: str | None = None
|
||||||
|
"""关联商品名称"""
|
||||||
|
|
||||||
|
num: int
|
||||||
|
"""可兑换数量"""
|
||||||
|
|
||||||
|
is_used: bool
|
||||||
|
"""是否已使用"""
|
||||||
|
|
||||||
|
|
||||||
|
class RedeemAdminResponse(SQLModelBase):
|
||||||
|
"""兑换码管理响应 DTO(管理侧)"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
"""兑换码ID"""
|
||||||
|
|
||||||
|
type: RedeemType
|
||||||
|
"""兑换码类型"""
|
||||||
|
|
||||||
|
product_id: UUID | None = None
|
||||||
|
"""关联商品UUID"""
|
||||||
|
|
||||||
|
num: int
|
||||||
|
"""可兑换数量"""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
"""兑换码"""
|
||||||
|
|
||||||
|
is_used: bool
|
||||||
|
"""是否已使用"""
|
||||||
|
|
||||||
|
used_at: datetime | None = None
|
||||||
|
"""使用时间"""
|
||||||
|
|
||||||
|
used_by: UUID | None = None
|
||||||
|
"""使用者UUID"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
class Redeem(SQLModelBase, TableBaseMixin):
|
class Redeem(SQLModelBase, TableBaseMixin):
|
||||||
"""兑换码模型"""
|
"""兑换码模型"""
|
||||||
|
|
||||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
type: RedeemType
|
||||||
"""兑换码类型 [TODO] 待定义枚举"""
|
"""兑换码类型"""
|
||||||
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
|
|
||||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
|
product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
|
||||||
code: str = Field(unique=True, index=True, description="兑换码,唯一")
|
"""关联商品UUID"""
|
||||||
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
|
|
||||||
|
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
||||||
|
"""可兑换数量/时长等"""
|
||||||
|
|
||||||
|
code: str = Field(unique=True, index=True)
|
||||||
|
"""兑换码,唯一"""
|
||||||
|
|
||||||
|
is_used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||||
|
"""是否已使用"""
|
||||||
|
|
||||||
|
used_at: datetime | None = None
|
||||||
|
"""使用时间"""
|
||||||
|
|
||||||
|
used_by: UUID | None = Field(default=None, foreign_key="user.id", ondelete="SET NULL")
|
||||||
|
"""使用者UUID"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
product: "Product" = Relationship(back_populates="redeems")
|
||||||
|
user: "User" = Relationship(back_populates="redeems")
|
||||||
|
|
||||||
|
def to_admin_response(self) -> RedeemAdminResponse:
|
||||||
|
"""转换为管理侧响应 DTO"""
|
||||||
|
return RedeemAdminResponse(
|
||||||
|
id=self.id,
|
||||||
|
type=self.type,
|
||||||
|
product_id=self.product_id,
|
||||||
|
num=self.num,
|
||||||
|
code=self.code,
|
||||||
|
is_used=self.is_used,
|
||||||
|
used_at=self.used_at,
|
||||||
|
used_by=self.used_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_info_response(self, product_name: str | None = None) -> RedeemInfoResponse:
|
||||||
|
"""转换为用户侧响应 DTO"""
|
||||||
|
return RedeemInfoResponse(
|
||||||
|
type=self.type,
|
||||||
|
product_name=product_name,
|
||||||
|
num=self.num,
|
||||||
|
is_used=self.is_used,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlmodel import Field, Relationship
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
|
||||||
from .mixin import TableBaseMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .share import Share
|
from .share import Share
|
||||||
@@ -21,10 +21,10 @@ class Report(SQLModelBase, TableBaseMixin):
|
|||||||
|
|
||||||
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
"""举报原因 [TODO] 待定义枚举"""
|
"""举报原因 [TODO] 待定义枚举"""
|
||||||
description: str | None = Field(default=None, max_length=255, description="补充描述")
|
description: Str255 | None = Field(default=None, description="补充描述")
|
||||||
|
|
||||||
# 外键
|
# 外键
|
||||||
share_id: int = Field(
|
share_id: UUID = Field(
|
||||||
foreign_key="share.id",
|
foreign_key="share.id",
|
||||||
index=True,
|
index=True,
|
||||||
ondelete="CASCADE"
|
ondelete="CASCADE"
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ from enum import StrEnum
|
|||||||
|
|
||||||
from sqlmodel import UniqueConstraint
|
from sqlmodel import UniqueConstraint
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, TableBaseMixin
|
||||||
from .mixin import TableBaseMixin
|
|
||||||
|
from .auth_identity import AuthProviderType
|
||||||
from .user import UserResponse
|
from .user import UserResponse
|
||||||
|
|
||||||
class CaptchaType(StrEnum):
|
class CaptchaType(StrEnum):
|
||||||
@@ -12,6 +13,19 @@ class CaptchaType(StrEnum):
|
|||||||
GCAPTCHA = "gcaptcha"
|
GCAPTCHA = "gcaptcha"
|
||||||
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
|
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Auth 配置 DTO ====================
|
||||||
|
|
||||||
|
class AuthMethodConfig(SQLModelBase):
|
||||||
|
"""认证方式配置 DTO"""
|
||||||
|
|
||||||
|
provider: AuthProviderType
|
||||||
|
"""提供者类型"""
|
||||||
|
|
||||||
|
is_enabled: bool
|
||||||
|
"""是否启用"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== DTO 模型 ====================
|
# ==================== DTO 模型 ====================
|
||||||
|
|
||||||
class SiteConfigResponse(SQLModelBase):
|
class SiteConfigResponse(SQLModelBase):
|
||||||
@@ -50,6 +64,30 @@ class SiteConfigResponse(SQLModelBase):
|
|||||||
captcha_key: str | None = None
|
captcha_key: str | None = None
|
||||||
"""验证码 public key(DEFAULT 类型时为 None)"""
|
"""验证码 public key(DEFAULT 类型时为 None)"""
|
||||||
|
|
||||||
|
auth_methods: list[AuthMethodConfig] = []
|
||||||
|
"""可用的登录方式列表"""
|
||||||
|
|
||||||
|
password_required: bool = True
|
||||||
|
"""注册时是否必须设置密码"""
|
||||||
|
|
||||||
|
phone_binding_required: bool = False
|
||||||
|
"""是否强制绑定手机号"""
|
||||||
|
|
||||||
|
email_binding_required: bool = True
|
||||||
|
"""是否强制绑定邮箱"""
|
||||||
|
|
||||||
|
avatar_max_size: int = 2097152
|
||||||
|
"""头像文件最大字节数(默认 2MB)"""
|
||||||
|
|
||||||
|
footer_code: str | None = None
|
||||||
|
"""自定义页脚代码"""
|
||||||
|
|
||||||
|
tos_url: str | None = None
|
||||||
|
"""服务条款 URL"""
|
||||||
|
|
||||||
|
privacy_url: str | None = None
|
||||||
|
"""隐私政策 URL"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== 管理员设置 DTO ====================
|
# ==================== 管理员设置 DTO ====================
|
||||||
|
|
||||||
@@ -125,6 +163,7 @@ class SettingsType(StrEnum):
|
|||||||
VERSION = "version"
|
VERSION = "version"
|
||||||
VIEW = "view"
|
VIEW = "view"
|
||||||
WOPI = "wopi"
|
WOPI = "wopi"
|
||||||
|
FILE_CATEGORY = "file_category"
|
||||||
|
|
||||||
# 数据库模型
|
# 数据库模型
|
||||||
class Setting(SettingItem, TableBaseMixin):
|
class Setting(SettingItem, TableBaseMixin):
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
|
||||||
from .mixin import TableBaseMixin
|
|
||||||
|
from .model_base import ResponseBase
|
||||||
|
from .object import ObjectType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
@@ -34,13 +36,13 @@ class ShareBase(SQLModelBase):
|
|||||||
preview_enabled: bool = True
|
preview_enabled: bool = True
|
||||||
"""是否允许预览"""
|
"""是否允许预览"""
|
||||||
|
|
||||||
score: int = 0
|
score: int = Field(default=0, ge=0)
|
||||||
"""兑换此分享所需的积分"""
|
"""兑换此分享所需的积分"""
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库模型 ====================
|
# ==================== 数据库模型 ====================
|
||||||
|
|
||||||
class Share(SQLModelBase, TableBaseMixin):
|
class Share(SQLModelBase, UUIDTableBaseMixin):
|
||||||
"""分享模型"""
|
"""分享模型"""
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
@@ -50,10 +52,10 @@ class Share(SQLModelBase, TableBaseMixin):
|
|||||||
Index("ix_share_object", "object_id"),
|
Index("ix_share_object", "object_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
code: str = Field(max_length=64, nullable=False, index=True)
|
code: Str64 = Field(nullable=False, index=True)
|
||||||
"""分享码"""
|
"""分享码"""
|
||||||
|
|
||||||
password: str | None = Field(default=None, max_length=255)
|
password: Str255 | None = None
|
||||||
"""分享密码(加密后)"""
|
"""分享密码(加密后)"""
|
||||||
|
|
||||||
object_id: UUID = Field(
|
object_id: UUID = Field(
|
||||||
@@ -78,10 +80,10 @@ class Share(SQLModelBase, TableBaseMixin):
|
|||||||
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||||
"""是否允许预览"""
|
"""是否允许预览"""
|
||||||
|
|
||||||
source_name: str | None = Field(default=None, max_length=255)
|
source_name: Str255 | None = None
|
||||||
"""源名称(冗余字段,便于展示)"""
|
"""源名称(冗余字段,便于展示)"""
|
||||||
|
|
||||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
score: int = Field(default=0, ge=0)
|
||||||
"""兑换此分享所需的积分"""
|
"""兑换此分享所需的积分"""
|
||||||
|
|
||||||
# 外键
|
# 外键
|
||||||
@@ -119,10 +121,17 @@ class ShareCreateRequest(ShareBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ShareResponse(SQLModelBase):
|
class CreateShareResponse(ResponseBase):
|
||||||
"""分享响应 DTO"""
|
"""创建分享响应 DTO"""
|
||||||
|
|
||||||
id: int
|
share_id: UUID
|
||||||
|
"""新创建的分享记录 ID"""
|
||||||
|
|
||||||
|
|
||||||
|
class ShareResponse(SQLModelBase):
|
||||||
|
"""查看分享响应 DTO"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
"""分享ID"""
|
"""分享ID"""
|
||||||
|
|
||||||
code: str
|
code: str
|
||||||
@@ -162,10 +171,67 @@ class ShareResponse(SQLModelBase):
|
|||||||
"""是否有密码"""
|
"""是否有密码"""
|
||||||
|
|
||||||
|
|
||||||
|
class ShareOwnerInfo(SQLModelBase):
|
||||||
|
"""分享者公开信息 DTO"""
|
||||||
|
|
||||||
|
nickname: str | None
|
||||||
|
"""昵称"""
|
||||||
|
|
||||||
|
avatar: str
|
||||||
|
"""头像"""
|
||||||
|
|
||||||
|
|
||||||
|
class ShareObjectItem(SQLModelBase):
|
||||||
|
"""分享中的文件/文件夹信息 DTO"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
"""对象UUID"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""名称"""
|
||||||
|
|
||||||
|
type: ObjectType
|
||||||
|
"""类型:file 或 folder"""
|
||||||
|
|
||||||
|
size: int
|
||||||
|
"""文件大小(字节),目录为 0"""
|
||||||
|
|
||||||
|
created_at: datetime
|
||||||
|
"""创建时间"""
|
||||||
|
|
||||||
|
updated_at: datetime
|
||||||
|
"""修改时间"""
|
||||||
|
|
||||||
|
|
||||||
|
class ShareDetailResponse(SQLModelBase):
|
||||||
|
"""获取分享详情响应 DTO(面向访客,隐藏内部统计数据)"""
|
||||||
|
|
||||||
|
expires: datetime | None
|
||||||
|
"""过期时间"""
|
||||||
|
|
||||||
|
preview_enabled: bool
|
||||||
|
"""是否允许预览"""
|
||||||
|
|
||||||
|
score: int
|
||||||
|
"""积分"""
|
||||||
|
|
||||||
|
created_at: datetime
|
||||||
|
"""创建时间"""
|
||||||
|
|
||||||
|
owner: ShareOwnerInfo
|
||||||
|
"""分享者信息"""
|
||||||
|
|
||||||
|
object: ShareObjectItem
|
||||||
|
"""分享的根对象"""
|
||||||
|
|
||||||
|
children: list[ShareObjectItem]
|
||||||
|
"""子文件/文件夹列表(仅目录分享有内容)"""
|
||||||
|
|
||||||
|
|
||||||
class ShareListItemBase(SQLModelBase):
|
class ShareListItemBase(SQLModelBase):
|
||||||
"""分享列表项基础字段"""
|
"""分享列表项基础字段"""
|
||||||
|
|
||||||
id: int
|
id: UUID
|
||||||
"""分享ID"""
|
"""分享ID"""
|
||||||
|
|
||||||
code: str
|
code: str
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user